########################################
TODOS:
- Mapping Network
- Generator Network
- Discriminator Network
- Augmentation/RGBBlock (style vector to RGB)



########################################


In [23]:
import torch
import torchsummary
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2

import os
import math
import glob

In [None]:
##########################
### SETTINGS
##########################

# Device
CUDA_DEVICE_NUM = 0
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# Hyperparameters
RANDOM_SEED = 42
GENERATOR_LEARNING_RATE = 0.0002
DISCRIMINATOR_LEARNING_RATE = 0.0002

NUM_EPOCHS = 200
BATCH_SIZE = 128

IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS = 256, 256, 3

In [None]:
from torchvision import datasets, transforms
import matplotlib.pyplot as plt # for plotting
import numpy as np
import torch
import math
import os
import glob

transform = transforms.Compose([
            # transforms.RandomHorizontalFlip(), # Flip the data horizontally
            transforms.ToPILImage(),
            transforms.Resize((256, 256)),
            # transforms.RandomAdjustSharpness(0.25),
            # transforms.RandomHorizontalFlip(0.5),
            # transforms.RandomVerticalFlip(0.5),
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        ])

# Load the data with all of the classes
root = '..\\data\\'
data_dirs = glob.glob(root + '**/*.jpg', recursive=True)
data_dirs = np.array(data_dirs)
data_dirs = data_dirs.flatten()

roots = np.array([root] * len(data_dirs))
data_dirs = np.core.defchararray.add(roots, data_dirs)


# data_list = np.array(data.imgs)

# Get the indices for the train/val/test split of 75/15/15
train_split = math.floor(1 * len(data_dirs))
val_split = math.ceil(0.00 * len(data_dirs))
test_split = val_split * 0

# print(train_split + val_split + test_split, len(data_arr))

# Make sure the splits are correct
assert train_split + val_split + test_split == len(data_dirs)

# Split the dataset randomely
generator = torch.Generator().manual_seed(1)
train, val, test = torch.utils.data.random_split(data_dirs, [train_split, val_split, test_split], generator=generator)

len(train), len(val), len(test)

In [None]:
import cv2

class GenData(torch.utils.data.Dataset):
    '''
        Data set class to store the feature maps
    
    '''

    def __init__(self, in_data, transform=None):
        # data = np.array(in_data)
        self.input_dirs = in_data
        # self.labels = data[:, 1]
        self.labels = []
        self.input_transform = transform[0]
        self.transform = transform[1]
        self.resize_transform = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize(128)
                                ])


    def __len__(self):
        return len(self.input_dirs)
    
    def __getitem__(self, idx):
        input_dir = self.input_dirs[idx]

        # Load the data
        inputs = cv2.imread(input_dir)
        inputs = cv2.cvtColor(inputs, cv2.COLOR_BGR2RGB)
        # inputs = inputs.swapaxes(0, 2)

        # labels = (int)(self.labels[idx])

        if self.transform:
            inputs = self.input_transform(inputs)
            transformed = self.transform(inputs)

            # labels = self.input_transform(labels)

            return inputs, inputs
        else:
            return inputs, inputs, inputs
        
# Create data loaders
def get_data_loaders(batch_size=1):
    
    train_data = GenData(data_dirs[train.indices], transform=(transform, transform))
    val_data = GenData(data_dirs[val.indices], transform=(transform, transform))
    test_data = GenData(data_dirs[test.indices], transform=(transform, transform))

    train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
    val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, val_dataloader, test_dataloader

In [None]:
##########################
### Dataset
##########################


custom_transforms = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop((160, 160)),
    torchvision.transforms.Resize([IMAGE_HEIGHT, IMAGE_WIDTH]),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_loader, valid_loader, test_loader = get_data_loaders(
    batch_size=BATCH_SIZE)

In [None]:
# Checking the dataset

    
# Checking the dataset
print('Validation Set:\n')
for images, labels in valid_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    #print(labels[:10])
    break

