In [None]:
import numpy as np
import torch.nn as nn
import torch
import time
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
from collections import OrderedDict
import os
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import save_image
%matplotlib inline


############ Transform Functions ############
to_img= T.Compose([T.ToPILImage()])
to_tensor = T.Compose([T.ToTensor()])
load_norm = T.Compose([T.ToTensor(),T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



############ Hyper-parameters ############
class Parser():
    #hyperparameters
    def __init__(self):
        #image setting
        self.z_dim = 62
        self.input_size = 28
        self.cat_dim = 10
        self.cont_dim = 2
        self.n_epoch = 30
        self.batch_size = 64
        self.lrD = 0.0002
        self.lrG = 0.0002
        self.b1 = 0.9
        self.b2 = 0.999
        self.show_freq = 50
        self.model_path = './InfoGAN/Model/'
        self.img_path = './InfoGAN/Image/' 
        self.img_save_freq = 500
        self.model_save_freq = 5000
        self.show_freq = 100

args = Parser()



############ Create Directories for saving imgs and models ############

if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)
if not os.path.exists(args.img_path):
    os.makedirs(args.img_path)
    
    
############ Check device ############    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


############ Some useful layers ############
    
class To_sq_Image(nn.Module):
    def __init__(self):
        super(To_sq_Image,self).__init__()
    def forward(self,x):
        return x.view(args.batch_size,-1,args.input_size//4, args.input_size//4)
            
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self,x):
        return x.view(x.size(0),-1)
    
    
############ Generator ############    

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.model= nn.Sequential(OrderedDict([   # 64 is input  DCNN? 
            ('Linear_1', nn.Linear(args.z_dim+args.cat_dim+2, 1024)),
            ('Bn_1',nn.BatchNorm1d(1024)),
            ('L_relu_1',nn.LeakyReLU(0.2,inplace=True)),
            ('Linear_2', nn.Linear(1024, 128*7*7)),
            ('bn_2',nn.BatchNorm1d(128*7*7)),
            ('L_relu_2',nn.LeakyReLU(0.2,inplace=True)),
            ('To_Image', To_sq_Image()),
            ('ConvT_1', nn.ConvTranspose2d(128,64,4,2,1)),
            ('Bn_4', nn.BatchNorm2d(64)),
            ('L_relu_4',nn.LeakyReLU(0.2,inplace=True)),
            ('convT_2', nn.ConvTranspose2d(64,1,4,2,1)),
            ('Tanh', nn.Tanh()),
        ]))
    
    def forward(self, z, y_cont, y_cat):

        y_cat_onehot = torch.zeros(args.batch_size,args.cat_dim).to(device)  
        y_cat_onehot.scatter_(1,y_cat.unsqueeze(1),1)
        
        input = torch.cat([z,y_cont,y_cat_onehot],1)
        
        out = self.model(input)
        
        return out
    
############ Discriminator ############
        
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('Conv_1',nn.Conv2d(1,64,4,2,1)),
            ('L_relu_1', nn.LeakyReLU(0.2,inplace=True)),
            ('Conv_2',nn.Conv2d(64,128,4,2,1)),
            ('Bn_1', nn.BatchNorm2d(128)),
            ('L_relu_2', nn.LeakyReLU(0.2,inplace=True)),
            ('Flatten_1',Flatten()),
            ('Linear_1',nn.Linear(6272,1024)),
            ('Bn_2', nn.BatchNorm1d(1024)),
            ('L_relu_3',nn.LeakyReLU(0.2)),
        ]))
        
        ############ For classifying the real imgs ############
        
        self.classify = nn.Sequential(OrderedDict([
            ('Linear_1',nn.Linear(1024,512)),
            ('Bn_1',nn.BatchNorm1d(512)),
            ('L_rela_1',nn.LeakyReLU(0.2,inplace=True)),
            ('Linear_2',nn.Linear(512,1))
        ]))

    def forward(self,input):
        
        shared_tensor = self.model(input)
        
        target = self.classify(shared_tensor)
      
        return shared_tensor, target
        

