In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import PIL
import urllib
from torchvision import datasets, transforms

In [None]:
mnist_data = datasets.MNIST('data', train=True, download=True)
mnist_data = list(mnist_data)
mnist_train = mnist_data[:5000]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 95577226.56it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 79225437.43it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29252193.72it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11766849.15it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [None]:
criterion = nn.BCELoss()

In [None]:
img_to_tensor = transforms.ToTensor()

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.model = nn.Sequential(
    nn.Linear(28*28, 300),
    nn.LeakyReLU(0.2),
    nn.Linear(300, 100),
    nn.LeakyReLU(0.2),
    nn.Linear(100, 1))
  def forward(self, x):
    x = x.view(x.size(0), -1)
    out = self.model(x)
    return out.view(x.size(0))

class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.model = nn.Sequential(
    nn.Linear(100, 300),
    nn.LeakyReLU(0.2),
    nn.Linear(300, 28*28),
    nn.Sigmoid())
  def forward(self, x):
    out = self.model(x).view(x.size(0), 1, 28, 28)
    return out.view(x.size(0))

In [None]:
def train_discriminator(discriminator, generator, images):
  batch_size = img_to_tensor(images).size(0)
  noise = torch.randn(batch_size, 100)
  fake_images = generator(noise)
  inputs = torch.cat([images, fake_images])
  labels = torch.cat([torch.zeros(batch_size), # Real
  torch.ones(batch_size)]) # Fake
  outputs = discriminator(inputs)
  loss = criterion(outputs, labels)
  return outputs, loss


In [None]:
def train_generator(discriminator, generator, images):
  batch_size = images.size(0)
  noise = torch.randn(batch_size, 100)
  fake_images = generator(noise)
  outputs = discriminator(fake_images)
  # Only looks at fake outputs
  # gets rewarded if we fool the discriminator!
  labels = torch.zeros(batch_size)
  loss = criterion(outputs, labels)
  return fake_images, loss


In [None]:
def train(discriminator, generator, lr, epochs, train_loader  ):
  torch.manual_seed(42)
  optimizerDisc = optim.Adam(discriminator.parameters(), lr = lr)
  optimizerGen = optim.Adam(generator.parameters(), lr = lr)

  Gtrain_acc, Gval_acc, iters, Gtrain_loss, Gval_loss = [], [], [], [], []
  Dtrain_acc, Dval_acc, iters, Dtrain_loss, Dval_loss = [], [], [], [], []
  genImgs = []

  start_time = time.time()
  for epoch in range(epochs):
    Gtotal_train_loss = 0.0
    Dtotal_train_loss = 0.0
    iteration = 0
    for imgs, __ in train_loader:
      outputD, lossD = train_discriminator(discriminator, generator, imgs)
      lossD.backward()
      Dtotal_train_loss += lossD.item()
      optimizerDisc.step()
      optimizerDisc.zero_grad()

      outputG, lossG = train_generator(discriminator, generator, imgs)
      lossG.backward()
      Gtotal_train_loss += lossG.item()
      optimizerGen.step()
      optimizerGen.zero_grad()
      iteration += 1

    Gtrain_loss.append(float(Gtotal_train_loss) / (iteration + 1))
    #Gval_loss.append(evaluate(model, valid_loader, criterion))
    Dtrain_loss.append(float(Dtotal_train_loss) / (iteration + 1))

    print(("Epoch {}: Gen Train loss: {}, Disc Train loss: {}").format(
               # + "Gen Validation loss: {}, Disc Validation loss: {}").format(
                   epoch + 1,
                   #train_acc[epoch],
                   Gtrain_loss[epoch],
                   Dtrain_loss[epoch]))
                   #val_acc[epoch],
                   #val_loss[epoch]))
    # Save the current model (checkpoint) to a file
    #model_path = get_model_name(model.name, learning_rate, epoch)
    #torch.save(model.state_dict(), model_path)
  print('Finished Training')
  end_time = time.time()
  elapsed_time = end_time - start_time
  print("Total time elapsed: {:.2f} seconds".format(elapsed_time))
  #print(("Final Train Accuracy: {}, |"+
  #             "Final Validation Accuracy: {}").format(
  #                 train_acc[-1],
  #                 val_acc[-1],))
#return train_acc, train_loss, val_acc, val_loss
  noise = torch.randn(64, 100)
  fake_imgs = generator(noise).detach()
  genImgs.append(fake_imgs)
  return Gtrain_loss, Dtrain_loss, genImgs

def plot(Gtrain_loss, Dtrain_loss):
    # Plotting
    plt.title("Gen Train Loss vs Disc Train Loss")
    n = len(Gtrain_loss) # number of epochs
    plt.plot(range(1,n+1), Gtrain_loss, label="Gen Train")
    plt.plot(range(1,n+1), Dtrain_loss, label="Disc Train")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()

def check(dataLoader, genImgs):
  real = next(iter(dataLoader))

  plt.figure(figsize = (15,15))
  plt.subplot(1,2,1)
  plt.axis("off")
  plt.title("Real")
  plt.imshow(np.transpose(real[0], (1,2,0)))

  plt.subplot(1,2,2)
  plt.axis("off")
  plt.title("Fake")
  plt.imshow(np.transpose(genImgs[0], (1,2,0)))
  plt.show()

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
Gtrain_loss, Dtrain_loss, genImgs = train(discriminator, generator, 0.001, 5, mnist_train)

RuntimeError: ignored