In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Ongoing Device : {device}')

Ongoing Device : cuda


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [4]:
class Generator(nn.Module):
    def __init__(self,image_dim=784,hidden_dim=256,latent_dim=100):
        super(Generator,self).__init__()

        self.fc = nn.Sequential(nn.Linear(latent_dim, 256),
                                nn.LeakyReLU(0.2),
                                nn.Linear(256, 512),
                                nn.LeakyReLU(0.2),
                                nn.Linear(512, 1024),
                                nn.LeakyReLU(0.2),
                                nn.Linear(1024, 28*28),
                                nn.Tanh()
                                )   
    def forward(self,x):
        return self.fc(x)
        

        
class Discriminator(nn.Module):
    def __init__(self,image_dim=784,hidden_dim=256,latent_dim=100):
        super(Discriminator,self).__init__()

        self.fc = nn.Sequential(nn.Linear(28*28, 1024),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(1024, 512),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(512, 256),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(256, 1),
                                nn.Sigmoid()    
                                )
        
    def forward(self,x):
        return self.fc(x)

In [5]:
G = Generator().to(device)
D = Discriminator().to(device)

In [6]:
# Loss function
criterion = nn.BCELoss()

# Optimizers
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
import torch.autograd as autograd

# Number of epochs
num_epochs = 200
latent_dim = 100

print(f"Ongoing Device : {device}")
# Training loop
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # Flatten the images for the Discriminator
        images = images.view(images.size(0), -1).to(device)

        # Real labels are 1, fake labels are 0
        real_labels = torch.ones(images.size(0), 1).to(device)
        fake_labels = torch.zeros(images.size(0), 1).to(device)

        ############################
        # Train the Discriminator
        ############################
        D_optimizer.zero_grad()

        # Compute BCELoss using real images
        outputs = D(images)
        D_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # Generate fake images
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = G(z)

        # Compute BCELoss using fake images
        outputs = D(fake_images.detach())
        D_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # Optimize the Discriminator
        D_loss = D_loss_real + D_loss_fake
        D_loss.backward()
        D_optimizer.step()

        ############################
        # Train the Generator
        ############################
        G_optimizer.zero_grad()

        # Generate fake images
        z = torch.randn(images.size(0), latent_dim).to(device)
        fake_images = G(z)

        # Compute BCELoss using fake images, with reversed labels
        outputs = D(fake_images)
        G_loss = criterion(outputs, real_labels)

        # Optimize the Generator
        G_loss.backward()
        G_optimizer.step()

    print('Epoch [{}/{}], Step [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'
        .format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item()))
    
    if epoch % 50 == 0:
            torch.save(G.state_dict(), f'checkpoints/GAN/ckpt_{epoch}.pth')

Ongoing Device : cuda
Epoch [1/200], Step [938/938], D_loss: 0.2783, G_loss: 2.6963
