In [1]:
import torch 
import torchvision
import torch.nn as nn # All nn modules
import torch.optim as optim # Optimisation algo like sdg, adam
import torchvision.datasets as datasets #has standard datasets that we can call upon
import torchvision.transforms as transforms # for data augmantation
from torch.utils.data import DataLoader # gives easier data managment and options like minibatches
from torch.utils.tensorboard import SummaryWriter #for printing results on tenserboard
#from model_units import Discriminator,Generator

In [2]:
class Discriminator(nn.Module):
    def __init__(self,channels_img,features_d):
        super(Discriminator,self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels_img,features_d,kernel_size=4,stride=2,padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d,features_d*2,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_d*2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d*2,features_d*4,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_d*4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d*4,features_d*8,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_d*8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
            nn.Sigmoid())
    def forward(self,X):
        return self.net(X)

In [3]:
class Generator(nn.Module):
    def __init__(self,channels_noise,channels_img,features_g):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # N* channels_noise *1 *1
            nn.ConvTranspose2d(channels_noise,features_g*16,kernel_size=4,stride=1,padding=0),
            nn.BatchNorm2d(features_g*16),
            nn.ReLU(),
            # N* features_g*16 *4 *4
            nn.ConvTranspose2d(features_g*16,features_g*8,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_g*8),
            nn.ReLU(),    
            # N * features*8 * 8 * 8
            nn.ConvTranspose2d(features_g*8,features_g*4,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_g*4),
            nn.ReLU(),
            # N * features_g*4 * 16 * 16
            nn.ConvTranspose2d(features_g*4,features_g*2,kernel_size=4,stride=2,padding=1),
            nn.BatchNorm2d(features_g*2),
            nn.ReLU(),
            # N * features_g*2 * 32 * 32
            nn.ConvTranspose2d(features_g*2,channels_img,kernel_size=4,stride=2,padding=1),
            nn.Tanh()
            # N * channels_img * 64 * 64
            )
    def forward(self,X):
        return self.net(X)

In [17]:
# Hyperparamtres
lr = 0.0002
batch_size = 64
img_size = 64 # in mnist we have 28*28 so we will need to resize them to 64*64
channels_img = 1
channels_noise = 256
num_epochs = 25
features_d = 16
features_g = 16

In [5]:
my_transforms = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,)),])

In [6]:
dataset = datasets.MNIST(root="dataset/",train=True,transform=my_transforms,download=True)

In [7]:
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
net_D = Discriminator(channels_img,features_d).to(device)
net_G = Generator(channels_noise,channels_img,features_g).to(device)

In [10]:
#seting up optimizers for G and D
optimizer_D = optim.Adam(net_D.parameters(),lr=lr, betas=(0.5,0.999))
optimizer_G = optim.Adam(net_G.parameters(),lr=lr,betas=(0.5,0.999))

In [11]:
net_D.train()
net_G.train()

