In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder, MNIST
from torch.autograd import Variable
from torchvision.utils import make_grid
from torchvision import transforms
from torch import autograd



In [None]:
batch_size = 64
this_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
])
data_loader = torch.utils.data.DataLoader(MNIST('data', download = "True", train = 'True', transform = this_transform), batch_size=batch_size, shuffle = True)


In [None]:

from matplotlib.widgets import Line2D
from torch.nn.modules.activation import LeakyReLU
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.label_embedding_layer = nn.Embedding(10,16)

    self.net = nn.Sequential(
        nn.Linear(144,256),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Linear(256,512),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Linear(512,1024),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Linear(1024,2048),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Linear(2048, 784),
        nn.Tanh()
    )

  def forward(self, x, label):
    embedded_label = self.label_embedding_layer(label)
    inp = torch.cat([x,embedded_label], 1)
    out = self.net(inp)
    return out.view(out.size(0), 28, 28)

  


class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()

    self.label_embedding_layer = nn.Embedding(10,16)

    self.net = nn.Sequential(
        nn.Linear(800,1024),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(1024,2048),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(2048,512),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(512,256),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(0.3),
        nn.Linear(256,1),
        nn.Sigmoid()
    )

  def forward(self, x, label):
    #print(x.shape, label.shape)
    x = x.view(x.size(0), 784)
    #print(x.shape, label.shape)
    embedded_label  = self.label_embedding_layer(label)
    inp = torch.cat([x, embedded_label], axis = 1)
    out = self.net(inp)
    return out.squeeze()




In [None]:
def steps_G(batch_size, net_G, net_D, optimizer_G, criteria):
  optimizer_G.zero_grad()

  # Noisy input for generator
  noise_ip = Variable(torch.randn(batch_size, 128)).cuda()
  fake_label = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()
  fake_image = net_G(noise_ip, fake_label)
  fake_pred = net_D(fake_image, fake_label)
  loss_G = criteria(fake_pred, Variable(torch.ones(batch_size)).cuda())  # Comparing against 1. Because these should be true labels
  loss_G.backward()
  optimizer_G.step()
  return loss_G.item()


def step_D(batch_size, net_G, net_D, optimizer_D, criteria, real_image, real_label):

  #Step 1: With REAL Image
  real_pred = net_D(real_image, real_label)
  
  real_loss = criteria(real_pred,Variable(torch.ones(batch_size)).cuda()).cuda()

  #Step 2: With fake Image
  noise_ip = Variable(torch.randn(batch_size, 128)).cuda()
  fake_label = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cuda()
  
  fake_image = net_G(noise_ip, fake_label)
  
  #print(fake_image.shape, fake_label.shape)
  fake_pred = net_D(fake_image, fake_label)
  #print("here")
  fake_loss = criteria(fake_pred,Variable(torch.zeros(batch_size)).cuda()).cuda()

  loss_D = real_loss + fake_loss
  loss_D.backward()
  optimizer_D.step()
  return loss_D.item()





In [None]:
from tensorboardX import SummaryWriter
criteria = nn.BCELoss()
net_G = Generator().cuda()
net_D = Discriminator().cuda()
optimizer_G = torch.optim.AdamW(net_G.parameters(), lr = 0.0001)
optimizer_D = torch.optim.AdamW(net_D.parameters(), lr = 0.001)

epochs = 50
display_epoch = 10
critics = 5
writer = SummaryWriter()

for epoch in range(epochs):
  print("Running epoch no: {}".format(epoch))

  for idx, (image, label) in enumerate(data_loader):
    real_image = Variable(image).cuda()
    label = Variable(label).cuda()
    net_G.train()
    #print("shape {}".format(real_image.shape))
    d_loss = 0
    for _ in range(critics):
      d_loss += step_D(real_image.shape[0], net_G, net_D, optimizer_D, criteria, real_image, label)

    g_loss = steps_G(real_image.shape[0], net_G, net_D, optimizer_G, criteria)
    writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / critics)}, epoch)

    if epoch % display_epoch == 0:
            print(" Generator Loss: {}   Discriminator Loss: {}".format(g_loss, d_loss))
            net_G.eval()
            z = Variable(torch.randn(9, 128)).cuda()
            labels = Variable(torch.LongTensor(np.arange(9))).cuda()
            sample_images = net_G(z, labels).unsqueeze(1)
            grid = make_grid(sample_images, nrow=3, normalize=True)
            writer.add_image('sample_image', grid, epoch)

print("Completed")

RuntimeError: ignored

In [None]:
torch.save(net_G.state_dict(), 'generator_state.pt')

In [None]:
CUDA_LAUNCH_BLOCKING = "1"

Variable(torch.Tensor([1,2])).to(device="cuda:1")

RuntimeError: ignored

In [None]:
images = net_G(z, labels).unsqueeze(1)

RuntimeError: ignored

In [None]:
from torchvision.utils import make_grid
grid = make_grid(images, nrow=10, normalize=True)
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(grid.permute(1, 2, 0).data, cmap='binary')
ax.axis('off')