In [1]:
#in this script we will implement the adversarial discriminative domain adaptation algorithm
#the algorithm is described in the paper "Adversarial Discriminative Domain Adaptation" by Ganin et al.
#we will use MNIST as the source domain and USPS as the target domain

In [2]:
#the whole training procedure has 3 steps:
#1. train the a (feature extractor CNN + Classifier ) on the source domain
#2. train a GAN on the features extracted from the source domain and the target domain to minimize the discrepancy between the two domains
#3. test the classifier on the target domain

#1. in step 1 both the feature extractor CNN and the classifier are trainable
#2. in step 2 only the GAN Discriminator is trainable
#3. in step 3 none are trainable

In [3]:
#we will use wasserstein loss
#we will use gradient penalty for the discriminator
#we eill use Resnet50 as base model for the feature extractor CNN

In [4]:
experiment_name = 'mnist_adda'
version = 'v2'

#concat experiment name and version to get experiment id
experiment_id = experiment_name + '_' + version

model_path = 'saved_models/ADDA'

In [5]:
#GPU name
#
GPU_NAME = 'cuda:1'

In [6]:
#neceassary imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable, Function
# from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
import torch.backends.cudnn as cudnn

import numpy as np

#import utils
import os
import itertools
import time
import copy
import random
import math


In [7]:
#imports for visualizations
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [8]:
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image
from torchsummary import summary
#import tenserboard
from torch.utils.tensorboard import SummaryWriter

#initialize tensorboard writer
#create writer for tensorboard
writer = SummaryWriter(f'runs/'+experiment_id)

2022-11-09 18:45:28.738842: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [9]:
#enable cudnn
cudnn.benchmark = True
#cuda cache clear
torch.cuda.empty_cache()

#set random seed
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)



In [10]:
#device
device = torch.device(GPU_NAME if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [11]:
#defining the hyperparameters
BATCH = 50

EPOCHS = 5
NUM_EPOCHS_PRETRAINING = 5

#WHGAN parameters
NUM_EPOCHS_GAN = 100
CRITIC_ITERATIONS = 5
LEARNING_RATE_GAN = 1e-4
LAMBDA_GP = 10



IMAGE_SIZE = 224
CHANNELS_IMG = 3
NUM_CLASSES = 10


#parameters of ADAM optimizer
LEARNING_RATE = 0.001
BETA_1 = 0.9
BETA_2 = 0.999

#parameters of SGD optimizer with momentum
MOMENTUM = 0.9




#### Utility Functions


In [12]:
#write a function for making all parameters of a model non trainable or trainable based on require_grad
def freeze_unfreeze_model(model, require_grad = True):
    for param in model.parameters():
        param.requires_grad = require_grad


In [13]:
#save model
def save_model(model):
    #check if model path exists
    if not os.path.exists(model_path, name_to_save):
        os.makedirs(model_path)

    #we will save the model by the name of the experiment id 
    torch.save(model.state_dict(),  f'{model_path}/{experiment_id}+{name_to_save}.pth')

## Model

In [14]:
# we will have 3 different models : 
# 1. Feature Extractor CNN or Encoder, Source Feature Extractor CNN or Source Encoder AND Target Feature Extractor CNN or Target Encoder
# 2. Classifier
# 3. Discriminator

#### Base resnet50

In [15]:
#let us first build the feature extractor
#we input a resent50 model

#creating the model
weights = ResNet50_Weights.DEFAULT
#send weight sto gpu
# weights = weights.to(device)
#sending the model to GPU

base_resnet = resnet50(weights=weights).to(device)

In [16]:
#print model
base_resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [17]:
#we will change the first convolution layer to accept single channel image
#conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# #if CHANNELS_IMG == 1:
# if CHANNELS_IMG == 1:
#     base_resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# #change the last fully connected layer to output classes in NUM_CLASSES
# base_resnet.fc = nn.Linear(2048, NUM_CLASSES,  bias=True)

#write a function to changethe model based on number of channels and number of classes
def change_model(model, num_classes = NUM_CLASSES):
    
    model.fc = nn.Linear(2048, num_classes,  bias=True)
    return model


base_resnet = change_model(base_resnet, num_classes = NUM_CLASSES)
base_resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [18]:
#now we will write class for the feature extractor network
#we will pass the resnet50 model as the input to the class, and will use : nn.Sequential(*list(original_model.children())[:-2]) to get the feature extractor part of the model

class ENCODER_CNN(nn.Module):
    def __init__(self, base_model):
        super(ENCODER_CNN, self).__init__()
        # self.base_model = base_model
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-3])

    def forward(self, x):
        x = self.feature_extractor(x)
        return x

