In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision
import torchvision.datasets as dataset
import torchvision.transforms as transform

import matplotlib.pyplot as plt

In [None]:
# Loading the dataset

transforms = transform.Compose(
    [transform.ToTensor(),
    transform.Normalize((0.5,), (0.5,))]
    
)
data = dataset.FashionMNIST('./', train=True, download=True, transform=transforms)
train_data = DataLoader(data, batch_size=64, shuffle=False)


In [None]:
d, l = next(iter(train_data))
d.shape, l.shape

In [None]:
class Generator(nn.Module):
    def __init__(self, img_channel, random_noise_channel, num_class):
        super().__init__()
        
        self.device = "cuda"
        
        self.emb_vector = nn.Embedding(num_embeddings= num_class, embedding_dim=num_class)
        self.input_dim = random_noise_channel + num_class
        
        self.g_theta = nn.Sequential(
            nn.ConvTranspose2d(self.input_dim, 512, 7, 1, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(256, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            
            # channel reduction
            nn.ConvTranspose2d(64, img_channel, 3, 1, 1),
            nn.Tanh() # output: [-1, 1]
            
        )
        
    def forward(self, x, label):
        embeddings = self.emb_vector(label)
        g_input = torch.concat([x, embeddings], dim=1).view(x.shape[0], 160, 1, 1)
        
        return self.g_theta(g_input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        
        self.emb_vector = nn.Embedding(num_embeddings= num_class, embedding_dim=num_class)
        self.D_w = nn.Sequential(
            nn.Conv2d(1 + num_class, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            
            nn.Flatten(),
            nn.Linear(128*7*7, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, label):
        embeddings = self.emb_vector(label) # (B, 10)
        
        embeddings = embeddings.unsqueeze(2).unsqueeze(3) # (B, 10, 1, 1)
        
        embeddings = embeddings.expand(-1, -1, x.size(2), x.size(3)) # (B, 10, 28, 28)
        
        d_input = torch.cat([x, embeddings], dim=1) # (B, 11, 28, 28)
        
        return self.D_w(d_input)


In [None]:
torch.cuda.set_device(1)
device = "cuda" if torch.cuda.is_available() else "cpu"
device
print(torch.cuda.get_device_name(torch.cuda.current_device()))

In [None]:
random_noise_dim = 150
image_channel = 1

In [None]:
generator = Generator(img_channel= image_channel, random_noise_channel=random_noise_dim, 
                    num_class=10).to(device)

discriminator = Discriminator(10).to(device)

generator, discriminator

In [None]:
optimizer_g = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

criterion = nn.BCELoss()

In [None]:

def show_generated_images(epoch, generator, fixed_noise, labels):
    generator.eval()
    with torch.no_grad():
        fake_imgs = generator(fixed_noise, labels)
        fake_imgs = fake_imgs * 0.5 + 0.5  # De-normalize

    grid = torchvision.utils.make_grid(fake_imgs, nrow=8)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f'Generated Images at Epoch {epoch}')
    plt.axis('off')
    plt.show()
    generator.train()

In [None]:
# Train Discriminator
fixed_random_samples = torch.randn(64, random_noise_dim).to(device)
fixed_labels = torch.randint(6, 7, (64, )).to(device)

In [None]:
epochs = 50

for epoch in range(epochs):
    for idx, (data, label) in enumerate(train_data):
        batch_size = data.shape[0]
        
        # create labels - discriminator
        ones = torch.ones(batch_size, 1).to(device) # for samples from dataset
        zeros = torch.zeros(batch_size, 1).to(device) # for samples from generator
        
        # Train Discriminator
        z = torch.randn(batch_size, random_noise_dim).to(device)
        fake_labels = torch.randint(0, 10, (batch_size, )).to(device)
        
        with torch.no_grad():
            generated_images = generator(z, fake_labels)
        
        output_d = discriminator(data.to(device), label.to(device))
        loss_d = criterion(output_d, ones)
        
        output_g = discriminator(generated_images.detach().to(device), fake_labels)
        loss_g = criterion(output_g, zeros)
        
        disc_loss = loss_d + loss_g
        
        optimizer_d.zero_grad()
        disc_loss.backward()
        optimizer_d.step()
        
        # Training the generator
        for _ in range(5):
            z = torch.randn(batch_size, random_noise_dim).to(device)
            fake_labels = torch.randint(0, 10, (batch_size, )).to(device)
            
            generated_images = generator(z, fake_labels)
            
            output_g = discriminator(generated_images.to(device), fake_labels)
            g_loss = criterion(output_g, ones) # tricking the discriminator
            
            optimizer_g.zero_grad()
            g_loss.backward()
            optimizer_g.step()
        
    
    if (epoch+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], D_loss: {disc_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
            show_generated_images(epoch+1, generator, fixed_random_samples, fixed_labels)
            torch.cuda.empty_cache()
    else:
        print(f"Epoch [{epoch+1}/{epochs}], D_loss: {disc_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
         