# Checking the dataset
print('\nTesting Set:')
for images, labels in test_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    #print(labels[:10])
    break

print('\nTraining Set:')
for images, labels in train_loader:  
    print('Image batch dimensions:', images.size())
    print('Image label dimensions:', labels.size())
    #print(labels[:10])
    break

In [None]:
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Source Images")
plt.imshow(np.transpose(torchvision.utils.make_grid(images[:64], 
                                         padding=2, normalize=False),
                        (1, 2, 0)))

In [59]:
class EqualizedWeight(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.constant = 1 / math.sqrt(np.prod(shape[1:]))
        self.weight = nn.Parameter(torch.randn(shape))

    def forward(self):
        return self.weight * self.constant

class UpsampleBlock(nn.Module):
    '''
    Block that upsamples image in the Synthesis Network
    ie: 4x4 -> 8x8

    Test: Done
    '''
    def __init__(self):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        self.kernel = torch.tensor([[[[1, 2, 1],[2, 4, 2],[1, 2, 1]]]],dtype=torch.float).cuda()

        self.kernel /= self.kernel.sum()
        self.pad = nn.ReflectionPad2d(1)
        self.conv2d = F.conv2d

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.upsample(x)
        x = x.view(-1, 1, h, w)
        x = self.pad(x)
        x = self.conv2d(x, self.kernel)

        return x.view(b, c, h * 2, w * 2)

class ConvModDeMod(nn.Module):
    def __init__(self, in_features, out_features, kernel_size, demodulate):
        super().__init__()
        self.demodulate = demodulate
        self.out_features = out_features

        self.padding = (kernel_size - 1) // 2

        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])

        self.eps = 1e-8

    def forward(self, x, style_vector):
        b, c, h, w = x.shape
        style_vector = style_vector[:, None, :, None, None]

        weights = self.weight()[None, :, :, :, :]

        # Modulation step
        weights *= style_vector

        if self.demodulate:
            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)

            weights *= sigma_inv

        x = x.reshape(1, -1, h, w)
        _, _, h, w = weights.shape

        weights = weights.reshape(b * self.out_features, h, w)

        x = F.conv2d(x, weights, padding=self.padding, groups=b)

        x = x.reshape(-1, self.out_feautres, h, w)

        return x




class StyleBlock(nn.Module):
    def __init__(self, w_size, in_features, out_features):
        super().__init__()

        self.to_style = nn.Linear(w_size, in_features)
        self.conv = ConvModDeMod(in_features, out_features, 3)
        self.scale_noise = nn.Parameter(torch.zeros(1))
        self.bias = nn.Parameter(torch.zeros(out_features))


        self.activation = nn.LeakyReLU(0.2, True)

    def forward(self, x, w, noise):

        s = self.to_style(w)
        x = self.conv(x, s)

        if noise is not None:
            x = x + self.scale_noise[None, :, None, None] * noise
        x = self.activation(x + self.bias[None, :, None, None])

        return x

class Generator(nn.Module):
    def __init__(self, batch_size=16, w_size=512):
        super().__init__()

        self.sizes = [4, 8, 16, 32, 64, 128, 256]
        self.features = [512, 512, 512, 512, 256, 128, 64]

        self.init_noise = torch.randn([batch_size, w_size, 4, 4])

        self.style_block = StyleBlock(w_size, self.features[0], self.features[0])
        self.to_rgb = None

        blocks = None
        self.blocks = nn.ModuleList(blocks)

        self.upsample = UpsampleBlock()


    








In [58]:
test = torch.rand([1, 3, 4, 4]).cuda()
upsample = UpsampleBlock().cuda()
out = upsample(test)
out.shape

torch.Size([1, 3, 8, 8])

In [None]:
class Discriminator(nn.Module):
    def __init__(self):

        self.features = [64, 128, 256, 512, 512, 512]
        