In [None]:
from torch.utils import data
import numpy as np
import  pytorch_fid_wrapper as pfw
from PIL import Image
import os
import glob
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import handshape_datasets as hd
from torchvision import transforms
from torch.utils.data import DataLoader 
from PIL import Image
from matplotlib import cm
import pickle

In [None]:
class CustomDataset(data.Dataset):
    
    def __init__(self, dataset, transform=None):
        super().__init__()
        self.x = torch.from_numpy(dataset[0]).permute(0,3,1,2)
        self.y = torch.from_numpy(dataset[1])
        self.len = dataset[0].shape[0]
        self.transform = transform

    def __getitem__(self, index):
        sample = self.x[index]
        if self.transform:
            sample = self.transform(sample)
        
        return sample, self.y[index]

    def __len__(self):
        return self.len

In [None]:
def load_dataset():
    
    data = hd.load('PugeaultASL_A')

    good_min = 40
    good_classes = []
    n_unique = len(np.unique(data[1]['y']))
    for i in range(n_unique):
        images = data[0][np.equal(i, data[1]['y'])]
        if len(images) >= good_min:
            good_classes = good_classes + [i]

    x = data[0][np.in1d(data[1]['y'], good_classes)]
    img_shape = x[0].shape
    print(img_shape)
    y = data[1]['y'][np.in1d(data[1]['y'], good_classes)]
    y_dict = dict(zip(np.unique(y), range(len(np.unique(y)))))
    y = np.vectorize(y_dict.get)(y)
    classes = np.unique(y)
    n_classes = len(classes)
    
    #escalo al rango  [1-,1]
    x = (x.astype('float32') -127.5 ) / 127.5

    x_train, x_test, y_train, y_test = train_test_split(
        x, y, train_size=0.8, test_size=0.2, stratify=y)

    return n_classes, x_train, y_train, x_test, y_test 

In [None]:
def load_dataset_with_subject(subject_test=1):
    
    data = hd.load('PugeaultASL_A')

    good_min = 40
    good_classes = []
    n_unique = len(np.unique(data[1]['y']))
    for i in range(n_unique):
        images = data[0][np.equal(i, data[1]['y'])]
        if len(images) >= good_min:
            good_classes = good_classes + [i]

    x = data[0][np.in1d(data[1]['y'], good_classes)]

    y = data[1]['y'][np.in1d(data[1]['y'], good_classes)]

    s = data[1]['subjects'][np.in1d(data[1]['y'], good_classes)]

    y_dict = dict(zip(np.unique(y), range(len(np.unique(y)))))
    y = np.vectorize(y_dict.get)(y)

    s_dict = dict(zip(np.unique(s), range(len(np.unique(s)))))
    s = np.vectorize(s_dict.get)(s)

    classes = np.unique(y)
    n_classes = len(classes)

    x_train = x[np.not_equal(subject_test, s)]
    y_train = y[np.not_equal(subject_test, s)]
    x_test = x[np.equal(subject_test, s)]
    y_test = y[np.equal(subject_test, s)]
    
    shuffler = np.random.permutation(x_train.shape[0])
    x_train = x_train[shuffler]
    y_train = y_train[shuffler]

    shuffler_test = np.random.permutation(x_test.shape[0])
    x_test = x_test[shuffler_test]
    y_test = y_test[shuffler_test]
    
    #escalo al rango  [1-,1]
    x_train = (x_train.astype('float32') -127.5 ) / 127.5
    x_test = (x_test.astype('float32') -127.5 ) / 127.5


    return n_classes, x_train, y_train, x_test, y_test 

In [None]:
def load_lsa16_rotated():
    images =[]
    labels = []
    i=0
    path = "/home/willys/tesis/Data-augmentation-using-GANs/datasets/rgb_black_background/"
    for filename in glob.glob(path+'*.png'):
        #append iamge
        image = Image.open(os.path.join(path, filename))
        image_to_numpy=np.asarray(image)
        images.append(image_to_numpy)
        #create label and append
        image_name = filename.split('/')[-1]
        label = (int(image_name.split('_')[0]))
        labels.append(label-1)

    images = np.asarray(images)
    labels = np.asarray(labels)
    n_classes = len(np.unique(labels))
    
    #Desordeno el nuevo dataset
    shuffler = np.random.permutation(images.shape[0])
    images = images[shuffler]
    labels = labels[shuffler]
    
    #escalo al rango  [1-,1]
    images = (images.astype('float32') -127.5 ) / 127.5
    #split
    x_train, x_test, y_train, y_test = train_test_split(
        images, labels, train_size=0.8, test_size=0.2, stratify=labels)
    
    return n_classes, x_train, y_train, x_test, y_test 

