# Imports

In [None]:
import random
import gc
import warnings
import json
import os
from tqdm import tqdm
import logging
import PIL
import pandas as pd
from PIL import Image
import pickle
from glob import glob
import random

from skimage import io, transform
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import h5py
from mpl_toolkits.axes_grid1 import ImageGrid
from tqdm import tqdm
import idx2numpy
from numpyencoder import NumpyEncoder
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.pyplot import imshow, imsave

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import MNIST
from torch.autograd import Variable
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torchsummary
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import datasets, models, transforms

warnings.filterwarnings("ignore")

<h1>Diffusion</h1>

<h3>Supporing Functions and HyperParameters</h3>

In [None]:
## Diffusion Hyper Parameters
BATCH_SIZE =  100
IMG_SIZE = 64
device = 'cuda:1'


def save_object(obj, filename):
    with open(filename, 'wb') as outp:  # Overwrites any existing file.
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)

def load_object(filename):
    with open(filename, 'rb') as fp:
        out = pickle.load(fp)
    return out

class CelebA_Dataset(Dataset):
    def __init__(self, data_list , img_dir, transform=None, target_transform=None):
        self.img_titles = data_list
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_titles)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_titles [idx])
        image = PIL.Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

class Bitmoji_Dataset(Dataset):
    def __init__(self, path_list , transform=None, target_transform=None):
        self.path_list = path_list
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.path_list)

    def __getitem__(self, idx):
        img_path = self.path_list[idx]
        image = PIL.Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

class Bitmoji_Categorical_Dataset(Dataset):
    def __init__(self, path_list , transform=None, target_transform=None):
        self.path_list = path_list
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.path_list)

    def __getitem__(self, idx):
        img_path,label = self.path_list[idx]
        label = torch.tensor(label,dtype = torch.long)
        image = PIL.Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image,label

def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, nrow = 10)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


def get_CelebA_data():
    data = pd.read_csv('/home/hiren/Apoorv Pandey/ADRL/Ass1/list_eval_partition.csv')
    

    train_data,val_data,test_data = data[data['partition']==0]['image_id'].to_list(),data[data['partition']==1]['image_id'].to_list(),\
                                    data[data['partition']==2]['image_id'].to_list()


    
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    
    train_dataset = CelebA_Dataset(train_data,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)
    val_dataset = CelebA_Dataset(val_data,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)
    test_dataset = CelebA_Dataset(test_data,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba')

    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    val_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True)
    return train_loader,val_loader,test_loader


def get_Celeb_A_categorical_data():
    celeb_attrs = pd.read_csv('/home/hiren/Apoorv Pandey/ADRL/Ass1/list_attr_celeba.csv')
    attrs_only = celeb_attrs.loc[:, celeb_attrs.columns != 'image_id']
    attrs_only = attrs_only.applymap(lambda x: 1 if x==1 else 0)
    attrs_only = attrs_only.sum(axis = 0, skipna = True)
    attrs_list = np.array(attrs_only.values.tolist())
    
    attrs_list = np.argsort(attrs_list)

    top_10_attrs = attrs_list[0:10]

    path_list = []
    for label in top_10_attrs:
        path = celeb_attrs[celeb_attrs.iloc[:,label+1] == 1]['image_id'].values.tolist()
        #print(len(path))
        path_list = path_list + [[p,label] for p in path]
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_path_list,val_path_list = path_list[0:int(0.8*len(path_list))],path_list[int(0.8*len(path_list)):]
    
    train_dataset = Bitmoji_Categorical_Dataset(train_path_list,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)
    val_dataset =  Bitmoji_Categorical_Dataset(val_path_list,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)            
    
    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True)
    return train_loader,val_loader
    

def get_Bitmoji_data():
    
    image_path_list = glob.glob('/home/hiren/adrl/1/bitmojis/*.png')
    train_images_path,val_images_path = image_path_list[0:int(0.8*len(image_path_list))],image_path_list[int(0.8*len(image_path_list)):]
    
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = Bitmoji_Dataset(train_images_path,transform)
    val_dataset = Bitmoji_Dataset(val_images_path,transform)
    
    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True)
    
    return train_loader,val_loader

<h2>Diffusion Model</h2>

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device=device):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t,mode = 'train'):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        if mode == 'eval':
            ## Done so that initial image is not corrupted while running denoising steps
            if t[0] == self.noise_steps-1: 
                epsilon = torch.zeros_like(x)
                sqrt_alpha_hat = 1.
                sqrt_one_minus_alpha_hat = 0.
            else:
                epsilon = torch.randn_like(x)
        
        print(sqrt_alpha_hat, sqrt_one_minus_alpha_hat)
        

        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x
    
    def sample_at_intervals(self,model,n,t=1000):
        model.eval()
        stepsize = self.noise_steps//t
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps,stepsize)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

<h3>Supporting Models</h3>

In [None]:
class Convolutional_Block(nn.Module):
    def __init__(self,in_ch,out_ch,mid_ch=None,residual = False):
        super().__init__()
        self.residual = residual
        if not mid_ch:
            mid_ch = out_ch
        self.conv = nn.Sequential(nn.Conv2d(in_ch,mid_ch,kernel_size=3,padding = 1,bias=False),\
                                nn.GroupNorm(1, mid_ch),nn.GELU(),\
                                nn.Conv2d(mid_ch,out_ch,kernel_size=3,padding = 1,bias=False),\
                                 nn.GroupNorm(1, out_ch),nn.GELU())

    def forward(self,x):
        if self.residual:
            return nn.GELU()(x + self.conv(x))
        else:
            return self.conv(x)


class EncoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            Convolutional_Block(in_channels, in_channels, residual=True),
            Convolutional_Block(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )
        

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class DecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            Convolutional_Block(in_channels, in_channels, residual=True),
            Convolutional_Block(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device=device):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = Convolutional_Block(c_in, 64)
        self.down1 = EncoderLayer(64, 128)
        self.down2 = EncoderLayer(128, 256)
        self.down3 = EncoderLayer(256, 256)


        self.bot1 = Convolutional_Block(256, 512)
        self.bot2 = Convolutional_Block(512, 512)
        self.bot3 = Convolutional_Block(512, 256)

        self.up1 = DecoderLayer(512, 128)
        self.up2 = DecoderLayer(256, 64)
        self.up3 = DecoderLayer(128, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.up2(x, x2, t)
        x = self.up3(x, x1, t)
        output = self.outc(x)
        return output

<h3>Training</h3>

In [None]:
def train(data_name,num_epochs=100):
    if data_name == 'CelebA':
        train_loader,_,_ = get_CelebA_data()
    else:
        train_loader,_ = get_Bitmoji_data()
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)
    mse = nn.MSELoss()
    diffusion = Diffusion(noise_steps = 500,img_size=IMG_SIZE)
    writer = SummaryWriter(f'runs/{data_name}/Diffusion/T=500')
    l = len(train_loader)

    for epoch in range(num_epochs):
        pbar = tqdm(train_loader)
        for i, images in enumerate(pbar):
            torch.cuda.empty_cache()
            images = images.to(device)
            #print(f'Images shape = {images.size()}')
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            
            x_t, noise = diffusion.noise_images(images, t)
            #print(f'x_t shape = {x_t.size()} t shape = {t.size()}' )
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            writer.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        sampled_images = diffusion.sample(model, n=100)
        save_images(sampled_images, f'/home/hiren/Apoorv Pandey/ADRL/Ass2/CelebA_Plots_Diffusion/Epoch_{epoch}.jpg')
        torch.save(model.state_dict(),f'/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model_{data_name}.pt' )

In [None]:
train('CelebA',100) ## Training on Celeb A Dataset
train('Bitmoji',100) ## Training on Bitmoji Dataset

<h3>FID Calculations</h3>

In [None]:
def get_FID(model,path,timesteps = 1000,data_name='CelebA'):
    
    if not os.path.exists(path):
        diffusion = Diffusion(img_size=IMG_SIZE,device=device)
        sampled_images = diffusion.sample_at_intervals(model, 1000,timesteps).type(torch.uint8).to('cpu')
        save_object(sampled_images,path)
    else:
        sampled_images = load_object(path)
        
    if data_name=='CelebA':
        real_loader,_,_ = get_CelebA_data()
    else:
        real_loader,_ = get_Bitmoji_data()
        
    real_images = next(iter(real_loader)).to('cpu')
    for i in range(9):
        curr_batch = next(iter(real_loader)).to('cpu')
        real_images = torch.cat((real_images,curr_batch),dim=0)

    real_images = real_images.reshape(1000,3,64,64)
    real_images = (real_images.clamp(-1, 1) + 1) / 2
    real_images = (real_images * 255).type(torch.uint8)
    
    sampled_images = sampled_images.type(torch.uint8).to('cpu')
    fid = FrechetInceptionDistance(feature=2048)
    # generate two slightly overlapping image intensity distributions
    fid.update(real_images, real=True)
    fid.update(sampled_images, real=False)
    fid_score = fid.compute()
    return fid_score

<h3>Generating images</h3>

