In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad as torch_grad
from torch.utils.data import DataLoader
import torchvision 
import torchvision.transforms as T
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
def one_hot(lbl):
    gen_lbl = torch.zeros(10)
    gen_lbl[lbl] = 1
    critic_lbl = torch.zeros(10,28,28)
    critic_lbl[lbl] = 1
    return gen_lbl,critic_lbl


transform = T.ToTensor()
dataset_train = MNIST(root='mnist_train',train=True,transform=transform,download=True)
dataset_test = MNIST(root='datasets/MNIST_test',train=False,transform=transform,download=True)
concatenated_dataset = ConcatDataset([dataset_train,dataset_test])

In [None]:
# @title Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()

        self.linear0 = nn.Sequential(
            nn.Linear(64,2048,bias=False),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU()
        )
        
        self.trans_conv0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=2048,
                               out_channels=512,
                               kernel_size=(4,4),
                               stride=1
                              ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU()
        )

        self.trans_conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512,
                               out_channels=64,
                               kernel_size=(4,4),
                               stride=2,
                               padding=1
                              ),
            nn.BatchNorm2d(64),
            nn.LeakyReLU()
        )

        self.trans_conv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64,
                               out_channels=32,
                               kernel_size=(4,4),
                               stride=2,
                               padding=2), 
            nn.BatchNorm2d(32),
            nn.LeakyReLU()
        )
        
        self.trans_conv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32,
                               out_channels=16,
                               kernel_size=(4,4),
                               stride=2,
                               padding=0), 
            nn.LeakyReLU()
        )
        
        self.conv0 = nn.Sequential(
            nn.Conv2d(in_channels=16,
                      out_channels=1,
                      kernel_size=3,
                      stride=1,
                      padding=0,
                      dilation=1),
            
        )



    def forward(self,z):
        x = self.linear0(z)
#         print(x.shape)
        x = x.view(x.shape[0],x.shape[1],1,1)
#         print(x.shape)
        x = self.trans_conv0(x)
#         print(x.shape)
        x = self.trans_conv1(x)
#         print(x.shape)
        x = self.trans_conv2(x)
#         print(x.shape)
        x = self.trans_conv3(x)
#         print(x.shape)
        x = self.conv0(x)
#         print(x.shape)
        return x

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic,self).__init__()

        self.conv0 = nn.Sequential(
            nn.Conv2d(11,32,5,2,2),
            nn.LayerNorm([32,14,14]),
            nn.ReLU()
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(32,512,5,2,2),
            nn.LayerNorm([512,7,7]),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(512,512,5,2,2),
            nn.LayerNorm([512,4,4]),
            nn.ReLU()
        )
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.linear0 = nn.Linear(512,1)
        
    def forward(self,img):
        x = self.conv0(img)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.avg_pool(x).flatten(1)
        out = self.linear0(x)
        return out

In [None]:
def generate_img(args):
    args['gen']['model'].eval()
    with torch.no_grad():
        img = args['gen']['model'](args['data']['constant_z']).to('cpu')
    args['gen']['model'].train()
    return img

def wloss(critic_output,critic_output_gen):
    return torch.mean(critic_output_gen) - torch.mean(critic_output)

def reg(args,critic,critic_input,critic_input_gen):
    
#     imgs_shape = (28,28,1)
    alpha = torch.rand(args['data']['batch_size'],1,1,1).to(args['device'])
    interpolated = alpha*critic_input.data + (1-alpha)*critic_input_gen.data
    interpolated.requires_grad = True
    pred_interpolated = args['critic']['model'](interpolated)
    
    gradient = torch_grad(outputs=pred_interpolated,
                          inputs=interpolated,
                          grad_outputs=torch.ones(args['data']['batch_size'],1).to(args['device']),
                          create_graph=True,
                          retain_graph=True)[0]
    
    gradient_norm = torch.sqrt(torch.sum(gradient**2,dim=(2,3))+1e-12)
    reg_value = torch.mean((gradient_norm-1)**2)
    return args['critic']['lamda']*reg_value

In [None]:
def train_critic(args,imgs,noise,gen_lbls,critic_lbls):
    
    args['critic']['model'].train()
    args['gen']['model'].eval()
    
    gen_input = torch.cat((noise,gen_lbls),dim=1)
    with torch.no_grad():
        gen_output = args['gen']['model'](gen_input)
    
    critic_input_gen = torch.cat((gen_output,critic_lbls),dim=1)
    critic_input = torch.cat((imgs,critic_lbls),dim=1)
    
    critic_output_gen = args['critic']['model'](critic_input_gen)
    critic_output = args['critic']['model'](critic_input)
    
    wloss_value = wloss(critic_output,critic_output_gen)
    reg_value = reg(args,critic_input,critic_input_gen)
    loss_value = wloss_value + reg_value
    loss_value.backward()
    
    args['critic']['optim']['algorithm'].step()
    args['critic']['optim']['algorithm'].zero_grad()