In [None]:
def load_lsa16_rotated_with_subject(subject_test=1):
    x_train =[]
    y_train = []
    x_test = []
    y_test = []
    i=0
    path = "/home/willys/tesis/Data-augmentation-using-GANs/datasets/rgb_black_background/"
    for filename in glob.glob(path+'*.png'):
        #append iamge
        image = Image.open(os.path.join(path, filename))
        image_to_numpy=np.asarray(image)
        
        # get label and subject
        image_name = filename.split('/')[-1]
        label = (int(image_name.split('_')[0]))
        subject = (int(image_name.split('_')[1]))
        
        if subject==subject_test:
            x_test.append(image_to_numpy)
            y_test.append(label-1)
        else:
            x_train.append(image_to_numpy)
            y_train.append(label-1)        

    x_train = np.asarray(x_train)
    y_train = np.asarray(y_train)
    x_test = np.asarray(x_test)
    y_test = np.asarray(y_test)
    n_classes = len(np.unique(y_train))
    
    #Desordeno los datos de train
    shuffler_train = np.random.permutation(x_train.shape[0])
    x_train = x_train[shuffler_train]
    y_train = y_train[shuffler_train]
    
    #Desordeno los datos de train
    shuffler_test = np.random.permutation(x_test.shape[0])
    x_test = x_test[shuffler_test]
    y_test = y_test[shuffler_test]
    
    #escalo al rango  [1-,1]
    x_train = (x_train.astype('float32') -127.5 ) / 127.5
    x_test = (x_test.astype('float32') -127.5 ) / 127.5
    
    return n_classes, x_train, y_train, x_test, y_test 

In [None]:
def orthogonal_regularization(weight):
    '''
    Function for computing the orthogonal regularization term for a given weight matrix.
    '''
    weight = weight.flatten(1)
    return torch.norm(
        torch.dot(weight, weight) * (torch.ones_like(weight) - torch.eye(weight.shape[0]))
    )

In [None]:
def show_tensor_images(image_tensor, num_images=16, size=(3, 32, 32), nrow=4, show=True, save=False, 
                       path=''):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if save:
        plt.savefig(path)
    if show:
        plt.show()

In [None]:
class ClassConditionalBatchNorm2d(nn.Module):
    '''
    ClassConditionalBatchNorm2d Class
    Values:
    in_channels: the dimension of the class embedding (c) + noise vector (z), a scalar
    out_channels: the dimension of the activation tensor to be normalized, a scalar
    '''

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.class_scale_transform = nn.utils.spectral_norm(nn.Linear(in_channels, out_channels, bias=False))
        self.class_shift_transform = nn.utils.spectral_norm(nn.Linear(in_channels, out_channels, bias=False))

    def forward(self, x, y):
        normalized_image = self.bn(x)
        class_scale = (1 + self.class_scale_transform(y))[:, :, None, None]
        class_shift = self.class_shift_transform(y)[:, :, None, None]
        transformed_image = class_scale * normalized_image + class_shift
        return transformed_image

In [None]:
class AttentionBlock(nn.Module):
    '''
    AttentionBlock Class
    Values:
    channels: number of channels in input
    '''
    def __init__(self, channels):
        super().__init__()

        self.channels = channels

        self.theta = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.phi = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.g = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 2, kernel_size=1, padding=0, bias=False))
        self.o = nn.utils.spectral_norm(nn.Conv2d(channels // 2, channels, kernel_size=1, padding=0, bias=False))

        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, x):
        spatial_size = x.shape[2] * x.shape[3]

        # Apply convolutions to get query (theta), key (phi), and value (g) transforms
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), kernel_size=2)
        g = F.max_pool2d(self.g(x), kernel_size=2)

        # Reshape spatial size for self-attention
        theta = theta.view(-1, self.channels // 8, spatial_size)
        phi = phi.view(-1, self.channels // 8, spatial_size // 4)
        g = g.view(-1, self.channels // 2, spatial_size // 4)

        # Compute dot product attention with query (theta) and key (phi) matrices
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), dim=-1)

        # Compute scaled dot product attention with value (g) and attention (beta) matrices
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, x.shape[2], x.shape[3]))

        # Apply gain and residual
        return self.gamma * o + x

In [None]:
class GResidualBlock(nn.Module):
    '''
    GResidualBlock Class
    Values:
    c_dim: the dimension of conditional vector [c, z], a scalar
    in_channels: the number of channels in the input, a scalar
    out_channels: the number of channels in the output, a scalar
    '''

    def __init__(self, c_dim, in_channels, out_channels, scale_factor=2):
        super().__init__()

        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        self.bn1 = ClassConditionalBatchNorm2d(c_dim, in_channels)
        self.bn2 = ClassConditionalBatchNorm2d(c_dim, out_channels)

        self.activation = nn.ReLU()
        self.upsample_fn = nn.Upsample(scale_factor=scale_factor) # upsample occurs in every gblock

        self.mixin = (in_channels != out_channels)
        if self.mixin:
            self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def forward(self, x, y):
        # h := upsample(x, y)
        h = self.bn1(x, y)
        h = self.activation(h)
        h = self.upsample_fn(h)
        h = self.conv1(h)

        # h := conv(h, y)
        h = self.bn2(h, y)
        h = self.activation(h)
        h = self.conv2(h)

        # x := upsample(x)
        x = self.upsample_fn(x)
        if self.mixin:
            x = self.conv_mixin(x)

        return h + x