In [19]:
#now create an instance of the feature extractor and print the model
source_cnn = ENCODER_CNN(base_resnet).to(device)
source_cnn

ENCODER_CNN(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0

In [20]:
#print model summary
# summary(source_cnn, (CHANNELS_IMG, IMAGE_SIZE, IMAGE_SIZE))


In [21]:
#create tthe format for output of the feature extractor
#create a random vector of size (BATCH, CHANNELS_IMG, IMAGE_SIZE, IMAGE_SIZE)
x = torch.randn(BATCH, CHANNELS_IMG, IMAGE_SIZE, IMAGE_SIZE).to(device)
#pass the random vector through the feature extractor
x = source_cnn(x)
#check the output shape
x.shape


torch.Size([50, 1024, 14, 14])

In [22]:
#now we will store the output shape of the feature extractor
output_shape = x.shape
#make the first dimension as 1 and then remove it
output_shape = output_shape[1:]
output_shape

torch.Size([1024, 14, 14])

In [23]:
#
#now we will create the classifier, it will be same as the part of the resnet50 model after the feature extractor i.e. the last two layers of the resnet50 model
#also the number of classes will be 10 as we have 10 classes in the MNIST dataset, the number of classes is stored in NUM_CLASSES

class Classifier(nn.Module):
    def __init__(self, base_model):
        super(Classifier, self).__init__()
        self.avgpool= nn.Sequential(*list(base_model.children())[-3:-1])
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(*list(base_model.children())[-1:])
        #define a flatten layer
        

    def forward(self, x):
        #print the shape of the input
        # print("inside classifier: input shape",x.shape)
        x = self.avgpool(x)

        x = self.flatten(x)

        x = self.fc(x)
        return F.softmax(x)

In [24]:
classifier = Classifier(base_resnet).to(device)
classifier

Classifier(
  (avgpool): Sequential(
    (0): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [25]:
#print summary of the classifier
#input shape is the output shape of the feature extractor
# summary(classifier, output_shape)

In [26]:
#now we will create the domain classifier: named as DomainClassifier
#it will have the rest of the resnet model after the feature extractor and the classifier and will have an additional layer at the end to output the domain label: 0 for source and 1 for target
#it will have gradient reversal layer in between the feature extractor and the classifier, i.e the first layer of the domain classifier 
#it will be exactly same as ClassClassifier except for the last layer, which is not number of classes but 2 for domain labels and sigmoid activation function instead of softmax
#it will also do same [-2:] to get the classifier part of the model, first layer be gradient reversal layer

class Discriminator(nn.Module):
    def __init__(self, base_model):
        super(Discriminator, self).__init__()
        #first layer of the domain classifier be the gradient reversal layer

        self.avgpool = nn.Sequential(*list(base_model.children())[-3:-1])
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(nn.Linear(2048, 1, bias=True))
       
        
        #now add the last output layer
        # self.domain_classifier.add_module('domain_classifier_output', nn.Linear(2048, 2))
        #change the last layer to output 2 classes
        # self.fc = nn.Linear(2048, 2 , bias=True)

        #forward
    def forward(self, x):
        x = self.avgpool(x)
        x = self.flatten(x)
        #output the domain label
        x = self.fc(x)
        
        # x = x.view(-1, 2)
        return x

        

In [27]:
#create an instance of the domain classifier
discriminator = Discriminator(base_resnet).to(device)
discriminator

Discriminator(
  (avgpool): Sequential(
    (0): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=Fal

In [28]:
#print summary of the domain classifier
# summary(discriminator, output_shape)

In [29]:
# ##create a writer and pot all the model sto tensorboard
# writer_sourcecnn = SummaryWriter('runs/plot_oh_ADDA_models_sourcecnn')
# #plot the models
# #create a dummy input
# dummy_input = torch.rand(CHANNELS_IMG, IMAGE_SIZE, IMAGE_SIZE).unsqueeze(0).to(device)
# writer_sourcecnn.add_graph(source_cnn, dummy_input)
# # writer.add_graph(class_classifier,(2048, 1, 1))
# # writer.add_graph(domain_classifier,(2048, 1, 1))
# #close
# writer_sourcecnn.close()

In [30]:
# #write classifier

# writer_classifier = SummaryWriter('runs/plot_oh_ADDA_models_classifier')
# #plot the models
# #create a dummy input
# dummy_input = torch.rand(1, 1024, 14, 14).to(device)
# writer_classifier.add_graph(classifier, dummy_input)
# # writer.add_graph(class_classifier,(2048, 1, 1))
# # writer.add_graph(domain_classifier,(2048, 1, 1))
# #close
# writer_classifier.close()

In [31]:
# #discriminator
# writer_discriminator = SummaryWriter('runs/plot_oh_ADDA_models_discriminator')
# #plot the models
# #create a dummy input: torch.Size([50, 1024, 14, 14])
# dummy_input = torch.rand(1, 1024, 14, 14).to(device)
# writer_discriminator.add_graph(discriminator, dummy_input)
# # writer.add_graph(class_classifier,(2048, 1, 1))
# # writer.add_graph(domain_classifier,(2048, 1, 1))
# #close
# writer_discriminator.close()

## Data-Processing

In [32]:
preprocess = weights.transforms()

In [33]:
#define the transform for the dataset
transform_mnist_resnet = transforms.Compose(
    [
  
    # if torch tensor then leave as it is, else convert to tensor
    transforms.Lambda(lambda x: x if isinstance(x, torch.Tensor) else transforms.functional.to_tensor(x)),
    #

    #resize to 224x224
    transforms.Resize(IMAGE_SIZE),

    #check if channels are 1, then convert to 3 channels
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),

    transforms.Lambda(lambda x: preprocess(x)),

    #if channels are 3, then make them 1
    # transforms.Lambda(lambda x: x[0].unsqueeze(0) if x.shape[0] == 3 else x),
    
    # normalize
    transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

#### Dataset

In [34]:
# for training we will use MNIST dataset in pytorch library
#for testing we will use USPS dataset

#### train data - MNIST
#### test data - USPS

In [35]:
#load train data
train_data = datasets.MNIST(root='./data/', download=True, transform=transform_mnist_resnet) 
#load train data
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH, shuffle=True, num_workers=4)

In [36]:

#load test data
#USPS dataset
test_data = datasets.USPS(root='./data/', download=True, transform=transform_mnist_resnet)
#load test data
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH, shuffle=True, num_workers=4)

In [37]:
#print the length of train and test data
print(len(train_data))
#print the shape of train data
print(train_data[0][0].shape)


60000
torch.Size([3, 224, 224])


In [38]:
#print length of test data
print(len(test_data))
#print shape of test data
print(test_data[0][0].shape)


7291
torch.Size([3, 224, 224])


In [39]:
#print number of batches in train and test data
print(len(train_loader))
print(len(test_loader))

1200
146


## Training by Adversarial Discriminative Domain Adaptation (ADDA) 

#### STEP 1: Pre- training

In [40]:
 #function to return gradient  norm
#write a function to calculate the gradient penalty
def gradient_norm(dnn, current_batch):

    BATCH_SIZE, C, H, W = current_batch.shape
    #print batch size, c,h,w
    # print("batch size, c, h, w", BATCH_SIZE, C, H, W)
    if BATCH_SIZE%2==1:
        #remove the last element
        current_batch = current_batch[:-1]
    #if batch size is 0 , then just return
    if BATCH_SIZE==0:
        return 0
    
    half_batch = int(BATCH_SIZE / 2)
    # current_batch = current_batch.to(device)
    # current_batch = Variable(current_batch, requires_grad=True)
    #we select the first half of the batch
    first_half = current_batch[:half_batch]
    #we select the second half of the batch
    second_half = current_batch[half_batch:]
    #we create a random number between 0 and 1
    # alpha = torch.rand(half_batch, 1)
    #we expand the alpha to the size of the first half of the batch
    # alpha = alpha.expand(first_half.size())
    #we create alpha as a random number between 0 and 1 which will allow us to interpolate between the first half and the second half
    
    alpha = torch.rand(half_batch, 1, 1, 1).repeat(1, C, H, W)
    #we expand the alpha to the size of the first half of the batch
    # alpha = alpha.expand(first_half.size())


    #we move alpha to the device
    alpha = alpha.to(device)
    #we interpolate between the first half and the second half
    interpolates = alpha * first_half + ((1 - alpha) * second_half)
    #we move interpolates to the device
    interpolates = interpolates.to(device)
    # interpolates = interpolates
    #we create a variable of interpolates
    interpolates = Variable(interpolates, requires_grad=True)
    #we pass interpolates through the cnn
    disc_interpolates = dnn(interpolates)
    #we calculate the gradients
    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    #we calculate the gradient penalty
    # calculate gradient norm 
    gradients_norm = gradients.norm(2, dim=1)
    #mean of the gradient norm without subtracting 1 or lambda
    gradient_norm_mean = (gradients_norm **2).mean()
    #max of sqrt of the gradient norm without subtracting 1 or lambda
    # gradient_norm_max = (gradients_norm **2).max( dim=0, keepdim=True)[0]

    #delete the variables from the memory
    del first_half
    del second_half
    del alpha
    del interpolates
    del disc_interpolates
    del gradients
    del gradients_norm
    #cache the garbage
    torch.cuda.empty_cache()


    
    # gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()    #have to check this formula    / * LAMBDA
    #gradient penalty  should be max(0, gradient_penalty-1)
    #we return the gradient penalty
    return gradient_norm_mean
    # , gradient_norm_max

    


In [41]:
#in this approach, we will use the source_cnn and the classifier to train on the source data
#write a function for pretraining the source_cnn and the classifier: we will input the source_cnn and the classifier and the train_loader, and the number of epochs, and device
#we will use cross entropy loss for the classifier and ADAM optimizer for both the source_cnn and the classifier
#we will also use tensorboard to visualize the training process, and plot the loss and accuracy
#we will print the loss and accuracy after each epoch and also plot the loss and accuracy after each epoch
def pre_train(source_cnn, classifier, train_loader, epoch=NUM_EPOCHS_PRETRAINING, device=device):

    #we will store the source cnn in file at path experiment_id + temp_storage
    #and will load this ile anytime source_cnn is needed
    
    #define the loss function
    criterion = nn.CrossEntropyLoss()
    #define the optimizer for the source_cnn and the classifier in a list, and we use ADAM optimizer combined for both the source_cnn and the classifier
    optimizer = optim.Adam(list(source_cnn.parameters()) + list(classifier.parameters()), lr=LEARNING_RATE)

    #make both the source_cnn and the classifier in train mode
    source_cnn.train()
    classifier.train()

    #loop for each epoch
    epoch_tracker = 0
    total_loss = 0
    total = 0
    correct = 0
    batch_tracker = 0
    ep=0
    #we will add th loss for each batch in the epoch and then divide by the number of batches
    for ep in range(epoch):
        epoch_total = 0
        epoch_correct = 0
        epoch_total_loss = 0
        #loop for each batch
        
        for batch_idx, (data, target) in enumerate(train_loader):
            #send data to gpu
            data, target = data.to(device), target.to(device)
            #set the gradients to zero
            optimizer.zero_grad()
            #forward pass
            output = classifier(source_cnn(data))
            #calculate the loss
            loss = criterion(output, target)
            #calculate the gradients
            loss.backward()
            #update the weights

            #we get the gradient norm by sending the model as sequential of source_cnn and classifier
            #we will calculate the gradient norm
            gradient_n = gradient_norm(nn.Sequential(source_cnn, classifier), data)
            #we will add the gradient norm to the tensorboard
            writer.add_scalar('Pretraining Gradient Norm', gradient_n, batch_tracker)
            optimizer.step()
            #write the loss to tensorboard
            writer.add_scalar('Pretraining Training loss', loss, global_step=batch_tracker)

            #calculate the total loss
            total_loss += loss.item()
            #total epoch loss sum
            epoch_total_loss += loss.item()


            #calculate the accuracy
            #get the max value from the output
            _, predicted = torch.max(output.data, 1)
            #calculate the total number of labels
            temp_total = target.size(0)
            #calculate the correct predictions
            temp_correct = (predicted == target).sum().item()
            #add the total and correct predictions
            total += temp_total
            epoch_total += temp_total
            correct += temp_correct
            epoch_correct += temp_correct
            #calculate the accuracy
            epoch_accuracy = 100 * epoch_correct / epoch_total
            #write the accuracy to tensorboard
            writer.add_scalar('Pretraining Training accuracy', epoch_accuracy, global_step=batch_tracker)
            #print the loss and accuracy
            #and
            #print the gradient norm
            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.2f}%\tGradient Norm: {:.6f}'.format(
                ep, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(),
                epoch_accuracy, gradient_n))
            

            
            #print the loss
            # if batch_idx % log_interval == 0:
            #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #         ep, batch_idx * len(data), len(train_loader.dataset),
            #         100. * batch_idx / len(train_loader), loss.item()))
                
        
            #write the epoch loss to tensorboard
            #first average the loss over the batches in the epoch
            batch_tracker += 1
        epoch_loss = total_loss / len(train_loader)

        #write the loss to tensorboard
        writer.add_scalar('Pretraining  Training - Epoch loss', epoch_loss, global_step=ep)
        #calculate the accuracy
        epoch_accuracy = 100 * correct / total
        #write the accuracy to tensorboard
        writer.add_scalar('Pretraining  Training - Epoch accuracy', epoch_accuracy, global_step=ep)

        # #save the model after every epoch, the name be experiment_id_epoch
        # #wew will save in the folder saved_models
        # torch.save(model.state_dict(), 'saved_models/'+experiment_id+'_'+str(ep)+'.pth')
        # #we will also save the optimizer
        # torch.save(optimizer.state_dict(), 'saved_models/'+experiment_id+'_'+str(ep)+'_optimizer.pth')
        #we will save the best model till now based on loss
        #check if first epoch, then save the model anyway
        if ep == 0:
            #save the model
            torch.save(source_cnn.state_dict(), 'saved_models/'+experiment_id+'_source_cnn.pth')
            torch.save(classifier.state_dict(), 'saved_models/'+experiment_id+'_classifier.pth')
            #save the optimizer
            torch.save(optimizer.state_dict(), 'saved_models/'+experiment_id+'_optimizer.pth')
            #save the loss
            best_loss = epoch_total_loss
            #save the epoch
            best_epoch = ep
        #if not first epoch, then check if the loss is less than the best loss
        else:
            #if loss is less than the best loss, then save the model
            if epoch_total_loss < best_loss:
                #save the model
                torch.save(source_cnn.state_dict(), 'saved_models/'+experiment_id+'_source_cnn.pth')
                torch.save(classifier.state_dict(), 'saved_models/'+experiment_id+'_classifier.pth')
                #save the optimizer
                torch.save(optimizer.state_dict(), 'saved_models/'+experiment_id+'_optimizer.pth')
                #save the loss
                best_loss = epoch_total_loss
                #save the epoch
                best_epoch = ep

        epoch_tracker += 1

    #print the accuracy
    total_accuracy = 100 * correct / total
    print('Accuracy: ', total_accuracy)

    #close the tensorboard writer
    writer.close()
    #save the model with name experiment_id and then the last epoch
    torch.save(source_cnn.state_dict(), 'saved_models/'+experiment_id+'_source_cnn_'+str(ep)+'.pth')
    torch.save(classifier.state_dict(), 'saved_models/'+experiment_id+'_classifier_'+str(ep)+'.pth')
    #save the optimizer
    torch.save(optimizer.state_dict(), 'saved_models/'+experiment_id+'_optimizer_'+str(ep)+'.pth')
    

    #return the models
    return source_cnn, classifier

            


    

In [42]:
#call the function
source_cnn, classifier = pre_train(source_cnn, classifier,train_loader)

  return F.softmax(x)


Accuracy:  97.50466666666667


In [43]:
#now we write a function to return accuracy, given encoder_cnn, classifier and data_loader
def get_accuracy(encoder_cnn, classifier, data_loader):
    #set the model to eval mode
    encoder_cnn.eval()
    classifier.eval()
    #set the total and correct to zero
    total = 0
    correct = 0
    #iterate over the data
    for batch_idx, (data, target) in enumerate(data_loader):
        #send data to gpu
        data, target = data.to(device), target.to(device)
        #forward pass
        output = classifier(encoder_cnn(data))
        #get the max value from the output
        _, predicted = torch.max(output.data, 1)
        #calculate the total number of labels
        temp_total = target.size(0)
        #calculate the correct predictions
        temp_correct = (predicted == target).sum().item()
        #add the total and correct predictions
        total += temp_total
        correct += temp_correct
    #calculate the accuracy
    total_accuracy = 100 * correct / total
    #return the accuracy
    #make models train mode again
    encoder_cnn.train()
    classifier.train()
    return total_accuracy

In [44]:
#now we have the source_cnn and classifier trained
#we will  now never train the source_cnn again

#### STEP 2: Adversarial Adaptation

In [45]:
#now we will use WGAN like training for the target_cnn and discriminator
#the REAL DATA will come from the output of source_cnn on the source data
#the generated data will come from the output of target_cnn on the target data
#we will minimize the WGAN loss
#also we use gradient penalty

In [46]:
#defining gradient penalty
def gradient_penalty(critic, source, target, device=device):
    BATCH_SIZE, C, H, W = source.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = source * alpha + target * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [47]:
#we will define the function to train the target_cnn and discriminator
#it will take as arguments the target_cnn, discriminator, source_cnn, train_loader, test_loader, epochs
#we will alos plot the loss and accuracy, per batch and per epoch
# also we will plot the gradient penalty
def train_adapt_target(target_cnn, discriminator, source_cnn, train_loader, test_loader, epochs=NUM_EPOCHS_GAN, device=device):

    #save
    source_nn_filename = 'saved_models/ADDA/'+experiment_id + 'temp_storage' + "source_cnn.pt"
    torch.save(source_cnn.state_dict(), source_nn_filename)
    #we will use the Adam optimizer for both the target_cnn and discriminator
    # , but seperate
    #we will use the same learning rate for both
    optimizer_target_cnn = optim.Adam(target_cnn.parameters(), lr=LEARNING_RATE_GAN, betas=(0.0, 0.9) )
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_GAN, betas=(0.0, 0.9) )

    #make both models trainable
    target_cnn.train()
    discriminator.train()
    #now make source cnn non trainable
    #use freeze_unfreeze function
    freeze_unfreeze_model(source_cnn, False)
    source_cnn.eval()

    #define dumy variables for keeping track of accuracy, loss and iterations through the dataset
    step = 0
    epoch_tracker = 0
    batch_tracker = 0
    #define the best loss and best epoch
    best_loss = 0
    best_epoch = 0

    #loop through epochs
    ep=0
    for ep in range(epochs):
        epoch_total_loss = 0
        epoch_critics_loss = 0
        epoch_target_cnn_loss = 0
        #loop through the batches
        #loop through batches of source data and target data combined
        for batch_idx, (source_data, target_data) in enumerate(zip(train_loader, train_loader)):
            #get the source and target images and we do not need labels, nbut anyway we will get them
            source_images, source_labels = source_data
            target_images, target_labels = target_data
            #get batch size as min of source and target batch size
            batch_size = min(source_images.shape[0], target_images.shape[0])
            #make the batch size of source and target equal
            source_images = source_images[:batch_size]
            source_labels = source_labels[:batch_size]
            target_images = target_images[:batch_size]
            target_labels = target_labels[:batch_size]

            #move the data to gpu
            source_images, source_labels = source_images.to(device), source_labels.to(device)
            target_images, target_labels = target_images.to(device), target_labels.to(device)

            # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
            for _ in range(CRITIC_ITERATIONS):
                #generate the features of the target images
                target_features = target_cnn(target_images)
                #generate the features of the source images
                #load the source_cnn
                source_cnn.load_state_dict(torch.load(source_nn_filename))
                source_features = source_cnn(source_images)
                #get gradient penalty
                gp = gradient_penalty(discriminator, source_features, target_features)
                critic_source = discriminator(source_features).reshape(-1)
                critic_target = discriminator(target_features).reshape(-1)
                critic_loss = -(torch.mean(critic_source) - torch.mean(critic_target)) + LAMBDA_GP * gp
                #zero the gradients
                discriminator.zero_grad()
                #backpropagate the loss
                critic_loss.backward(retain_graph=True)
                #update the weights
                optimizer_discriminator.step()

            # Train target_cnn: min -E[critic(gen_fake)] <-> max E[critic(gen_fake)]
            critic_target = discriminator(target_features).reshape(-1)
            loss_target_cnn = -torch.mean(critic_target)
            #zero the gradients
            target_cnn.zero_grad()
            #backpropagate the loss
            loss_target_cnn.backward()
            #update the weights
            optimizer_target_cnn.step()

            #add losses to epoch losses
            epoch_total_loss += critic_loss.item() + loss_target_cnn.item()
            epoch_critics_loss += critic_loss.item()
            epoch_target_cnn_loss += loss_target_cnn.item()

            #we will plot the loss on tensorboard
            #we will plot the critic loss, target_cnn loss, gradient penalty
            writer.add_scalar('ADDA_Loss_Critic', critic_loss, global_step=batch_tracker)
            writer.add_scalar('ADDA_Loss_Target_CNN', loss_target_cnn, global_step=batch_tracker)
            writer.add_scalar('ADDA_Gradient_Penalty', gp, global_step=batch_tracker)

            #print losses after every 100 steps
            if step % 100 == 0:
                print(f"Epoch [{ep}/{epochs}] Batch {batch_idx}/{len(train_loader)} \
                      Loss D: {critic_loss:.4f}, loss G: {loss_target_cnn:.4f}, gp: {gp:.4f}")
            
                
            #increment the batch tracker
            batch_tracker += 1
                
        #print the epoch loss
        print(f"Epoch [{ep}/{epochs}] Loss D: {epoch_critics_loss:.4f}, loss G: {epoch_target_cnn_loss:.4f}")
        #add the epoch loss to tensorboard
        writer.add_scalar('ADDA_Epoch_Total_Loss', epoch_total_loss, global_step=ep)
        #critics loss
        writer.add_scalar('ADDA_Epoch_Loss_Critic', epoch_critics_loss, global_step=ep)
        #target cnn loss
        writer.add_scalar('ADDA_Epoch_Loss_Target_CNN', epoch_target_cnn_loss, global_step=ep)
        #every epoch we will save the model
        #save the model with name experiment_id and epoch
        torch.save(target_cnn.state_dict(), f"{experiment_id}_target_cnn_{ep}.pth")
        torch.save(discriminator.state_dict(), f"{experiment_id}_discriminator_{ep}.pth")
        #test the accuracy of the model on the test set
        test_accuracy = get_accuracy(target_cnn, classifier, test_loader)
        #make classifier non trainable
        freeze_unfreeze_model(classifier, False)
        classifier.eval()
        #make model trainable
        target_cnn.train()
        #print the test accuracy
        print(f"Epoch [{ep}/{epochs}] Test Accuracy: {test_accuracy:.4f}")
        #add the test accuracy to tensorboard
        writer.add_scalar('ADDA_Test_Accuracy_Target_CNN', test_accuracy, global_step=ep)  
        


    #return the target cnn and discriminator
    return target_cnn, discriminator



