In [1]:
import torch 
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import torchinfo
from WGAN import Generator, Discriminator, initialize_weights
from ModelTrainer import train_models

from typing import List, Tuple
from pathlib import Path
from tqdm import tqdm

In [2]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = (64, 64)
CHANNELS = 3
Z_DIM = 256
NUM_EPOCHS = 4
DISC_HIDDEN = 64
GEN_HIDDEN = 64
DISCRIMINATOR_ITERATIONS = 5
LAMBDA_GP = 10

GENERATOR_SAVE_PATH = 'Models/celebal_first_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/celebal_first_discriminator.pth'
RESULT_PATH = 'Results/Train 1'

In [4]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*CHANNELS, std=[0.5]*CHANNELS)
])

In [5]:
# Training data

train_data = datasets.ImageFolder(root='../DCGAN/celebal_data', transform=input_transform)
dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_data), len(dataloader)

(202599, 3166)

In [6]:
# Creating model instances

generator = Generator(latent_channels=Z_DIM, hidden_channels=GEN_HIDDEN, img_channels=CHANNELS).to(device)
initialize_weights(generator)

discriminator = Discriminator(in_channels=CHANNELS, hidden_channels=DISC_HIDDEN).to(device)
initialize_weights(discriminator)

In [7]:
# Loss function and optimizers 

gen_opt = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
disc_opt = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

In [8]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs")

A generator aleady exists... Loading that model and training it for the specified epochs
A discriminator aleady exists... Loading that model and training it for the specified epochs


In [9]:
# Training the model 

train_models(generator=generator,
             discriminator=discriminator,
             dataloader=dataloader,
             gen_optimizer=gen_opt,
             disc_optimizer=disc_opt,
             BATCH_SIZE=BATCH_SIZE,
             Z_DIM=Z_DIM,
             NUM_EPOCHS=NUM_EPOCHS,
             device=device,
             DISC_ITERATIONS=DISCRIMINATOR_ITERATIONS,
             LAMBDA_GP=LAMBDA_GP,
             gen_path=GENERATOR_SAVE_PATH,
             disc_path=DISCRIMINATOR_SAVE_PATH,
             result_path=RESULT_PATH)

Epoch [2/4] : 100%|██████████| 3166/3166 [1:50:09<00:00,  2.09s/it, Gen Batch Loss=133, Gen Loss=216, Disc Batch Loss=52.3, Disc Loss=51.1]  
Epoch [3/4] : 100%|██████████| 3166/3166 [1:49:54<00:00,  2.08s/it, Gen Batch Loss=274, Gen Loss=232, Disc Batch Loss=28.7, Disc Loss=39.7]
Epoch [4/4] : 100%|██████████| 3166/3166 [1:49:50<00:00,  2.08s/it, Gen Batch Loss=219, Gen Loss=237, Disc Batch Loss=30.6, Disc Loss=46]   
Epoch [5/4] : 100%|██████████| 3166/3166 [1:49:45<00:00,  2.08s/it, Gen Batch Loss=227, Gen Loss=207, Disc Batch Loss=29, Disc Loss=39.6]   


{'Generator Loss': [215.55980612978632,
  231.74419914433568,
  237.17214706105216,
  206.86496381343954],
 'Discriminator Loss': [51.12361939497109,
  39.68997619277077,
  45.96103652428767,
  39.63442630972654]}