<a href="https://colab.research.google.com/github/Olivia-Feldman/NUGAN-DISTGAN/blob/Olivia/DIST_GAN2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import time
import matplotlib.pyplot as plt



In [38]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [40]:

def visualize_results(gan,recon_images):

      samples = (recon_images + 1) / 2
      samples = samples.clamp(0,1)
      samples = samples.reshape(recon_images.size(0),1,28,28)
      samples = samples.cpu().data.numpy()
      plt.figure(figsize=((1,5)))
      fig,ax = plt.subplots(1,5)
      for i in range(5):
        s=ax[i].imshow(np.squeeze(samples[i,]))
        s=ax[i].get_xaxis().set_visible(False)
        s=ax[i].get_yaxis().set_visible(False)
      s=plt.show()

In [41]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [49]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder,self).__init__()
  
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), 
            nn.Linear(64, 12), 
            nn.ReLU(True),
            nn.Linear(12, 2))

    def forward(self, x):
       
        x = self.encoder(x)
        #x = x.view(x.size(0),-1)
     
        
        return x

In [50]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, input_size=28, base_size=128):
        super(Generator, self).__init__()  
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = 28
        self.base_size = base_size

        self.fc1 = nn.Linear(self.input_dim, self.base_size)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, self.input_size* self.input_size )
        initialize_weights(self)                    
    
    # forward method
    def forward(self, x): 
       # x = x.view(-1, self.input_size * self.input_size)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = x.view(x.size(0),-1)
        return torch.tanh(self.fc4(x))



class Discriminator(nn.Module):
  def __init__(self, input_dim, output_dim, input_size=28, base_size=128):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.base_size = base_size



        self.fc1 = nn.Linear( self.input_size* self.input_size,self.base_size)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, output_dim)

        initialize_weights(self)

     # forward method
  def forward(self, x):
       # x = x.view(-1, self.input_size * self.input_size)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
       # x = x.view(-1, self.input_size * self.input_size)
      
        return torch.sigmoid(self.fc4(x))

# DIST-GAN Network 

---

1. Autoencoder optimization $\min_{A}  \Big({\rm BCE}\left((x_{\rm real}),\underbrace{X_r{\rm recon}}_{1}\right) + {\rm BCE}\left(G(\underbrace{x_{\rm fake}}_{G[z]}),\underbrace{X_f{\rm fake}}_0\right)\Big)
$


2. Generator optimization $\min_{G}  ~{\rm BCE}\left(D(\underbrace{X_f{\rm fake}}_{G[z]}),\underbrace{X_r{\rm real}}_1\right)$

<font color=red>The generator tries to assign $D(x_{\rm fake}) \rightarrow 1$ </font>.




In [94]:

from torch.autograd import grad as torch_grad
class GAN():
    def __init__(self,params):
        # parameters
        self.epoch = params['max_epochs']
        self.sample_num = 100
        self.batch_size = params['base_size']
        self.input_size = 28
        self.z_dim = params['z_dim']
        self.base_size = params['base_size']

     
        self.lamda_p = 1.0     # regularization term of gradient penalty
        self.lamda_r = 1.0    # autoencoders regularization term  
       
        
        
        # load dataset
        self.data_loader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=self.batch_size, 
                                               shuffle=True)
        data = self.data_loader.__iter__().__next__()[0]

        #print(data.shape[0])

        # initialization of the generator and discriminator and autoencoder 
        self.A = autoencoder().cuda()
        self.G = Generator(input_dim=self.z_dim, output_dim=data.shape[0], input_size=self.input_size,base_size=self.base_size).cuda()
        self.D = Discriminator(input_dim=data.shape[0], output_dim=1, input_size=self.input_size,base_size=self.base_size).cuda()
      

        self.A_optimizer =  optim.Adam(self.A.parameters(), lr=params['lr_g'], betas=(params['beta1'], params['beta2']),eps=1e-09)
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=params['lr_g'], betas=(params['beta1'], params['beta2']),eps=1e-09)
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=params['lr_g'], betas=(params['beta1'], params['beta2']),eps=1e-09)
        
        # initialization of the loss function Hinge Embedding loss ( from paper )
       
        self.BCE = nn.BCEWithLogitsLoss().cuda()
     
        
        # Gettng a batch of noise to generate the fake data
        self.sample_z_ = torch.rand((self.batch_size, self.z_dim)).cuda()
        