In [None]:
def generate_noisy_images(model,real_path,fake_path):
    diffusion = Diffusion(img_size=IMG_SIZE,device=device)
    random_10_samples = diffusion.sample(model, n=10).to('cpu')
    save_images(random_10_samples,real_path)
    noisy_images = torch.empty((10,10,3,64,64),dtype = torch.uint8)
    for i in range(0,10):
        t = (torch.ones(10)*(999-i) ).long().to(device)
        noisy_image,eps =  diffusion.noise_images(random_10_samples.type(torch.float).to(device), t)
        noisy_image = (noisy_image.clamp(-1, 1) + 1) / 2
        noisy_image = (noisy_image * 255).type(torch.uint8)
        noisy_images[:,i,:,:,:] = noisy_image

    noisy_images = noisy_images.view(-1,3,64,64)
    save_images(noisy_images,fake_path)

<h3>Computing Celeb A FID Scores</h3>

In [None]:

model = UNet().to(device)
model.load_state_dict(torch.load('/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model.pt'))
celebA_fid_1000 = get_FID(model,'./Fake_CelebA.pkl',1000)
celebA_fid_500 = get_FID(model,'./Fake_CelebA_500.pkl',500)
celebA_fid_100 = get_FID(model,'./Fake_CelebA_100.pkl',100)
with open('./FID_Scores_Q1.txt','a') as f:
    print(f'Time Steps: 1000 Celeb A  : {celebA_fid_1000}',file = f)
    print(f'Time Steps: 500 Celeb A  : {celebA_fid_500}',file = f)
    print(f'Time Steps: 100 Celeb A  : {celebA_fid_100}',file = f)

<h3>Computing Bitmoji  FID Scores</h3>

In [None]:
model = UNet().to(device)
model.load_state_dict(torch.load('/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model_Bitmoji.pt'))
bitmoji_fid_1000 = get_FID(model,'./Fake_Bitmoji.pkl',1000,'Bitmoji')
bitmoji_fid_500 = get_FID(model,'./Fake_Bitmoji_500.pkl',500,'Bitmoji')
bitmoji_fid_100 = get_FID(model,'./Fake_Bitmoji_100.pkl',100,'Bitmoji')
with open('./FID_Scores_Q1.txt','a') as f:
    print(f'Time Steps: 1000 Bitmoji : {bitmoji_fid_1000}',file = f)
    print(f'Time Steps: 500 Bitmoji : {bitmoji_fid_500}',file = f)
    print(f'Time Steps: 100 Bitmoji : {bitmoji_fid_100}',file = f)

<h3>Sampling 100 generated celeb A images</h3>

In [None]:
sampled_images_1000 = load_object('./Fake_CelebA.pkl')[0:100]
sampled_images_500 = load_object('./Fake_CelebA_500.pkl')[0:100]
sampled_images_100 = load_object('./Fake_CelebA_100.pkl')[0:100]

save_images(sampled_images_1000,'./Fake_CelebA_1000.png')
save_images(sampled_images_500,'./Fake_CelebA_500.png')
save_images(sampled_images_100,'./Fake_CelebA_100.png')

<h3>Sampling 100 generated BitMoji images</h3>

In [None]:
sampled_images_1000 = load_object('./Fake_Bitmoji.pkl')[0:100]
sampled_images_500 = load_object('./Fake_Bitmoji_500.pkl')[0:100]
sampled_images_100 = load_object('./Fake_Bitmoji_100.pkl')[0:100]

save_images(sampled_images_1000,'./Fake_Bitmoji_1000.png')
save_images(sampled_images_500,'./Fake_Bitmoji_500.png')
save_images(sampled_images_100,'./Fake_Bitmoji_100.png')

<h3>Generate Denoising Images for Celeb A and Bitmoji</h3>

In [None]:
model = UNet().to(device)
model.load_state_dict(torch.load('/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model.pt'))

generate_noisy_images(model,'./CelebA_10_images.jpg','./CelebA_noisy_Images.jpg')


model = UNet().to(device)
model.load_state_dict(torch.load('/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model_Bitmoji.pt'))

generate_noisy_images(model,'./BItmoji_10_images.jpg','./Bitmoji_noisy_Images.jpg')

<h2>Classifier Guidance</h2>

<h3>Hyperparameters and Supporting functions</h3>

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 64
BATCH_SIZE = 64
T = 1500

class CelebA_Categorical_Dataset(Dataset):
    def __init__(self, img_titles ,img_dir, transform=None, target_transform=None):
        self.img_titles = img_titles
        self.transform = transform
        self.img_dir = img_dir
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_titles)

    def __getitem__(self, idx):
        img_path,label = os.path.join(self.img_dir, self.img_titles[idx][0]),self.img_titles[idx][1]
        label = torch.tensor(label,dtype = torch.long)
        image = PIL.Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image,label

def get_Celeb_A_categorical_data():
    celeb_attrs = pd.read_csv('/home/hiren/Apoorv Pandey/ADRL/Ass1/list_attr_celeba.csv')
    attrs_only = celeb_attrs.loc[:, celeb_attrs.columns != 'image_id']
    attrs_only = attrs_only.applymap(lambda x: 1 if x==1 else 0)
    attrs_only.to_csv('/home/hiren/Apoorv Pandey/ADRL/Ass1/attrs_only.csv')
    attrs_only = attrs_only.sum(axis = 0, skipna = True)
    #print(attrs_only)
    attrs_list = np.array(attrs_only.values.tolist())
    
    attrs_list = np.argsort(attrs_list)
    #print(attrs_list)
    top_10_attrs = attrs_list[0:10]
    top_10_attrs_names = []
    for attr in top_10_attrs:
        top_10_attrs_names.append(celeb_attrs.columns[attr+1])
    print(f'Selected attributes : {top_10_attrs_names}')
    path_list = []
    for i,label in enumerate(top_10_attrs):
        path = celeb_attrs[celeb_attrs.iloc[:,label+1] == 1]['image_id'].values.tolist()
        #print(len(path))
        path_list = path_list + [[p,i] for p in path]  ## pass labels in range 0 to 9 with 0 being most frequent one
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_path_list,val_path_list = path_list[0:int(0.8*len(path_list))],path_list[int(0.8*len(path_list)):]
    #print(train_path_list)
    train_dataset = CelebA_Categorical_Dataset(train_path_list,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)
    val_dataset =  CelebA_Categorical_Dataset(val_path_list,'/home/hiren/Apoorv Pandey/ADRL/Ass1/img_align_celeba/img_align_celeba',transform)            
    
    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True)
    return train_loader,val_loader

<h3>Conditional Diffusion</h3>

In [None]:
class Conditional_Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size

        self.beta = self.noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n,labels,gamma=3):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(device)
                predicted_noise = model(x, t,labels)
                if gamma > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = uncond_predicted_noise + gamma*(predicted_noise - uncond_predicted_noise)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

class UNet_Conditional(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes = None):
        super().__init__()
        self.time_dim = time_dim
        self.num_classes = num_classes
        self.inc = Convolutional_Block(c_in, 64)
        self.down1 = EncoderLayer(64, 128)
        self.down2 = EncoderLayer(128, 256)
        self.down3 = EncoderLayer(256, 256)

        self.bot1 = Convolutional_Block(256, 512)
        self.bot2 = Convolutional_Block(512, 512)
        self.bot3 = Convolutional_Block(512, 256)

        self.up1 = DecoderLayer(512, 128)
        self.up2 = DecoderLayer(256, 64)
        self.up3 = DecoderLayer(128, 64)

        self.outc = nn.Conv2d(64, c_out, kernel_size=1)
        
        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes,time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t,labels):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        if labels is not None:
            t += self.label_emb(labels)
        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)

        x = self.up2(x, x2, t)

        x = self.up3(x, x1, t)
        output = self.outc(x)
        return output

<h3>Training</h3>

In [None]:
train_loader,val_loader = get_Celeb_A_categorical_data()
model = UNet_Conditional(num_classes=10).to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
mse = nn.MSELoss()
diffusion = Conditional_Diffusion(img_size=IMG_SIZE)
output_file = './UNet_Conditional_logs.txt'
logger = SummaryWriter('/home/hiren/Apoorv Pandey/ADRL/Ass2/runs/CelebA_Conditional')
l = len(train_loader)
#model.load_state_dict(torch.load('/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model_CelebA_Conditional.pt'))
for epoch in range(100):
    pbar = tqdm(train_loader)
    for i, (images,labels) in enumerate(pbar):
        torch.cuda.empty_cache()
        images = images.to(device)
        labels = labels.to(device)
        #print(f'Images shape = {images.size()}')
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        if np.random.random()<0.1:  ## 10% of time use unconditional generation
            labels = None
        x_t, noise = diffusion.noise_images(images, t)
        #print(f'x_t shape = {x_t.size()} t shape = {t.size()}' )
        predicted_noise = model(x_t, t,labels)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(MSE=loss.item())
        logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
        with open(output_file, 'a') as f:
            print(f'Epoch :{epoch}  MSE : {loss.item()} global_step: {epoch * l + i}', file=f) 
    
    class_labels = [[i]*10 for i in range(10)]
    class_labels = torch.tensor(class_labels,dtype = torch.long).to(device)
    class_labels = class_labels.reshape(-1)
    sampled_images = diffusion.sample(model, n=100,labels = class_labels)
    save_images(sampled_images, f'/home/hiren/Apoorv Pandey/ADRL/Ass2/CelebA_Categorical/Epoch_{epoch}.jpg')
    torch.save(model.state_dict(),'/home/hiren/Apoorv Pandey/ADRL/Ass2/Diffusion_Model_CelebA_Conditional.pt' )