In [None]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
    z_dim: the dimension of random noise sampled, a scalar
    shared_dim: the dimension of shared class embeddings, a scalar
    base_channels: the number of base channels, a scalar
    bottom_width: the height/width of image before it gets upsampled, a scalar
    n_classes: the number of image classes, a scalar
    '''

    def __init__(self, base_channels=96, bottom_width=4, z_dim=120, shared_dim=128, n_classes=1000):
        super().__init__()

        n_chunks = 5    # 4 (generator blocks) + 1 (generator input)
        self.z_chunk_size = z_dim // n_chunks
        self.z_dim = z_dim
        self.shared_dim = shared_dim
        self.bottom_width = bottom_width

        # No spectral normalization on embeddings, which authors observe to cripple the generator
        self.shared_emb = nn.Embedding(n_classes, shared_dim)

        self.proj_z = nn.Linear(self.z_chunk_size, 16 * base_channels * bottom_width ** 2)

        # Can't use one big nn.Sequential since we are adding class+noise at each block
        self.g_blocks = nn.ModuleList([

            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 8 * base_channels),
                AttentionBlock(8 * base_channels),
            ]),
             nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 8 * base_channels, 4 * base_channels),
                AttentionBlock(4 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 4 * base_channels, 2 * base_channels),
                AttentionBlock(2 * base_channels),
            ]),
            nn.ModuleList([
                GResidualBlock(shared_dim + self.z_chunk_size, 2 * base_channels, base_channels),
                AttentionBlock(base_channels),
            ]),
        ])
        self.proj_o = nn.Sequential(
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(base_channels, 3, kernel_size=1, padding=0)),
            nn.Tanh(),
        )
        
    def loss (self, disc_fake_pred):
        
        return - disc_fake_pred.mean()

    def forward(self, z, y):
        '''
        z: random noise with size self.z_dim
        y: class embeddings with size self.shared_dim
            = NOTE =
            y should be class embeddings from self.shared_emb, not the raw class labels
        '''
        # Chunk z and concatenate to shared class embeddings
        zs = torch.split(z, self.z_chunk_size, dim=1)
        z = zs[0]
        ys = [torch.cat([y, z], dim=1) for z in zs[1:]]

        # Project noise and reshape to feed through generator blocks
        h = self.proj_z(z)
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)

        # Feed through generator blocks
        for idx, g_block in enumerate(self.g_blocks):
            h = g_block[0](h, ys[idx])
            h = g_block[1](h)

        # Project to 3 RGB channels with tanh to map values to [-1, 1]
        h = self.proj_o(h)

        return h

In [None]:
class DResidualBlock(nn.Module):
    '''
    DResidualBlock Class
    Values:
    in_channels: the number of channels in the input, a scalar
    out_channels: the number of channels in the output, a scalar
    downsample: whether to apply downsampling
    use_preactivation: whether to apply an activation function before the first convolution
    '''

    def __init__(self, in_channels, out_channels, downsample=True, use_preactivation=False, 
                 downsample_scale=2):
        super().__init__()

        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        self.activation = nn.ReLU()
        self.use_preactivation = use_preactivation  # apply preactivation in all except first dblock

        self.downsample = downsample    # downsample occurs in all except last dblock
        if downsample:
            self.downsample_fn = nn.AvgPool2d(downsample_scale)
        self.mixin = (in_channels != out_channels) or downsample
        if self.mixin:
            self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def _residual(self, x):
        if self.use_preactivation:
            if self.mixin:
                x = self.conv_mixin(x)
            if self.downsample:
                x = self.downsample_fn(x)
        else:
            if self.downsample:
                x = self.downsample_fn(x)
            if self.mixin:
                x = self.conv_mixin(x)
        return x

    def forward(self, x):
        # Apply preactivation if applicable
        if self.use_preactivation:
            h = F.relu(x)
        else:
            h = x

        h = self.conv1(h)
        h = self.activation(h)
        if self.downsample:
            h = self.downsample_fn(h)

        return h + self._residual(x)

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
    base_channels: the number of base channels, a scalar
    n_classes: the number of image classes, a scalar
    '''

    def __init__(self, base_channels=96, n_classes=1000):
        super().__init__()

        # For adding class-conditional evidence
        self.shared_emb = nn.utils.spectral_norm(nn.Embedding(n_classes, 8 * base_channels))

        self.d_blocks = nn.Sequential(
            DResidualBlock(3, base_channels, downsample=True, use_preactivation=False),
            AttentionBlock(base_channels),

            DResidualBlock(base_channels, 2 * base_channels, downsample=True, use_preactivation=True),
            AttentionBlock(2 * base_channels),

            DResidualBlock(2 * base_channels, 4 * base_channels, downsample=True, use_preactivation=True,
                          downsample_scale=2),
            AttentionBlock(4 * base_channels),
            
            DResidualBlock(4 * base_channels, 8 * base_channels, downsample=True, use_preactivation=True,
                          downsample_scale=2),
            AttentionBlock(8 * base_channels),

            nn.ReLU(inplace=True),
        )
        self.proj_o = nn.utils.spectral_norm(nn.Linear(8 * base_channels, 1))
        
    def loss(self, disc_real_pred, disc_fake_pred):
        
        d_loss_fake = torch.nn.ReLU()(1.0 + disc_fake_pred).mean()
        d_loss_real = torch.nn.ReLU()(1.0 - disc_real_pred).mean()
        
        return d_loss_real + d_loss_fake

    def forward(self, x, y=None):
        h = self.d_blocks(x)
        h = torch.sum(h, dim=[2, 3])

        # Class-unconditional output
        uncond_out = self.proj_o(h)
        if y is None:
            return uncond_out

        # Class-conditional output
        cond_out = torch.sum(self.shared_emb(y) * h, dim=1, keepdim=True)
        
        return uncond_out + cond_out

