In [1]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset # 
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [40]:
# some parameters
image_size = 32
channel_dim = 1
mean = 0.0
std = 1.0 # paper says -1,1
z_dim = 100
num_classes = 10

In [3]:
# initializing mnist dataset
mnist = datasets.MNIST(root = '/content/sample_data', train = True, transform = transforms.Compose(
            [transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([mean], [std])]
        ), download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw



In [4]:
# getting a dataloader
dataloader = DataLoader(dataset = mnist, batch_size = 64, shuffle = True)

In [29]:
for i, (imgs, labels) in enumerate(dataloader):
  print(i)
  print(labels)
  print(labels.shape)
  print(imgs.shape)
  break

0
tensor([4, 3, 0, 4, 4, 3, 1, 8, 9, 2, 6, 6, 6, 9, 2, 6, 9, 4, 0, 6, 6, 3, 9, 5,
        6, 6, 4, 4, 9, 6, 3, 4, 5, 5, 9, 1, 6, 1, 3, 1, 6, 3, 4, 7, 4, 4, 5, 2,
        6, 6, 1, 7, 7, 5, 0, 6, 0, 1, 8, 0, 6, 5, 6, 0])
torch.Size([64])
torch.Size([64, 1, 32, 32])


In [27]:
embedding = nn.Embedding(10,10,)
input = torch.randint(0,10,(64,))
print(input)
out = embedding(input)
print(out.shape)
z = torch.from_numpy(np.random.normal(0, 1, (64, 100))).float()
input = torch.cat((z,out),1)
print(input.shape)

tensor([8, 9, 9, 2, 3, 3, 8, 7, 0, 0, 2, 2, 2, 1, 8, 0, 9, 6, 0, 7, 1, 1, 3, 4,
        6, 7, 7, 0, 1, 9, 9, 7, 2, 3, 6, 8, 4, 6, 2, 0, 6, 7, 4, 6, 7, 4, 5, 2,
        0, 9, 7, 9, 4, 6, 4, 5, 8, 3, 4, 9, 2, 2, 9, 9])
torch.Size([64, 10])
torch.Size([64, 110])


In [75]:
class generator(nn.Module):
  def __init__(self,):
    super(generator, self).__init__()
    self.init_size = 8 # so we need to upsample twice! 8x8
    #cgan modify
    self.embedding = nn.Embedding(num_classes,num_classes) # embedding 
    self.l1 = nn.Linear(100 + 10 , 128 * self.init_size **2)
    #
    self.conv = nn.Sequential(
        nn.BatchNorm2d(128),
        nn.Upsample(scale_factor = 2), 
        nn.Conv2d(128, 128, 3, stride=1, padding=1),
        nn.BatchNorm2d(128, 0.8),
        nn.ReLU(),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(128, 64, 3, stride=1, padding=1),
        nn.BatchNorm2d(64, 0.8),
        nn.ReLU(),
        nn.Conv2d(64, channel_dim, 3, stride=1, padding=1),
        nn.Tanh()
    )

  def forward(self,z,labels):
    label_embed = self.embedding(labels) # doesnt matter even if we dont do detach or torch.nograd since unique label will always have same embed.
    input = torch.cat((z,label_embed),1) #64x110
    out = self.l1(input)
    out = out.view(out.shape[0],128,self.init_size,self.init_size)
    out = self.conv(out)
    return out

In [76]:
class discriminator(nn.Module):
  def __init__(self):
    super(discriminator, self).__init__()
    
    def block(in_filters, out_filters, bn = False, drop = True):
      blk = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.02,inplace=True)]
      if drop:
        blk.append(nn.Dropout2d(0.25))
      if bn:
        blk.append(nn.BatchNorm2d(out_filters, 0.8))
      return blk

    self.downsampled = image_size//2**4

    # added 1 cos label embedding
    self.conv = nn.Sequential(
        *block(channel_dim + 1,16), 
        *block(16,32,bn = True),
        *block(32,64,bn = True),
        *block(64,128,bn = True)
    )

    self.last = nn.Sequential(nn.Linear(128*self.downsampled**2,1),nn.Sigmoid())

    #cgan modif
    self.embedding = nn.Embedding(num_classes,num_classes)
    self.linear_embed = nn.Linear(10, image_size* image_size* channel_dim) 

  def forward(self,x,labels):
    embedding = self.embedding(labels)
    lin_embedding = self.linear_embed(embedding)
    label_channel = lin_embedding.view(lin_embedding.shape[0],1,image_size,image_size)
    input = torch.cat((x,label_channel),1)
    out = self.conv(input)
    out = out.view(out.shape[0], 128*self.downsampled**2)
    realness = self.last(out)
    return realness