<h1>Domain Adaptation</h1>

<h2>Basic ResNet</h2>

In [None]:
class uspsDataset(Dataset):
    def __init__(self, path, tr = True):
        
        with h5py.File(path, 'r') as hf:
            if(tr):
                train = hf.get('train')
                x = train.get('data')[:].reshape(7291,16,16)
                self.X = np.array([transform.resize(i, (32, 32), anti_aliasing=True) for i in x])
                self.y = train.get('target')[:]
            else:
                test = hf.get('test')
                x = test.get('data')[:].reshape(2007,16,16)
                self.X = np.array([transform.resize(i, (32, 32), anti_aliasing=True) for i in x])
                self.y = test.get('target')[:]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
class mnistDataset(Dataset):
    def __init__(self, path, tr = True):
            if(tr):
                x = idx2numpy.convert_from_file('mnist/train-images.idx3-ubyte')
                self.X = np.array([transform.resize(i, (32, 32), anti_aliasing=True) for i in x])
                #self.X = transform.resize(idx2numpy.convert_from_file('mnist/train-images.idx3-ubyte'), (60000, 1, 32, 32), anti_aliasing=True)
                self.y = idx2numpy.convert_from_file('mnist/train-labels.idx1-ubyte')
            else:
                x = idx2numpy.convert_from_file('mnist/t10k-images.idx3-ubyte')
                self.X = np.array([transform.resize(i, (32, 32), anti_aliasing=True) for i in x])
                #self.X = transform.resize(idx2numpy.convert_from_file('mnist/t10k-images.idx3-ubyte'), (10000, 1, 32, 32), anti_aliasing=True)
                self.y = idx2numpy.convert_from_file('mnist/t10k-labels.idx1-ubyte')

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
class clipRealData(Dataset):
    def __init__(self, path):
        self.X = np.load(path+'/X.npy', allow_pickle=True)
        self.y = np.load(path+'/y.npy', allow_pickle=True)
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = stride, padding = 0),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 2, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv3 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 1, stride = stride, padding = 0),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.skipconv = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1, bias=False),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.bn = nn.BatchNorm2d(out_channels)
        
        
    def forward(self, x):
        out = self.conv1(x)
        try:
            out = self.conv2(out)
            out = self.conv3(out)
        except ValueError:
            pass
        res = self.bn(self.skipconv(x))
        return out+res

In [None]:
class ResNet(nn.Module):
    def __init__(self, imsize, num_classes = 10, mnist = True):
        super(ResNet, self).__init__()
        self.imszie = imsize
        self.mnist = mnist
        if mnist:
            self.layer1 = ResidualBlock(1, 2**4)
        else:
            self.layer1 = ResidualBlock(3, 2**4)
        self.layer2 = ResidualBlock(2**4, 2**5)
        self.layer3 = ResidualBlock(2**5, 2**6)
        self.layer4 = ResidualBlock(2**6, 2**7)
        self.layer5 = ResidualBlock(2**7, 2**8)
        self.layer6 = ResidualBlock(2**8, 2**9)
        self.layer7 = ResidualBlock(2**9, 2**10)
        self.imsize = imsize
        if(imsize==64):
            self.fc = nn.Linear(2**9, num_classes)
        else:
            self.fc = nn.Linear(2**8, num_classes)
    def forward(self, x):
        bsize = x.shape[0]
        if self.mnist:
            x = x.view(bsize,1,self.imsize, self.imsize)
        else:
            x = x.view(bsize,3,self.imsize, self.imsize)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        if(self.imsize==64):
            x = self.layer6(x)
        return F.softmax(self.fc(x.squeeze()))

In [None]:
def eval(model, data, bsize=3000, path=None):
    dataloader = DataLoader(data, bsize)
    true = torch.empty(0).to(device)
    pred = torch.empty(0).to(device)
    for ip,t in dataloader:
        #ip = torch.cuda.FloatTensor(data.X)
        #true = data.y#.cpu().detach().numpy()
        op = model(ip.float().to(device))
        p = torch.argmax(op, axis = 1)
        true = torch.cat((true, t.to(device)), dim=0)
        pred = torch.cat((pred, p.to(device)), dim=0)
        del(op, ip)
        torch.cuda.empty_cache()
        gc.collect()
    true = true.cpu().detach().numpy()
    pred = pred.cpu().detach().numpy()
    if path:
        x = {'acc' : accuracy_score(true, pred), 'f1' : precision_recall_fscore_support(true, pred, average='macro')[2], 'misses':list(np.where(pred!=true)[0])}
        y = json.dumps(x, indent=4, cls=NumpyEncoder)
        with open(path, 'w') as outfile:
            outfile.write(y)
    
    return accuracy_score(true, pred), precision_recall_fscore_support(true, pred, average='macro')[2], np.where(pred!=true)[0]

In [None]:
def eval_cycleGAN(model, gen, data, bsize=256, path=None):
    dataloader = DataLoader(data, bsize)
    true = torch.empty(0).to(device)
    pred = torch.empty(0).to(device)
    for ip,t in dataloader:
        #ip = torch.cuda.FloatTensor(data.X)
        #true = data.y#.cpu().detach().numpy()
        ip = gen(ip.float().cuda())
        op = model(ip.float().to(device))
        p = torch.argmax(op, axis = 1)
        true = torch.cat((true, t.to(device)), dim=0)
        pred = torch.cat((pred, p.to(device)), dim=0)
        del(op, ip)
        torch.cuda.empty_cache()
        gc.collect()
    true = true.cpu().detach().numpy()
    pred = pred.cpu().detach().numpy()
    
    return accuracy_score(true, pred), precision_recall_fscore_support(true, pred, average='macro')[2]

In [None]:
seed = 369
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

params = {
    "bsize" : 5000,# Batch size during training.
    'imsize' : 32,# Spatial size of training images. All images will be resized to this size during preprocessing.
    'nepochs' : 15,#Number of training epochs.
    'lr' : 0.002,#Learning rate for optimizers
    'nclasses' : 10, #number of classes
    'save_epoch' : 20}# Save step.

device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")

uspsTrain = uspsDataset('usps/usps.h5')
uspsTrainLoader = DataLoader(uspsTrain, params['bsize'], shuffle=True)
uspsTest = uspsDataset('usps/usps.h5', False)
uspsTestLoader = DataLoader(uspsTest, params['bsize'], shuffle=True)
mnistTrain = mnistDataset('')
mnistTrainLoader = DataLoader(mnistTrain, params['bsize'], shuffle=True)
mnistTest = mnistDataset('', False)
mnistTestLoader = DataLoader(mnistTest, params['bsize'], shuffle=True)

In [None]:
model = ResNet(params['imsize'], params['nclasses'], True).to(device)
model.apply(weights_init)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])  
train_loss, domain_accuracy, cross_domain_accuracy = [],[],[]
for epoch in range(params['nepochs']):
    for images, labels in uspsTrainLoader:  
        images = images.to(device)
        labels = labels.type(torch.LongTensor).to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()
    domain_accuracy.append(eval(model, uspsTest)[0])
    cross_domain_accuracy.append(eval(model, mnistTrain)[0])
    print ('Epoch [{}/{}], Loss: {:.4f}, Test accuracy : {:.4}' 
                   .format(epoch+1, params['nepochs'], loss.item(), domain_accuracy[-1]))
torch.cuda.empty_cache()
gc.collect()
eval(model, mnistTrain, path='output/mnistTrain.json')
eval(model, mnistTest, path='output/mnistTest.json')
torch.save({
            'model' : model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'params' : params
            }, 'output/trained_on_usps_resnet.pth')

del(model, optimizer)
torch.cuda.empty_cache()
gc.collect()

In [None]:
plt.plot(np.arange(len(train_loss)), np.array(train_loss))
plt.xlabel("Iterations")
plt.title("Iteration vs loss(resnet trained on USPS)")
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(len(domain_accuracy)), np.array(domain_accuracy),color='r', label='Domain Accuracy')
plt.plot(np.arange(len(domain_accuracy)), np.array(cross_domain_accuracy), color='g', label='Cross Domainn Accuracy')
  
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Epochs vs Accuracy(ResNet trained on USPS)")
plt.legend()
plt.show()

In [None]:
model = ResNet(params['imsize'], params['nclasses'], True).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())  
train_loss, domain_accuracy, cross_domain_accuracy = [],[],[]
for epoch in range(params['nepochs']):
    for images, labels in mnistTrainLoader:  
        images = images.float().to(device)
        labels = labels.type(torch.LongTensor).to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        domain_accuracy.append(eval(model, mnistTest)[0])
        cross_domain_accuracy.append(eval(model, uspsTrain)[0])

        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}, Test accuracy : {:.4}' 
                   .format(epoch+1, params['nepochs'], loss.item(), domain_accuracy[-1]))

