In [1]:
# Implement a Cycle-GAN for the pair of MNIST-USPS datasets. Use the
# output of the converted target in the source classifier and report the result
# on adaptation.

In [2]:
#in this script we will first implement a cycle gan
#and then we will train a resnet50 based classifier on the source dataset
#then using cyclegan we will translate the target dataset to source domain
#and then use the trained clssifier to classify the translated target dataset

In [3]:
#the cycle gan follows the implementation from the paper: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Zhu et al.

In [4]:
experiment_name = 'mnist_cycle_adaptation'
version = 'v0'

#concat experiment name and version to get experiment id
experiment_id = experiment_name + '_' + version

model_path = 'saved_models/cycle_gan'

results_path = 'Results/cycle_gan'

In [5]:
#GPU name
#
GPU_NAME = 'cuda:1'

In [6]:
import torchvision.utils as vutils

In [7]:
#neceassary imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable, Function
# from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
import torch.backends.cudnn as cudnn

import numpy as np

#import utils
import os
import itertools
import time
import copy
import random
import math


In [8]:
#imports for visualizations
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [9]:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image
from torchsummary import summary
#import tenserboard
from torch.utils.tensorboard import SummaryWriter

#initialize tensorboard writer
#create writer for tensorboard
writer = SummaryWriter(f'runs/'+experiment_id)

2022-11-10 13:01:34.174401: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [10]:
#enable cudnn
cudnn.benchmark = True
#cuda cache clear
torch.cuda.empty_cache()

#set random seed
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)



In [11]:
#device
device = torch.device(GPU_NAME if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [27]:
#defining the hyperparameters

BATCH_SIZE = 128
EPOCHS = 1
NUM_EPOCHS_PRETRAINING = 5

LOAD_MODEL = False

#WHGAN parameters
NUM_EPOCHS_GAN = 1
CRITIC_ITERATIONS = 5
LEARNING_RATE_GAN = 1e-4
LAMBDA_GP = 10
CYCLE_LOSS_WEIGHT =1


IMAGE_SIZE = 224
CHANNELS_IMG = 1
A_Channels = 1
B_Channels = 1

NUM_CLASSES = 10


#parameters of ADAM optimizer
LEARNING_RATE = 0.001
BETA_1 = 0.9
BETA_2 = 0.999

#parameters of SGD optimizer with momentum
MOMENTUM = 0.9




##### utility functions #####

In [28]:
def generate_images(a, b, ab_gen, ba_gen, samples_path = results_path, epoch=0):
    ab_gen.eval()
    ba_gen.eval()

    b_fake = ab_gen(a)
    a_fake = ba_gen(b)

    a_imgs = torch.zeros((a.shape[0] * 2, 3, a.shape[2], a.shape[3]))
    b_imgs = torch.zeros((b.shape[0] * 2, 3, b.shape[2], b.shape[3]))

    even_idx = torch.arange(start=0, end=a.shape[0] * 2, step=2)
    odd_idx = torch.arange(start=1, end=a.shape[0] * 2, step=2)

    a_imgs[even_idx] = a.cpu()
    a_imgs[odd_idx] = b_fake.cpu()

    b_imgs[even_idx] = b.cpu()
    b_imgs[odd_idx] = a_fake.cpu()

    rows = math.ceil((a.shape[0] * 2) ** 0.5)
    a_imgs_ = vutils.make_grid(a_imgs, normalize=True, nrow=rows)
    b_imgs_ = vutils.make_grid(b_imgs, normalize=True, nrow=rows)

    vutils.save_image(a_imgs_, os.path.join(samples_path, 'a2b_' + str(epoch) + '.png'))
    vutils.save_image(b_imgs_, os.path.join(samples_path, 'b2a_' + str(epoch) + '.png'))

    #plot the images on tensorboard
    writer.add_image('a2b', a_imgs_, epoch)
    writer.add_image('b2a', b_imgs_, epoch)

    

## Cycle-GAN

### Models

##### the model architectures for CycleGan was inspired from github repo :  https://github.com/s-chh/Pytorch-CycleGAN-Digits

In [14]:
#we will have both Generator and Discriminator inspired from Resnet50 based architecture

In [15]:
#we define TWO main architectural blocks of the generator and discriminator: Convolutional Block and Residual Block

In [16]:
#we define the general convolutional block, which is used in both generator and discriminator
#convolutional block is a function which takes in the number of input channels, number of output channels, kernel size, stride, padding and whether to use batch normalization or not and transpose convolution or not
#and returns a nn.Sequential object which contains the convolutional block
def ConvolutionalBlock(in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, use_batchnorm=True, use_transpose=False):
    block = []
    if use_transpose:
        block.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias = not use_batchnorm))
    else:
        block.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias = not use_batchnorm))
    if use_batchnorm:
        block.append(nn.BatchNorm2d(out_channels))
        #return
    return nn.Sequential(*block)


