In [1]:
import torch 
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import torchinfo
from DCGAN import Generator, Discriminator, initialize_weights
from ModelTrainer import train_models

from typing import List, Tuple
from pathlib import Path
from tqdm import tqdm

In [2]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

LEARNING_RATE = 2e-4
BATCH_SIZE = 32
IMAGE_SIZE = (64, 64)
CHANNELS = 1
Z_DIM = 100
NUM_EPOCHS = 5
DISC_HIDDEN = 64
GEN_HIDDEN = 64
GENERATOR_SAVE_PATH = 'Models/mnist_first_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/mnist_first_discriminator.pth'

In [4]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*CHANNELS, std=[0.5]*CHANNELS)
])

In [5]:
# Training data

train_data = datasets.MNIST(root='dataset/', train=True, transform=input_transform, target_transform=None, download=True)
dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
# Creating model instances

generator = Generator(latent_channels=Z_DIM, hidden_channels=GEN_HIDDEN, img_channels=CHANNELS).to(device)
initialize_weights(generator)

discriminator = Discriminator(in_channels=CHANNELS, hidden_channels=DISC_HIDDEN).to(device)
initialize_weights(discriminator)

In [7]:
# Loss function and optimizers 

loss_fn = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
gen_scaler = torch.cuda.amp.GradScaler()
disc_scaler = torch.cuda.amp.GradScaler()

In [8]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs")

A generator does not exist in the specified path... Creating the model and training it for the specified epochs
A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs


In [9]:
# Training the model 

train_models(generator=generator,
             discriminator=discriminator,
             dataloader=dataloader,
             loss_fn=loss_fn,
             gen_optimizer=gen_opt,
             disc_optimizer=disc_opt,
             gen_scaler=gen_scaler,
             disc_scaler=disc_scaler,
             BATCH_SIZE=BATCH_SIZE,
             Z_DIM=Z_DIM,
             NUM_EPOCHS=NUM_EPOCHS,
             device=device,
             gen_path=GENERATOR_SAVE_PATH,
             disc_path=DISCRIMINATOR_SAVE_PATH)

Epoch [1/5] : 100%|██████████| 1875/1875 [19:34<00:00,  1.60it/s, Gen Batch Loss=2.89, Gen Loss=2.33, Disc Batch Loss=0.213, Disc Loss=0.324, Real=0.76, Fake=0.11]      
Epoch [2/5] : 100%|██████████| 1875/1875 [19:29<00:00,  1.60it/s, Gen Batch Loss=4, Gen Loss=3.5, Disc Batch Loss=0.0373, Disc Loss=0.185, Real=0.966, Fake=0.0378]      
Epoch [3/5] : 100%|██████████| 1875/1875 [19:23<00:00,  1.61it/s, Gen Batch Loss=0.321, Gen Loss=3.93, Disc Batch Loss=1.12, Disc Loss=0.164, Real=0.17, Fake=0.00728]   
Epoch [4/5] : 100%|██████████| 1875/1875 [19:26<00:00,  1.61it/s, Gen Batch Loss=1.4, Gen Loss=4.14, Disc Batch Loss=0.541, Disc Loss=0.154, Real=0.435, Fake=0.000573]  
Epoch [5/5] : 100%|██████████| 1875/1875 [19:49<00:00,  1.58it/s, Gen Batch Loss=6.04, Gen Loss=4.41, Disc Batch Loss=0.0241, Disc Loss=0.133, Real=0.969, Fake=0.0157]     


{'Generator Loss': [2.3328766955137255,
  3.497848308424155,
  3.931248655498028,
  4.143418270414074,
  4.408899548407023],
 'Discriminator Loss': [0.3243900042417149,
  0.18486669129083555,
  0.1638568659228583,
  0.15395608094533283,
  0.1331067289258043]}