eval(model, uspsTrain, path='output/uspsTrain.json')
eval(model, uspsTest, path='output/uspsTest.json')
torch.save({
            'model' : model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'params' : params
            }, 'output/trained_on_mnist_resnet.pth')
del(model, optimizer, mnistTest, mnistTrain, mnistTrainLoader, mnistTestLoader, uspsTest, uspsTrain, uspsTrainLoader, uspsTestLoader)
torch.cuda.empty_cache()
gc.collect()
print(torch.cuda.memory_allocated())

In [None]:
plt.plot(np.arange(len(train_loss)), np.array(train_loss))
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Iteration vs loss(resnet trained on MNIST)")
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(len(domain_accuracy)), np.array(domain_accuracy),color='r', label='Domain Accuracy')
plt.plot(np.arange(len(domain_accuracy)), np.array(cross_domain_accuracy), color='g', label='Cross Domainn Accuracy')
  
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Epochs vs Accuracy(ResNet Trained on MNINST)")
plt.legend()
plt.show()

In [None]:
seed = 369
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

params = {
    "bsize" : 256,# Batch size during training.
    'imsize' : 64,# Spatial size of training images. All images will be resized to this size during preprocessing.
    'nepochs' : 50,#Number of training epochs.
    'lr' : 0.002,#Learning rate for optimizers
    'nclasses' : 65, #number of classes
    'save_epoch' : 20}# Save step.

device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")

clipData = clipRealData('Clipart')
realData = clipRealData('Real World')
realLoader = DataLoader(realData, batch_size=params['bsize'])
clipLoader = DataLoader(clipData, batch_size=params['bsize'])

In [None]:
model = ResNet(params['imsize'], params['nclasses'], False).to(device)
model.apply(weights_init)

train_loss, domain_accuracy, cross_domain_accuracy = [],[],[]
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], weight_decay = 0.001, momentum = 0.9)  

for epoch in range(params['nepochs']):
    for images, labels in clipLoader:  
        images = images.float().to(device)
        labels = labels.type(torch.LongTensor).to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        cross_domain_accuracy.append(eval(model, realData)[0])
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, params['nepochs'], loss.item()))

eval(model, realData,params['bsize'], path = 'output/real.json')
torch.save({
            'model' : model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'params' : params
            }, 'output/trained_on_clip_resnet.pth')

del(model, optimizer)
torch.cuda.empty_cache()
gc.collect()

In [None]:
plt.plot(np.arange(len(train_loss)), np.array(train_loss))
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Iteration vs loss(resnet trained on Clipart)")
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(len(cross_domain_accuracy)), np.array(cross_domain_accuracy))
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.title("Accuracy vs loss(resnet trained on Clipart)")
plt.legend()
plt.show()

In [None]:
model = ResNet(params['imsize'], params['nclasses'], False).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], weight_decay = 0.001, momentum = 0.9)  
train_loss, domain_accuracy, cross_domain_accuracy = [],[],[]
for epoch in range(params['nepochs']):
    for images, labels in realLoader:  
        images = images.float().to(device)
        labels = labels.type(torch.LongTensor).to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()        
        train_loss.append(loss.item())
        cross_domain_accuracy.append(eval(model, clipData)[0])
        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print ('Epoch [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, params['nepochs'], loss.item()))

eval(model, clipData, path = 'output/real.json')
torch.save({
            'model' : model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'params' : params
            }, 'output/trained_on_clip_resnet.pth')

del(model, optimizer)
torch.cuda.empty_cache()
gc.collect()

In [None]:
plt.plot(np.arange(len(train_loss)), np.array(train_loss))
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Iteration vs loss(resnet trained on Real World)")
plt.legend()
plt.show()

In [None]:
plt.plot(np.arange(len(cross_domain_accuracy)), np.array(cross_domain_accuracy))
plt.xlabel("Iterations")
plt.ylabel("Accuracy")
plt.title("Accuracy vs loss(resnet trained on Real World)")
plt.legend()
plt.show()

<h2>DANN</h2>

<h3>Supporting Functions</h3>

In [None]:
class USPS_Dataset(Dataset):
    def __init__(self, path ,transform=None, target_transform=None):
        with h5py.File(path, 'r') as hf:
            train = hf.get('train')
            X_tr = train.get('data')[:]
            y_tr = train.get('target')[:]
            test = hf.get('test')
            X_te = test.get('data')[:]
            y_te = test.get('target')[:]
        X_tr = X_tr.reshape(-1,1,16,16)
        X_te = X_te.reshape(-1,1,16,16)
        X_usps = np.concatenate([X_tr,X_te],axis=0)
        y_usps = np.concatenate([y_tr,y_te],axis=0)
        self.data = torch.tensor(X_usps,dtype = torch.float)
        self.labels = torch.tensor(y_usps,dtype = torch.long)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        
        return self.data[idx], self.labels[idx]

class OfficeHomeDataset(Dataset):
    def __init__(self, img_titles ,labels, transform=None, target_transform=None):
        self.img_titles = img_titles
        self.transform = transform
        self.labels = torch.tensor(labels,dtype = torch.long)
        
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_titles)

    def __getitem__(self, idx):
        img_path = self.img_titles[idx]
        label = self.labels[idx]
        image = PIL.Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image,label

def usps_dataloader():
    transform = transforms.Compose([transforms.Grayscale(1),transforms.Resize((16,16)),transforms.ToTensor(),\
                                    transforms.Lambda(lambda x: x.repeat(3, 1, 1) ),\
                                    transforms.Normalize((0.5,0.5,0.5),std=(0.5,0.5,0.5))])
    usps_dataset = USPS_Dataset('/home/hiren/Apoorv Pandey/ADRL/Ass2/usps.h5',transform = transform)
    usps_dataloader = DataLoader(usps_dataset,batch_size=batch_size,shuffle=True)
    return usps_dataloader

def mnist_dataloader():
    transform = transforms.Compose([transforms.Grayscale(1),transforms.Resize((16,16)),transforms.ToTensor(),\
                                    transforms.Lambda(lambda x: x.repeat(3, 1, 1) ),\
                                    transforms.Normalize((0.5,0.5,0.5),std=(0.5,0.5,0.5))])

    mnist_dataset = MNIST('./Mnist',train=True,download=True,transform=transform)
    mnist_loader = DataLoader(mnist_dataset,batch_size = batch_size,shuffle=True)
    return mnist_loader

    
def office_dataloader():
    office_dataset_directory = '/home/hiren/Apoorv Pandey/ADRL/Ass2/OfficeHomeDataset_10072016'


    real_world_imgs = []

    real_world_path = '/home/hiren/Apoorv Pandey/ADRL/Ass2/OfficeHomeDataset_10072016/Real World/'
    clipart_path = "/home/hiren/Apoorv Pandey/ADRL/Ass2/OfficeHomeDataset_10072016/Clipart/"
    category_mapping = {}
    for i,x in enumerate(os.walk(real_world_path)):
        if i == 0:
            continue
        category = x[0]
        if category.startswith(real_world_path):
            category = category.replace(real_world_path, '', 1)
        category_mapping[category] = i


    print(category_mapping)
    real_img_list,real_img_labels = [],[]
    for i,x in enumerate(os.walk(real_world_path)):
        if i == 0:
            continue
        category = x[0]
        if category.startswith(real_world_path):
            category = category.replace(real_world_path, '', 1)
        label = category_mapping[category]

        real_img_list =  real_img_list + glob.glob(real_world_path + category+'/*.jpg')
        real_img_labels = real_img_labels + [i]*len(real_img_list)



    clip_img_list,clip_img_labels = [],[]

    for i,x in enumerate(os.walk(clipart_path)):
        if i == 0:
            continue
        category = x[0]
        if category.startswith(clipart_path):
            category = category.replace(clipart_path, '', 1)
        label = category_mapping[category]

        clip_img_list =  clip_img_list + glob.glob(clipart_path + category+'/*.jpg')

        label = category_mapping[category]
        clip_img_labels = clip_img_labels + [label]*len(clip_img_list)


    transform = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5),
                             std=(0.5,0.5,0.5))
    ])
    real_office_dataset = OfficeHomeDataset(real_img_list,real_img_labels,transform)
    clipart_dataset = OfficeHomeDataset(clip_img_list,clip_img_labels,transform)
    
    realworld_loader = DataLoader(real_office_dataset,batch_size = batch_size,shuffle=True)
    clipart_loader = DataLoader(clipart_dataset,batch_size = batch_size,shuffle=True)
    
    return realworld_loader,clipart_loader

<h3>Models</h3>

In [None]:
batch_size = 100

class GradientReversalLayer(torch.autograd.Function):
    
    @staticmethod

    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

## Basic ResNet50 Classifier
class BaseClassifier(nn.Module):
    def __init__(self,num_classes=10):
        super().__init__()
        self.model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.model.fc = nn.Sequential(nn.Linear(2048,512),nn.ReLU(),nn.Linear(512,num_classes))
  
    def forward(self,x):
        x = self.model(x)
        return x