In [17]:
#we definr residual block as a class
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv_block1 = ConvolutionalBlock(channels, channels, kernel_size = 3, stride = 1, padding = 1, use_batchnorm = True)
        self.conv_block2 = ConvolutionalBlock(channels, channels, kernel_size = 3, stride = 1, padding = 1, use_batchnorm = True)
        #def call
    def __call__(self, x):
        x = F.relu(self.conv_block1(x))
        return x + self.conv_block2(x)

In [18]:
# function fror initializing the weights of the model
def weights_init(model):
    for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)
    

##### generator

In [19]:

class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, conv_dim=64):
        super(Generator, self).__init__()
        self.conv1 = ConvolutionalBlock(in_channels, conv_dim, kernel_size = 5, stride = 1, padding = 2, use_batchnorm = True)
        self.conv2 = ConvolutionalBlock(conv_dim, conv_dim*2, kernel_size = 3, stride = 2, padding = 1, use_batchnorm = True)
        self.conv3 = ConvolutionalBlock(conv_dim*2, conv_dim*4, kernel_size = 3, stride = 2, padding = 1, use_batchnorm = True)
        self.residual_block1 = ResidualBlock(conv_dim*4)
        self.trans_conv1 = ConvolutionalBlock(conv_dim*4, conv_dim*2, kernel_size = 3, stride = 1, padding = 1, use_batchnorm = True, use_transpose = True)
        self.trans_conv2 = ConvolutionalBlock(conv_dim*2, conv_dim, kernel_size = 3, stride = 2, padding = 1, use_batchnorm = True, use_transpose = True)
        self.conv4 = ConvolutionalBlock(conv_dim, out_channels, kernel_size = 5, stride = 1, padding = 2, use_batchnorm = False)

        #initialize weights
        self.apply(weights_init)

        #forward
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.residual_block1(x))
        x = F.relu(self.trans_conv1(x))
        x = F.relu(self.trans_conv2(x))
        x = torch.tanh(self.conv4(x))
        return x
        

##### discriminator

In [20]:
class Discriminator(nn.Module):
    def __init__(self, channels=1, conv_dim=64):
        super(Discriminator, self).__init__()
        self.conv1 = ConvolutionalBlock(channels, conv_dim, use_batchnorm = False)
        self.conv2 = ConvolutionalBlock(conv_dim, conv_dim*2)
        self.conv3 = ConvolutionalBlock(conv_dim*2, conv_dim*4)
        self.conv4 = ConvolutionalBlock(conv_dim*4, 1, kernel_size = 3, stride = 1, padding = 1, use_batchnorm = False)

        #initialize weights
        self.apply(weights_init)
    #forward

    def forward(self, x):
        alpha = 0.2
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = self.conv4(x)
        x = x.reshape([x.shape[0], -1]).mean(1)
        return x

#### creating networks

In [21]:
#first create generators
#we create generators for both the domains
#first from source domain to target domain
generator_source_to_target = Generator(in_channels=A_Channels, out_channels=B_Channels)
#then from target domain to source domain
generator_target_to_source = Generator(in_channels=A_Channels, out_channels=B_Channels)

#now define the discriminators
#first for the source domain
discriminator_source = Discriminator(channels=A_Channels)
#then for the target domain
discriminator_target = Discriminator(channels=B_Channels)

#### dataloaders

In [22]:
#define transforms
#define the transform for the dataset
mean = np.array([0.5])
std = np.array([0.5])
transform = transforms.Compose(
    [
  
    # if torch tensor then leave as it is, else convert to tensor
    transforms.Lambda(lambda x: x if isinstance(x, torch.Tensor) else transforms.functional.to_tensor(x)),
    #

    #resize to 224x224
    transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),


    # normalize
    transforms.Normalize(mean, std),
    ]
)

#### source data - MNIST
#### target data - USPS

In [24]:
#load train data
source_data = datasets.MNIST(root='./data/', download=True, transform=transform) 
#load train data
source_loader = torch.utils.data.DataLoader(dataset=source_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)
                                              

