In [9]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
class Generator(nn.Module):
    def __init__(self, noise_dim, hidden_dim, img_size, num_channels, num_classes=10, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed = nn.Embedding(num_classes, noise_dim)
        self.fc1 = nn.Linear(noise_dim*2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, img_size*img_size*num_channels)

    def forward(self, x, y):
        leak = 1e-2

        out = torch.cat([x, self.embed(y)], dim=1)
        
        out = self.fc1(out)
        out = F.leaky_relu(out, leak)
        
        out = self.fc2(out)
        out = F.tanh(out)
        
        return out
    
x = torch.zeros(10, 100).to(device)
y = torch.arange(10).to(device)
model = Generator(100, 256, 28, 1).to(device)
output = model(x, y)
output.shape    

torch.Size([10, 784])

In [12]:
class Discriminator(nn.Module):
    def __init__(self, img_size, num_channels, hidden_dim, noise_dim, num_classes=10, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embed = nn.Embedding(num_classes, img_size*img_size*num_channels)
        self.fc1 = nn.Linear(img_size*img_size*num_channels*2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, noise_dim)
        self.fc3 = nn.Linear(noise_dim, 1)

    def forward(self, x, y):
        leak = 1e-2
        
        out = torch.cat([x, self.embed(y)], dim=1)
        
        out = self.fc1(out)
        out = F.leaky_relu(out, leak)

        out = self.fc2(out)
        out = F.leaky_relu(out, leak)

        out = self.fc3(out)

        return out
    
x = torch.zeros(10, 28*28*1).to(device)
y = torch.arange(10).to(device)
model = Discriminator(28, 1, 256, 100).to(device)
output = model(x, y)
output.shape

torch.Size([10, 1])

In [13]:
img_size = 64
num_channels = 1
noise_dim = 100
hidden_dim = 256
num_classes = 10
model_d = Discriminator(img_size, num_channels, hidden_dim, noise_dim, num_classes).to(device)
model_g = Generator(noise_dim, hidden_dim, img_size, num_channels, num_classes).to(device)

lr = 3e-4
momentum = 0.9
optimizer_d = optim.SGD(model_d.parameters(), lr=lr, momentum=momentum)
optimizer_g = optim.SGD(model_g.parameters(), lr=lr, momentum=momentum)

loss_fn = nn.BCEWithLogitsLoss()

transform = transforms.Compose(
    [
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(num_channels)], [0.5 for _ in range(num_channels)]
        )
    ]
)

train_data = datasets.MNIST(root='mnist/train', train=True, transform=transform, download=False)
val_data = datasets.MNIST(root='mnist/val', train=False, transform=transform, download=False)

writer_fake = SummaryWriter('logs/fake')
writer_real = SummaryWriter('logs/real')

fixed_noise = torch.randn((num_classes, 100)).to(device)
fixed_label = torch.arange(num_classes).to(device)

In [14]:
def calc_disc_loss(model_d, model_g, real, label, noise, loss_fn, is_train):
    with torch.set_grad_enabled(is_train):
        real_arg = model_d(real, label)
        loss_real = loss_fn(real_arg, torch.ones_like(real_arg))
        
        fake = model_g(noise, label)
        fake_arg = model_d(fake, label)
        loss_fake = loss_fn(fake_arg, torch.zeros_like(fake_arg))

        loss_d = loss_fake + loss_real

    return loss_d

def calc_gen_loss(model_d, model_g, noise, label, loss_fn, is_train):
    with torch.set_grad_enabled(is_train):
        fake = model_g(noise, label)
        fake_arg = model_d(fake, label)
        loss_g = loss_fn(fake_arg, torch.ones_like(fake_arg))

    return loss_g

