Run on google colab:     <a href="https://colab.research.google.com/github/elyas1376/GAN/blob/main/GAN.ipynb" target="_blank" rel="noreferrer noopener"><img src="https://camo.githubusercontent.com/84f0493939e0c4de4e6dbe113251b4bfb5353e57134ffd9fcab6b8714514d4d1/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667" alt="Open In Colab" data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" style="max-width: 100%;"></a>

In [60]:
#@title Imports
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os
from torchvision import transforms as trfms
from torchvision.datasets import FashionMNIST
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import sys
import os
import time

In [43]:
#@title Creating Transformer to transform train img data
transformers  = trfms.Compose([
    trfms.ToTensor(),
    trfms.Normalize(0.5,0.5)]
)

In [41]:
#@title Downloading dataset and createing dataloader
dataset =  FashionMNIST("./",train = True, transform = transformers,download= True)

batch_size = 128
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=2)

In [5]:
#@title random tensor to IMG
t = torch.randn((104,3,28,28))
save_image(t,"test.png")

In [39]:
#@title Discriminator
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1),
        nn.LeakyReLU(0.2)
    )

    self.fc = nn.Sequential(
        nn.Linear(in_features=(32*24*24), out_features=256),
        nn.LeakyReLU(0.2),
        nn.Linear(in_features=256, out_features=32),
        nn.LeakyReLU(0.2),
        nn.Linear(in_features=32, out_features=1),
        nn.LeakyReLU(0.2),
    )


  def forward(self,x):
    x = self.conv(x)
    x = x.view(-1,32*24*24)
    x = self.fc(x)
    return x 





d = Discriminator()
t = torch.randn((20,1,28,28))
img = d.forward(t)
save_image(img,"dis_test.png")

In [40]:
#@title Generator
class Generator(nn.Module):
  def __init__(self,latent_dim):
    super(Generator,self).__init__()


    self.latent_dim = latent_dim


    self.fully_connected = nn.Sequential(
        nn.Linear(in_features = self.latent_dim , out_features=1280),
        nn.LeakyReLU(0.2),
        nn.BatchNorm1d(num_features=1280, momentum=0.7),

        nn.Linear(in_features = 1280 , out_features=2560),
        nn.LeakyReLU(0.2),
        nn.BatchNorm1d(num_features=2560, momentum=0.7),
        
        
        nn.Linear(in_features = 2560 , out_features=5760),
        nn.LeakyReLU(0.2),
        nn.BatchNorm1d(num_features=5760, momentum=0.7),# 10 * 24 *24
    )


    self.conv_transpose = nn.Sequential(
        nn.ConvTranspose2d(in_channels = 10, out_channels=3,kernel_size=3,stride = 1),#  3 *26*26
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(in_channels = 3, out_channels=1,kernel_size=3,stride = 1),
        nn.Tanh())
    

  def forward(self,x):
    x= self.fully_connected(x)
    x = x.view(-1,10,24,24)
    x = self.conv_transpose(x)
    return x


test = torch.randn(5,64)
g = Generator(64)
img = g.forward(test)
save_image(img,f"gen_test.png")

In [47]:
#@title Parameters

device = ('cuda' if torch.cuda.is_available() else 'cpu')



latent_dim = 64
batch_size = 128



### creating models
D = Discriminator().to(device)
G = Generator(latent_dim).to(device)

criterion = nn.BCEWithLogitsLoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
#@title Train Loop
ones_ = torch.ones(batch_size ,1).to(device)
zeros_= torch.zeros(batch_size ,1).to(device)

d_losses = []
g_losses = []

N_EPOCHS = 300

for epoch in range(N_EPOCHS):
  start = dt.now()
  for batch_i , (inputs,_) in enumerate(dataloader):
    n = inputs.size(0)
    inputs = inputs.to(device)

    ones =  ones_[:n] #True  (real) labels
    zeros= zeros_[:n] #False (fake) labels
    
    ##################################
    #######Train Discriminator########
    ##################################
    
    #real images
    real_outputs = D.forward(inputs)
    d_loss_real  = criterion(real_outputs,ones)



    #fake images
    noise = torch.randn(n,latent_dim).to(device)
    fake_images = G.forward(noise)

    fake_outputs=D.forward(fake_images)
    d_loss_fake = criterion(fake_outputs,zeros)

    #Gradient Decient Step
    d_loss = 0.5 * (d_loss_fake + d_loss_real)

    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

    d_loss.backward()

    d_optimizer.step()

    ##################################
    ####### Train Generator ###########
    ##################################
    
    for i in range(2):
      #fake images 
      noise = torch.randn(n,latent_dim).to(device)    
      fake_images = G.forward(noise)
      fake_outputs = D.forward(fake_images)

      g_loss = criterion(fake_outputs,ones)#tricking the discriminator by giving true labels  along with fake outputs

      #Zero grading the optimizers
      d_optimizer.zero_grad()
      g_optimizer.zero_grad()

        #Grading the g_loss
      g_loss.backward()

      #Stepping forward toward the train with the defined learning-rate
      g_optimizer.step()

    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

  end =  dt.now() - start
  info  = "Epoch: {:3d} , d_loss:{:2.4f}, g_loss:{:2.4f}".format(epoch,d_loss.item(),g_loss.item())
  sys.stdout.write(
      "\r[Epoch: %d/%d] ,[Batch %d/%d] , [d_loss: %2.5f], [g_loss:%2.5f]"
      %(
          epoch,
          N_EPOCHS,
          batch_i+1,
          len(dataloader),
          d_loss.item(),
          g_loss.item()
      )
  )
  #print(info)

  fake_images = (fake_images.reshape(-1,1,28,28) + 1)/2

  os.makedirs('Generated Images',exist_ok = True)
  save_image(fake_images,'Generated Images/{:03d}.png'.format(epoch))

[Epoch: 13/300] ,[Batch 469/469] , [d_loss: 0.32488], [g_loss:1.36108] [ETA : 0:00:23.441186/1:57:12.355800]

In [None]:
#@title Plotting g and d losses
plt.figure(figsize=(16,8))
plt.plot(d_losses,label = 'd_losses')
plt.plot(g_losses,label = 'g_losses')
plt.legend()