In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import seaborn as sns
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import optim
# from model import Discriminator, Generator, initialize_weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator,self).__init__()
    self.channels_img = channels_img
    self.features_d = features_d
    self.disc = nn.Sequential(
        nn.Conv2d(self.channels_img,self.features_d, 4, 2, 1),
        nn.LeakyReLU(0.2),
        self._block(self.features_d, self.features_d*2, 4, 2, 1),
        self._block(self.features_d*2,self.features_d*4, 4, 2, 1),
        self._block(self.features_d*4,self.features_d*8, 4, 2, 1),
        nn.Conv2d(self.features_d*8, 1, kernel_size = 4, stride= 2, padding= 0),
        nn.Sigmoid()
    )
  
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels ,out_channels, kernel_size,stride,padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
    )

  def forward(self,x):
    return self.disc(x)


class Generator(nn.Module):
  def __init__(self,z_dim, channels_img, feature_g):
    super(Generator,self).__init__()
    self.gen = nn.Sequential(
        self._block(z_dim, feature_g*16, 4, 1, 0),
        self._block(feature_g*16, feature_g*8, 4, 2, 1),
        self._block(feature_g*8, feature_g*4, 4, 2, 1),
        self._block(feature_g*4, feature_g*2, 4, 2, 1),
        nn.ConvTranspose2d(
            feature_g * 2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        # Output: N x channels_img x 64 x 64
        nn.Tanh(),
    )
  def _block(self,in_channels,out_channel,kernel_size,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channel,kernel_size = kernel_size,stride=stride,padding=padding,bias=False),
        nn.ReLU(0.2),
    )

  def forward(self,x):
    return self.gen(x)

def inititalize_weights(model):
  for m in model.modules():
    if isinstance (m, (nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    # print(disc(x).shape)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    # print(gen(z).shape)
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
test()

In [None]:
learning_rate=2e-4
batch_size= 128
z_dim = 100
channels_img = 1
image_size = 64
features_disc = 64
features_gen = 64
num_epochs =3

transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]
        ),
    ]
)

dataset = datasets.MNIST(root = 'content/data',download=True,train=True,transform=transform)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
gen = Generator(z_dim,channels_img,features_gen).to(device)
disc = Discriminator(channels_img,features_disc).to(device)
inititalize_weights(gen)
inititalize_weights(disc)

opt_gen = optim.Adam(gen.parameters(),lr=learning_rate)
opt_disc = optim.Adam(disc.parameters(),lr=learning_rate)
criterion = nn.BCELoss()

fixed_noise = torch.randn(32,z_dim,1,1).to(device)
writer_real = SummaryWriter(log_dir=f'logs/real')
writer_fake = SummaryWriter(log_dir=f'logs/fake')
step = 0

gen.train()
disc.train()



Discriminator(
  (disc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (6): Sigmoid()
  )
)

In [None]:
for epoch in range(num_epochs):
  for batch_idx,(real,_) in enumerate(dataloader):
    real = real.to(device)
    noise = torch.randn(batch_size,z_dim,1,1).to(device)
    fake = gen(noise)

    ## Max log D(x) + log(1-D(G(x)))
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real,torch.ones_like(disc_real))
    disc_fake = disc(fake.detach()).reshape(-1)
    loss_disc_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real+loss_disc_fake)/2
    
    disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()


    #### Max(D(G(z)))
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()


    # Print losses occasionally and print to tensorboard
    if batch_idx % 100 == 0:
        print(
            f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
              Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
            fake = gen(fixed_noise)
            # take out (up to) 32 examples
            img_grid_real = torchvision.utils.make_grid(
                real[:32], normalize=True
            )
            img_grid_fake = torchvision.utils.make_grid(
                fake[:32], normalize=True
            )

            writer_real.add_image("Real", img_grid_real, global_step=step)
            writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1 

Epoch [0/3] Batch 0/469               Loss D: 0.6037, loss G: 0.9041
Epoch [0/3] Batch 100/469               Loss D: 0.0125, loss G: 4.2731
Epoch [0/3] Batch 200/469               Loss D: 0.0035, loss G: 5.5840
Epoch [0/3] Batch 300/469               Loss D: 0.0014, loss G: 6.4656
Epoch [0/3] Batch 400/469               Loss D: 0.0008, loss G: 6.9651
Epoch [1/3] Batch 0/469               Loss D: 0.1912, loss G: 4.2732
Epoch [1/3] Batch 100/469               Loss D: 0.0600, loss G: 4.9816
Epoch [1/3] Batch 200/469               Loss D: 0.0020, loss G: 7.6787
Epoch [1/3] Batch 300/469               Loss D: 0.0009, loss G: 7.8094
Epoch [1/3] Batch 400/469               Loss D: 0.1591, loss G: 4.1054
Epoch [2/3] Batch 0/469               Loss D: 0.0062, loss G: 6.9955
Epoch [2/3] Batch 100/469               Loss D: 0.0021, loss G: 7.6106
Epoch [2/3] Batch 200/469               Loss D: 0.0041, loss G: 6.8876
Epoch [2/3] Batch 300/469               Loss D: 0.0010, loss G: 7.9957
Epoch [2/3] 