In [7]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [8]:
# It uses standard convnet block
# convolution-batchnorm-relu
class Discriminator(nn.Module):
    def __init__(self,channels_img, features_d):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img,features_d,kernel_size=4,stride=2,padding=1),
            # img: 64x64
            nn.LeakyReLU(0.2),
            self.block(features_d,features_d*2,4,2,1),
            # img: 32x32
            self.block(features_d*2,features_d*4,4,2,1),
            # img: 16x16
            self.block(features_d*4,features_d*8,4,2,1),
            # img: 8x8
            nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
            # img: 4x4
            nn.Sigmoid()
            # img: 1x1
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.disc(x)

In [9]:
# It used transpose convolution since
# we need to upscale noize tensor to get image in the end
class Generator(nn.Module):
    def __init__(self,z_dim,img_channels, features_g):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # img: 4x4
            self._block(z_dim,features_g*16,4,1,0),
            # img: 8x8
            self._block(features_g*16,features_g*8,4,2,1),
            # img: 16x16
            self._block(features_g*8,features_g*4,4,2,1),
            # img: 32x32
            self._block(features_g*4,features_g*2,4,2,1),
            # img: 64x64
            nn.ConvTranspose2d(features_g*2,img_channels,kernel_size=4,stride=2,padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels,out_channels,kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self,x):
        return self.net(x)

In [10]:
# It is necessary for faster convergence of the model
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [None]:
# training hyperparameters

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 2e-4
batch_size = 128
image_size = 64
channels_img = 1
noise_dim = 100
num_epochs = 5
features_disc = 64
features_gen = 64

# initializing models, weights, optimizer and loss function,
# load data and create dataset

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(channels_img)],[0.5 for _ in range(channels_img)])
])

dataset = datasets.MNIST(root="dataset/",transform=transform,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
disc = Discriminator(channels_img,features_disc).to(device)
gen = Generator(noise_dim,channels_img,features_gen).to(device)
initialize_weights(disc)
initialize_weights(gen)

opt_disc = optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
opt_gen = optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32,noise_dim,1,1).to(device)

gen.train()
disc.train()

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn((batch_size, noise_dim, 1, 1)).to(device)

        # Train discriminator
        # we try to max log(D(real)) + log(1 - D(G(z)))
        fake = gen(noise)
        disc_real = disc(real).reshape(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).reshape(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # Train generator
        # we try to max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Print losses
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )