In [1]:
import torch 
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import torchinfo
from WGAN import Generator, Discriminator, initialize_weights
from ModelTrainer import train_models

from typing import List, Tuple
from pathlib import Path
from tqdm import tqdm

In [2]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS = 1
NUM_CLASSES = 10
GEN_EMBEDDING = 256
Z_DIM = 256
NUM_EPOCHS = 4
DISC_HIDDEN = 64
GEN_HIDDEN = 64
DISCRIMINATOR_ITERATIONS = 5
LAMBDA_GP = 10

GENERATOR_SAVE_PATH = 'Models/mnist_first_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/mnist_first_discriminator.pth'
RESULT_PATH = 'Results/Train 1'

In [4]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*CHANNELS, std=[0.5]*CHANNELS)
])

In [5]:
# Training data

train_data = datasets.MNIST(root='../DCGAN/dataset/', transform=input_transform, download=True)
dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_data), len(dataloader)

(60000, 938)

In [6]:
# Creating model instances

generator = Generator(latent_channels=Z_DIM, hidden_channels=GEN_HIDDEN, img_channels=CHANNELS, num_classes=NUM_CLASSES, img_size=IMAGE_SIZE, embed_size=GEN_EMBEDDING).to(device)
initialize_weights(generator)

discriminator = Discriminator(in_channels=CHANNELS, hidden_channels=DISC_HIDDEN, num_classes=NUM_CLASSES, img_size=IMAGE_SIZE).to(device)
initialize_weights(discriminator)

In [7]:
# Loss function and optimizers 

gen_opt = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
disc_opt = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
gen_scaler = torch.cuda.amp.GradScaler()
disc_scaler = torch.cuda.amp.GradScaler()

In [8]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs")

A generator aleady exists... Loading that model and training it for the specified epochs
A discriminator aleady exists... Loading that model and training it for the specified epochs


In [9]:
# Training the model 

train_models(generator=generator,
             discriminator=discriminator,
             dataloader=dataloader,
             gen_optimizer=gen_opt,
             disc_optimizer=disc_opt,
             gen_scaler=gen_scaler,
             disc_scaler=disc_scaler,
             BATCH_SIZE=BATCH_SIZE,
             Z_DIM=Z_DIM,
             NUM_EPOCHS=NUM_EPOCHS,
             device=device,
             DISC_ITERATIONS=DISCRIMINATOR_ITERATIONS,
             LAMBDA_GP=LAMBDA_GP,
             gen_path=GENERATOR_SAVE_PATH,
             disc_path=DISCRIMINATOR_SAVE_PATH,
             result_path=RESULT_PATH)

Epoch [2/5] : 100%|██████████| 938/938 [1:33:20<00:00,  5.97s/it, Gen Batch Loss=73.4, Gen Loss=76.3, Disc Batch Loss=6.95, Disc Loss=9.17]
Epoch [3/5] : 100%|██████████| 938/938 [1:32:21<00:00,  5.91s/it, Gen Batch Loss=56.6, Gen Loss=66.9, Disc Batch Loss=9.21, Disc Loss=7.88]
Epoch [4/5] : 100%|██████████| 938/938 [1:33:57<00:00,  6.01s/it, Gen Batch Loss=48.7, Gen Loss=54.6, Disc Batch Loss=10.9, Disc Loss=7.07]
Epoch [5/5] : 100%|██████████| 938/938 [1:32:40<00:00,  5.93s/it, Gen Batch Loss=35.3, Gen Loss=38.8, Disc Batch Loss=9.67, Disc Loss=6.59]


{'Generator Loss': [76.26887828200611,
  66.90160749716037,
  54.643503237889014,
  38.82902493100685],
 'Discriminator Loss': [9.172338694143397,
  7.875285068300487,
  7.070296933401877,
  6.5862765080893215]}