In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [2]:
train_dataset = torchvision.datasets.MNIST(
    root = ".",
    download = True,
    transform = transforms.ToTensor()
)

In [3]:
batch_size = 128
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
n_epochs = 100
device = "cpu"
lr = 0.001
data_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

# Generator

In [4]:
def gen_block(in_features, out_features):
    return nn.Sequential(
        nn.Linear(in_features,out_features),
        nn.BatchNorm1d(out_features),
        nn.ReLU(inplace = True)
    )

In [5]:
class Generator(nn.Module):
    
    def __init__(self, img_dim = 28*28, z_dim = 10, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            gen_block(img_dim,hidden_dim),
            gen_block(hidden_dim, hidden_dim*2),
            gen_block(hidden_dim*2, hidden_dim*4),
            gen_block(hidden_dim*4, hidden_dim*8),
            nn.Linear(hidden_dim*8, img_dim),
            nn.Sigmoid()
        )
        
    def forward(self, noise):
        return self.gen(noise)
    
    

In [6]:
def get_noise(n_samples,z_dim, device):
    return torch.randn(n_samples, z_dim, device = device)

# Discriminator

In [7]:
def disc_block(inp_dim, out_dim):
    model = nn.Sequential(
        nn.Linear(inp_dim, out_dim),
        nn.LeakyReLU(0.1, inplace = True)
    )
    return model

In [8]:
class Discriminator(nn.Module):
    def __init__(self, img_dim = 28*28, hidden_dim = 128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            disc_block(img_dim, hidden_dim*4),
            disc_block(hidden_dim*4, hidden_dim*2),
            disc_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self,img):
        return self.disc(self, img)

# Training


In [9]:
gen = Generator(z_dim).to(device)
disc = Discriminator().to(device)
gen_optim = torch.optim.Adam(gen.parameters(), lr = lr)
disc_optim = torch.optim.Adam(disc.parameters(), lr = lr)

In [10]:
def gen_loss(z_dim, n_img, criterion, gen, disc, real, device):
    fake_noise = get_noise(num_images, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake)
    gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
    return gen_loss

In [11]:
def disc_loss(z_dim, n_img, criterion, gen, disc, real, device):
    fake_noise = get_noise(n_img, z_dim, device=device)
    fake = gen(fake_noise)
    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
    disc_loss = (disc_fake_loss + disc_real_loss) / 2
    return disc_loss

In [12]:
from tqdm.auto import tqdm
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True 
gen_loss = False
error = False
for epoch in range(n_epochs):
  
    for real, _ in tqdm(data_loader):
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)
        disc_optim.zero_grad()
        disc_loss = disc_loss(z_dim, cur_batch_size, criterion, gen, disc, real,  device)
        disc_loss.backward(retain_graph=True)
        disc_optim.step()
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()
        gen_optim.zero_grad()
        gen_loss = gen_loss(z_dim, cur_batch_size, criterion, gen, disc, real,  device)
        gen_loss.backward()
        gen_optim.step()

  0%|          | 0/469 [00:00<?, ?it/s]

TypeError: forward() takes 2 positional arguments but 3 were given