In [None]:

import torch 
from torch import nn 
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torchinfo 

from models import Generator
from models import Discriminator
from trainer import train_models

from pathlib import Path

In [None]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

In [None]:
# Hyperparameters 

START_STEP = 8
END_STEP = 8
LEARNING_RATE = 1e-4
BATCH_SIZES = [16, 16, 16, 16, 16, 8, 5, 4, 2]
LATENT_DIM = 512
LAMBDA_GP = 10
IMG_CHANNELS = 3

PROGRESSIVE_EPOCHS = [10] * len(BATCH_SIZES)
torch.manual_seed(42)
FIXED_NOISE = torch.randn(8, LATENT_DIM, 1, 1).to(device)
torch.seed()

ROOT_DIR = 'CelebaHQ'
GENERATOR_SAVE_PATH = 'Models/first_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/first_discriminator.pth'
RESULT_SAVE_PATH = 'Results'

In [None]:
# Initialising models 

generator = Generator(in_channels=LATENT_DIM, out_channels=IMG_CHANNELS).to(device)
discriminator = Discriminator(latent_channels=LATENT_DIM, img_channels=IMG_CHANNELS).to(device)

In [None]:
# Loss functions and grad scalers 

g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [None]:
model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("1) Exists")
else:
    print("1) Created")

model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("2) Exists")
else:
    print("2) Created")

In [None]:
train_models(generator=generator, 
             discriminator=discriminator, 
             g_optimizer=g_optimizer, 
             d_optimizer=d_optimizer, 
             g_scaler=g_scaler, 
             d_scaler=d_scaler,
             BATCH_SIZES=BATCH_SIZES, 
             PROGRESSIVE_EPOCHS=PROGRESSIVE_EPOCHS, 
             LATENT_DIM=LATENT_DIM, 
             LAMBDA_GP=LAMBDA_GP, 
             START_STEP=START_STEP, 
             END_STEP=END_STEP, 
             FIXED_NOISE=FIXED_NOISE, 
             ROOT_DIR=ROOT_DIR,
             g_path=GENERATOR_SAVE_PATH, 
             d_path=DISCRIMINATOR_SAVE_PATH, 
             result_path=RESULT_SAVE_PATH,
             device=device)