In [25]:

#load test data
#USPS dataset
target_data = datasets.USPS(root='./data/', download=True, transform=transform)
#load test data
target_loader = torch.utils.data.DataLoader(dataset=target_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

#### cycle gan training 

In [23]:
#we define a funcrion which will train our cycle gan
#this function will take as input: the generators, the discriminators, the source and target dataloaders, the number of epochs, the learning rate, the weight of the cycle consistency loss, the weight of the identity loss


In [30]:
#we will write losses and accuracies to tensorboard
#we will also plot the images generated by the generators to tensorboard
def train_cycle_gan(generator_source_to_target, generator_target_to_source, discriminator_source, discriminator_target, source_loader, target_loader, num_epochs= NUM_EPOCHS_GAN):
    #define optimizers
    #we will use Adam optimizer for all the models
    generator_optimizer = torch.optim.Adam(list(generator_source_to_target.parameters()) + list(generator_target_to_source.parameters()), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)
    discriminator_optimizer = torch.optim.Adam(list(discriminator_source.parameters()) + list(discriminator_target.parameters()), lr=0.0002, betas=(0.5, 0.999), weight_decay=2e-5)

    iters_per_epoch = min(len(source_loader), len(target_loader)) 

    # Fix images for vizibility and plotting purposes to tensorboard
    source_fixed = iter(source_loader).next()[0]
    target_fixed = iter(target_loader).next()[0]

    #send the models to gpu
    generator_source_to_target = generator_source_to_target.to(device)
    generator_target_to_source = generator_target_to_source.to(device)
    discriminator_source = discriminator_source.to(device)
    discriminator_target = discriminator_target.to(device)

    #send the fixed images to gpu
    source_fixed = source_fixed.to(device)
    target_fixed = target_fixed.to(device)

    tensorboard_step = 0

    #iterate for the number of epochs
    for epoch in range(num_epochs):

        #have all models train mode
        generator_source_to_target.train()
        generator_target_to_source.train()
        discriminator_source.train()
        discriminator_target.train()

        #iterate through the datasets, we combine both by zip and enumerate
        for i, (source, target) in enumerate(zip(source_loader, target_loader)):

            #we do not need labels for the cycle gan
            #so load the images
            source_real , _ = source
            #move to gpu
            source_real = source_real.to(device)
            target_real , _ = target
            #move to gpu
            target_real = target_real.to(device)

            #fake images
            #we will generate fake images from the real images
            source_fake = generator_target_to_source(target_real)
            target_fake = generator_source_to_target(source_real)   

            #training the discriminators
            #we will train the discriminators to distinguish between real and fake images
            source_real_discriminator_output = discriminator_source(source_real)
            source_fake_discriminator_output = discriminator_source(source_fake.detach())
            #loss for the source domain
            source_discriminator_loss = (torch.mean((source_real_discriminator_output - 1)**2) + torch.mean(source_fake_discriminator_output**2))/2

            target_real_discriminator_output = discriminator_target(target_real)
            target_fake_discriminator_output = discriminator_target(target_fake.detach())
            #loss for the target domain
            target_discriminator_loss = (torch.mean((target_real_discriminator_output - 1)**2) + torch.mean(target_fake_discriminator_output**2))/2

            #sum the losses to get the total loss
            discriminator_loss = source_discriminator_loss + target_discriminator_loss

            #apply zero_grad to the optimizers
            discriminator_optimizer.zero_grad()
            #backpropagate the loss
            discriminator_loss.backward()
            #update the parameters
            discriminator_optimizer.step()

            #write the losses to tensorboard
            writer.add_scalar('BATCH_Total_Discriminator_Loss', discriminator_loss, tensorboard_step)
            writer.add_scalar('BATCH_Source_Discriminator_Loss', source_discriminator_loss, tensorboard_step)
            writer.add_scalar('BATCH_Target_Discriminator_Loss', target_discriminator_loss, tensorboard_step)



            #training the generators
            #we will train the generators to fool the discriminators
            #we will also train the generators to preserve the identity of the images
            #we will also train the generators to preserve the cycle consistency
            #we will train the generators to minimize the adversarial loss

            source_fake_generator_output = discriminator_source(source_fake)
            target_fake_generator_output = discriminator_target(target_fake)

            #adversarial loss
            source_generator_loss = torch.mean((source_fake_generator_output - 1)**2)
            target_generator_loss = torch.mean((target_fake_generator_output - 1)**2)
            #sum the losses to get the total loss
            generator_loss = source_generator_loss + target_generator_loss

            #cycle consistency loss
            source_generator_cycle_consistency_loss = (source_real - generator_target_to_source(target_fake)).abs().mean()
            target_generator_cycle_consistency_loss = (target_real - generator_source_to_target(source_fake)).abs().mean()
            #sum the losses to get the total loss
            generator_cycle_consistency_loss = source_generator_cycle_consistency_loss + target_generator_cycle_consistency_loss

            cycle_gan_loss = generator_loss + CYCLE_LOSS_WEIGHT * generator_cycle_consistency_loss

            #zero_grad the optimizers
            generator_optimizer.zero_grad()
            #backpropagate the loss
            cycle_gan_loss.backward()
            #update the parameters
            generator_optimizer.step()

            #write the losses to tensorboard
            writer.add_scalar('BATCH_Total_Generator_Loss', cycle_gan_loss, tensorboard_step)
            writer.add_scalar('BATCH_Source_Generator_Loss', source_generator_loss, tensorboard_step)
            writer.add_scalar('BATCH_Target_Generator_Loss', target_generator_loss, tensorboard_step)
            writer.add_scalar('BATCH_Cycle_Consistency_Loss', generator_cycle_consistency_loss, tensorboard_step)
            #cycle gan loss
            writer.add_scalar('BATCH_Cycle_GAN_Loss', cycle_gan_loss, tensorboard_step)

            #print the losses if the iteration is a multiple of 100
            if i % 100 == 0:
                #print all the losses
                print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}, Cycle Consistency Loss: {:.4f}, Cycle GAN Loss: {:.4f}'
                        .format(epoch+1, num_epochs, i+1, len(source_loader), discriminator_loss.item(), generator_loss.item(), generator_cycle_consistency_loss.item(), cycle_gan_loss.item()))
                
                    
                




            #we will update teensorboard steps
            tensorboard_step += 1

        #we will save the models after every epoch
        torch.save(generator_source_to_target.state_dict(), os.path.join(model_path, 'generator_source_to_target_{}.pth'.format(epoch+1)))
        torch.save(generator_target_to_source.state_dict(), os.path.join(model_path, 'generator_target_to_source_{}.pth'.format(epoch+1)))
        torch.save(discriminator_source.state_dict(), os.path.join(model_path, 'discriminator_source_{}.pth'.format(epoch+1)))
        torch.save(discriminator_target.state_dict(), os.path.join(model_path, 'discriminator_target_{}.pth'.format(epoch+1)))
        #write images fixed to tensorboard
        writer.add_image('Source_Real', torchvision.utils.make_grid(source_real.cpu().data, normalize=True), epoch+1)
        writer.add_image('Source_Fake', torchvision.utils.make_grid(source_fake.cpu().data, normalize=True), epoch+1)
        writer.add_image('Target_Real', torchvision.utils.make_grid(target_real.cpu().data, normalize=True), epoch+1)
        writer.add_image('Target_Fake', torchvision.utils.make_grid(target_fake.cpu().data, normalize=True), epoch+1)

        #pass arguments to the function generate images
        generate_images(source_fixed, target_fixed, generator_source_to_target, generator_target_to_source, results_path, epoch=epoch+1)

    #generate the final images
    generate_images(source_fixed, target_fixed, generator_source_to_target, generator_target_to_source, results_path)

    #return the models
    return generator_source_to_target, generator_target_to_source, discriminator_source, discriminator_target
        
   


            


      


   

In [32]:
#call the function on the models and the dataloaders and get the models back
generator_source_to_target, generator_target_to_source, discriminator_source, discriminator_target = train_cycle_gan(generator_source_to_target, generator_target_to_source, discriminator_source, discriminator_target, source_loader, target_loader)

RuntimeError: The size of tensor a (32) must match the size of tensor b (7) at non-singleton dimension 3

#### Pretraining part 

In [33]:
# in this part we will train a resnet50 based classifier on source data


#### Step 1: Initialize model with the best available weights

In [34]:
#creating the model
weights = ResNet50_Weights.DEFAULT
#send weight sto gpu
# weights = weights.to(device)
#sending the model to GPU

model = resnet50(weights=weights).to(device)

In [None]:
#print model summary
summary(model, (3, 224, 224))