In [1]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [6]:
latent_dim=100

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(input_dim, output_dim, normalize=True):
            layers=[nn.Linear(input_dim, output_dim)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_dim, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model=nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1*28*28),
            nn.Tanh()
        )
    
    def forward(self, z):
        img=self.model(z)
        img=img.view(img.size(0), 1, 28, 28)
        return img

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model=nn.Sequential(
            nn.Linear(1*28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, img):
        flattened=img.view(img.size(0), -1)
        output=self.model(flattened)
        return output

In [4]:
transforms_train=transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset=datasets.MNIST(root="./dataset", train=True, download=True, transform=transforms_train)
dataloader=torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset\MNIST\raw\train-images-idx3-ubyte.gz


100.1%

Extracting ./dataset\MNIST\raw\train-images-idx3-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz


113.5%

Extracting ./dataset\MNIST\raw\train-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz


100.4%

Extracting ./dataset\MNIST\raw\t10k-images-idx3-ubyte.gz to ./dataset\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./dataset\MNIST\raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
generator=Generator()
discriminator=Discriminator()

generator.cuda()
discriminator.cuda()

adversarial_loss=nn.BCELoss()
adversarial_loss.cuda()

lr=0.0002

optimizer_G=torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D=torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [8]:
import time

n_epochs=200
sample_interval=2000
start_time=time.time()

In [9]:
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real=torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0)
        fake=torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0)
        
        real_imgs=imgs.cuda()
        
        optimizer_G.zero_grad()
        
        z=torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).cuda()
        
        generated_imgs=generator(z)
        
        g_loss=adversarial_loss(discriminator(generated_imgs), real)
        
        g_loss.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        
        real_loss=adversarial_loss(discriminator(real_imgs), real)
        fake_loss=adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss=(real_loss+fake_loss)/2
        
        d_loss.backward()
        optimizer_D.step()
        
        done=epoch+len(dataloader)+i
        if done%sample_interval==0:
            save_image(generated_imgs.data[:25], f"{done}.png", nrow=5, normalize=True)
    print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time() - start_time:.2f}s]")

[Epoch 0/200] [D loss: 0.568325] [G loss: 1.300308] [Elapsed time: 402.57s]
[Epoch 1/200] [D loss: 0.383768] [G loss: 1.107510] [Elapsed time: 406.21s]
[Epoch 2/200] [D loss: 0.340844] [G loss: 1.020312] [Elapsed time: 409.96s]
[Epoch 3/200] [D loss: 0.426620] [G loss: 0.705824] [Elapsed time: 413.61s]
[Epoch 4/200] [D loss: 1.076822] [G loss: 4.125616] [Elapsed time: 417.29s]
[Epoch 5/200] [D loss: 0.315911] [G loss: 1.278444] [Elapsed time: 420.99s]
[Epoch 6/200] [D loss: 0.344296] [G loss: 0.972125] [Elapsed time: 424.53s]
[Epoch 7/200] [D loss: 0.214985] [G loss: 2.805456] [Elapsed time: 428.12s]
[Epoch 8/200] [D loss: 0.230796] [G loss: 2.108602] [Elapsed time: 431.69s]
[Epoch 9/200] [D loss: 0.218235] [G loss: 1.487318] [Elapsed time: 435.19s]
[Epoch 10/200] [D loss: 0.181039] [G loss: 1.606486] [Elapsed time: 438.71s]
[Epoch 11/200] [D loss: 0.141854] [G loss: 2.605847] [Elapsed time: 442.22s]
[Epoch 12/200] [D loss: 0.633585] [G loss: 0.429155] [Elapsed time: 445.97s]
[Epoch 13

[Epoch 107/200] [D loss: 0.228679] [G loss: 2.447406] [Elapsed time: 778.49s]
[Epoch 108/200] [D loss: 0.222144] [G loss: 2.149158] [Elapsed time: 782.03s]
[Epoch 109/200] [D loss: 0.186704] [G loss: 2.022593] [Elapsed time: 785.56s]
[Epoch 110/200] [D loss: 0.313482] [G loss: 3.884405] [Elapsed time: 789.33s]
[Epoch 111/200] [D loss: 0.224471] [G loss: 2.478915] [Elapsed time: 792.82s]
[Epoch 112/200] [D loss: 0.309097] [G loss: 1.596032] [Elapsed time: 796.42s]
[Epoch 113/200] [D loss: 0.360946] [G loss: 2.419624] [Elapsed time: 800.02s]
[Epoch 114/200] [D loss: 0.331765] [G loss: 1.712710] [Elapsed time: 803.62s]
[Epoch 115/200] [D loss: 0.245333] [G loss: 1.804303] [Elapsed time: 807.31s]
[Epoch 116/200] [D loss: 0.320033] [G loss: 3.188808] [Elapsed time: 811.03s]
[Epoch 117/200] [D loss: 0.293326] [G loss: 3.838867] [Elapsed time: 814.69s]
[Epoch 118/200] [D loss: 0.193630] [G loss: 2.305561] [Elapsed time: 818.24s]
[Epoch 119/200] [D loss: 0.188222] [G loss: 2.689086] [Elapsed t

In [11]:
from IPython.display import Image

Image('92000.png')

FileNotFoundError: No such file or directory: '92000.png'

FileNotFoundError: No such file or directory: '92000.png'

<IPython.core.display.Image object>