In [1]:
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.activation import ReLU

In [2]:
def gradient_penalty(critic,labels,real,fake,device="cpu"):
  BATCH_SIZE,C,H,W = real.shape
  epsilon = torch.rand((BATCH_SIZE,1,1,1)).repeat(1,C,H,W).to(device)
  interpolated_images = real*epsilon+fake*(1-epsilon)
  mixed_scores = critic(interpolated_images,labels)

  gradient = torch.autograd.grad(
      inputs = interpolated_images,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True,
  )[0]

  gradient = gradient.view(gradient.shape[0],-1)
  gradient_norm = gradient.norm(2,dim=1)
  gradient_penalty = torch.mean((gradient_norm-1)**2)
  return gradient_penalty

In [3]:
class Critic(nn.Module):
  def __init__ (self, channels_img, features_d,num_classes,img_size):
    super(Critic,self).__init__()
    self.img_size = img_size
    self.critic = nn.Sequential(
        #64x64
        nn.Conv2d(channels_img+1, features_d, kernel_size=4,stride=2, padding = 1),#32x32 
        nn.LeakyReLU(0.2),
        self._block(features_d,features_d * 2 , 4 , 2 , 1),#16x16
        self._block(features_d * 2,features_d * 4 , 4 , 2 , 1),#8x8
        self._block(features_d * 4,features_d * 8 , 4 , 2 , 1),#4x4
        nn.Conv2d(features_d*8,1, kernel_size=4,stride=2, padding = 0),#1x1
    )
    self.embed = nn.Embedding(num_classes, img_size**2)

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = False,
        ),
        nn.InstanceNorm2d(out_channels,affine =True),
        nn.LeakyReLU(0.2),
    )
  def forward(self,x,labels):
    embedding = self.embed(labels).view(labels.shape[0],1,self.img_size,self.img_size) # bugs witch channels_img??????
    x = torch.cat([x, embedding],dim = 1)
    return self.critic(x)

In [4]:
class Generator(nn.Module):
  def __init__ (self,z_dim,channels_img,features_g,num_classes,img_size,embed_size):
    super(Generator,self).__init__()
    self.img_size = img_size
    self.net = nn.Sequential(
        # N x z_dim x 1 x 1
        self._block(z_dim + embed_size,features_g * 16, 4 , 1 , 0),# N x f_g*16 x 4 x 4 
        self._block(features_g * 16,features_g * 8, 4 , 2 , 1),# 8x8
        self._block(features_g * 8,features_g * 4, 4 , 2 , 1),# 16x16
        self._block(features_g * 4,features_g * 2, 4 , 2 , 1),# 32x32
        nn.ConvTranspose2d(features_g*2,channels_img,kernel_size=4,stride = 2, padding = 1,),
        nn.Tanh(),
    )
    self.embed = nn.Embedding(num_classes,embed_size)
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias = False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )
  def forward(self,x,labels): #how this works
    embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([x,embedding],dim =1)
    return self.net(x)

In [5]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

In [6]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100

  x = torch.randn((N, in_channels, H, W))
  critic = Critic(in_channels,8)
  assert critic(x).shape == (N, 1, 1, 1), "Critic test failed"

  gen = Generator(z_dim, in_channels, 8)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
  print("success")

In [7]:
test()

success


In [8]:
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [9]:
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATION = 5
LAMBDA_GP = 10

In [10]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [11]:
dataset = datasets.MNIST(root = "dataset/",train = True, transform = transforms, download = True)
loader = DataLoader(dataset,batch_size = BATCH_SIZE,shuffle = True)

In [12]:
gen = Generator(Z_DIM, CHANNELS_IMG,FEATURES_GEN,NUM_CLASSES,IMAGE_SIZE,GEN_EMBEDDING).to(device)
critic = Critic(CHANNELS_IMG,FEATURES_DISC,NUM_CLASSES,IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [13]:
opt_gen = optim.Adam(gen.parameters(),lr = LEARNING_RATE,betas = (0.0,0.9))
opt_critic = optim.Adam(critic.parameters(),lr = LEARNING_RATE,betas = (0.0,0.9))

In [14]:
writen_real = SummaryWriter(f"logs/real")
writen_fake = SummaryWriter(f"logs/fake")
step = 0

In [15]:
gen.train()
critic.train()

Critic(
  (critic): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [None]:
for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, labels) in enumerate(loader):
    real = real.to(device)
    cur_batch_size = real.shape[0]
    labels = labels.to(device)
    for _ in range(CRITIC_ITERATION):
      noise = torch.randn((cur_batch_size,Z_DIM,1,1)).to(device)
      fake = gen(noise,labels)
      critic_real = critic(real,labels).reshape(-1)
      critic_fake = critic(fake,labels).reshape(-1)
      gp = gradient_penalty(critic,labels,real,fake,device = device)
      loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))+LAMBDA_GP*gp
      critic.zero_grad()
      loss_critic.backward(retain_graph = True)
      opt_critic.step()

      output = critic(fake,labels).reshape(-1)
      loss_gen = -torch.mean(output)
      gen.zero_grad()
      loss_gen.backward()
      opt_gen.step()

    if batch_idx % 100 == 0:
        print(
            f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
              Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
        )

        with torch.no_grad():
            fake = gen(noise,labels)
            # take out (up to) 32 examples
            img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
            img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

            writen_real.add_image("Real", img_grid_real, global_step=step)
            writen_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1

Epoch [0/5] Batch 0/938               Loss D: 65.8064, loss G: 4.8808
Epoch [0/5] Batch 100/938               Loss D: -38.8260, loss G: 46.1698
Epoch [0/5] Batch 200/938               Loss D: -5.9252, loss G: 43.5130
Epoch [0/5] Batch 300/938               Loss D: -6.8480, loss G: 36.4607
Epoch [0/5] Batch 400/938               Loss D: -6.3273, loss G: 31.1886
Epoch [0/5] Batch 500/938               Loss D: -6.5994, loss G: 34.8394
Epoch [0/5] Batch 600/938               Loss D: -7.3441, loss G: 29.7633
Epoch [0/5] Batch 700/938               Loss D: -5.7495, loss G: 32.2783
Epoch [0/5] Batch 800/938               Loss D: -4.7336, loss G: 29.8325
Epoch [0/5] Batch 900/938               Loss D: -5.6575, loss G: 29.7220
Epoch [1/5] Batch 0/938               Loss D: -3.8127, loss G: 27.6373
Epoch [1/5] Batch 100/938               Loss D: -4.3145, loss G: 27.4836
Epoch [1/5] Batch 200/938               Loss D: -3.1664, loss G: 27.0205
Epoch [1/5] Batch 300/938               Loss D: -5.9979

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port=6010