In [48]:
#we will now train the target cnn and discriminator
#let us create target cnn with same weights as trained source cnn
# target_cnn = copy.deepcopy(source_cnn)
#but both be different and changes in one will not affect the other
#create target cnn as same weights as source cnn
target_cnn = copy.deepcopy(source_cnn)
#make target cnn trainable
freeze_unfreeze_model(target_cnn, True)
#make source cnn non trainable
freeze_unfreeze_model(source_cnn, False)
#make discriminator trainable
freeze_unfreeze_model(discriminator, True)

In [49]:
#get accuracy of source cnn on source data
source_cnn_accuracy = get_accuracy(source_cnn, classifier, train_loader)
print(f"Source CNN Accuracy on Source Data: {source_cnn_accuracy:.4f}")

  return F.softmax(x)


Source CNN Accuracy on Source Data: 98.7517


In [50]:
#get accuracy of source cnn on target data
target_cnn_accuracy = get_accuracy(source_cnn, classifier, test_loader)
#print accuracy
print(f"Source CNN Accuracy on Target Data: {target_cnn_accuracy:.4f}")

  return F.softmax(x)


Source CNN Accuracy on Target Data: 74.9966


In [51]:
#get accuracy of target cnn on source data
target_cnn_accuracy_source = get_accuracy(target_cnn, classifier, train_loader)
#print accuracy
print(f"Target CNN Accuracy on Source Data: {target_cnn_accuracy_source:.4f}")

  return F.softmax(x)