def train_gen(args,noise,gen_lbls,critic_lbls):
    
    args['critic']['model'].eval()
    args['gen']['model'].train()
    
    gen_input = torch.cat((noise,gen_lbls),dim=1)
    gen_output = args['gen']['model'](gen_input)
    
    critic_input_gen = torch.cat((gen_output,critic_lbls),dim=1)
    critic_output_gen = args['critic']['model'](critic_input_gen)
    
    loss_value = -1*torch.mean(critic_output_gen)
    loss_value.backward()
    
    args['gen']['optim']['algorithm'].step()
    args['gen']['optim']['algorithm'].zero_grad()
    args['critic']['model'].zero_grad()

In [None]:
def train_epoch(args):
    
    batch_size = args['data']['batch_size']
    z_features = args['data']['z_features']
    ncritic = args['critic']['ncritic']
    
    num_iter = 0
    for imgs,(gen_lbls,critic_lbls) in tqdm(args['data']['loader']):
        num_iter+=1
        
        imgs = imgs.to(args['device'])
        gen_lbls = gen_lbls.to(args['device'])
        critic_lbls = critic_lbls.to(args['device'])
        noise = torch.randn(batch_size,z_features).to(args['device'])
        
        train_critic(args,imgs,noise,gen_lbls,critic_lbls)
        
        if num_iter%ncritic == 0:
            train_gen(args,noise,gen_lbls,critic_lbls)

In [None]:
def train(args):
    for epoch in range(*args['epochs']):
        
        if (epoch+1) % 10 == 0:
            print('saving generator')
            torch.save(args['gen']['model'].state_dict(),args['gen']['dir'])
            print('saving critic')
            torch.save(args['critic']['model'].state_dict(),args['critic']['dir'])
            
            print('saving generated images')
            img = generate_img(args)
            img_name = f'image_samples/generated_img_sample_epoch_{epoch+1}.pt'
            torch.save(img,img_name)
        
        print(f'Epoch: {epoch}')
        train_epoch(args)
        args['gen']['schedular'].step()
        args['critic']['schedular'].step()

In [None]:
args = {
    'critic':{
        'model': None,
        'ncritic':5,
        'optim':{
            'algorithm':None,
            'lr':0.0001,
            'betas':(0.5,0.99),
        },
        'schedular':None,
        'dir':'Critic.pth',
        'load':False,
        'lamda':10
    },
    'gen':{
        'model': None,
        'optim':{
            'algorithm':None,
            'lr':0.0001,
            'betas':(0.5,0.99)
        },
        'schedular':None,
        'dir':'Conditional_Convolutional_Generator.pth',
        'load':False
    },
    'data':{
        'loader': None,
        'batch_size':64,
        'img_shape':[1,28,28],
        'z_features':54,
        'constant_z':None,
        },
    'device':'cuda',
    'epochs':(0,3)
}

constant_z = torch.randn(10,64)
constant_z[:,-10:] = 0
for i in range(0,10):
    constant_z[i][i-10] = 1
args['data']['constant_z'] = constant_z.to(args['device'])

args['gen']['model'] = Generator().to(args['device'])
args['critic']['model'] = Critic().to(args['device'])

args['gen']['optim']['algorithm'] = optim.Adam(args['gen']['model'].parameters(),
                        lr=args['gen']['optim']['lr'],
                        betas=args['gen']['optim']['betas'])
args['critic']['optim']['algorithm'] = optim.Adam(args['critic']['model'].parameters(),
                        lr=args['critic']['optim']['lr'],
                        betas=args['critic']['optim']['betas'])

args['gen']['schedular'] = optim.lr_scheduler.ExponentialLR(args['gen']['optim']['algorithm'],gamma=0.9747)
args['critic']['schedular'] = optim.lr_scheduler.ExponentialLR(args['critic']['optim']['algorithm'],gamma=0.9747)

args['data']['loader'] = DataLoader(dataset=concatenated_dataset,
                        batch_size=args['data']['batch_size'],
                        shuffle=True,drop_last=True,
                        pin_memory=True,num_workers=2)
iter_loader = iter(args['data']['loader'])

In [None]:
train(args)