In [1]:
import torch 
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import torchinfo
from DCGAN2 import Generator, Discriminator, initialize_weights
from ModelTrainer2 import train_models

from typing import List, Tuple
from pathlib import Path
from tqdm import tqdm

In [2]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

LEARNING_RATE = 2e-4
BATCH_SIZE = 32
IMAGE_SIZE = (64, 64)
CHANNELS = 3
Z_DIM = 256
NUM_EPOCHS = 1
DISC_HIDDEN = 128
GEN_HIDDEN = 64
GENERATOR_SAVE_PATH = 'Models/celebal_third_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/celebal_third_discriminator.pth'
RESULT_PATH = 'Results/Train 2'

In [4]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*CHANNELS, std=[0.5]*CHANNELS)
])

In [5]:
# Training data

train_data = datasets.ImageFolder(root='celebal_data', transform=input_transform)
dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_data), len(dataloader)

(202599, 6332)

In [6]:
# Creating model instances

generator = Generator(latent_channels=Z_DIM, hidden_channels=GEN_HIDDEN, img_channels=CHANNELS).to(device)
initialize_weights(generator)

discriminator = Discriminator(in_channels=CHANNELS, hidden_channels=DISC_HIDDEN).to(device)
initialize_weights(discriminator)

In [7]:
# Loss function and optimizers 

loss_fn = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

In [8]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs")

A generator does not exist in the specified path... Creating the model and training it for the specified epochs
A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs


In [9]:
# Training the model 

train_models(generator=generator,
             discriminator=discriminator,
             dataloader=dataloader,
             loss_fn=loss_fn,
             gen_optimizer=gen_opt,
             disc_optimizer=disc_opt,
             BATCH_SIZE=BATCH_SIZE,
             Z_DIM=Z_DIM,
             NUM_EPOCHS=NUM_EPOCHS,
             device=device,
             gen_path=GENERATOR_SAVE_PATH,
             disc_path=DISCRIMINATOR_SAVE_PATH, 
             result_path=RESULT_PATH)

Epoch [1/1] : 100%|██████████| 6332/6332 [3:56:17<00:00,  2.24s/it, Gen Batch Loss=1.19, Gen Loss=1.97, Disc Batch Loss=0.616, Disc Loss=0.558, Real=0.392, Fake=0.215]    


{'Generator Loss': [1.9661598781774448],
 'Discriminator Loss': [0.5575076717129848]}

In [10]:
# Saving models 

torch.save(obj=generator.state_dict(), f=GENERATOR_SAVE_PATH)
torch.save(obj=discriminator.state_dict(), f=DISCRIMINATOR_SAVE_PATH)