In [None]:
class RandomApplyEach(nn.Module):
    def __init__(self, transforms, p):
        super().__init__()
        self.transforms = transforms
        self.p = p

    def forward(self, img):
        for t in self.transforms:
            if self.p > torch.rand(1, device='cuda'):
                img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n    p={}'.format(self.p)
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

In [None]:
def train(dataloader, base_channels, z_dim, shared_dim, n_classes,
          generator, discriminator, gen_opt, disc_opt, epochs, weights_dir, summary_fid):
    
    cur_step = 0
    min_fids = np.array([2000,2000,2000])
    augmentation_transforms = [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(degrees=15, fill=0),
        transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), fill=0),
        transforms.RandomAffine(degrees=0, scale=(0.7, 1.3), fill=0),
        transforms.ColorJitter(brightness=0.5),
        transforms.ColorJitter(contrast=0.5),
        transforms.ColorJitter(saturation=0.5),
        transforms.ColorJitter(hue=0.2),
    ]
    p = torch.tensor(0.0, device=device)
    ada_target = 0.6
    update_iteration = 8
    adjustment_size = 250000 # number of images to reach p=1
    augmentation = RandomApplyEach(augmentation_transforms, p).to(device)
    ada_buf = torch.tensor([0.0, 0.0], device=device)
    fakes = torch.tensor([], device='cpu')
    
    for epoch in range(epochs):
        print('##############################')
        print('#epoch: {}'.format(epoch))
        print('##############################')
        
        for batch_ndx, sample in enumerate(loader):
            real, labels = sample[0], sample[1]
            batch_size = len(real)
            real = real.to(device)
            real_augmented = augmentation(real)

            for i in range(2):
                ### Update discriminator ###
                # Zero out the discriminator gradients
                disc_opt.zero_grad()
                # Get noise corresponding to the current batch_size 
                z = torch.randn(batch_size, z_dim, device=device)       # Generate random noise (z)
                y = labels.to(device).long()    # Generate a batch of labels (y), one for each class
                y_emb = generator.shared_emb(y)                         # Retrieve class embeddings (y_emb) from generator
                fake = generator(z, y_emb)
                fake = augmentation(fake.detach())

                disc_fake_pred = discriminator(fake, y)  
                disc_real_pred = discriminator(real_augmented, y)
                
                #loss
                disc_loss = discriminator.loss(disc_real_pred, disc_fake_pred)
                # Update gradients
                disc_loss.backward(retain_graph=True)
                # Update optimizer
                disc_opt.step()



            ### Update generator ###
            # Zero out the generator gradients
            gen_opt.zero_grad()

            fake = generator(z, y_emb)
            fake = augmentation(fake)
            disc_fake_pred = discriminator(fake, y)  
            #loss
            gen_loss =  generator.loss(disc_fake_pred)
            # Update gradients
            gen_loss.backward()
            # Update optimizer
            gen_opt.step()            
            
            torch.cuda.empty_cache()
            cur_step +=1
            fakes = torch.cat((fakes, fake.to('cpu')))
            
            if cur_step % update_iteration == 0:
                # Adaptive Data Augmentation
                pred_signs, n_pred = ada_buf
                r_t = pred_signs / n_pred

                sign = 1 if r_t > ada_target else -1

                augmentation.p = torch.clamp(augmentation.p + (sign * n_pred / adjustment_size), min=0, max=1)

                ada_buf = ada_buf * 0
            
            if cur_step % 500 == 0:
                print('===========================================================================')
                show_tensor_images(real)
                show_tensor_images(fake)
                val_fid = pfw.fid(fakes, real_m=real_m, real_s=real_s)
                fakes = torch.tensor([], device='cpu')
                print('FID: {}'.format(val_fid))
                print('augmentation p: {}'.format(augmentation.p))
                if (val_fid < min_fids).any():
                    idx = min_fids.argmax()
                    min_fids[idx] = val_fid
                    weights_dir_specific = weights_dir+'weights_{}/'.format(idx)
                    if not os.path.exists(weights_dir_specific):
                        os.makedirs(weights_dir_specific)
                    torch.save(generator.state_dict(), (weights_dir_specific+'gen.state_dict'))
                    path_image = weights_dir_specific+'images_gen.png'
                    show_tensor_images(fake, show=False, save=True, path=path_image)
                print('===========================================================================')