# Fucntion to train the GAN, where you alternate between the training of the genenator and discriminator
#--------------------------------------------------------------------------------------------------------

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []

        # Setting up the labels for real and fake images
        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,1).fill_(0.9).type(torch.float32).cuda(), torch.zeros(self.batch_size, 1).cuda()
        
        print('training start!!')

        for epoch in range(self.epoch):
            epoch_start_time = time.time()


            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break
                z_ = torch.rand((self.batch_size, self.z_dim))
                x_ = x_.view(x_.size(0), -1)
                z_ = z_.view(z_.size(0), -1)
                x_, z_ = x_.cuda() ,z_.cuda()
        
        

                #-------------Train Autoencoder & Generator to minimize reconstruction loss ------------# 
                # auto-encoders and its regularization

                torch.autograd.set_detect_anomaly(True)
                # reconstructed image
                recon = self.A(x_)
                X_r = self.G(recon.detach()) # reconstructed iamges from generator 
                print("\tX_r:", X_r.shape)
                X_f = self.G(z_.detach()) # fake images from generator 
                print("\tX_f:", X_f.shape)
                X_r, X_f = X_r, X_f

         

                ## auto-encoders and Regularization from Dist-Gan paper 

                R_loss = torch.mean(self.BCE(x_,X_r))
                #print("\tR_loss:", R_loss.shape)
                f = torch.mean(X_r - X_f) #distance between reconstructed imgs and reconstructed fake imgs 
                #print("\tf:", f.shape)
                g = torch.mean(recon - z_) * 0.15625  # distance between reconstruced imgs and noise
                #print("\tg:", g.shape)
                R_reg = torch.square(f - g)

               
                
                # train Autoencoder with R_loss 
                R_loss = R_loss + self.lamda_r * R_reg
                #print("\tR_loss", R_loss.shape)
                self.A_optimizer.zero_grad()
                R_loss.backward(retain_graph=True)
                self.A_optimizer.step()


                #train Generator with G_Loss 
                self.G_optimizer.zero_grad()
                #X_r = self.G(recon)
                #print("\tX_r", X_r.shape)
               # X_f = self.G(z_)
               # print("\tX_f", X_f.shape)
                G_loss = torch.abs(torch.mean(X_r - torch.mean(X_f)))
                #G_loss = X_r + X_f
               # print("\tG_loss", G_loss.shape)
                G_loss.backward()
                self.G_optimizer.step()
               #-----------Train Discriminator --------#######

                #Train Discriminator with Discriminator Score 
                D_real = self.D(x_.detach())
                D_recon = self.D(X_r.detach())
                D_fake = self.D(X_f.detach())


                # interpolation     
                epsilon = torch.rand(x_.size())
                epsilon = epsilon.cuda()
                interpolation = x_ * epsilon + (1-epsilon) * X_f
                d_inter = self.D(interpolation)
                # Gradien Penalty 
                gradients = torch_grad(d_inter, interpolation, grad_outputs=torch.ones(d_inter.size()).cuda())[0]
              
                slopes = torch.square(torch.mean(torch.square(gradients)))
                gp = torch.mean((slopes-1)**2)
            
 
            
                # Discriminator loss on data
                d_loss_real = torch.mean(self.BCE(self.y_real_, D_real))
                d_loss_recon = torch.mean(self.BCE(D_real, D_recon))
                d_loss_fake = torch.mean(self.BCE(self.y_fake_,D_fake))

                # loss w.r.t gradient penalty and reconstruction term 
                D_loss = (d_loss_real + d_loss_recon)*0.5 + d_loss_fake
                D_loss = D_loss + self.lamda_p * gp

          
              
                #train Discriminator 
                self.D_optimizer.zero_grad()
                D_loss.backward(retain_graph=True)
                self.D_optimizer.step()



                # 5. Set the current loss in self.train_hist['D_loss]
                self.train_hist['D_loss'].append(D_loss.item())
                
      
     
                

                #---------------Train Generator to minimize discriminator score --------#
                D_real = self.D(x_.detach())
                D_recon = self.D(X_r.detach())
                D_fake = self.D(X_f.detach())
             
                self.G_optimizer.zero_grad()
                G_loss = torch.abs(torch.mean(D_real) - torch.mean(D_fake))
                G_loss.backward(retain_graph=True)
                self.G_optimizer.step()
                # 5. Set the current loss in self.train_hist['G_loss]    
                self.train_hist['G_loss'].append(G_loss.item())

                # Print iterations and losses

                if ((iter + 1) % 50) == 0:
                  print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
    
                  
            # Visualize results
            with torch.no_grad():
                visualize_results(self,recon_images=X_r)
        #plt.figure(figsize=(16,8))
       # s=plt.plot(gan.train_hist['D_loss'],c='b')
        #s=plt.plot(gan.train_hist['G_loss'],c='r')
        #s = plt.ylim((0,1))
        #s = plt.grid()
       # s=plt.legend(('Discriminator loss','Generator loss'))

        print("Training finished!")

In [95]:
params = {'beta1': 0.05, 'beta2': 0.999,'lr_g':0.0002,'lr_d':0.0002,'max_epochs':30}
params['z_dim'] = 2
params['base_size'] = 128

gan = GAN(params)

gan.train()

training start!!
1
	X_r: torch.Size([128, 784])
	X_f: torch.Size([128, 784])
2
	R_loss: torch.Size([])
	f: torch.Size([])
	g: torch.Size([])
3
	R_loss torch.Size([])
4
	X_r torch.Size([128, 784])
	X_f torch.Size([128, 784])
	G_loss torch.Size([])
5
6
7
8
9
10
11
1
	X_r: torch.Size([128, 784])
	X_f: torch.Size([128, 784])
2
	R_loss: torch.Size([])
	f: torch.Size([])
	g: torch.Size([])
3
	R_loss torch.Size([])
4
	X_r torch.Size([128, 784])
	X_f torch.Size([128, 784])
	G_loss torch.Size([])
5
6
7
8
9
10
11
1
	X_r: torch.Size([128, 784])
	X_f: torch.Size([128, 784])
2
	R_loss: torch.Size([])
	f: torch.Size([])
	g: torch.Size([])
3
	R_loss torch.Size([])
4
	X_r torch.Size([128, 784])
	X_f torch.Size([128, 784])
	G_loss torch.Size([])
5
6
7
8
9
10
11
1
	X_r: torch.Size([128, 784])
	X_f: torch.Size([128, 784])
2
	R_loss: torch.Size([])
	f: torch.Size([])
	g: torch.Size([])
3
	R_loss torch.Size([])
4
	X_r torch.Size([128, 784])
	X_f torch.Size([128, 784])
	G_loss torch.Size([])
5
6
7
8
9
10
11

KeyboardInterrupt: ignored