In [1]:
import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms as tt

In [2]:
data = MNIST(root='data', train= True, download=True, transform= tt.Compose([tt.ToTensor(), tt.Normalize((0.5,),(0.5,))]))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
mean = 0.5
std = 0.5
def denorm(img):
    img = (img*std) + mean
    return img

In [4]:
from torch.utils.data import DataLoader

In [5]:
data_dl = DataLoader(data, batch_size= 200, shuffle=True, pin_memory=True, num_workers=2)

In [6]:
import torch.nn as nn

In [7]:
#Discriminator Model

Dis = nn.Sequential(
    nn.Linear(784,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,256),
    nn.LeakyReLU(0.2),
    nn.Linear(256,1),
    nn.Sigmoid())

In [8]:
Dis

Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): LeakyReLU(negative_slope=0.2)
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): LeakyReLU(negative_slope=0.2)
  (4): Linear(in_features=256, out_features=1, bias=True)
  (5): Sigmoid()
)

In [9]:
#Generator Model

Gen = nn.Sequential(
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256,256),
    nn.ReLU(),
    nn.Linear(256,784),
    nn.Tanh())

In [10]:
Gen

Sequential(
  (0): Linear(in_features=64, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=784, bias=True)
  (5): Tanh()
)

In [11]:
import torch

In [12]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [13]:
Dis.to(device)
Gen.to(device)

Sequential(
  (0): Linear(in_features=64, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=784, bias=True)
  (5): Tanh()
)

In [14]:
loss_func = nn.BCELoss()
dis_opt_func = torch.optim.Adam(Dis.parameters(), lr=2e-4)
gen_opt_func = torch.optim.Adam(Gen.parameters(), lr=2e-4)

In [15]:
def disTraining(images):
    
    fake_labels = torch.zeros(200, 1)
    real_labels = torch.ones(200, 1)
    
    fake_labels = fake_labels.to(device)
    real_labels = real_labels.to(device)

    #determining loss of the discriminator for the real images
    preds = Dis(images)
    dis_loss_real = loss_func(preds, real_labels)

    # determining loss of the discriminator for the fake generated images
    x = torch.randn(200, 64)
    x = x.to(device)
    fake_preds = Gen(x)
    fake_preds_dis = Dis(fake_preds)
    dis_loss_fake = loss_func(fake_preds_dis, fake_labels)

    #total discriminator loss
    dis_loss = dis_loss_fake + dis_loss_real

    dis_opt_func.zero_grad()

    dis_loss.backward()

    dis_opt_func.step()
    

    return dis_loss


In [16]:
def genTraining():

    x = torch.randn(200, 64)
    x = x.to(device)
    
    labels = torch.ones(200, 1)
    labels = labels.to(device)
    
    fake_preds = Gen(x)
    fake_preds_dis = Dis(fake_preds)

    gen_loss = loss_func(fake_preds_dis, labels)

    gen_opt_func.zero_grad()

    gen_loss.backward()

    gen_opt_func.step()

    return gen_loss

In [17]:
from torchvision.utils import save_image
import os

random_vectors = torch.randn(200,64).to(device)
def save_after_epoch(i):

    out = Gen(random_vectors)
    out = out.reshape(out.size(0),1,28,28)
    
    name = f'saved_image_after_epoch{i+1}.png'
    
    print(f'saving {name}')
    
    save_image(denorm(out), os.path.join('./images', name), nrow=20)

In [18]:
def fit(num_epochs):
    for epoch in range(num_epochs):
        for images, _ in data_dl:
            
            images = images.reshape(images.size(0), -1)
            images = images.to(device)
            
            d_loss = disTraining(images)
            g_loss = genTraining()
        print(f'discriminator loss: {d_loss}, generator_loss : {g_loss}')
        save_after_epoch(epoch)

In [19]:
fit(300)

discriminator loss: 0.039929114282131195, generator_loss : 4.456997394561768
saving saved_image_after_epoch1.png
discriminator loss: 0.2915172278881073, generator_loss : 4.176916599273682
saving saved_image_after_epoch2.png
discriminator loss: 0.16048961877822876, generator_loss : 3.2788524627685547
saving saved_image_after_epoch3.png
discriminator loss: 0.2639872431755066, generator_loss : 4.294976234436035
saving saved_image_after_epoch4.png
discriminator loss: 0.39071834087371826, generator_loss : 3.5823564529418945
saving saved_image_after_epoch5.png
discriminator loss: 0.7913505434989929, generator_loss : 2.742218017578125
saving saved_image_after_epoch6.png
discriminator loss: 0.6514420509338379, generator_loss : 3.797415018081665
saving saved_image_after_epoch7.png
discriminator loss: 0.9534247517585754, generator_loss : 4.186120986938477
saving saved_image_after_epoch8.png
discriminator loss: 0.6885478496551514, generator_loss : 1.7760659456253052
saving saved_image_after_epoch

In [21]:
from google.colab import files
!zip -r /content/images.zip /content/images

  adding: content/images/ (stored 0%)
  adding: content/images/saved_image_after_epoch61.png (deflated 5%)
  adding: content/images/saved_image_after_epoch296.png (deflated 4%)
  adding: content/images/saved_image_after_epoch62.png (deflated 5%)
  adding: content/images/saved_image_after_epoch198.png (deflated 4%)
  adding: content/images/saved_image_after_epoch2.png (deflated 7%)
  adding: content/images/saved_image_after_epoch187.png (deflated 4%)
  adding: content/images/saved_image_after_epoch154.png (deflated 4%)
  adding: content/images/saved_image_after_epoch169.png (deflated 4%)
  adding: content/images/saved_image_after_epoch249.png (deflated 4%)
  adding: content/images/saved_image_after_epoch267.png (deflated 4%)
  adding: content/images/saved_image_after_epoch20.png (deflated 4%)
  adding: content/images/saved_image_after_epoch150.png (deflated 4%)
  adding: content/images/saved_image_after_epoch232.png (deflated 4%)
  adding: content/images/saved_image_after_epoch263.png (

In [22]:
torch.save(Dis.state_dict(), 'G.ckpt')
torch.save(Gen.state_dict(), 'D.ckpt')