In [None]:
for subject in range(0,1):
    
    device = 'cuda'
    #Creo los paths 
    var_dir = './saved_variables/PugeaultASL_A_ADA/subject_{}'.format(subject)
    if not os.path.exists(var_dir):
        os.makedirs(var_dir)
    weights_dir = 'generators_weights/PugeaultASL_A_ADA/subject_{}'.format(subject)
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    data_dir = 'numpy_data/PugeaultASL_A_ADA/'
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    
    summary_fid = open(weights_dir+"summary_fid.txt", "a")
    
    #charge data
    n_classes, x_train, y_train, x_test, y_test = load_dataset_with_subject(subject_test=subject)
    
    batch_size = 32
    dataset = CustomDataset((x_train,y_train))#, transform=transforms.ToTensor())
    print("Creating dataset object")
    
    loader = data.DataLoader(dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=4)

    # Initialize models
    print("Creating models")
    base_channels = 96
    z_dim = 120
    shared_dim = 128
    generator = Generator(base_channels=base_channels, bottom_width=2, z_dim=z_dim, shared_dim=shared_dim, n_classes=n_classes).to(device)
    discriminator = Discriminator(base_channels=base_channels, n_classes=n_classes).to(device)

    # Initialize weights orthogonally
    for module in generator.modules():
        if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
            nn.init.orthogonal_(module.weight)
    for module in discriminator.modules():
        if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
            nn.init.orthogonal_(module.weight)

    # Initialize optimizers
    gen_opt = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.999), eps=1e-6)
    disc_opt = torch.optim.Adam(discriminator.parameters(), lr=4e-4, betas=(0.0, 0.999), eps=1e-6)
        
    # Setup FID
    print("Calculating FID parameters")
    pfw.set_config(batch_size=batch_size, device=device)
    if os.path.isfile(var_dir+'fid_stats_ASL.pkl'):
        with open(var_dir+'fid_stats_ASL.pkl', 'rb') as f:
            real_m, real_s = pickle.load(f)        
    else:
        real_m, real_s = pfw.get_stats((dataset.x))
        with open(var_dir+'fid_stats_ASL.pkl', 'wb') as f:
            pickle.dump([real_m, real_s], f)
    print("FID parameters calculated!")
    
    
    
    train(loader, base_channels, z_dim, shared_dim, n_classes,
          generator, discriminator, gen_opt, disc_opt, 10, weights_dir, summary_fid)
    
    summary_fid.close()

In [None]:
(dataset.x).shape

## Training GANS with reduce dataset

In [None]:
factor_reduces = [0.05, 0.01, 0.005, 0.004]
n_classes, _, _, _, _ = load_lsa16_rotated()
for factor_reduce in factor_reduces:
    device = 'cuda'
    #Creo los paths 
    weights_dir = 'generators_weights/lsa_16_rotated_reduce/{}/'.format(str(factor_reduce))
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    data_dir = 'numpy_data/lsa_16_rotated_reduce/'
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
        
    #Creo un txt para registrar los FID
    summary_fid = open(weights_dir+"summary_fid.txt", "a") 
    
    #load data
    x_train = np.load(data_dir+'x_train.npy')
    y_train = np.load(data_dir+'y_train.npy')
    
    #spliteo segun el factor_reduce
    x_train, _, y_train, _ = train_test_split(
                                            x_train, y_train, train_size=factor_reduce, 
                                            test_size=(1-factor_reduce),stratify=y_train)
    
    dataset = CustomDataset((x_train,y_train), transform=transforms.ToTensor())
    batch_size = 16
    loader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, num_workers=4)
    
   # Initialize models
    base_channels = 96
    z_dim = 120  
    shared_dim = 128
    generator = Generator(base_channels=base_channels, bottom_width=4, z_dim=z_dim, shared_dim=shared_dim, n_classes=n_classes).to(device)
    discriminator = Discriminator(base_channels=base_channels, n_classes=n_classes).to(device)

    # Initialize weights orthogonally
    for module in generator.modules():
        if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
            nn.init.orthogonal_(module.weight)
    for module in discriminator.modules():
        if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding)):
            nn.init.orthogonal_(module.weight)

    # Initialize optimizers
    gen_opt = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.999), eps=1e-6)
    disc_opt = torch.optim.Adam(discriminator.parameters(), lr=4e-4, betas=(0.0, 0.999), eps=1e-6)
    
    train(loader, base_channels, z_dim, shared_dim, n_classes,
          generator, discriminator, gen_opt, disc_opt, 20000, weights_dir, summary_fid)
    
    summary_fid.close()

## Generating Data

