In [1]:
import os
import torch
import torch.nn as nn 
import numpy as np
import torchvision
from torchvision.utils import make_grid
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784',version=1,as_frame=False)

In [3]:
dataset_x = mnist.data.reshape(-1,1,28,28)
dataset_y = mnist.target.astype(np.int64)

In [45]:
x_train = torch.tensor(dataset_x)
y_train = torch.tensor(dataset_y)


In [99]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


LATENT_DIM = 64
IN_CHANNELS = 1
IM_SIZE = (28,28)
BATCH_SIZE = 64
NUM_EPOCHS = 50
NUM_SAMPLES = 225
NROWS = 15

# DEFINING THE GENERATOR CLASS

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.latent_dim = LATENT_DIM
        self.img_size = IM_SIZE
        self.channels = IN_CHANNELS
        activation = nn.LeakyReLU()
        layers_dim = [self.latent_dim,128,256,512,self.img_size[0]*self.img_size[1]*self.channels]
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(layers_dim[i],layers_dim[i+1]),
                nn.BatchNorm1d(layers_dim[i+1]) if i != len(layers_dim) - 2 else nn.Identity(),
                activation if i != len(layers_dim) - 2 else nn.Tanh() 
            )

            for i in range(len(layers_dim)-1)
        ])

    def forward(self,z):
        batch_size = z.shape[0]
        out = z.reshape(-1,self.latent_dim)
        for layer in self.layers:
            out = layer(out)
        out = out.reshape(batch_size,self.channels,self.img_size[0],self.img_size[1])
        return out
    
        
                


# DEFINING THE DISCRIMINATOR CLASS

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_size = IM_SIZE
        self.channels = IN_CHANNELS
        activation = nn.LeakyReLU()
        layers_dim = [self.img_size[0]*self.img_size[1]*self.channels,512,256,128,1]
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(layers_dim[i],layers_dim[i+1]),
                nn.LayerNorm(layers_dim[i+1]) if i != len(layers_dim) - 2 else nn.Identity(),
                activation if i != len(layers_dim) - 2 else nn.Identity() 
            )

            for i in range(len(layers_dim)-1)
        ])

    def forward(self,x):
        out = x.reshape(-1,self.img_size[0]*self.img_size[1]*self.channels)
        for layer in self.layers:
            out = layer(out)
        return out
        

In [87]:
def add_color_channels(images):

    device = images.device
    
    # Generate random tints for the whole batch (batch_size, 1, 1) and broadcast across all images
    # Make sure the random tint factors are within a reasonable range
    red_tint = torch.rand(images.size(0), 1, 1, device=device) * 0.8 + 0.2  # Random red tint (0.2 to 1.0)
    green_tint = torch.rand(images.size(0), 1, 1, device=device) * 0.8 + 0.2  # Random green tint (0.2 to 1.0)
    blue_tint = torch.rand(images.size(0), 1, 1, device=device) * 0.8 + 0.2  # Random blue tint (0.2 to 1.0)
    
    # Apply the tints to the grayscale images (broadcasting over the batch)
    red_channel = images * red_tint
    green_channel = images * green_tint
    blue_channel = images * blue_tint
    
    # Stack the 3 channels along a new dimension (batch_size, 3, 28, 28)
    colored_images = torch.stack([red_channel, green_channel, blue_channel], dim=1)
    
    # Clip the values to stay within the range [0, 1]
    colored_images = torch.clamp(colored_images, 0, 1)
    
    return colored_images

# DEFINING THE TRAIN FUNCTION

In [102]:
def infer(generated_sample_count,generator):
    fake_im_noise = torch.rand((NUM_SAMPLES,LATENT_DIM),device=device)
    fake_ims = generator(fake_im_noise)
    ims = torch.clamp(fake_ims, -1., 1.).detach().cpu()
    ims = (ims + 1) / 2
    grid = make_grid(ims,nrow=NROWS)
    img = torchvision.transforms.ToPILImage()(grid)
    if not os.path.exists('samples'):
        os.makedirs('samples')
    img.save('samples/{}.png'.format(generated_sample_count))

In [103]:
def train():
    mnist_dataset = TensorDataset(x_train,y_train)
    mnist_loader = DataLoader(mnist_dataset,batch_size=BATCH_SIZE,shuffle=True)

    generator = Generator().to(device) # loaded the generator to gpu(if available)
    generator.train() #training mode activated for generator

    discriminator = Discriminator().to(device) # loaded the discriminator to gpu(if available)
    discriminator.train() #training mode activated for discriminator

    optimizer_generator = Adam(generator.parameters(), lr=1E-4, betas=(0.5,0.999))
    optimizer_discriminator = Adam(discriminator.parameters(), lr=1E-4, betas=(0.5,0.999))
    criterion = torch.nn.BCEWithLogitsLoss()

    steps = 0
    generated_sample_count = 0

    for epoch_no in range(NUM_EPOCHS):
        print("Epoch no - ",epoch_no+1,"/",NUM_EPOCHS)
        for im,label in tqdm(mnist_loader):
            real_ims = im.float().to(device)
            batch_size = real_ims.shape[0]


            # optimize the discriminator first
            optimizer_discriminator.zero_grad()
            fake_im_noise = torch.randn((batch_size,LATENT_DIM),device=device)
            fake_ims = generator(fake_im_noise) #generator generated images
            real_label = torch.ones((batch_size,1), device=device)
            fake_label = torch.zeros((batch_size,1),device=device)
            disc_real_pred = discriminator(real_ims)
            disc_fake_pred = discriminator(fake_ims.detach()) #detach is used to stop the gradient calculation for the generator
            disc_real_loss = criterion(disc_real_pred.reshape(-1),real_label.reshape(-1))
            disc_fake_loss = criterion(disc_fake_pred.reshape(-1),fake_label.reshape(-1)) 
            disc_loss = (disc_real_loss + disc_fake_loss)/2
            disc_loss.backward()
            optimizer_discriminator.step()

            # now optimize the Generator
            optimizer_generator.zero_grad()
            fake_im_noise = torch.randn((batch_size,LATENT_DIM),device=device)
            fake_ims = generator(fake_im_noise)  #generator generated images
            disc_fake_pred = discriminator(fake_ims)
            gen_fake_loss = criterion(disc_fake_pred.reshape(-1),real_label.reshape(-1))
            gen_fake_loss.backward()
            optimizer_generator.step()

            #save samples

            if (steps % 300) == 0:
                with torch.no_grad():
                    generator.eval() # switch the generator to evaluation mode
                    infer(generated_sample_count,generator)
                    generated_sample_count += 1
                    generator.train() # again switched to training mode
            steps += 1
            
        torch.save(generator.state_dict(),'generator_ckpt.pth')
        torch.save(discriminator.state_dict(),'discriminator_ckpt.pth')

    print('Done Training.....')

In [None]:
train()