class FeatureExtractor1(nn.Module):
    def __init__(self):

        super().__init__()
        self.model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])

    def forward(self,x):
        x = self.feature_extractor(x).squeeze(2).squeeze(2)
        #print(f'size of x = {x.size()}')
        return x
  
class FeatureExtractor2(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
    self.feature_extractor = nn.Sequential(*list(self.model.children())[:-3])
  
  def forward(self,x):
    x = self.feature_extractor(x).squeeze(2).squeeze(2)
    #print(f'size of x = {x.size()}')
    return x

class FeatureExtractor3(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
    self.feature_extractor = nn.Sequential(*list(self.model.children())[:-4])
  
  def forward(self,x):
    x = self.feature_extractor(x).squeeze(2).squeeze(2)
    #print(f'size of x = {x.size()}')
    return x

class Classifier(nn.Module):
  def __init__(self,in_features,num_classes=10):
    super().__init__()
    self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, num_classes)
        )
  def forward(self,x):
    
    return self.classifier(x)

class Discriminator(nn.Module):
    def __init__(self,in_features,num_classes=2):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, num_classes)
        )
    
    def forward(self,x,alpha):
        reversed_x = GradientReversalLayer.apply(x,alpha)
        return self.discriminator(reversed_x)

<h3>Training</h3>

In [None]:
### Train Basic ResNet50 
def base_classifier_train(src_loader,tgt_loader,model,optimizer,writer,checkpoint_path,text_file):
    
    n_epochs = 30
    global_step = 0
    best_acc = 0.0
    criterion = nn.CrossEntropyLoss()
    for epoch in range(n_epochs):
        
        avg_loss,steps = 0.0,0
        correct_src,num_samples_src = 0.0,0
        
        for idx,src_data in enumerate(src_loader):
            model.train()
            src_x,src_y = src_data
            src_x ,src_y= src_x.to(device),src_y.to(device)
            
            if src_x.size(1)==1:
                src_x = torch.cat([src_x,src_x,src_x],dim=1)
            
            src_batch_len = src_x.size(0)
            optimizer.zero_grad()
            out = model(src_x)
            preds = torch.argmax(out,dim=1)
            correct_src += torch.sum(preds==src_y)
            num_samples_src += src_y.size(0)
            loss = criterion(out,src_y)
            loss.backward()
            optimizer.step()
            avg_loss += loss.item()
            steps += 1
            writer.add_scalar('loss',loss.item(),global_step)
        avg_loss = avg_loss/steps
        avg_accuracy_src = correct_src/num_samples_src
        with torch.no_grad():
            model.eval()
            correct_tgt,num_samples_tgt = 0.0,0
            for idx,tgt_data in enumerate(tgt_loader):

                tgt_x,tgt_y = tgt_data
                tgt_batch_len = tgt_x.size(0)
                if tgt_x.size(1)==1:
                    tgt_x = torch.cat([tgt_x,tgt_x,tgt_x],dim=1)
                tgt_x,tgt_y = tgt_x.to(device),tgt_y.to(device)
                out = model(tgt_x)
                preds = torch.argmax(out,dim=1)
                correct_tgt += torch.sum(preds==tgt_y)
                num_samples_tgt += tgt_y.size(0)
            avg_accuracy_tgt = correct_tgt/num_samples_tgt
            with open(text_file, 'a') as f:
                print(f'Epoch :{epoch}  Src Accuracy = {avg_accuracy_src} Tgt accuracy : {avg_accuracy_tgt} loss: {avg_loss} ', file=f) 
        

        if best_acc < avg_accuracy_tgt:
            best_acc = avg_accuracy_tgt
            torch.save(model.state_dict(),checkpoint_path)