class InfoGAN(nn.Module):
    
    def __init__(self,):
        super(InfoGAN, self).__init__()
        
        
        ############ Create Generator and Discriminator ############
        self.G = Generator().to(device)
        self.D = Discriminator().to(device)
        
        ############ Prepare Labels for Loss functions ############
        self.y_real_, self.y_fake_ = torch.ones(args.batch_size).to(device), torch.zeros(args.batch_size).to(device)
        
        
        ############ Define Q Network ############
        
        self.Q = nn.Sequential(OrderedDict([
                                ('Linear_1',nn.Linear(1024,512)),
                                ('Bn_1',nn.BatchNorm1d(512)),
                                ('L_relu_1',nn.LeakyReLU(0.2,inplace=True)),
                                ('Linear_2',nn.Linear(512,12)),
                               ])).to(device)
                                
        ############ Define Three Loss Functions ############
        
        ## BCE for Classifying the real imgs
        self.BCE_loss = nn.BCEWithLogitsLoss().to(device)
        ## CE for Categorial C
        self.CE_loss = nn.CrossEntropyLoss().to(device)
        ## MSE for Continouse C
        self.MSE_loss = nn.MSELoss().to(device)
            
        ############ Define Optimizers ############
        g_params = list(self.G.parameters())+ list(self.Q.parameters())
        self.G_optim = optim.Adam(g_params,lr = args.lrG, betas = (args.b1, args.b2))
        
        d_params = list(self.D.parameters())
        self.D_optim = optim.Adam(d_params,lr = args.lrD, betas = (args.b1, args.b2))
        
        ############ Create Noise Distribution ############
        self.cat_dis = torch.distributions.Categorical(torch.tensor([0.1]*args.cat_dim))
        
        ############ Create Loss hist ############
        
        self.G_loss_hist = []
        self.D_loss_hist = []
        self.info_loss_hist = []
        
    def forward(self, img):
        
        ###########  Train D ########### 
        
        ########### For checking fake or true ###########
        
        y = self.cat_dis.sample((args.batch_size,)).to(device)

        img = img.to(device)
        
        self.D_optim.zero_grad()
        
        _,D_real = self.D(img) # has been through the sigmoid function
        D_real_loss = self.BCE_loss(D_real.squeeze(), self.y_real_)
        self.get_noise()

        self.G_img = self.G(self.z, self.y_cont, y)
        _,D_fake = self.D(self.G_img)
        D_fake_loss = self.BCE_loss(D_fake.squeeze(), self.y_fake_)

        self.D_loss = D_real_loss + D_fake_loss
        self.D_loss_hist.append(self.D_loss)
        
        self.D_loss.backward()
        self.D_optim.step()
        
        ###########  Train G ########### 
        
        self.G_optim.zero_grad()
        self.G_img = self.G(self.z, self.y_cont, y)
        shared_tensor, D_fake = self.D(self.G_img)
        self.G_loss = self.BCE_loss(D_fake.squeeze(), self.y_real_)
        self.G_loss_hist.append(self.G_loss)

        ########### Info Loss ########### 
        
        c = self.Q(shared_tensor) 
        disc_loss = self.CE_loss(c[:,args.cont_dim:], y)
        cont_loss = self.MSE_loss(c[:,:args.cont_dim], self.y_cont)
        self.info_loss = disc_loss + cont_loss
        self.info_loss_hist.append(self.info_loss)
        
        self.G_info_loss = self.G_loss + self.info_loss
        self.G_info_loss.backward()
        self.G_optim.step()
        
    def get_noise(self,):

        self.z = torch.randn(args.batch_size, args.z_dim).to(device)
        self.y_cont = torch.FloatTensor(args.batch_size, args.cont_dim).uniform_(-1,1).to(device)
            
    def image_save(self, step):
        
        ##################### Saving Continous change imgs #####################
        temp_c = torch.linspace(-1, 1, 8)
        c_con_1 = torch.zeros((args.batch_size, 2))
        for i in range(8):
            for j in range(8):
                c_con_1[i*8+j, 0] = temp_c[i]
                c_con_1[i*8+j, 1] = temp_c[j]
        
        
        c_con_1 = c_con_1.to(device)
        z_1= torch.rand((1, args.z_dim)).expand(args.batch_size, args.z_dim).to(device)
        c_cat_1 = torch.ones(args.batch_size).type(torch.LongTensor).to(device)
            
        change_cont = self.G(z_1,c_con_1,c_cat_1)   
        
        save_image(change_cont, args.img_path+'Con_Grids_S_'+str(step)+'.png',
                   nrow=8, normalize=True,range=(-1,1))
        
        ##################### Saving Categorail change imgs #####################
        
        sample_z = []
        for i in range(8):
            sample_z.append(torch.rand(1,args.z_dim).repeat(8,1))

        z_2 = torch.stack(sample_z).view(-1,args.z_dim).to(device)
        con_c = torch.zeros(args.batch_size,2).to(device)
        cat_c = torch.arange(args.cat_dim).repeat(7).type(torch.LongTensor)[:args.batch_size].to(device)
  
        change_cat = self.G(z_2,con_c,cat_c)
        save_image(change_cont[:60], args.img_path+'Cat_Grids_S_'+str(step)+'.png',
                   nrow=args.cat_dim, normalize=True,range=(-1,1))
        
        
        ##################### Saving Training imgs ##################### 
        
        training_img_path = args.img_path + "InfoGAN_Step_"+str(step)+".png"
        save_image(self.G_img[:25], training_img_path , nrow=5, normalize=True, range=(-1,1))
        print('Image saved')
        
    def model_save(self,step):
        
        path = args.model_path + 'InfoGAN_Step_' + str(step) + '.pth'
        torch.save({'InfoGAN':self.state_dict()}, path)
        print('Model saved')
    
    def load_step_dict(self, step):
        
        path = args.model_path + 'InfoGAN_Step_' + str(step) + '.pth'
        self.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)['InfoGAN'])
        
    def plot_all_loss(self,):
        
        fig, ax = plt.subplots(figsize= (20,20))
        plt.subplot(212)
        plt.plot(self.G_loss_hist,label='G_loss')
        plt.plot(self.D_loss_hist,label='D_loss')
        plt.plot(self.info_loss_hist,label='Info_loss')
        plt.ylabel('Loss',fontsize=15)
        plt.xlabel('Number of Steps',fontsize=15)
        plt.title('Loss',fontsize=30,fontweight ="bold")
        plt.legend(loc = 'upper left')
        fig.savefig("InfoGAN_Loss.png")
        
    def num_all_params(self,):
        
        return sum([param.nelement() for param in infoGAN.parameters()])



In [None]:
dataloader = DataLoader(datasets.MNIST('./data/mnist',train=True,
                    download=True,transform=load_norm),
                    batch_size=args.batch_size, shuffle=True,drop_last=True) 


infoGAN = InfoGAN().to(device)

epoch = 0
all_steps = 1

############## Strat Training ##############

while epoch < args.n_epoch:
    
    for i,(img,y) in enumerate(dataloader):

        start_t = time.time()
        infoGAN(img)
        end_t = time.time()
        
        print('| Epoch [%d] | batch [%d] | D Loss: [%.4f] | G Loss: [%.4f] | Info Loss: [%.4f] | Time: %.1fs' %\
              (epoch, i, infoGAN.D_loss.item(), infoGAN.G_loss.item(), infoGAN.info_loss.item(),
               end_t - start_t))
        
        all_steps += 1
        #args.show_freq 
        if all_steps % args.show_freq == 0:  
            plt.figure()
            plt.imshow(to_img(infoGAN.G_img[0].cpu()))
            plt.show()
            if all_steps % args.img_save_freq == 0: 
                infoGAN.image_save(all_steps)
                if all_steps % args.model_save_freq == 0:
                    infoGAN.model_save(all_steps)
                    
    epoch += 1