In [None]:
for q in range(2,5):
    print('##############')
    print('#  subject {} #'.format(q))
    print('##############')
    weights_path = 'generators_weights/PugeaultASL_A_transforms/subject_{}/'.format(q)
    numpy_data_path = 'numpy_data/PugeaultASL_A_transforms/subject_{}/'.format(q)
    if not os.path.exists(numpy_data_path):
        os.makedirs(numpy_data_path)
    total_images = 1792
    batch_size = 64
    n_classes, x_train, y_train, _, _ = load_dataset_with_subject(subject_test=q)
    # Initialize models
    base_channels = 96
    z_dim = 120  
    shared_dim = 128
    device = 'cuda'
    generator = Generator(base_channels=base_channels, bottom_width=4, 
                          z_dim=z_dim, shared_dim=shared_dim, n_classes=n_classes).to(device)
    
    #100% imagenes generadas
    x_aug_100 = []
    y_aug_100 = []
    
    for k in range(3):
        specific_weights_path = weights_path+('weights_{}/gen.state_dict'.format(k))
        generator.load_state_dict(torch.load(specific_weights_path))
        generator.eval()
        
        for j in range(total_images//batch_size):
            for i in range(n_classes):

                #creo imagenes falsas y las mergeo
                z = torch.randn(batch_size, z_dim, device=device)                 # Generate random noise (z)
                y = (torch.ones(batch_size) * i).to(device).long()    # Generate a batch of labels (y), one for each class
                y_emb = generator.shared_emb(y)                                  # Retrieve class embeddings (y_emb) from generator
                fake = generator(z, y_emb)

                #acomodo las dimensiones, escala y tipo
                image_unflat = fake.detach().cpu()
                final_images = image_unflat.permute(0,2, 3, 1).numpy()

                x_aug_100.append(final_images)
                y_tmp = np.ones(batch_size,)* i
                y_aug_100.append(y_tmp)
                    
    
    #dataset 100% GAN
    x_aug_100 = (np.asarray(x_aug_100)).reshape((-1,32,32,3))
    y_aug_100 = (np.asarray(y_aug_100)).reshape((-1))
    
    #dataset 50/50
    x_aug_50, _, y_aug_50, _ = train_test_split(
        x_aug_100, y_aug_100, train_size=0.5, test_size=0.5, stratify=y_aug_100)
    
    x_aug_50 =np.concatenate((x_aug_50, x_train), axis=0)
    y_aug_50 = np.concatenate((y_aug_50, y_train), axis=0)
    
    #dataset 75/25 (GAN/normal)
    '''Como el dataset normal tiene la mitad de datos que x_aug_100, el 25% de datos normales 
    con respecto al dataset GAN equivale a la mitad del dataset original'''
    x_aug_75, _, y_aug_75, _ = train_test_split(
        x_aug_100, y_aug_100, train_size=0.75, test_size=0.25, stratify=y_aug_100)
    
    x_train_25, _, y_train_25, _ = train_test_split(
        x_train, y_train, train_size=0.5, test_size=0.5, stratify=y_train)
    
    x_aug_75 =np.concatenate((x_aug_75, x_train_25), axis=0)
    y_aug_75 = np.concatenate((y_aug_75, y_train_25), axis=0)
    
    #dataset 25/75 (GAN/normal)
    '''Como el dataset normal tiene la mitad de datos que x_aug_100, el 25% de datos GAN 
    con respecto al dataset original es del 12.5%'''
    x_aug_25, _, y_aug_25, _ = train_test_split(
        x_aug_100, y_aug_100, train_size=0.125, test_size=0.875, stratify=y_aug_100)
    
    x_train_75, _, y_train_75, _ = train_test_split(
        x_train, y_train, train_size=0.75, test_size=0.25, stratify=y_train)
    
    x_aug_25 =np.concatenate((x_aug_25, x_train_75), axis=0)
    y_aug_25 = np.concatenate((y_aug_25, y_train_75), axis=0)
    
    #Desordeno el nuevo dataset 100% GAN
    shuffler = np.random.permutation(x_aug_100.shape[0])
    x_aug_100 = x_aug_100[shuffler]
    y_aug_100 = y_aug_100[shuffler]
    print(" dataset 100% GAN shape: {}, y final:{}".format(x_aug_100.shape,y_aug_100.shape))

    #Desordeno el nuevo dataset 50/50
    shuffler = np.random.permutation(x_aug_50.shape[0])
    x_aug_50 = x_aug_50[shuffler]
    y_aug_50 = y_aug_50[shuffler]
    print(" dataset 50/50 shape: {}, y final:{}".format(x_aug_50.shape,y_aug_50.shape))

    #Desordeno el nuevo dataset 75/25 (GAN/normal)
    shuffler = np.random.permutation(x_aug_75.shape[0])
    x_aug_75 = x_aug_75[shuffler]
    y_aug_75 = y_aug_75[shuffler]
    print(" dataset 75/25 shape: {}, y final:{}".format(x_aug_75.shape,y_aug_75.shape))

    #Desordeno el nuevo dataset 25/75 (GAN/normal)
    shuffler = np.random.permutation(x_aug_25.shape[0])
    x_aug_25 = x_aug_25[shuffler]
    y_aug_25 = y_aug_25[shuffler]
    print(" dataset 25/75 shape: {}, y final:{}".format(x_aug_25.shape,y_aug_25.shape))

    #Guardamos los datos mergeados
    np.save(numpy_data_path+'x_aug_100',x_aug_100)
    np.save(numpy_data_path+'y_aug_100',y_aug_100)
    np.save(numpy_data_path+'x_aug_50',x_aug_50)
    np.save(numpy_data_path+'y_aug_50',y_aug_50)
    np.save(numpy_data_path+'x_aug_75',x_aug_75)
    np.save(numpy_data_path+'y_aug_75',y_aug_75)
    np.save(numpy_data_path+'x_aug_25',x_aug_25)
    np.save(numpy_data_path+'y_aug_25',y_aug_25)

    #limpiamos memoria
    torch.cuda.empty_cache()
    del generator, x_train, y_train

## Generating Data 2

In [None]:
reduces = [0.005]
for factor_reduce in reduces:
    print('##############')
    print('#  Factor_reduce {} #'.format(factor_reduce))
    print('##############')
    
    weights_path = 'generators_weights/lsa_16_rotated_reduce/{}/'.format(factor_reduce)
    numpy_data_path = '/media/willys/MULTIBOOT/tesis/numpy_data/lsa_16_rotated_reduce/{}/'.format(factor_reduce)
    if not os.path.exists(numpy_data_path):
        os.makedirs(numpy_data_path)
    
    #original dataset
    x_train = np.load('numpy_data/lsa_16_rotated_reduce/x_train.npy')
    y_train = np.load('numpy_data/lsa_16_rotated_reduce/y_train.npy')
    #splt
    x_train, _, y_train, _ = train_test_split(
        x_train, y_train, train_size=factor_reduce, stratify=y_train)
    print('x_train shape: {}, y_train shape{}'.format(x_train.shape,y_train.shape))
    
    n_classes, _, _, _, _ = load_lsa16_rotated()
    total_images_of_generator = (x_train.shape[0]//n_classes)+1
    
    batch_size = 16 if total_images_of_generator > 16 else int(total_images_of_generator)
    print('batch_size: {}'.format(batch_size))
    # Initialize models
    base_channels = 96
    z_dim = 120  
    shared_dim = 128
    device = 'cuda'
    generator = Generator(base_channels=base_channels, bottom_width=4, 
                          z_dim=z_dim, shared_dim=shared_dim, n_classes=n_classes).to(device)
    
    #100% imagenes generadas
    x_gan_300 = []
    y_gan_300 = []
    
    for k in range(3):
        specific_weights_path = weights_path+('weights_{}/gen.state_dict'.format(k))
        generator.load_state_dict(torch.load(specific_weights_path))
        generator.eval()
        
        for j in range(round(total_images_of_generator/batch_size)):
            for i in range(n_classes):

                #creo imagenes falsas y las mergeo
                z = torch.randn(batch_size, z_dim, device=device)                 # Generate random noise (z)
                y = (torch.ones(batch_size) * i).to(device).long()    # Generate a batch of labels (y), one for each class
                y_emb = generator.shared_emb(y)                                  # Retrieve class embeddings (y_emb) from generator
                fake = generator(z, y_emb)

                #acomodo las dimensiones, escala y tipo
                image_unflat = fake.detach().cpu()
                final_images = image_unflat.permute(0,2, 3, 1).numpy()

                x_gan_300.append(final_images)
                y_tmp = np.ones(batch_size,)* i
                y_gan_300.append(y_tmp)
                    
    
    #300% data gen
    x_gan_300 = (np.asarray(x_gan_300)).reshape((-1,64,64,3))
    y_gan_300 = (np.asarray(y_gan_300)).reshape((-1))
    print('shape x_gan_300: {}, y_gan_300: {}'.format(x_gan_300.shape, y_gan_300.shape))
    
    #200% y 100% data gen 
    x_gan_200, x_gan_100, y_gan_200, y_gan_100 = train_test_split(
        x_gan_300, y_gan_300, train_size=0.66, test_size=0.33, stratify=y_gan_300)
    
    #dataset original + 25% datos generados
    '''x_gan_25, _, y_gan_25, _ = train_test_split(
        x_gan_100, y_gan_100, train_size=int(round(x_train.shape[0]/4)),
        stratify=y_gan_100)
    
    x_aug_25 = np.concatenate((x_gan_25, x_train), axis=0)
    y_aug_25 = np.concatenate((y_gan_25, y_train), axis=0)'''
    
    #dataset original + 50% datos generados
    x_gan_50, _, y_gan_50, _ = train_test_split(
        x_gan_100, y_gan_100, train_size=int(round(x_train.shape[0]/2)), stratify=y_gan_100)
    
    x_aug_50 =np.concatenate((x_gan_50, x_train), axis=0)
    y_aug_50 = np.concatenate((y_gan_50, y_train), axis=0)
    
    #dataset original + 75% datos generados
    x_gan_75, _, y_gan_75, _ = train_test_split(
        x_gan_100, y_gan_100, train_size=(int(round(x_train.shape[0]*3/4))), 
        stratify=y_gan_100)
    
    x_aug_75 =np.concatenate((x_gan_75, x_train), axis=0)
    y_aug_75 = np.concatenate((y_gan_75, y_train), axis=0)
    
    #dataset original + (100%, 200%, 300%) datos generados
    x_aug_100 =np.concatenate((x_gan_100, x_train), axis=0)
    y_aug_100 = np.concatenate((y_gan_100, y_train), axis=0)
    
    x_aug_200 =np.concatenate((x_gan_200, x_train), axis=0)
    y_aug_200 = np.concatenate((y_gan_200, y_train), axis=0)
    
    x_aug_300 =np.concatenate((x_gan_300, x_train), axis=0)
    y_aug_300 = np.concatenate((y_gan_300, y_train), axis=0)
    
    ## Desordeno los dataset mergeados 
    # x_aug_25 y y_aug_25
    ''' shuffler = np.random.permutation(x_aug_25.shape[0])
    x_aug_25 = x_aug_25[shuffler]
    y_aug_25 = y_aug_25[shuffler]
    print("x_aug_25 shape: {}, y_aug_25 shape:{}".format(x_aug_25.shape,y_aug_25.shape))'''

    # x_aug_50 y y_aug_50
    shuffler = np.random.permutation(x_aug_50.shape[0])
    x_aug_50 = x_aug_50[shuffler]
    y_aug_50 = y_aug_50[shuffler]
    print("x_aug_50 shape: {}, y_aug_50 shape:{}".format(x_aug_50.shape,y_aug_50.shape))
    
    # x_aug_75 y y_aug_75
    shuffler = np.random.permutation(x_aug_75.shape[0])
    x_aug_75 = x_aug_75[shuffler]
    y_aug_75 = y_aug_75[shuffler]
    print("x_aug_75 shape: {}, y_aug_75 shape:{}".format(x_aug_75.shape,y_aug_75.shape))
    
    # x_aug_100 y y_aug_100
    shuffler = np.random.permutation(x_aug_100.shape[0])
    x_aug_100 = x_aug_100[shuffler]
    y_aug_100 = y_aug_100[shuffler]
    print("x_aug_100 shape: {}, y_aug_100 shape:{}".format(x_aug_100.shape,y_aug_100.shape))
    
    # x_aug_200 y y_aug_200
    shuffler = np.random.permutation(x_aug_200.shape[0])
    x_aug_200 = x_aug_200[shuffler]
    y_aug_200 = y_aug_200[shuffler]
    print("x_aug_200 shape: {}, y_aug_200 shape:{}".format(x_aug_200.shape,y_aug_200.shape))
    
    # x_aug_300 y y_aug_300
    shuffler = np.random.permutation(x_aug_300.shape[0])
    x_aug_300 = x_aug_300[shuffler]
    y_aug_300 = y_aug_300[shuffler]
    print("x_aug_300 shape: {}, y_aug_300 shape:{}".format(x_aug_300.shape,y_aug_300.shape))

    #Guardamos los datos mergeados
    np.save(numpy_data_path+'x_train',x_train)
    np.save(numpy_data_path+'y_train',y_train)
    '''np.save(numpy_data_path+'x_aug_25',x_aug_25)
    np.save(numpy_data_path+'y_aug_25',y_aug_25)'''
    np.save(numpy_data_path+'x_aug_50',x_aug_50)
    np.save(numpy_data_path+'y_aug_50',y_aug_50)
    np.save(numpy_data_path+'x_aug_75',x_aug_75)
    np.save(numpy_data_path+'y_aug_75',y_aug_75)
    np.save(numpy_data_path+'x_aug_100',x_aug_100)
    np.save(numpy_data_path+'y_aug_100',y_aug_100)
    np.save(numpy_data_path+'x_aug_200',x_aug_200)
    np.save(numpy_data_path+'y_aug_200',y_aug_200)
    np.save(numpy_data_path+'x_aug_300',x_aug_300)
    np.save(numpy_data_path+'y_aug_300',y_aug_300)
    np.save(numpy_data_path+'x_gan_100',x_gan_100)
    np.save(numpy_data_path+'y_gan_100',y_gan_100)
    np.save(numpy_data_path+'x_gan_200',x_gan_200)
    np.save(numpy_data_path+'y_gan_200',y_gan_200)
    np.save(numpy_data_path+'x_gan_300',x_gan_300)
    np.save(numpy_data_path+'y_gan_300',y_gan_300)

    #limpiamos memoria
    torch.cuda.empty_cache()
    del generator, x_train, y_train, x_aug_25, y_aug_25, x_aug_50, y_aug_50, x_aug_75, y_aug_75
    del x_aug_100, y_aug_100, x_aug_200, y_aug_200, x_aug_300, y_aug_300, x_gan_100, y_gan_100
    del x_gan_200, y_gan_200, x_gan_300, y_gan_300