In [None]:
def train(src_loader,tgt_loader,F,C,D,optimizer,writer,checkpoint_path,text_file):
    n_epochs = 30
    step = 0
    best_acc = 0.0
    criterion_c = nn.CrossEntropyLoss()
    criterion_d = nn.CrossEntropyLoss()
    for epoch in range(n_epochs):
        start_steps = epoch * len(src_loader)
        total_steps = n_epochs * len(tgt_loader)
        avg_c_loss,avg_d_loss,avg_acc = 0.0,0.0,0.0
        c_steps,d_steps,acc_steps = 0,0,0
        for idx,(src_data,tgt_data) in enumerate(zip(src_loader,tgt_loader)):
            F.train(),C.train(),D.train()
            
            src_x,src_y = src_data
            tgt_x,tgt_y = tgt_data
            src_batch_len = src_x.size(0)
            tgt_batch_len = tgt_x.size(0)
            if src_x.size(1)==1:
                src_x = torch.cat([src_x,src_x,src_x],dim=1)
            if tgt_x.size(1)==1:
                tgt_x = torch.cat([tgt_x,tgt_x,tgt_x],dim=1)
            #print(f'Src shape = {src_x.size()} Tgt shape {tgt_x.size()}')
            src_x ,src_y= src_x.to(device),src_y.to(device)
            tgt_x,tgt_y = tgt_x.to(device),tgt_y.to(device)
            p = float(idx + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1
            
            combined_x = torch.cat([src_x,tgt_x],dim=0)
            combined_features = F(combined_x)
            if combined_features.dim()>2:
                combined_features = combined_features.reshape(src_batch_len+tgt_batch_len,-1)
            src_features = combined_features[0:len(src_x)]
            tgt_features = combined_features[len(src_x):]
            
            #print(f'Src features shape {src_features.size()}')
            

            class_pred_src = C(src_features)
            c_loss = criterion_c(class_pred_src,src_y)


            src_labels = torch.ones(src_batch_len,dtype=torch.long).to(device)
            tgt_labels = torch.zeros(tgt_batch_len,dtype=torch.long).to(device)
            combined_labels = torch.cat([src_labels,tgt_labels],dim=0)
            disc_pred = D(combined_features,alpha)
            disc_pred_labels = torch.argmax(disc_pred,dim=1)
            d_loss = criterion_d(disc_pred,combined_labels)

            #c_loss.backward(retain_graph=True)

            #d_loss.backward()
            total_loss = d_loss + c_loss

            total_loss.backward()
            avg_c_loss += c_loss.item()
            c_steps += src_y.size(0)
            avg_d_loss += d_loss.item()
            d_steps = combined_labels.size(0)

            optimizer.step()
            optimizer.zero_grad()

            writer.add_scalar('C loss',c_loss.item(),step)
            writer.add_scalar('D loss',d_loss.item(),step)
            writer.add_scalar('Total Loss',total_loss.item(),step)
            

            with torch.no_grad():
                F.eval()
                C.eval()
                
                class_pred_tgt = C(tgt_features)
                tgt_labels = torch.argmax(class_pred_tgt,dim=1)

                correct = torch.sum(tgt_labels==tgt_y)
                accuracy = correct/tgt_labels.size(0)
                writer.add_scalar('Accuracy',accuracy.item(),step)
                avg_acc += correct
                acc_steps += tgt_y.size(0)

            step += 1
            #print(f'Epoch :{epoch} idx :{idx} accuracy : {accuracy} C_loss: {c_loss} d_loss :{d_loss}')
            state = {'F':F.state_dict(),'C':C.state_dict,'D':D.state_dict(),'opt':optimizer}
        avg_acc = avg_acc/acc_steps
        avg_c_loss = avg_c_loss/c_steps
        avg_d_loss = avg_d_loss/d_steps
        
        with open(text_file, 'a') as f:
            print(f'Epoch :{epoch}  accuracy : {avg_acc} C_loss: {avg_c_loss} d_loss :{avg_d_loss} ', file=f) 
        

        if best_acc < avg_acc:
            best_acc = avg_acc
            torch.save(state,checkpoint_path)

In [None]:
usps_loader = usps_dataloader()
mnist_loader = mnist_dataloader()
realworld_loader,clipart_loader = office_dataloader()

writer = SummaryWriter('runs/Base/USPS-MNIST')
model = BaseClassifier().to(device)
optimizer = torch.optim.AdamW(list(model.parameters()))
base_classifier_train(usps_loader,mnist_loader,model,optimizer,writer,'./Base_USPS-MNIST.pt','./Base_USPS-MNIST.txt')

writer = SummaryWriter('runs/Base/MNIST-USPS')
model = BaseClassifier().to(device)
optimizer = torch.optim.AdamW(list(model.parameters()))
base_classifier_train(mnist_loader,usps_loader,model,optimizer,writer,'./Base_MNIST-USPS.pt','./Base_MNIST-USPS.txt')

writer = SummaryWriter('runs/Base/Real-Clipart')
model = BaseClassifier(num_classes=65).to(device)
optimizer = torch.optim.AdamW(list(model.parameters()))
base_classifier_train(usps_loader,mnist_loader,model,optimizer,writer,'./Base_Real-Clipart.pt','./Base_Real-Clipart.txt')

writer = SummaryWriter('runs/Base/Clipart-Real')
model = BaseClassifier(num_classes=65).to(device)
optimizer = torch.optim.AdamW(list(model.parameters()))
base_classifier_train(usps_loader,mnist_loader,model,optimizer,writer,'./Base_Clipart-Realt.pt','./Base_Clipart-Real.txt')

writer1 = SummaryWriter('runs/DANN/USPS-MNIST/1')
writer2 = SummaryWriter('runs/DANN/USPS-MNIST/2')
writer3 = SummaryWriter('runs/DANN/USPS-MNIST/3')
F1 = FeatureExtractor1().to(device)
F2 = FeatureExtractor2().to(device)
F3 = FeatureExtractor3().to(device)
C1= Classifier(2048,10).to(device)
C2 = Classifier(1024,10).to(device)
C3 = Classifier(2048,10).to(device)
D1 = Discriminator(2048,2).to(device)
D2 = Discriminator(1024,2).to(device)
D3= Discriminator(2048,2).to(device)

optimizer1 = torch.optim.AdamW(list(F1.parameters()) +list(C1.parameters()) +list(D1.parameters()))
train(usps_loader,mnist_loader,F1,C1,D1,optimizer1,writer1,'./DANN_USPS-MNIST_1.pt','./USPS-MNIST_1.txt')

optimizer2 = torch.optim.AdamW(list(F2.parameters()) +list(C2.parameters()) +list(D2.parameters()))
train(usps_loader,mnist_loader,F2,C2,D2,optimizer2,writer2,'./DANN_USPS-MNIST_2.pt','./USPS-MNIST_2.txt')

optimizer3 = torch.optim.AdamW(list(F3.parameters()) +list(C3.parameters()) +list(D3.parameters()))
train(usps_loader,mnist_loader,F3,C3,D3,optimizer3,writer3,'./DANN_USPS-MNIST_3.pt','./USPS-MNIST_3.txt')


writer1 = SummaryWriter('runs/DANN/MNIST-USPS/1')
writer2 = SummaryWriter('runs/DANN/MNIST-USPS/2')
writer3 = SummaryWriter('runs/DANN/MNIST-USPS/3')
F1 = FeatureExtractor1().to(device)
F2 = FeatureExtractor2().to(device)
F3 = FeatureExtractor3().to(device)
C1= Classifier(2048,10).to(device)
C2 = Classifier(1024,10).to(device)
C3 = Classifier(2048,10).to(device)
D1 = Discriminator(2048,2).to(device)
D2 = Discriminator(1024,2).to(device)
D3= Discriminator(2048,2).to(device)

optimizer1 = torch.optim.AdamW(list(F1.parameters()) +list(C1.parameters()) +list(D1.parameters()))
train(mnist_loader,usps_loader,F1,C1,D1,optimizer1,writer1,'./DANN_mnist_USPS_1.pt','./MNIST-USPS_1.txt')

optimizer2 = torch.optim.AdamW(list(F2.parameters()) +list(C2.parameters()) +list(D2.parameters()))
train(mnist_loader,usps_loader,F2,C2,D2,optimizer2,writer2,'./DANN_mnist_USPS_2.pt','./MNIST-USPS_2.txt')

optimizer3 = torch.optim.AdamW(list(F3.parameters()) +list(C3.parameters()) +list(D3.parameters()))
train(mnist_loader,usps_loader,F3,C3,D3,optimizer3,writer3,'./DANN_mnist_USPS_3.pt','./MNIST-USPS_3.txt')


writer1 = SummaryWriter('runs/DANN/Office/R_C/1')
writer2 = SummaryWriter('runs/DANN/Office/R_C/2')
writer3 = SummaryWriter('runs/DANN/Office/R_C/3')
F1 = FeatureExtractor1().to(device)
F2 = FeatureExtractor2().to(device)
F3 = FeatureExtractor3().to(device)
C1= Classifier(2048,65).to(device)
C2 = Classifier(4096,65).to(device)
C3 = Classifier(8192,65).to(device)
D1 = Discriminator(2048,2).to(device)
D2 = Discriminator(4096,2).to(device)
D3= Discriminator(8192,2).to(device)


optimizer1 = torch.optim.AdamW(list(F1.parameters()) +list(C1.parameters()) +list(D1.parameters()))
train(realworld_loader,clipart_loader,F1,C1,D1,optimizer1,writer1,'./DANN_Office_RC_1.pt','./Real-Clip_1.txt')
optimizer2 = torch.optim.AdamW(list(F2.parameters()) +list(C2.parameters()) +list(D2.parameters()))
train(realworld_loader,clipart_loader,F2,C2,D2,optimizer2,writer2,'./DANN_Office_RC_2.pt','./Real-Clip_2.txt')
optimizer3 = torch.optim.AdamW(list(F3.parameters()) +list(C3.parameters()) +list(D3.parameters()))
train(realworld_loader,clipart_loader,F3,C3,D3,optimizer3,writer3,'./DANN_Office_RC_3.pt','./Real-Clip_3.txt')

writer1 = SummaryWriter('runs/DANN/Office/C_R/1')
writer2 = SummaryWriter('runs/DANN/Office/C_R/2')
writer3 = SummaryWriter('runs/DANN/Office/C_R/3')
F1 = FeatureExtractor1().to(device)
F2 = FeatureExtractor2().to(device)
F3 = FeatureExtractor3().to(device)
C1= Classifier(2048,65).to(device)
C2 = Classifier(4096,65).to(device)
C3 = Classifier(8192,65).to(device)
D1 = Discriminator(2048,2).to(device)
D2 = Discriminator(4096,2).to(device)
D3= Discriminator(8192,2).to(device)


optimizer1 = torch.optim.AdamW(list(F1.parameters()) +list(C1.parameters()) +list(D1.parameters()))
train(clipart_loader,realworld_loader,F1,C1,D1,optimizer1,writer1,'./DANN_Office_CR_1.pt','./Clip-Real_1.txt')

optimizer2 = torch.optim.AdamW(list(F2.parameters()) +list(C2.parameters()) +list(D2.parameters()))
train(clipart_loader,realworld_loader,F2,C2,D2,optimizer2,writer2,'./DANN_Office_CR_2.pt','./Clip-Real_2.txt')

optimizer3 = torch.optim.AdamW(list(F3.parameters()) +list(C3.parameters()) +list(D3.parameters()))
train(clipart_loader,realworld_loader,F3,C3,D3,optimizer3,writer3,'./DANN_Office_CR_3.pt','./Clip-Real_3.txt')


<h2>ADDA</h2>

<h3>Hyperparameters and supporing functions</h3>

In [None]:
device = 'cuda:0'
batch_size = 100
num_epochs = 100
d_learning_rate = 1e-4
c_learning_rate = 1e-4
beta1 = 0.5
beta2 = 0.9
lambda_gp = 10

CUDA_LAUNCH_BLOCKING=1

def make_variable(tensor):
    """Convert Tensor to Variable."""
    
    if torch.cuda.is_available():
        tensor = tensor.cuda()
    
    return Variable(tensor,requires_grad= False)

def init_random_seed(manual_seed):
    """Init random seed."""
    seed = None
    if manual_seed is None:
        seed = random.randint(1, 10000)
    else:
        seed = manual_seed
    print("use random seed: {}".format(seed))
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
def save_model(model,path):
    torch.save(model.state_dict(),path)

<h3>Models</h3>

In [None]:
class Encoder(nn.Module):
    """ encoder model for ADDA."""

    def __init__(self):
        super().__init__()
        self.model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        #self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])

    def forward(self,x):
        x = self.feature_extractor(x).squeeze(2).squeeze(2)
        return x

class Classifier(nn.Module):
    def __init__(self,in_features=2048,num_classes=10):
        super().__init__()
        self.classifier = nn.Sequential(
                nn.Linear(in_features, 512),
                nn.LeakyReLU(0.2),
                nn.Linear(512, 128),
                nn.LeakyReLU(0.2),
                nn.Linear(128, num_classes)
            )
    def forward(self,x):

        return self.classifier(x)

class Discriminator(nn.Module):
    def __init__(self,in_features=2048,num_classes=2):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, num_classes),
            nn.LogSoftmax()
        )
    
    def forward(self,x):
        
        return self.discriminator(x)

<h3>Training</h3>

In [None]:
def gradient_penalty(crit, real, fake, epsilon):
    '''
    Return the gradient of the critic's scores with respect to mixes of real and fake images.
    Parameters:
        crit: the critic model
        real: a batch of real images
        fake: a batch of fake images
        epsilon: a vector of the uniformly random proportions of real/fake per mixed image
    Returns:
        gradient: the gradient of the critic's scores, with respect to the mixed image
    '''
    # Mix the images together

    #print(f'Real size = {real.size()} Fake size = {fake.size()}')
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)
    
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        # Note: You need to take the gradient of outputs with respect to inputs.
        # This documentation may be useful, but it should not be necessary:
        # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad
        #### START CODE HERE ####
        inputs=mixed_images,
        outputs=mixed_scores,
        #### END CODE HERE ####
        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    
    # Penalize the mean squared distance of the gradient norms from 1
    #### START CODE HERE ####
    penalty = torch.mean((gradient_norm - 1)**2)
    #### END CODE HERE ####
    return penalty
    