Generator(
  (net): Sequential(
    (0): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)

In [12]:
# selecting Loss function
criterion = nn.BCELoss()

In [13]:
real_label = 1
fake_label = 0
fixed_noise = torch.randn(64,channels_noise,1,1).to(device)

In [15]:
writer_real = SummaryWriter(f'runs/GAN_MNIST/test_real')
writer_fake = SummaryWriter(f'runs/GAN_MNIST/test_fake')    
step = 0

In [18]:
print("Starting training.....")
for epoch in range(num_epochs):
    for batch_idx,(data,targets) in enumerate(dataloader):
        data = data.to(device)
        batch_size = data.shape[0]
        #Train Discriminator: max log(D(x)) + log(1-D(G(z)))
        net_D.zero_grad()
        label = (torch.ones(batch_size)*0.9).to(device)
        output = net_D(data).reshape(-1)
        loss_D_real = criterion(output,label)
        D_x = output.mean().item()
        
        noise = torch.randn(batch_size,channels_noise,1,1).to(device)
        fake = net_G(noise)
        label = (torch.ones(batch_size)*0.1).to(device)
        output = net_D(fake.detach()).reshape(-1)
        loss_D_fake = criterion(output,label)
        
        loss_D = loss_D_real+loss_D_fake
        loss_D.backward()
        optimizer_D.step()
        # Train Generator: max log(D(G(z)))
        net_G.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = net_D(fake).reshape(-1)
        loss_G = criterion(output,label)
        loss_G.backward()
        optimizer_G.step()
        
        if batch_idx % 100==0:
            print(f'Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)}\
            Loos D:{loss_D:.4f},Loss G:{loss_G:.4f}  D(x):{D_x:.4f}')            
            with torch.no_grad():
                fake = net_G(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(data[:32],normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32],normalize=True)
                writer_real.add_image("MNIST Real Images",img_grid_real,global_step=step)
                writer_fake.add_image("MNIST Fake Images",img_grid_fake,global_step=step)
                step += 1

Starting training.....
Epoch [0/25] Batch 0/938            Loos D:0.7985,Loss G:1.3535  D(x):0.7519
Epoch [0/25] Batch 100/938            Loos D:0.8802,Loss G:0.9905  D(x):0.6344
Epoch [0/25] Batch 200/938            Loos D:0.8001,Loss G:1.4022  D(x):0.7193
Epoch [0/25] Batch 300/938            Loos D:0.8257,Loss G:1.6954  D(x):0.7679
Epoch [0/25] Batch 400/938            Loos D:0.8192,Loss G:1.3863  D(x):0.6820
Epoch [0/25] Batch 500/938            Loos D:0.8753,Loss G:2.3062  D(x):0.8991
Epoch [0/25] Batch 600/938            Loos D:0.8166,Loss G:1.5446  D(x):0.6847
Epoch [0/25] Batch 700/938            Loos D:0.7613,Loss G:1.9751  D(x):0.8402
Epoch [0/25] Batch 800/938            Loos D:0.7820,Loss G:1.3144  D(x):0.7226
Epoch [0/25] Batch 900/938            Loos D:0.8312,Loss G:1.2998  D(x):0.6769
Epoch [1/25] Batch 0/938            Loos D:0.8953,Loss G:1.4509  D(x):0.7318
Epoch [1/25] Batch 100/938            Loos D:0.8301,Loss G:1.4681  D(x):0.7917
Epoch [1/25] Batch 200/938       

Epoch [10/25] Batch 400/938            Loos D:0.6798,Loss G:2.3358  D(x):0.8909
Epoch [10/25] Batch 500/938            Loos D:0.7419,Loss G:2.5387  D(x):0.9250
Epoch [10/25] Batch 600/938            Loos D:0.6994,Loss G:2.7829  D(x):0.9165
Epoch [10/25] Batch 700/938            Loos D:0.6775,Loss G:2.6414  D(x):0.8924
Epoch [10/25] Batch 800/938            Loos D:0.6805,Loss G:2.5121  D(x):0.9087
Epoch [10/25] Batch 900/938            Loos D:0.7127,Loss G:2.3761  D(x):0.8188
Epoch [11/25] Batch 0/938            Loos D:0.7169,Loss G:1.4651  D(x):0.7835
Epoch [11/25] Batch 100/938            Loos D:0.7467,Loss G:2.0088  D(x):0.8151
Epoch [11/25] Batch 200/938            Loos D:0.6959,Loss G:1.9683  D(x):0.8315
Epoch [11/25] Batch 300/938            Loos D:0.6775,Loss G:2.1752  D(x):0.8868
Epoch [11/25] Batch 400/938            Loos D:0.6992,Loss G:3.0149  D(x):0.9146
Epoch [11/25] Batch 500/938            Loos D:0.9625,Loss G:1.1759  D(x):0.5802
Epoch [11/25] Batch 600/938            Loo

Epoch [20/25] Batch 700/938            Loos D:0.6791,Loss G:1.9027  D(x):0.8408
Epoch [20/25] Batch 800/938            Loos D:0.6946,Loss G:1.9991  D(x):0.8161
Epoch [20/25] Batch 900/938            Loos D:0.6677,Loss G:2.5583  D(x):0.8616
Epoch [21/25] Batch 0/938            Loos D:0.6641,Loss G:2.6229  D(x):0.9124
Epoch [21/25] Batch 100/938            Loos D:0.6644,Loss G:2.3034  D(x):0.8706
Epoch [21/25] Batch 200/938            Loos D:0.6656,Loss G:2.2393  D(x):0.8665
Epoch [21/25] Batch 300/938            Loos D:0.6652,Loss G:2.5590  D(x):0.9025
Epoch [21/25] Batch 400/938            Loos D:0.6900,Loss G:2.9620  D(x):0.9403
Epoch [21/25] Batch 500/938            Loos D:0.6798,Loss G:2.7405  D(x):0.9001
Epoch [21/25] Batch 600/938            Loos D:0.6819,Loss G:2.5057  D(x):0.9212
Epoch [21/25] Batch 700/938            Loos D:0.7143,Loss G:1.6027  D(x):0.7753
Epoch [21/25] Batch 800/938            Loos D:0.8241,Loss G:1.7479  D(x):0.7083
Epoch [21/25] Batch 900/938            Loo