In [77]:
def gen_loss(z,gen_img,dis,labels): #detach dis
  loss = (-1/64)*torch.sum(torch.log(dis(gen_img,labels))) 
  return loss

def dis_loss(z,x,gen_img,dis,labels):
  loss = (-1/64)*torch.sum(torch.log(dis(x,labels)) + torch.log(1 - dis(gen_img,labels)))
  return loss

def weight_init(m):
  classname = m.__class__.__name__
  if classname.find('Linear')!=-1:
    torch.nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
  if classname.find('Conv')!=-1:
    torch.nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
  

In [79]:
gen = generator()
gen.to(device)
dis = discriminator()
dis.to(device)
gen.apply(weight_init)
dis.apply(weight_init)
optimizer_G = torch.optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [80]:

for epoch in range(200):
  for i, (imgs, labels) in enumerate(dataloader):
    real_imgs = imgs.float().to(device)
    labels = labels.to(device)

    z = torch.from_numpy(np.random.normal(0, 1, (imgs.shape[0], 100))).float().to(device)
    gen_img = gen(z,labels)
    optimizer_G.zero_grad()
    gloss = gen_loss(z,gen_img,dis,labels)
    gloss.backward()
    optimizer_G.step()

    optimizer_D.zero_grad()
    dloss = dis_loss(z,real_imgs,gen_img.detach(),dis,labels)
    dloss.backward()
    optimizer_D.step()
  
    print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, 200, i, len(dataloader), dloss.item(), gloss.item())
        )
    batches_done = epoch * len(dataloader) + i
    if batches_done % 400 == 0:
      label = torch.arange(0,10)
      label = torch.cat((label,label,label,label,label,label,label,label,label,label)).to(device)
      
      z = torch.from_numpy(np.random.normal(0, 1, (100, 100))).float().to(device)
      display_gen_img = gen(z,label) # out put is 100 imgs
      save_image(display_gen_img.data[:], "/content/gen/%d.png" % batches_done, nrow=10, normalize=True)
    

[Epoch 22/200] [Batch 779/938] [D loss: 0.670549] [G loss: 1.191302]
[Epoch 22/200] [Batch 780/938] [D loss: 0.628193] [G loss: 0.488809]
[Epoch 22/200] [Batch 781/938] [D loss: 0.420290] [G loss: 2.401342]
[Epoch 22/200] [Batch 782/938] [D loss: 0.749749] [G loss: 1.105287]
[Epoch 22/200] [Batch 783/938] [D loss: 1.758098] [G loss: 0.715056]
[Epoch 22/200] [Batch 784/938] [D loss: 0.520198] [G loss: 3.399500]
[Epoch 22/200] [Batch 785/938] [D loss: 0.488586] [G loss: 1.738253]
[Epoch 22/200] [Batch 786/938] [D loss: 0.689723] [G loss: 1.885129]
[Epoch 22/200] [Batch 787/938] [D loss: 0.290431] [G loss: 2.632135]
[Epoch 22/200] [Batch 788/938] [D loss: 0.338310] [G loss: 2.588498]
[Epoch 22/200] [Batch 789/938] [D loss: 0.396941] [G loss: 2.177020]
[Epoch 22/200] [Batch 790/938] [D loss: 0.325816] [G loss: 1.680483]
[Epoch 22/200] [Batch 791/938] [D loss: 1.776653] [G loss: 0.239587]
[Epoch 22/200] [Batch 792/938] [D loss: 2.044701] [G loss: 2.792624]
[Epoch 22/200] [Batch 793/938] [D 

KeyboardInterrupt: ignored