def train_src(encoder, classifier, data_loader,writer,logfile,num_epochs = 20):
    """Train classifier for source domain."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    encoder.train()
    classifier.train()

    # setup criterion and optimizer
    optimizer = optim.Adam(
        list(encoder.parameters()) + list(classifier.parameters()),
        lr=c_learning_rate,
        betas=(beta1, beta2))
    criterion = nn.CrossEntropyLoss()

    ####################
    # 2. train network #
    ####################
    global_step = 0
    for epoch in range(num_epochs):
        for step, (images, labels) in enumerate(data_loader):

            # make images and labels variable
            
            if images.size(1)==1:
                images = torch.cat([images,images,images],dim=1)
            
            
            images = make_variable(images)
            labels = make_variable(labels.squeeze_())

            # zero gradients for optimizer
            optimizer.zero_grad()

            # compute loss for critic
            preds = classifier(encoder(images))
            loss = criterion(preds, labels)

            # optimize source classifier
            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss',loss.item(),global_step)
            global_step += 1

            # print step info
            if ((step + 1) % 20 == 0):
                with open(logfile,'a') as f:
                    print(f"Epoch [{epoch + 1}/{num_epochs}] Step [{step + 1}/{len(data_loader)}] \
                       loss={loss.item()} ",file = f)
               

        # eval model on test set
        if ((epoch + 1) % 20 == 0):
            eval_src(encoder, classifier, data_loader)

        # save model parameters
        if ((epoch + 1) % 100 == 0):
            save_model(encoder, f"./ADDA_new/ADDA-source-encoder-{epoch + 1}.pt")
            save_model(
                classifier, f"./ADDA_new/ADDA-source-classifier-{epoch + 1}.pt")

    # # save final model
    torch.save(encoder.state_dict(),"./ADDA_new/ADDA-source-encoder-final.pt")
    torch.save(classifier.state_dict(), "./ADDA_new/ADDA-source-classifier-final.pt")

    return encoder, classifier


def eval_src(encoder, classifier, data_loader):
    """Evaluate classifier for source domain."""
    # set eval state for Dropout and BN layers
    encoder.eval()
    classifier.eval()

    # init loss and accuracy
    loss = 0.0
    acc = 0.0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    # evaluate network
    for (images, labels) in data_loader:
        
        if images.size(1)==1:
            images = torch.cat([images,images,images],dim=1)
        
        
        images = make_variable(images)
        labels = make_variable(labels)

        preds = classifier(encoder(images))
        loss += criterion(preds, labels).item()

        pred_cls = preds.data.max(1)[1]
        acc += pred_cls.eq(labels.data).cpu().sum()

    loss = loss/len(data_loader)
    acc = acc/len(data_loader.dataset)

    print(f"Avg Loss = {loss}, Avg Accuracy = {acc}")

def train_tgt(src_encoder, tgt_encoder, critic,classifier,
              src_data_loader, tgt_data_loader,writer,logfile,num_epochs=100):
    """Train encoder for target domain."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    tgt_encoder.train()
    critic.train()

    # setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=c_learning_rate,
                               betas=(beta1, beta2))
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=d_learning_rate,
                                  betas=(beta1,beta2))
    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
    best_acc = 0.0

    ####################
    # 2. train network #
    ####################
    global_step = 0
    for epoch in range(num_epochs):
        
        # zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, _), (images_tgt, _)) in data_zip:
            critic.train()
            tgt_encoder.train()
            ###########################
            # 2.1 train discriminator #
            ###########################
            
            if images_src.size(1)==1:
                images_src = torch.cat([images_src,images_src,images_src],dim=1)
            if images_tgt.size(1)==1:
                images_tgt = torch.cat([images_tgt,images_tgt,images_tgt],dim=1)
            
            # make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            # zero gradients for optimizer
            optimizer_critic.zero_grad()

            # extract and concat features
            feat_src = src_encoder(images_src)
            feat_tgt = tgt_encoder(images_tgt)
            
            real_pred = critic(feat_src)
            fake_pred = critic(feat_tgt)

            epsilon = torch.rand(len(feat_tgt), 1, 1, 1, device=device, requires_grad=True)
            gp = gradient_penalty(critic, feat_src, feat_tgt, epsilon)
        
            # Adversarial loss
            loss_critic = -torch.mean(real_pred) + torch.mean(fake_pred) + lambda_gp * gp

            # compute loss for critic
            loss_critic.backward()

            # optimize critic
            optimizer_critic.step()


            ############################
            # 2.2 train target encoder #
            ############################

            # zero gradients for optimizer
            optimizer_critic.zero_grad()
            optimizer_tgt.zero_grad()

            # extract and target features
            feat_tgt = tgt_encoder(images_tgt)

            # predict on discriminator
            pred_tgt = critic(feat_tgt)

            # prepare fake labels
            label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

            # compute loss for target encoder
            loss_tgt = criterion(pred_tgt, label_tgt)
            loss_tgt.backward()

            # optimize target encoder
            optimizer_tgt.step()
            writer.add_scalar('Loss Critic',loss_critic.item(),global_step)
            writer.add_scalar('Loss Tgt',loss_tgt.item(),global_step)
   
            global_step += 1
            #######################
            # 2.3 print step info #
            #######################
            if ((step + 1) % 20 == 0):

                print(f"Epoch [{epoch + 1}/{num_epochs}] Step [{step + 1}/{len_data_loader}]:\
                    d_loss={loss_critic.item()} g_loss={loss_tgt.item()}")
            
        with torch.no_grad(): 
            loss = 0.0
            tgt_acc = 0.0
            tgt_encoder.eval()
            classifier.eval()
            for (images, labels) in tgt_data_loader:
                if images.size(1)==1:
                    images = torch.cat([images,images,images],dim=1)
                images = make_variable(images)
                labels = make_variable(labels).squeeze_()

                preds = classifier(tgt_encoder(images))
                loss += criterion(preds, labels).item()

                pred_cls = preds.data.max(1)[1]
                tgt_acc += pred_cls.eq(labels.data).cpu().sum()
            loss /= len(tgt_data_loader)
            tgt_acc /= len(tgt_data_loader.dataset)
        with open(logfile,'a') as f:
            print(f"Epoch : {epoch} tgt_acc = {tgt_acc}",file = f)

        #############################
        # 2.4 save model parameters #
        #############################
        if best_acc < tgt_acc:
            best_acc = tgt_acc
            torch.save(critic.state_dict(),"./ADDA_new/ADDA-critic-final.pt")
            torch.save(tgt_encoder.state_dict(),"./ADDA_new/ADDA-target-encoder-final.pt")
            
    return tgt_encoder

In [None]:
src_data_loader = mnist_dataloader()

tgt_data_loader = usps_dataloader()

# load models
src_encoder = Encoder().to(device)

src_classifier = Classifier(num_classes=65).to(device)

tgt_encoder = Encoder().to(device)

critic = Discriminator().to(device)

writer_src = SummaryWriter('runs/ADDA/MNIST-USPS/train_src')
writer_tgt = SummaryWriter('runs/ADDA/MNIST-USPS//train_tgt')

train_tgt_logs = './ADDA_new/MNIST-USPS/train_tgt_logs.txt'
train_src_logs = './ADDA_new/MNIST-USPS/train_src_logs.txt'


if os.path.exists("./ADDA_new/MNIST-USPS//ADDA-source-encoder-final.pt"):
    src_encoder.load_state_dict(torch.load("./ADDA_new/MNIST-USPS/ADDA-source-encoder-final.pt"))
    src_classifier.load_state_dict(torch.load("./ADDA_new/MNIST-USPS/ADDA-source-classifier-final.pt"))
else:
    src_encoder, src_classifier = train_src(src_encoder, src_classifier, src_data_loader,writer_src,train_src_logs)

# eval source model
eval_src(src_encoder, src_classifier, src_data_loader)

tgt_encoder.load_state_dict(src_encoder.state_dict())

if os.path.exists("./ADDA_new/MNIST-USPS/ADDA-target-encoder-final.pt"):
    tgt_encoder.load_state_dict(torch.load("./ADDA_new/MNIST-USPS/ADDA-target-encoder-final.pt"))
else:  

    tgt_encoder = train_tgt(src_encoder, tgt_encoder, critic,src_classifier, src_data_loader, tgt_data_loader,\
                            writer_tgt,train_tgt_logs)

<h2>CycleGAN</h2>

In [None]:
def weights_init(w):
    classname = w.__class__.__name__
    if classname.find('conv') != -1:
        nn.init.normal_(w.weight.data, 0.0, 0.02)
    elif classname.find('bn') != -1:
        nn.init.normal_(w.weight.data, 1.0, 0.02)
        nn.init.constant_(w.bias.data, 0)

In [None]:
def conv_block(c_in, c_out, k_size=4, stride=2, pad=1, use_bn=True, transpose=False):
    module = []
    if transpose:
        module.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, output_padding=pad, bias=not use_bn))
    else:
        module.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=not use_bn))
    if use_bn:
        module.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*module)

class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.conv1 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)
        self.conv2 = conv_block(channels, channels, k_size=3, stride=1, pad=1, use_bn=True)

    def __call__(self, x):
        x = F.relu(self.conv1(x))
        return x + self.conv2(x)

