In [23]:
import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [24]:
class Generator(nn.Module):
    def __init__(self, z_dim = 10, img_channels = 1, hidden_dim = 64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self.gen_block(z_dim, hidden_dim * 16, 4, 1, 0),
            self.gen_block(hidden_dim * 16, hidden_dim * 8, 4, 2, 1),
            self.gen_block(hidden_dim * 8, hidden_dim * 4, 4,  2, 1),
            self.gen_block(hidden_dim * 4, hidden_dim * 2, 4, 2, 1),
            nn.ConvTranspose2d(hidden_dim * 2, 1, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh()
        )
        
    def gen_block(self, input_channels, output_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU()
        )
    
    def forward(self, noise):
        return self.gen(noise)
    

In [25]:
def get_noise(n_samples, z_dim):
    return torch.randn(n_samples, z_dim, 1, 1).to('cuda')

In [26]:
class Discriminator(nn.Module):
    def __init__(self, img_channels, hidden_dim = 64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(img_channels, hidden_dim, kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2),
            self.disc_block(hidden_dim, hidden_dim * 2, 4, 2, 1),
            self.disc_block(hidden_dim * 2, hidden_dim * 4, 4, 2, 1),
            self.disc_block(hidden_dim * 4, hidden_dim * 8, 4, 2, 1), # 4x4
            nn.Conv2d(hidden_dim * 8, 1, kernel_size = 4, stride = 2, padding = 0), # 1x1
            nn.Sigmoid()
        )
    
    def disc_block(self, input_channels, output_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, bias = False),
            nn.BatchNorm2d(output_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, img):
        return self.disc(img)

In [27]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [28]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 100
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    FashionMNIST('.', download = False, transform = transform),
    batch_size = batch_size,
    shuffle = True
)

In [29]:
gen = Generator(z_dim, 1, 64).to(device)
init_weights(gen)
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = (beta_1, beta_2))

disc = Discriminator(1, 64).to(device)
init_weights(disc)
disc_opt = torch.optim.Adam(disc.parameters(), lr = lr, betas = (beta_1, beta_2))

In [None]:
def show_tensor_images(im):
    image_tensor = (im + 1) / 2 # unnormalize images
    
    

In [31]:
n_epochs = 10
step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
display_step = 500

for epoch in range(n_epochs):
    for real, _ in(dataloader):
        real = real.to(device)
        
        disc_opt.zero_grad()
        fake_noise = get_noise(batch_size, z_dim)
        fake_img = gen(fake_noise)
        disc_preds_fake = disc(fake_img.detach()).reshape(-1)
        disc_fake_loss = criterion(disc_preds_fake, torch.zeros_like(disc_preds_fake))
        disc_preds_real = disc(real).reshape(-1)
        disc_real_loss = criterion(disc_preds_real, torch.ones_like(disc_preds_real))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        
        mean_discriminator_loss += disc_loss.item() / display_step # item() returns the value of tensor as a single python number
        disc_loss.backward(retain_graph = True)
        disc_opt.step()
        
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(batch_size, z_dim)
        fake_img_2 = gen(fake_noise_2)
        disc_preds_fake = disc(fake_img_2).reshape(-1)
        gen_loss = criterion(disc_preds_fake, torch.ones_like(disc_preds_fake))
        gen_loss.backward()
        gen_opt.step()
        
        mean_generator_loss += gen_loss.item() / display_step
        
        if step % display_step == 0 and step > 0:
            print(f"Step {step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
#             show_tensor_images(fake)
#             show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        step += 1
        

Step 500: Generator loss: 0.6945296313762679, discriminator loss: 0.5042146071195603
Step 1000: Generator loss: 0.693145994305615, discriminator loss: 0.5032055307626727
Step 1500: Generator loss: 0.6931467665433926, discriminator loss: 0.5032048879861822
Step 2000: Generator loss: 0.6931469439268133, discriminator loss: 0.5032046910524366
Step 2500: Generator loss: 0.6931470544338226, discriminator loss: 0.5032046095132799
Step 3000: Generator loss: 0.6931470632553101, discriminator loss: 0.50320456576347
Step 3500: Generator loss: 0.6931470632553101, discriminator loss: 0.5032045435905427
Step 4000: Generator loss: 0.6931470576524735, discriminator loss: 0.503204541802404
Step 4500: Generator loss: 0.6931470503807104, discriminator loss: 0.5032045313119884