Target CNN Accuracy on Source Data: 98.7517


In [52]:
#get accuracy of target cnn on target data
target_cnn_accuracy_target = get_accuracy(target_cnn, classifier, test_loader)
#print accuracy
print(f"Target CNN Accuracy on Target Data: {target_cnn_accuracy_target:.4f}")

  return F.softmax(x)


Target CNN Accuracy on Target Data: 74.9966


In [53]:

#train the target cnn and discriminator
target_cnn, discriminator = train_adapt_target(target_cnn, discriminator, source_cnn, train_loader, test_loader)


Epoch [0/100] Batch 0/1200                       Loss D: 9.7407, loss G: -0.4197, gp: 0.9844
Epoch [0/100] Batch 1/1200                       Loss D: 9.7744, loss G: -0.5794, gp: 0.9832
Epoch [0/100] Batch 2/1200                       Loss D: 9.7840, loss G: -0.7735, gp: 0.9824
Epoch [0/100] Batch 3/1200                       Loss D: 9.6882, loss G: -0.8442, gp: 0.9806
Epoch [0/100] Batch 4/1200                       Loss D: 9.7207, loss G: -1.0027, gp: 0.9793
Epoch [0/100] Batch 5/1200                       Loss D: 9.7015, loss G: -1.1043, gp: 0.9765
Epoch [0/100] Batch 6/1200                       Loss D: 9.7313, loss G: -1.1699, gp: 0.9771
Epoch [0/100] Batch 7/1200                       Loss D: 9.6760, loss G: -1.2930, gp: 0.9737
Epoch [0/100] Batch 8/1200                       Loss D: 9.8270, loss G: -1.6139, gp: 0.9736
Epoch [0/100] Batch 9/1200                       Loss D: 9.6413, loss G: -1.6509, gp: 0.9703
Epoch [0/100] Batch 10/1200                       Loss D: 9.5651, loss

  return F.softmax(x)


