In [1]:
import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset # 
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# some parameters
image_size = 32
mean = 0.0
std = 1.0 # paper says -1,1
z_dim = 100

In [3]:
# initializing mnist dataset
mnist = datasets.MNIST(root = '/content/sample_data', train = True, transform = transforms.Compose(
            [transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([mean], [std])]
        ), download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/train-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/train-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/sample_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /content/sample_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/sample_data/MNIST/raw



In [4]:
# getting a dataloader
dataloader = DataLoader(dataset = mnist, batch_size = 64, shuffle = True)

In [5]:
a = torch.rand((64,100,100))
a.size(0)

64

In [6]:
class generator(nn.Module):
  def __init__(self,):
    super(generator, self).__init__()
    self.init_size = 8 # so we need to upsample twice! 8x8
    self.l1 = nn.Linear(100, 128 * self.init_size **2)
    self.conv = nn.Sequential(
        nn.BatchNorm2d(128),
        nn.Upsample(scale_factor = 2), 
        nn.Conv2d(128, 128, 3, stride=1, padding=1),
        nn.BatchNorm2d(128, 0.8),
        nn.ReLU(),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(128, 64, 3, stride=1, padding=1),
        nn.BatchNorm2d(64, 0.8),
        nn.ReLU(),
        nn.Conv2d(64, 1, 3, stride=1, padding=1),
        nn.Tanh()
    )

  def forward(self,z):
    out = self.l1(z)
    out = out.view(out.shape[0],128,self.init_size,self.init_size)
    out = self.conv(out)

    return out

In [7]:
class discriminator(nn.Module):
  def __init__(self):
    super(discriminator, self).__init__()

    def block(in_filters, out_filters, bn = False, drop = True):
      blk = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.02,inplace=True)]
      if drop:
        blk.append(nn.Dropout2d(0.25))
      if bn:
        blk.append(nn.BatchNorm2d(out_filters, 0.8))
      return blk

    self.downsampled = 2

    self.conv = nn.Sequential(
        *block(1,16),
        *block(16,32,bn = True),
        *block(32,64,bn = True),
        *block(64,128,bn = True)
    )

    self.last = nn.Sequential(nn.Linear(128*self.downsampled**2,1),nn.Sigmoid())


  def forward(self,x):
    out = self.conv(x)
    out = out.view(out.shape[0], 128*self.downsampled**2)
    realness = self.last(out)
    return realness

In [28]:
def gen_loss(z,gen_img,dis): #detach dis
  loss = (-1/64)*torch.sum(torch.log(dis(gen_img))) 
  return loss

def dis_loss(z,x,gen_img,dis):
  loss = (-1/64)*torch.sum(torch.log(dis(x)) + torch.log(1 - dis(gen_img)))
  return loss

def weight_init(m):
  classname = m.__class__.__name__
  if classname.find('Linear')!=-1:
    torch.nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
  if classname.find('Conv')!=-1:
    torch.nn.init.normal_(m.weight.data, mean=0.0, std=0.02)
  

In [29]:
gen = generator()
gen.to(device)
dis = discriminator()
dis.to(device)
gen.apply(weight_init)
dis.apply(weight_init)
optimizer_G = torch.optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [10]:



for epoch in range(200):
  for i, (imgs, _) in enumerate(dataloader):
    real_imgs = imgs.float().to(device)
    z = torch.from_numpy(np.random.normal(0, 1, (imgs.shape[0], 100))).float().to(device)
    gen_img = gen(z)
    optimizer_G.zero_grad()
    gloss = gen_loss(z,gen_img,dis)
    gloss.backward()
    optimizer_G.step()

    optimizer_D.zero_grad()
    dloss = dis_loss(z,real_imgs,gen_img.detach(),dis)
    dloss.backward()
    optimizer_D.step()

    print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, 200, i, len(dataloader), dloss.item(), gloss.item())
        )
    batches_done = epoch * len(dataloader) + i
    if batches_done % 400 == 0:
      save_image(gen_img.data[:25], "/content/gen/%d.png" % batches_done, nrow=5, normalize=True)
    

[Epoch 67/200] [Batch 675/938] [D loss: 0.665829] [G loss: 1.982623]
[Epoch 67/200] [Batch 676/938] [D loss: 0.231213] [G loss: 1.647354]
[Epoch 67/200] [Batch 677/938] [D loss: 0.150161] [G loss: 1.839098]
[Epoch 67/200] [Batch 678/938] [D loss: 0.283003] [G loss: 0.723949]
[Epoch 67/200] [Batch 679/938] [D loss: 0.992412] [G loss: 1.828816]
[Epoch 67/200] [Batch 680/938] [D loss: 0.499651] [G loss: 2.308850]
[Epoch 67/200] [Batch 681/938] [D loss: 1.870312] [G loss: 0.600332]
[Epoch 67/200] [Batch 682/938] [D loss: 1.551552] [G loss: 1.701792]
[Epoch 67/200] [Batch 683/938] [D loss: 1.585146] [G loss: 2.089775]
[Epoch 67/200] [Batch 684/938] [D loss: 0.753094] [G loss: 3.597869]
[Epoch 67/200] [Batch 685/938] [D loss: 0.392669] [G loss: 1.531216]
[Epoch 67/200] [Batch 686/938] [D loss: 0.308362] [G loss: 1.611702]
[Epoch 67/200] [Batch 687/938] [D loss: 0.054551] [G loss: 3.151272]
[Epoch 67/200] [Batch 688/938] [D loss: 0.286297] [G loss: 2.686210]
[Epoch 67/200] [Batch 689/938] [D 

TypeError: ignored