In [1]:
import torch, os
import torch.nn as nn
from tqdm.auto import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
torch.manual_seed(0) # Set for our testing purposes, please do not change!
from torchvision.utils import save_image

  warn(f"Failed to load image Python extension: {e}")


In [2]:
class Generator(nn.Module):
    def __init__(self, input_dim=10, output_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.gen_block(input_dim, hidden_dim * 4, 3, 2), # image : Batch * (64*4) * (3) * (3)
            self.gen_block(hidden_dim * 4, hidden_dim * 2, 4, 1), # image : Batch *  (64*2) * (6) * (6)
            self.gen_block(hidden_dim * 2, hidden_dim, 3, 2), # image : Batch *  (64) * (13) * (13)
            self.gen_block(hidden_dim, output_chan, 4, 2, final_layer=True), # image : Batch *  (1) * (28) * (28)
        )      
           
    def gen_block(self, input_channels, output_channels, kernel_size, stride, final_layer=False):
        
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size,
                stride=stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size,
                stride=stride),
                nn.Tanh())
        
    def forward(self, x):
        return self.gen(x)
    

In [3]:
class Disc(nn.Module):
    def __init__(self, image_chan=1, hidden_dim=64):
        super(Disc, self).__init__()
        self.disc = nn.Sequential(   # Initial Image: batch * 1 * 28 * 28
            self.critic_block(image_chan, hidden_dim, 4, 2), # Image: Batch * 64 * 13 * 13
            self.critic_block(hidden_dim, hidden_dim * 2, 4, 2), # Image: Batch * 64 * 5 * 5
            self.critic_block(hidden_dim * 2, 1, 4, 2, final_layer=True), # Image: Batch * 1 * 1 * 1      
        )
        
    def critic_block(self,input_channels, output_channels, kernel_size, stride, final_layer=False):
        
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),               
            )
        
    def forward(self, x):
        return self.disc(x)

In [4]:
def initilalize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [5]:
def test():
    B, C, H, W = 64, 1, 28, 28
    input_dim=10
    x = torch.randn((B, C, H, W))
    disc = Disc(1, 64)
    initilalize_weights(disc)
    assert disc(x).shape == (B, 1, 1, 1)
    gen = Generator(input_dim=10, output_chan=1, hidden_dim=64)
    z = torch.randn(B, input_dim, 1, 1)
    initilalize_weights(gen)
    assert gen(z).shape == (B, C, H, W)
    print("Success")
    

In [6]:
test()

Success


In [7]:
#Hyperparameter setting
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 2e-4
num_disc = 5
LAMBDA = 10
batch_size = 128
image_size = 28
image_chan = 1
num_epoches = 200
hidden_dim = 64
feature_gen = 16
z_dim = 64
display_step = 500
Transforms = transforms.Compose([transforms.Resize(image_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.5 for _ in range(image_chan)], [0.5 for _ in range(image_chan)]),
                                 ])