Epoch [0/100] Test Accuracy: 5.8703
Epoch [1/100] Batch 0/1200                       Loss D: 65.5441, loss G: -28.9736, gp: 7.4469
Epoch [1/100] Batch 1/1200                       Loss D: 89.2027, loss G: -35.5695, gp: 9.3902
Epoch [1/100] Batch 2/1200                       Loss D: 104.9088, loss G: -54.6721, gp: 8.6851
Epoch [1/100] Batch 3/1200                       Loss D: 167.8103, loss G: -44.4752, gp: 16.1114
Epoch [1/100] Batch 4/1200                       Loss D: 148.1119, loss G: -62.7957, gp: 12.1709
Epoch [1/100] Batch 5/1200                       Loss D: 183.0674, loss G: -58.6810, gp: 16.4319
Epoch [1/100] Batch 6/1200                       Loss D: 152.9984, loss G: -65.2952, gp: 13.6258
Epoch [1/100] Batch 7/1200                       Loss D: 199.0995, loss G: -83.2139, gp: 14.9473
Epoch [1/100] Batch 8/1200                       Loss D: 223.1862, loss G: -80.5721, gp: 18.6216
Epoch [1/100] Batch 9/1200                       Loss D: 257.5622, loss G: -94.0308, gp: 19.3143

In [None]:
#we will now test the target cnn
#we will test the target cnn on the test data
#the classifier will be the same as trained
#call the function to test the target cnn accuracy
get_accuracy(target_cnn, classifier, test_loader)
# print accuracy

  return F.softmax(x)


14.401316691811823

In [None]:
#get the accuracy of source cnn on source data
get_accuracy(source_cnn, classifier, train_loader)

  return F.softmax(x)


16.363333333333333

In [None]:
#get accuracy of source cnn on target data
get_accuracy(source_cnn, classifier, test_loader)

  return F.softmax(x)


7.69441777533946