def calc_metrics(model_d, model_g, loss_fn, train_loader, val_loader, metrics_dict, device=device):
    with torch.no_grad():
        final_str = ''
        loaders_list = [('Train', train_loader), ('Val', val_loader)]
        
        if metrics_dict == None:
            metrics_dict = {'Train': {'DiscLoss': [], 'GenLoss': []}, 'Val': {'DiscLoss': [], 'GenLoss': []}}

        for name, loader in loaders_list:
            len_data = 0
            total_loss_d = 0
            total_loss_g = 0

            for real, label in loader:
                real, label = real.to(device), label.to(device)

                batch_size = real.shape[0]
                len_data += batch_size

                real = real.view(batch_size, img_size*img_size*num_channels)
                noise = torch.randn(batch_size, noise_dim).to(device)

                loss_d = calc_disc_loss(model_d, model_g, real, label, noise, loss_fn, is_train=False)
                total_loss_d += loss_d

                loss_g = calc_gen_loss(model_d, model_g, noise, label, loss_fn, is_train=False)
                total_loss_g += loss_g
            
            disc_loss = total_loss_d/len_data
            gen_loss = total_loss_g/len_data

            final_str += ' -- {} Disc Loss: {:.5f} -- {} Gen Loss: {:.5f}'.format(name, disc_loss, name, gen_loss)
            
            metrics_dict[name]['DiscLoss'].append(disc_loss.item())
            metrics_dict[name]['GenLoss'].append(gen_loss.item()) 
    
    return final_str, metrics_dict

def visualize_tensorboard(model_g, fixed_noise, fixed_label, epoch):    
    with torch.no_grad():
        fake = model_g(fixed_noise, fixed_label).view(-1, num_channels, img_size, img_size)
        fake_images = make_grid(fake, nrow=5, normalize=True)
        writer_fake.add_image('Fake', fake_images, global_step=epoch)
    
    return None
    
def training_loop(n_epochs, disc_iter, optimizer_d, optimizer_g, model_d, model_g, train_loader, val_loader, device=device):
    metrics_dict = None
    for epoch in range(1, n_epochs+1):
        for real, label in train_loader:
            real, label = real.to(device), label.to(device)
            
            batch_size = real.shape[0]
            real = real.view(batch_size, img_size*img_size*num_channels)

            for _ in range(disc_iter):
                noise = torch.randn(batch_size, noise_dim).to(device)
                loss_d = calc_disc_loss(model_d, model_g, real, label, noise, loss_fn, is_train=True)

                optimizer_d.zero_grad()
                loss_d.backward()
                optimizer_d.step()

            noise = torch.randn(batch_size, noise_dim).to(device)
            loss_g = calc_gen_loss(model_d, model_g, noise, label, loss_fn, is_train=True)

            optimizer_g.zero_grad()
            loss_g.backward()
            optimizer_g.step()

        if epoch == 1 or epoch%2 == 0:
            final_str, metrics_dict = calc_metrics(model_d, model_g, loss_fn, train_loader, val_loader, metrics_dict, device)
            print('Epoch: {:3d}'.format(epoch) + final_str)
            
        visualize_tensorboard(model_g, fixed_noise, fixed_label, epoch)

    return model_d, model_g, metrics_dict

In [15]:
batch_size = 512
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

In [16]:
model_d, model_g, metrics_dict = training_loop(200, 1, optimizer_d, optimizer_g, model_d, model_g, train_loader, val_loader, device)

Epoch:   1 -- Train Disc Loss: 0.00037 -- Train Gen Loss: 0.00368 -- Val Disc Loss: 0.00038 -- Val Gen Loss: 0.00375
Epoch:   2 -- Train Disc Loss: 0.00007 -- Train Gen Loss: 0.00694 -- Val Disc Loss: 0.00007 -- Val Gen Loss: 0.00706
Epoch:   4 -- Train Disc Loss: 0.00002 -- Train Gen Loss: 0.00934 -- Val Disc Loss: 0.00002 -- Val Gen Loss: 0.00949
Epoch:   6 -- Train Disc Loss: 0.00001 -- Train Gen Loss: 0.01015 -- Val Disc Loss: 0.00001 -- Val Gen Loss: 0.01032
Epoch:   8 -- Train Disc Loss: 0.00002 -- Train Gen Loss: 0.00980 -- Val Disc Loss: 0.00002 -- Val Gen Loss: 0.00996
Epoch:  10 -- Train Disc Loss: 0.00005 -- Train Gen Loss: 0.00793 -- Val Disc Loss: 0.00005 -- Val Gen Loss: 0.00806
Epoch:  12 -- Train Disc Loss: 0.00008 -- Train Gen Loss: 0.00742 -- Val Disc Loss: 0.00008 -- Val Gen Loss: 0.00754
Epoch:  14 -- Train Disc Loss: 0.00016 -- Train Gen Loss: 0.00638 -- Val Disc Loss: 0.00016 -- Val Gen Loss: 0.00648
Epoch:  16 -- Train Disc Loss: 0.00089 -- Train Gen Loss: 0.0033