In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [None]:
# It allows for more stable training of the network 
# without exploding and vanishing gradients
# and requires very little hyper-parameter tuning.
def gradient_penalty(critic,labels,real,fake,device='cpu'):
    batch_size,c,h,w = real.shape
    epsilon = torch.rand((batch_size,1,1,1)).repeat(1,c,h,w).to(device)
    interpolated_imgs = real * epsilon + fake * (1 - epsilon)

    # calculate scores
    mixed_scores = critic(interpolated_imgs,labels)

    gradient = torch.autograd.grad(
        inputs=interpolated_imgs,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1)**2)

    return gradient_penalty

In [None]:
# It uses instance norm instead of batch norm
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d, num_classes, image_size):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img+1,features_d,kernel_size=4,stride=2,padding=1),
            # img: 64x64
            nn.LeakyReLU(0.2),
            self.block(features_d,features_d*2,4,2,1),
            # img: 32x32
            self.block(features_d*2,features_d*4,4,2,1),
            # img: 16x16
            self.block(features_d*4,features_d*8,4,2,1),
            # img: 8x8
            nn.Conv2d(features_d*8,1,kernel_size=4,stride=2,padding=0),
            # img: 4x4
            nn.Sigmoid()
            # img: 1x1
        )
        # additional channel for image
        self.embed = nn.Embedding(num_classes, image_size*image_size)
        self.image_size = image_size

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.InstanceNorm2d(out_channels, affine= True), 
            nn.LeakyReLU(0.2)
        )

    def forward(self,x,labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.image_size, self.image_size)
        x = torch.cat([x,embedding], dim=1)
        return self.disc(x)

In [None]:
# It used transpose convolution since
# we need to upscale noize tensor to get image in the end
class Generator(nn.Module):
    def __init__(self,z_dim,img_channels, features_g, num_classes, image_size, embed_size):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # img: 4x4
            self.block(z_dim+embed_size,features_g*16,4,1,0),
            # img: 8x8
            self.block(features_g*16,features_g*8,4,2,1),
            # img: 16x16
            self.block(features_g*8,features_g*4,4,2,1),
            # img: 32x32
            self.block(features_g*4,features_g*2,4,2,1),
            # img: 64x64
            nn.ConvTranspose2d(features_g*2,img_channels,kernel_size=4,stride=2,padding=1),
            nn.Tanh()
        )
        self.img_size = image_size
        self.embed = nn.Embedding(num_classes, embed_size)

    def block(self, in_channels,out_channels,kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self,x, labels):
        embedding = self.embed(labels)[:,:,None,None]
        x = torch.cat([x,embedding], dim=1)
        return self.net(x)

In [None]:
# It is necessary for faster convergence of the model
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [None]:
# training hyperparameters

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 1e-4
batch_size = 64
image_size = 64
channels_img = 1
num_classes = 10
gen_embed = 100
noise_dim = 100
num_epochs = 5
features_disc = 64
features_gen = 64
disc_iterations = 5
Lambda = 10

# initializing models, weights, optimizer and loss function,
# load data and create dataset

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(channels_img)],[0.5 for _ in range(channels_img)])
])

dataset = datasets.MNIST(root="dataset/",transform=transform,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
disc = Discriminator(channels_img,features_disc,num_classes,image_size).to(device)
gen = Generator(noise_dim,channels_img,features_gen,num_classes,image_size,gen_embed).to(device)
initialize_weights(disc)
initialize_weights(gen)

opt_disc = optim.Adam(disc.parameters(),lr=lr,betas=(0.0,0.9))
opt_gen = optim.Adam(gen.parameters(),lr=lr,betas=(0.0,0.9))

fixed_noise = torch.randn(32,noise_dim,1,1).to(device)

gen.train()
disc.train()

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real, labels) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)

        for _ in range(disc_iterations):
            noise = torch.randn((cur_batch_size, noise_dim, 1, 1)).to(device)
            fake = gen(noise,labels)
            disc_real = disc(real,labels).reshape(-1)
            disc_fake = disc(fake,labels).reshape(-1)
            gp = gradient_penalty(disc,labels,real,fake,device=device)
            loss_disc = -(torch.mean(disc_real) - torch.mean(disc_fake)) + Lambda * gp
            disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            opt_disc.step()

        # Train generator
        output = disc(fake,labels).view(-1)
        lossG = -torch.mean(output)
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Print losses
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {lossG:.4f}"
            )