In [1]:
import torch 
import torch.nn as nn 
from torch.utils.data import DataLoader 
from torchvision.datasets import MNIST 
from torchvision import transforms 
import torch.nn.functional as F 
from torch.utils.tensorboard import SummaryWriter 
import torchvision


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
device

'cuda'

In [4]:
class Discriminator(nn.Module):
    def __init__(self, im_chan = 1, hidden_dim = 64, labels = 0):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.disc_block(im_chan+labels, hidden_dim, 4, 2, 1), # 32x32 
            self.disc_block(hidden_dim, hidden_dim*2, 4,2,1), # 16x 16 
            self.disc_block(hidden_dim*2, hidden_dim*4, 4,2,1), # 8x8 
            self.disc_block(hidden_dim*4, hidden_dim*8, 4,2,1), # 4x4 
            self.disc_block(hidden_dim*8, 1, 4,2,0, True), # 1x1 
        )
        
    def disc_block(self, input_channels, output_channels, kernel_size= 4, stride = 2,padding = 0 , final_layer = False):
        if not final_layer: 
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride,padding = padding, bias = False),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else: 
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
        
        
    def forward(self, x):
        disc_pred = self.disc(x)
        return disc_pred
        

In [5]:
class Generator(nn.Module):
    def __init__(self, z_dim = 100, hidden_dim = 32, img_channel=1, labels = 0):
        super(Generator, self).__init__()
        input_dim = z_dim + labels 
        output_dim = img_channel
        self.gen = nn.Sequential(
            self.gen_block(input_dim , hidden_dim*8, 4, 2, 0),  # 4x4 
            self.gen_block(hidden_dim*8,hidden_dim*4,  4, 2, 1), # 8x8 
            self.gen_block(hidden_dim*4,hidden_dim*2,  4, 2, 1), # 16x 16 
            self.gen_block(hidden_dim*2,hidden_dim,  4, 2, 1), # 32 x 32 
            self.gen_block(hidden_dim, output_dim, 4, 2, 1, True)
            
            
            
        
        )
        
    def gen_block(self,input_dim,output_dim, kernel_size = 4, stride = 1, padding = 0, final_layer = False):
        if not final_layer: 
            return nn.Sequential(
                nn.ConvTranspose2d(input_dim,output_dim,kernel_size, stride, padding, bias = False),
                nn.BatchNorm2d(output_dim),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else: 
            return nn.Sequential(
                nn.ConvTranspose2d(input_dim,output_dim, kernel_size, stride, padding),
                nn.Tanh()
            )
            
    def forward(self, x):
        out = self.gen(x)
        return out 
            
            
        

In [6]:
z_dim = 100 
N_classes = 10
nlabels = 10 
img_size = 64
img_channel = 1
epochs = 200
batch_size = 128 
criterion = nn.BCEWithLogitsLoss()

In [7]:
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [8]:
dataset = MNIST('.', download = True, transform = transform)
loader = DataLoader(dataset, 
                   batch_size = batch_size, 
                   shuffle= True)

In [9]:
gen = Generator(z_dim = 100, hidden_dim = 32, img_channel= img_channel, labels = 10).to(device)
disc = Discriminator(im_chan = img_channel, hidden_dim = 64, labels = 10).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr = 0.0001, betas=(0, 0.9))
disc_opt = torch.optim.Adam(disc.parameters(), lr = 0.0001, betas = (0, 0.9))


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [10]:
print(disc)

