In [12]:
import torch
import torch.nn as nn 
import numpy as np
import torchvision
from torchvision.utils import make_grid
from tqdm import tqdm
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784',version=1,as_frame=False)

In [31]:
dataset_x = mnist.data.reshape(-1,1,28,28)
dataset_y = mnist.target.astype(np.int64)

In [32]:
x_train = torch.tensor(dataset_x)
y_train = torch.tensor(dataset_y)


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


LATENT_DIM = 64
IN_CHANNELS = 1
IM_SIZE = (28,28)
BATCH_SIZE = 64
NUM_EPOCHS = 55
NROWS = 15

# DEFINING THE GENERATOR CLASS

In [10]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.latent_dim = LATENT_DIM
        self.img_size = IM_SIZE
        self.channels = IN_CHANNELS
        activation = nn.LeakyReLU()
        layers_dim = [self.latent_dim,128,256,512,self.img_size[0]*self.img_size[1]*self.channels]
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(layers_dim[i],layers_dim[i+1]),
                nn.BatchNorm1d(layers_dim[i+1]) if i != len(layers_dim) - 2 else nn.Identity(),
                activation if i != len(layers_dim) - 2 else nn.Tanh() 
            )

            for i in range(len(layers_dim)-1)
        ])

    def forward(self,z):
        batch_size = z.shape[0]
        out = z.reshape(-1,self.latent_dim)
        for layer in self.layers:
            out = layer(out)
        out = out.reshape(batch_size,self.channels,self.img_size[0],self.img_size[1])
        return out
    
        
                


# DEFINING THE DISCRIMINATOR CLASS

In [11]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_size = IM_SIZE
        self.channels = IN_CHANNELS
        activation = nn.LeakyReLU()
        layers_dim = [self.img_size[0]*self.img_size[1]*self.channels,512,256,128,1]
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(layers_dim[i],layers_dim[i+1]),
                nn.LayerNorm(layers_dim[i+1]) if i != len(layers_dim) - 2 else nn.Identity(),
                activation if i != len(layers_dim) - 2 else nn.Identity() 
            )

            for i in range(len(layers_dim)-1)
        ])

    def forward(self,x):
        out = x.reshape(-1,self.img_size[0]*self.img_size[1]*self.channels)
        for layer in self.layers:
            out = layer(out)
        return out
        

# DEFINING THE TRAIN FUNCTION

In [None]:
def train():
    mnist_dataset = TensorDataset(x_train,y_train)
    mnist_loader = DataLoader(mnist_dataset,batch_size=BATCH_SIZE,shuffle=True)

    generator = Generator().to(device) # loaded the generator to gpu(if available)
    generator.train() #training mode activated for generator

    discriminator = Discriminator().to(device) # loaded the discriminator to gpu(if available)
    discriminator.train() #training mode activated for discriminator

    optimizer_generator = Adam(generator.parameters(), lr=1E-4, betas=(0.5,0.999))
    optimizer_discriminator = Adam(discriminator.parameters(), lr=1E-4, betas=(0.5,0.999))
    criterion = torch.nn.BCEWithLogitsLoss()

    steps = 0
    generated_sample_count = 0

    for epoch_no in range(NUM_EPOCHS):
        for im,label in tqdm(mnist_loader):
            real_ims = im.float().to(device)
            batch_size = real_ims.shape[0]


            # optimize the discriminator first
            optimizer_discriminator.zero_grad()
            fake_im_noise = torch.randn((batch_size,LATENT_DIM),device=device)
            fake_ims = generator(fake_im_noise) #generator generated images
            real_label = torch.ones((batch_size,1), device=device)
            fake_label = torch.zeros((batch_size,1),device=device)
            disc_real_pred = discriminator(real_ims)
            disc_fake_pred = discriminator(fake_ims.detach()) #detach is used to stop the gradient calculation for the generator
            disc_real_loss = criterion(disc_real_pred.reshape(-1),real_label.reshape(-1))
            






In [38]:
tp = [(1,2),(7,0),(3,6),(0,4),(7,9)]

for x in tqdm(tp):
    print(x," - ")

100%|██████████| 5/5 [00:00<?, ?it/s]

(1, 2)  - 
(7, 0)  - 
(3, 6)  - 
(0, 4)  - 
(7, 9)  - 



