In [None]:
import torch 
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import torchinfo
from DCGAN import Generator, Discriminator, initialize_weights
from ModelTrainer import train_models

from typing import List, Tuple
from pathlib import Path
from tqdm import tqdm

In [None]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

In [None]:
# Hyperparameters 

LEARNING_RATE = 2e-4
BATCH_SIZE = 32
IMAGE_SIZE = (64, 64)
CHANNELS = 3
Z_DIM = 256
NUM_EPOCHS = 5
DISC_HIDDEN = 64
GEN_HIDDEN = 64
GENERATOR_SAVE_PATH = 'Models/celebal_first_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/celebal_first_discriminator.pth'
RESULT_PATH = 'Results/Train 1'

In [None]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*CHANNELS, std=[0.5]*CHANNELS)
])

In [None]:
# Training data

train_data = datasets.ImageFolder(root='celebal_data', transform=input_transform)
dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_data), len(dataloader)

In [None]:
# Creating model instances

generator = Generator(latent_channels=Z_DIM, hidden_channels=GEN_HIDDEN, img_channels=CHANNELS).to(device)
initialize_weights(generator)

discriminator = Discriminator(in_channels=CHANNELS, hidden_channels=DISC_HIDDEN).to(device)
initialize_weights(discriminator)

In [None]:
# Loss function and optimizers 

loss_fn = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
gen_scaler = torch.cuda.amp.GradScaler()
disc_scaler = torch.cuda.amp.GradScaler()

In [None]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs")

In [None]:
# Training the model 

train_models(generator=generator,
             discriminator=discriminator,
             dataloader=dataloader,
             loss_fn=loss_fn,
             gen_optimizer=gen_opt,
             disc_optimizer=disc_opt,
             gen_scaler=gen_scaler,
             disc_scaler=disc_scaler,
             BATCH_SIZE=BATCH_SIZE,
             Z_DIM=Z_DIM,
             NUM_EPOCHS=NUM_EPOCHS,
             device=device,
             gen_path=GENERATOR_SAVE_PATH,
             disc_path=DISCRIMINATOR_SAVE_PATH,
             result_path=RESULT_PATH)