Discriminator(
  (disc): Sequential(
    (0): Sequential(
      (0): Conv2d(11, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=

In [11]:
# for tensorboard plotting
noise = torch.randn(32, z_dim,1,1).to(device)
y_labels = torch.randint(0,9, (32,)).to(device)
labels_onehot = F.one_hot(y_labels, nlabels).view(32, -1, 1,1).to(device)
noise_embed_label = torch.cat([noise, labels_onehot], axis = 1)

fixed_noise = noise_embed_label 
writer_real = SummaryWriter(f"logs/GAN_MNIST4/real")
writer_fake = SummaryWriter(f"logs/GAN_MNIST4/fake")
CRITIC_ITERATIONS = 5 

In [12]:
def gradient_penalty(critic, real, fake,img_labels, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)
    

    # Calculate critic scores
    
    int_img_nd_labels = torch.cat([interpolated_images, img_labels], axis =1)
    mixed_scores = critic(int_img_nd_labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=int_img_nd_labels,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
LAMBDA_GP = 10 
gen.train()
disc.train()
step = 0 
for epoch in range(epochs): 
    for batch_idx, (true_imgs, y_labels) in enumerate(loader):
        
        true_imgs = true_imgs.to(device)
        cur_batch_size = true_imgs.shape[0]
        
        # update the critic 
        
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, z_dim,1,1).to(device)
            labels_onehot = F.one_hot(y_labels, nlabels).view(cur_batch_size, -1, 1,1).to(device)
            noise_embed_label = torch.cat([noise, labels_onehot], axis = 1)
            
            imgs_onehot = torch.ones(size = (cur_batch_size, nlabels,img_size, img_size), dtype = true_imgs.dtype, device = device)
            img_labels = imgs_onehot * labels_onehot 
            true_img_nd_labels = torch.cat([true_imgs, img_labels], axis = 1)
            
            # make the prediction with critic and generator 
            
            #print(true_img_nd_labels.shape)
            
            fake = gen(noise_embed_label)
            #raise Exception
            fake_img_nd_labels = torch.cat([fake, img_labels], axis =1)
            critic_real = disc(true_img_nd_labels).reshape(-1)
            critic_fake = disc(fake_img_nd_labels).reshape(-1)
            
            gp = gradient_penalty(disc, true_imgs, fake, img_labels, device = device)
            
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            
            disc.zero_grad()
            loss_critic.backward(retain_graph = True)
            disc_opt.step()
            #print("hi")
            
        # update the generator 
        critic_fake = disc(fake_img_nd_labels).reshape(-1)
        loss_gen = -torch.mean(critic_fake)
        gen.zero_grad()
        loss_gen.backward()
        
        gen_opt.step()
        
        #print('Done step {}'.format(step))
        
        #Print losses occasionally and print to tensorboard
        if batch_idx % 20 == 0 and batch_idx > 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(true_imgs[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32,:,:,:], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        step += 1
        

Epoch [0/200] Batch 20/469                   Loss D: -4.9631, loss G: 2.7132
Epoch [0/200] Batch 40/469                   Loss D: -12.7823, loss G: 6.4920
Epoch [0/200] Batch 60/469                   Loss D: -23.8688, loss G: 11.8114
Epoch [0/200] Batch 80/469                   Loss D: -38.1554, loss G: 18.5943
Epoch [0/200] Batch 100/469                   Loss D: -54.9208, loss G: 26.8070
Epoch [0/200] Batch 120/469                   Loss D: -75.6240, loss G: 36.5842
Epoch [0/200] Batch 140/469                   Loss D: -99.2053, loss G: 47.8824
Epoch [0/200] Batch 160/469                   Loss D: -125.7252, loss G: 60.5838
Epoch [0/200] Batch 180/469                   Loss D: -155.6257, loss G: 74.8743
Epoch [0/200] Batch 200/469                   Loss D: -188.4962, loss G: 90.4659
Epoch [0/200] Batch 220/469                   Loss D: -224.4855, loss G: 107.7368
Epoch [0/200] Batch 240/469                   Loss D: -263.8954, loss G: 126.4942
Epoch [0/200] Batch 260/469             

Epoch [4/200] Batch 160/469                   Loss D: -15932.5830, loss G: 7589.7295
Epoch [4/200] Batch 180/469                   Loss D: -16232.8428, loss G: 7734.5166
Epoch [4/200] Batch 200/469                   Loss D: -16541.7070, loss G: 7881.7715
Epoch [4/200] Batch 220/469                   Loss D: -16834.3652, loss G: 8024.9629
Epoch [4/200] Batch 240/469                   Loss D: -17115.2812, loss G: 8173.7637
Epoch [4/200] Batch 260/469                   Loss D: -17424.6543, loss G: 8303.2217
Epoch [4/200] Batch 280/469                   Loss D: -17592.9492, loss G: 8363.1328
Epoch [4/200] Batch 300/469                   Loss D: -17907.7656, loss G: 8538.8809
Epoch [4/200] Batch 320/469                   Loss D: -18227.3301, loss G: 8681.7939
Epoch [4/200] Batch 340/469                   Loss D: -18565.4121, loss G: 8859.2812
Epoch [4/200] Batch 360/469                   Loss D: -18920.6934, loss G: 9025.5527
Epoch [4/200] Batch 380/469                   Loss D: -19233.5000

Epoch [8/200] Batch 260/469                   Loss D: 18.1340, loss G: -11746.9141
Epoch [8/200] Batch 280/469                   Loss D: -1002.8507, loss G: -10625.9551
Epoch [8/200] Batch 300/469                   Loss D: -407.1713, loss G: -11007.4668
Epoch [8/200] Batch 320/469                   Loss D: -1991.2924, loss G: -11287.3311
Epoch [8/200] Batch 340/469                   Loss D: -2254.4614, loss G: -10988.5908
Epoch [8/200] Batch 360/469                   Loss D: -2693.8308, loss G: -10624.8066
Epoch [8/200] Batch 380/469                   Loss D: -2994.8428, loss G: -10017.7393
Epoch [8/200] Batch 400/469                   Loss D: -3771.5732, loss G: -9205.4668
Epoch [8/200] Batch 420/469                   Loss D: -3946.7166, loss G: -8305.1426
Epoch [8/200] Batch 440/469                   Loss D: -6206.9893, loss G: -6812.7900
Epoch [8/200] Batch 460/469                   Loss D: -7504.1318, loss G: -4617.1504
Epoch [9/200] Batch 20/469                   Loss D: -10247.37

Epoch [12/200] Batch 340/469                   Loss D: -41809.2344, loss G: 14663.4619
Epoch [12/200] Batch 360/469                   Loss D: -31873.0293, loss G: 23603.8984
Epoch [12/200] Batch 380/469                   Loss D: -41971.7539, loss G: 21719.4492
Epoch [12/200] Batch 400/469                   Loss D: -12061.2188, loss G: 20668.6289
Epoch [12/200] Batch 420/469                   Loss D: -45223.6172, loss G: 22448.9961
Epoch [12/200] Batch 440/469                   Loss D: -23649.7637, loss G: 8010.6260
Epoch [12/200] Batch 460/469                   Loss D: -42448.3320, loss G: 24100.4297
Epoch [13/200] Batch 20/469                   Loss D: -34207.4766, loss G: 24353.4531
Epoch [13/200] Batch 40/469                   Loss D: -44271.9688, loss G: 22902.3105
Epoch [13/200] Batch 60/469                   Loss D: -44089.9609, loss G: 18973.0273
Epoch [13/200] Batch 80/469                   Loss D: -47006.6055, loss G: 22118.6797
Epoch [13/200] Batch 100/469                   L

Epoch [16/200] Batch 400/469                   Loss D: -80530.2578, loss G: 41550.2422
Epoch [16/200] Batch 420/469                   Loss D: -86481.3125, loss G: 41684.8711
Epoch [16/200] Batch 440/469                   Loss D: -80510.3672, loss G: 38657.3906
Epoch [16/200] Batch 460/469                   Loss D: -76582.9922, loss G: 36279.1758
Epoch [17/200] Batch 20/469                   Loss D: -40075.4766, loss G: 15179.3438
Epoch [17/200] Batch 40/469                   Loss D: -75684.0547, loss G: 36220.4531
Epoch [17/200] Batch 60/469                   Loss D: -55419.9375, loss G: 14853.4883
Epoch [17/200] Batch 80/469                   Loss D: -45827.9766, loss G: 41279.5312
Epoch [17/200] Batch 100/469                   Loss D: -73226.6641, loss G: 38152.8203
Epoch [17/200] Batch 120/469                   Loss D: -71032.9375, loss G: 42564.9492
Epoch [17/200] Batch 140/469                   Loss D: -81583.6328, loss G: 37505.6250
Epoch [17/200] Batch 160/469                   