In [None]:
# !pip install pycm livelossplot
# !pip install torchsummary 
# !pip install tsne_torch

# %pylab inline

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets
import torchvision.transforms as transforms
from torchsummary import summary
from torchvision.datasets import FashionMNIST, MNIST
import torch.nn.functional as F
from collections import Counter, defaultdict
from tqdm import tqdm
from tsne_torch import TorchTSNE as TSNE
from livelossplot import PlotLosses
import random

import sys
sys.path.insert(1, '..')
from models import cGAN

In [None]:
 
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = False  ##uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. -
    torch.backends.cudnn.enabled   = False

    return True



In [None]:
device = 'cpu'
if torch.cuda.device_count() > 0 and torch.cuda.is_available():
    print("Cuda installed! Running on GPU!")
    device = 'cuda'
else:
    print("No GPU available!")

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

<br>

---

<br>

In [None]:
batch_size = 100

In [None]:
# Load the FashionMNIST dataset and specify the transformations.
fashion_mnist_dataset = FashionMNIST("./", 
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         ]), 
                                     target_transform=torchvision.transforms.Compose([
                                         lambda x:torch.LongTensor([x])
                                        ]),
                                     download=True, train=True)

In [None]:
fashion_mnist_dataset

In [None]:
class_to_idx = fashion_mnist_dataset.class_to_idx
class_to_idx

In [None]:
Counter(fashion_mnist_dataset.targets.to('cpu').detach().numpy())

In [None]:
fashion_mnist_dataset.data[0].shape

#### Plotting 10 samples from each class

In [None]:
def plot_classes(dataset, num_per_class=10):
  class_counts = defaultdict(int)
  images = []
  for img, label_tensor in zip(dataset.data, dataset.targets):
    label = label_tensor.item()
    if class_counts[label] < 10:
      images.append((img.to('cpu').detach().numpy(), label))
      class_counts[label] += 1

  images = [x for x in sorted(images, key=lambda t: t[1])]
  _, ax = plt.subplots(10, 10, figsize=[20, 20])
  for i, img in enumerate(images):
    ax[img[1], i % 10].imshow(img[0].squeeze(), cmap='gray')

plot_classes(fashion_mnist_dataset)

In [None]:
train_loader = DataLoader(dataset=fashion_mnist_dataset, batch_size=batch_size, shuffle=True)

In [None]:


G = cGAN.Generator()
D = cGAN.Discriminator()


In [None]:
# Define loss
criterion = nn.BCELoss() 

# Define dimensions of noise vector input into generator
z_dim = 100

# set learning rate
lr = 0.0001

def D_train(G, D, D_optimizer, x, label):
    D.train()
    D_optimizer.zero_grad()

    # train discriminator on real data -- assign high score (use 1 here)
    x_real, y_real = x.view(-1, 28*28), torch.ones((batch_size, 1))  # we are assigning the label 'real data' to the samples (don't care anymore about what number they are)
    x_real, y_real = x_real.to(device), y_real.to(device)

    # Get output from real image
    D_output = D(x_real, label)

    # Calculate loss from real images. Use label smoothing for y_real.
    D_real_loss = criterion(D_output, y_real-0.1)

    # train discriminator on fake data -- assign low score (use 0 here)
    # sample vector and produce generator output
    z = torch.randn(batch_size, 100, 1, 1).to(device)

    # Create random labels.
    label_fake = torch.randint(0, 9, (batch_size, 1)).to(device)

    # Generate fake image.
    x_fake, y_fake = G(z, label_fake), torch.zeros((batch_size, 1)).to(device)

    # Get discriminator output from fake images.
    D_output = D(x_fake, label_fake)

    # Calculate loss from fake images. Use label smoothing for y_fake.
    D_fake_loss = criterion(D_output, y_fake+0.1)

    # Add real and fake loss.
    D_loss = D_real_loss + D_fake_loss

    # Update model.
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

def G_train(G, D, G_optimizer, x, label):
    G.train()
    G_optimizer.zero_grad()

    # Create random labels.
    random_label = torch.randint(0, 9, (batch_size, 1)).to(device)
    
    # Create random vector for input to generator
    z = torch.randn(batch_size, 100, 1, 1).to(device)

    # Generate fake images.
    G_output = G(z, random_label)

    # Get output of discriminator with fake images.
    D_output = D(G_output, random_label)

    # Calculate loss from generated images.
    y = torch.ones((batch_size, 1)).to(device)
    G_loss = criterion(D_output, y)

    # Update model.
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [None]:
from livelossplot import PlotLosses

start_epoch = 1
load_model = False
n_epoch = 50 
groups = {'Loss': ['D_Loss', 'G_Loss']}
liveloss = PlotLosses(groups=groups)

generator_path = "/content/gdrive/My Drive/models/Generator_50.pth"
discriminator_path = "/content/gdrive/My Drive/models/Discriminator_50.pth"

G = Generator().to(device)
D = Discriminator().to(device)

# Instantiate optimizers for G and D
G_optimizer = torch.optim.Adam(G.parameters(), lr = lr, betas = (0.5, 0.9))
D_optimizer = torch.optim.Adam(D.parameters(), lr = lr, betas = (0.5, 0.9))

# Used to load existing model.
if load_model:
  G_checkpoint = torch.load(generator_path)
  G.load_state_dict(G_checkpoint['model_state_dict'])
  G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])

  D_checkpoint = torch.load(discriminator_path)
  D.load_state_dict(D_checkpoint['model_state_dict'])
  D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])

  start_epoch = D_checkpoint['epoch']

  G.train()
  D.train()


for epoch in range(start_epoch, start_epoch+n_epoch+1):  
  D_losses, G_losses = [], []
  logs = {}
  for batch_idx, (x, label) in enumerate(train_loader):
    x, label = x.to(device), label.to(device)

    # Train discriminator and generator
    logs['D_Loss'] = D_train(G, D, D_optimizer, x, label)
    logs['G_Loss'] = G_train(G, D, G_optimizer, x, label)
  liveloss.update(logs)
  liveloss.draw()

  # save every 10 epochs
  if(np.mod(epoch, 10) == 0):

    torch.save({
            'epoch': epoch,
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': G_optimizer.state_dict(),
            'loss': logs['G_Loss'],
            }, generator_path)
    torch.save({
            'epoch': epoch,
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': D_optimizer.state_dict(),
            'loss': logs['D_Loss'],
            }, discriminator_path)

In [None]:
set_seed(0)

## Load the generator
# G = Generator()
# G.load_state_dict(torch.load(generator_path))
# G.train()

with torch.no_grad():

    # Generate random inputs for testing.
    test_z, labels = torch.randn(batch_size, 100, 1, 1).to(device), torch.linspace(0,9,10).repeat(10).to(device).long().view(-1, 1)
    generated = G(test_z, labels)

    # save_image(generated.view(generated.size(0), 1, 28, 28), './sample_' + '.png')
fig, axarr = plt.subplots(10, 10, figsize=(12, 12))
for ax, img in zip(axarr.flatten(), generated.view(generated.size(0), 28, 28).cpu()):
  ax.imshow(img, cmap="gray")
plt.title('Epoch = {:03d}'.format(epoch-1))