## Data Loading
train_dataset = MNIST(root="/.", train=True, transform=Transforms, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
num_class = len(train_loader.dataset.classes)


  return torch._C._cuda_getDeviceCount() > 0


In [8]:
for batch_idx, (data, labels) in enumerate(train_loader):
    #print(data.shape)
    #print(F.one_hot(labels).unsqueeze(2).unsqueeze(3).shape)
    #one_hot_class = F.one_hot(labels).unsqueeze(2).unsqueeze(3)
    #print(one_hot_class.shape)
    #p = one_hot_class[0]
    #print(len(labels))
    #print(p.repeat(1,1,data.shape[2], data.shape[3])[4])
    #print(one_hot_class.repeat(1,1,data.shape[2], data.shape[3]).shape)
    break
train_loader.dataset.data[0].unsqueeze(0).shape
len(train_loader.dataset)

60000

In [9]:
#print(p.repeat(1,1,data.shape[2], data.shape[3])[:,6,:,:])

In [10]:
def Disc_input_hot_encoded(data, labels, num_class):
    '''
    data = real data (size: B*C*H*W)
    labels = real labels (size: B)
    Takes the  data, labels and num_class & returns 
    one hot encode channels of the class ( size: B * (num_class+C) * H * W)
    '''
    B, C, H, W = data.shape
    one_hot_encoding = F.one_hot(labels, num_class) # image size: B * num_class
    one_hot_label = one_hot_encoding.unsqueeze(2).unsqueeze(3) # image size: B * num_class * 1 * 1 
    one_hot_channels = one_hot_label.repeat(1,1,data.shape[2], data.shape[3]) # image size: B * C * H * W
    disc_input = torch.cat((data, one_hot_channels), 1) # image size: B * (num_class+C) * H * W
    #disc_dim = disc_input.shape[1]
    return disc_input

In [11]:
def test_disc_encoded():
    data = torch.randn(128, 1, 28, 28)
    labels = torch.randint(0, 10, (128,))
    num_class = 10
    disc_input = Disc_input_hot_encoded(data, labels, num_class)
    assert disc_input.shape == (128, 11, 28, 28)
    print("success")
    

In [12]:
test_disc_encoded()

success


In [13]:
# Model Instantiation
gen = Generator(input_dim = z_dim + num_class, output_chan = image_chan + num_class, hidden_dim=64).to(device=device)
disc = Disc(image_chan + num_class, hidden_dim=64).to(device=device)
initilalize_weights(gen)
initilalize_weights(disc)

In [14]:
criterion = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.Adam(params=gen.parameters(),lr=lr)
disc_opt = torch.optim.Adam(params=disc.parameters(),lr=lr)

In [None]:
import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter

In [None]:
#'''
%load_ext tensorboard
%tensorboard --logdir logs
#'''

In [15]:
def Noise(z_dim, labels, num_class):
    one_hot_labels = F.one_hot(labels, num_class).unsqueeze(2).unsqueeze(3)
    noise = torch.randn(B, z_dim, 1, 1)
    return torch.cat((noise,one_hot_labels), 1)

In [16]:
def scale_image(image):
    return (image + 1) / 2

In [17]:
if not os.path.exists("gen_image"):
    os.makedirs("gen_image")

In [None]:
# Training Process
fixed_noise = torch.randn(batch_size, z_dim + num_class, 1, 1).to(device=device) # used for Tensorboard
step = 0
gen_losses = []
disc_losses = []

gen.train()
disc.train()

for epoch in range(num_epoches):
    for batch_idx, (real, labels) in enumerate(train_loader):
        B, C, H, W = real.shape
        real = real.to(device=device)
        labels = labels.to(device=device)
        input_noise = Noise(z_dim, labels, num_class).to(device=device)
        fake_image = gen(input_noise)
        
        # Train Discriminator
        disc_input = Disc_input_hot_encoded(real, labels, num_class)
        disc_real = disc(disc_input).view(-1) # prediction of real images
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real)) # loss on real images       
        disc_fake = disc(fake_image).view(-1) # prediction of fake images
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # loss on fake images 
        LOSS_DISC = (loss_disc_fake + loss_disc_real) / 2 # Total DISC LOSS
        disc.zero_grad()
        LOSS_DISC.backward(retain_graph=True)
        disc_opt.step()
        disc_losses.append(LOSS_DISC.item())
        
        # Train Generator
        fake_pred = disc(fake_image).view(-1) # DISC prediction of fake images produced by generator
        LOSS_GEN = criterion(fake_pred, torch.ones_like(fake_pred)) # Generator LOSS
        gen.zero_grad()
        LOSS_GEN.backward(retain_graph=True)
        gen_opt.step() 
        gen_losses.append(LOSS_GEN.item())
        

    fake_images = fake_image.reshape(-1, C, H, W) 
    save_image(scale_image(fake_images), f"gen_image/{epoch+1}.png")