In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np


In [None]:
class critic(nn.Module):
  def __init__(self,channel_imgs,num_classes,img_size):

    super(critic,self).__init__()
    self.img_size=img_size
    self.critic=nn.Sequential(
      nn.Conv2d(channel_imgs+1,64,4,2,1),
      nn.LeakyReLU(0.2),
      nn.Conv2d(64,128,4,2,1),
      nn.InstanceNorm2d(128),
      nn.LeakyReLU(0.2),
      nn.Conv2d(128,256,4,2,1),
      nn.InstanceNorm2d(256),
      nn.LeakyReLU(0.2),
      nn.Conv2d(256,512,4,2,1),
      nn.InstanceNorm2d(512),
      nn.LeakyReLU(0.2),
      nn.Conv2d(512,1,4,2,0)
    )
    self.embed=nn.Embedding(num_classes,img_size*img_size)
  def forward(self,x,labels):
    embedding=self.embed(labels).view(labels.shape[0],1,self.img_size,self.img_size)
    x=torch.cat([x,embedding],dim=1)
    return self.critic(x)
  

In [None]:
class Generator(nn.Module):
  def __init__(self,z_dim,channels_img,num_classes,img_size,embed_size):
    super(Generator,self).__init__()
    self.img_size=img_size
    self.gen=nn.Sequential(
      nn.ConvTranspose2d(z_dim+embed_size,1024,4,1,1,bias=False),
      nn.BatchNorm2d(1024),
      nn.ReLU(),
      nn.ConvTranspose2d(1024,512,4,2,1,bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(),
      nn.ConvTranspose2d(512,256,4,2,1,bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(),
      nn.ConvTranspose2d(256,128,4,2,1,bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(),
      nn.ConvTranspose2d(128,64,4,2,1,bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(),
      nn.ConvTranspose2d(64,channels_img,4,2,1),
      nn.Tanh()
    )
    self.embed=nn.Embedding(num_classes,embed_size)
  def forward(self,x,labels):
    embedding=self.embed(labels).unsqueeze(2).unsqueeze(3)
    x=torch.cat([x,embedding],dim=1)
    return self.gen(x)

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)


In [None]:
batch_size=64
c=5
epochs=5
lambda_gp=10
image_size=64
z_dim=100
learning_rate=1e-4
channel_img=1
device="cuda"
num_classes=10
gen_embedding=100

In [None]:
transforms=transforms.Compose([
  transforms.Resize(image_size),
  transforms.ToTensor(),
  transforms.Normalize([0.5 for _ in range(channel_img)],[0.5 for _ in range(channel_img)])
])


In [None]:
datasets=datasets.MNIST(root='./data',download=True,train=True)

In [None]:
loader=DataLoader(datasets,batch_size=batch_size,shuffle=True)

In [None]:
def gradient_penulty(critic,real,fake,labels,device="cpu"):
  batch_size,c,h,w=real.shape
  alpha=torch.rand((batch_size,1,1,1)).repeat(1,c,h,w).to(device)
  interpolated_images=real*alpha+fake*(1-alpha)

  mixed_scores=critic(interpolated_images,labels)

  gradient=torch.autograd.grad(
    inputs=interpolated_images,
    outputs=mixed_scores,
    grad_outputs=torch.ones_like(mixed_scores),
    create_graph=True,
    retain_graph=True
  )[0]
  gradient=gradient.view(gradient.shape[0],-1)
  gradient_norm=gradient.norm(2,dim=1)
  gradient_penulty=torch.mean((gradient_norm-1)**2)
  return gradient_penulty

In [None]:
gen=Generator(z_dim,channel_img,num_classes,image_size,gen_embedding).to(device)
critic=critic(channel_img,num_classes,image_size).to(device)
initialize_weights(gen)
initialize_weights(critic)


In [None]:
gen_optim=optim.Adam(gen.parameters(),lr=learning_rate,betas=(0.0,0.9))
critic_optim=optim.Adam(critic.parameters(),lr=learning_rate,betas=(0.0,0.9))


In [None]:
fixed_noise=torch.randn(32,z_dim,1,1).to(device)
gen.train()
critic.train()

In [None]:
def show_images(real_images,fake_images,epoch):
  real_grid=torchvision.utils.make_grid(real_images[:8],normalize=True)
  fake_grid=torchvision.utils.make_grid(fake_grid[:8],normalize=True)
  fig,axs=plt.subplot(1,2,figsize=(12,6))
  axs[0].imshow(real_grid.permute(1,2,0).cpu().numpy())
  axs[0].set_title(f'Real images (Epoch {epoch})')
  axs[0].axis('off')
  axs[1].imshow(fake_grid.permute(1,2,1).cpu().numpy())
  axs[1].set_title(f'Fake images (Epoch {epoch})')
  axs[1].axis('off')
  plt.show()

In [None]:
from tqdm import tqdm
for epoch in range(epochs):
  with tqdm(total=len(loader),desc=f"epochs{epoch+1/{epoch}}") as pbar:
  for batch_idx,(real,labels) in enumerate(loader):
    real=real.to(device)
    cur_batch_size=real.shape[0]
    labels=labels.to(device)
    for _ in range(c):
      noise=torch.rand((cur_batch_size,z_dim,1,1)).to(device)
      fake=gen(noise,labels)
      critic_real=critic(real,labels).reshape(-1)
      critic_fake=critic(fake,labels).reshape(-1)
      gp=gradient_penulty(critic,real,fake,labels,device)
      critic_loss=(-torch.mean(real)-torch.mean(fake))+lambda_gp*gp
      critic_optim.zero_grad()
      critic_loss.backward(retain_graph=True)
      critic_optim.step()
    gen_fake=critic(fake,labels).reshape(-1)
    gen_loss=-torch.mean(gen_fake)
    gen_optim.zero_grad()
    gen_loss.backward()
    gen_optim.step()
    pbar.update(1)
    pbar.set_postfix(critic_loss=critic_loss.item(),Generator_loss=gen_loss.item())
  with torch.no_grad():
    fake_images=gen(fixed_noise)
    show_images(real,fake_images,epoch+1)