class Discriminator(nn.Module):
    def __init__(self, params):
        super(Discriminator, self).__init__()
        self.conv1 = conv_block(params['nc'], params['ndf'], use_bn=False)
        self.conv2 = conv_block(params['ndf'], params['ndf'] * 2)
        self.conv3 = conv_block(params['ndf'] * 2, params['ndf'] * 4)
        self.conv4 = conv_block(params['ndf'] * 4, 1, k_size=3, stride=1, pad=1, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        b_size = x.shape[0]
        x = x.view(b_size,1,32,32)
        alpha = 0.2
        x = F.leaky_relu(self.conv1(x), alpha)
        x = F.leaky_relu(self.conv2(x), alpha)
        x = F.leaky_relu(self.conv3(x), alpha)
        x = self.conv4(x)
        x = x.reshape([x.shape[0], -1]).mean(1)
        return x

class Generator(nn.Module):
    def __init__(self, params):
        super(Generator, self).__init__()
        self.conv1 = conv_block(params['nc'], params['ngf'], k_size=5, stride=1, pad=2, use_bn=True)
        self.conv2 = conv_block(params['ngf'], params['ngf'] * 2, k_size=3, stride=2, pad=1, use_bn=True)
        self.conv3 = conv_block(params['ngf'] * 2, params['ngf'] * 4, k_size=3, stride=2, pad=1, use_bn=True)
        self.res4 = ResBlock(params['ngf'] * 4)
        self.tconv5 = conv_block(params['ngf'] * 4, params['ngf'] * 2, k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.tconv6 = conv_block(params['ngf'] * 2, params['ngf'], k_size=3, stride=2, pad=1, use_bn=True, transpose=True)
        self.conv7 = conv_block(params['ngf'], params['nc'], k_size=5, stride=1, pad=2, use_bn=False)

        # Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        b_size = x.shape[0]
        x = x.view(b_size,1,32,32)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.res4(x))
        x = F.relu(self.tconv5(x))
        x = F.relu(self.tconv6(x))
        x = torch.tanh(self.conv7(x))
        return x

In [None]:
seed = 369
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

params = {
    "bsize" : 256,# Batch size during training.
    'imsize' : 32,# Spatial size of training images. All images will be resized to this size during preprocessing.
    'nc' : 1,# Number of channles in the training images. For coloured images this is 3.
    'nz' : 100,# Size of the Z latent vector (the input to the generator).
    'ngf' : 64,# Size of feature maps in the generator. The depth will be multiples of this.
    'ndf' : 64, # Size of features maps in the discriminator. The depth will be multiples of this.
    'nepochs' : 50,# Number of training epochs.
    'lr' : 0.0002,# Learning rate for optimizers
    'beta1' : 0.5,# Beta1 hyperparam for Adam optimizer
    'save_epoch' : 2,# Save step.
    'lambda_cyc':10.0,
    'lambda_id':5.0}

device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")

MNISTData = mnistDataset('mnist')
USPSData = uspsDataset('usps/usps.h5')
mnistLoader = DataLoader(MNISTData, params['bsize'], shuffle=True)
uspsLoader = DataLoader(USPSData, params['bsize'], shuffle=True)
iters_per_epoch = min(len(mnistLoader), len(uspsLoader))

In [None]:
netG_AB = Generator(params).to(device)
netG_AB.apply(weights_init)
netG_BA = Generator(params).to(device)
netG_BA.apply(weights_init)
netD_A = Discriminator(params).to(device)
netD_A.apply(weights_init)
netD_B = Discriminator(params).to(device)
netD_B.apply(weights_init)

criterion = nn.MSELoss()
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

fixed_noise = torch.randn(64, params['nz'], 1, 1, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(list(netD_A.parameters())+list(netD_B.parameters()), lr=params['lr'], betas=(params['beta1'], 0.999), weight_decay=2e-5)
optimizerG = optim.Adam(list(netG_AB.parameters())+list(netG_BA.parameters()), lr=params['lr'], betas=(params['beta1'], 0.999), weight_decay=2e-5)

img_list = []
G_losses = []
D_losses = []

prev_time = time.time()

In [None]:
a_fixed = iter(mnistLoader).next()[0]
b_fixed = iter(uspsLoader).next()[0]

a_fixed = a_fixed.cuda()
b_fixed = b_fixed.cuda()

In [None]:
iters = 0

print("Starting Training Loop...")
print("-"*25)
recon_loss, gan_loss, dis_loss = [],[],[]
for epoch in range(params['nepochs']):
    netG_AB.train()
    netG_BA.train()
    netD_A.train()
    netD_B.train()
    for i, (a_data, b_data) in enumerate(zip(mnistLoader, uspsLoader)):
        a_real, _ = a_data
        b_real, _ = b_data

        a_real, b_real = a_real.float().cuda(), b_real.float().cuda()

        # Fake Images
        b_fake = netG_AB(a_real)
        a_fake = netG_BA(b_real)

        # Training discriminator
        a_real_out = netD_A(a_real)
        a_fake_out = netD_A(a_fake.detach())
        a_d_loss = (torch.mean((a_real_out - 1) ** 2) + torch.mean(a_fake_out ** 2)) / 2

        b_real_out = netD_B(b_real)
        b_fake_out = netD_B(b_fake.detach())
        b_d_loss = (torch.mean((b_real_out - 1) ** 2) + torch.mean(b_fake_out ** 2)) / 2

        optimizerD.zero_grad()
        d_loss = a_d_loss + b_d_loss
        d_loss.backward()
        optimizerD.step()

        # Training Generator
        a_fake_out = netD_A(a_fake)
        b_fake_out = netD_B(b_fake)

        a_g_loss = torch.mean((a_fake_out - 1) ** 2)
        b_g_loss = torch.mean((b_fake_out - 1) ** 2)
        g_gan_loss = a_g_loss + b_g_loss

        a_g_ctnt_loss = (a_real - netG_BA(b_fake)).abs().mean()
        b_g_ctnt_loss = (b_real - netG_AB(a_fake)).abs().mean()
        g_ctnt_loss = a_g_ctnt_loss + b_g_ctnt_loss
        recon_loss.append([g_ctnt_loss.item()])
        gan_loss.append([g_gan_loss.item()])
        dis_loss.append([d_loss.item()])
        optimizerG.zero_grad()
        g_loss = g_gan_loss + g_ctnt_loss
        g_loss.backward()
        optimizerG.step()

        if i % 10 == 0:
            print("Epoch: " + str(epoch + 1) + "/" + str(params['nepochs'])
                    + " it: " + str(i) + "/" + str(iters_per_epoch)
                    + "\ta_d_loss:" + str(round(a_d_loss.item(), 4))
                    + "\ta_g_loss:" + str(round(a_g_loss.item(), 4))
                    + "\ta_g_ctnt_loss:" + str(round(a_g_ctnt_loss.item(), 4))
                    + "\tb_d_loss:" + str(round(b_d_loss.item(), 4))
                    + "\tb_g_loss:" + str(round(b_g_loss.item(), 4))
                    + "\tb_g_ctnt_loss:" + str(round(b_g_ctnt_loss.item(), 4)))

In [None]:
if epoch % params['save_epoch'] == 0:
        torch.save({
            'generator_AB' : netG_AB.state_dict(),
            'discriminator_A' : netD_A.state_dict(),
            'generator_BA' : netG_BA.state_dict(),
            'discriminator_B' : netD_B.state_dict(),
            'optimizerG' : optimizerG.state_dict(),
            'optimizerD' : optimizerD.state_dict(),
            'params' : params
            }, 'final_model.pth'.format(epoch))

In [None]:
plt.plot(np.arange(len(recon_loss)), np.array(recon_loss),color='r', label='Reconstruction Loss')
plt.plot(np.arange(len(recon_loss)), np.array(gan_loss), color='g', label='Generator Loss')
plt.plot(np.arange(len(recon_loss)), np.array(dis_loss), color='b', label='Discriminator Loss')
  
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss vs Iterations(CycleGAN)")
plt.legend()
plt.show()

In [None]:
model = ResNet(params['imsize'], 10, True).to(device)
model.load_state_dict(torch.load('output/trained_on_mnist_resnet.pth')['model'])
eval_cycleGAN(model, netG_AB, MNISTData)

In [None]:
model = ResNet(params['imsize'], 10, True).to(device)
model.load_state_dict(torch.load('output/trained_on_usps_resnet.pth')['model'])
eval_cycleGAN(model, netG_BA, USPSData)

In [None]:
r = np.random.randint(0, len(USPSData), size=50)
generated_img = netG_BA(torch.tensor(USPSData[r][0]).float().cuda().abs()).cpu().detach().numpy().reshape(50, 32, 32)
fig = plt.figure(figsize=(200., 200.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(10, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, i in zip(grid, range(50)):
    # Iterating over the grid returns the Axes.
    ax.imshow(USPSData[r[i]][0], cmap='gray')
    ax.imshow(np.clip(generated_img[i], 0, 1), cmap='gray')
plt.show()
