In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchinfo 

from generator import Generator
from discriminator import Discriminator
from data import ImageNetForPIXGAN
from trainer import train_models

from pathlib import Path

In [2]:
# Device agnostic code

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

LEARNING_RATE_G = 2e-4
LEARNING_RATE_D = 2e-4
L1_LAMBDA = 100
BATCH_SIZE = 5
NUM_EPOCHS = 512

GENERATOR_SAVE_PATH = 'Models/abacus_generator.pth'
DISCRIMINATOR_SAVE_PATH = 'Models/abacus_discriminator.pth'
RESULT_SAVE_PATH = 'Results/Abacus'

In [4]:
# Setting up the transforms 

input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Grayscale()
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

In [5]:
# Dataset and dataloaders 

root_dir = 'Abacus'
data = ImageNetForPIXGAN(root=root_dir, transform=input_transform, target_transform=target_transform)
dataloader = DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=True)
len(data), len(dataloader)

(614, 123)

In [6]:
# Creating model instances

def initialize_weights(model):
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(module.weight.data, 0.0, 0.02)

generator = Generator(in_channels=1, out_channels=3).to(device)
#initialize_weights(generator)

discriminator = Discriminator(in_channels=4).to(device)
#initialize_weights(discriminator)

In [7]:
# Loss function, optimizer and scaler

bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=LEARNING_RATE_G, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE_D, betas=(0.5, 0.999))
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [8]:
# Saving and loading models

model_file = Path(GENERATOR_SAVE_PATH)
if model_file.is_file():
    generator.load_state_dict(torch.load(f=GENERATOR_SAVE_PATH))
    print("A generator aleady exists... Loading that model and training it for the specified epochs...")
else:
    print("A generator does not exist in the specified path... Creating the model and training it for the specified epochs...")
    
model_file = Path(DISCRIMINATOR_SAVE_PATH)
if model_file.is_file():
    discriminator.load_state_dict(torch.load(f=DISCRIMINATOR_SAVE_PATH))
    print("A discriminator aleady exists... Loading that model and training it for the specified epochs...")
else:
    print("A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs...")

A generator does not exist in the specified path... Creating the model and training it for the specified epochs...
A discriminator does not exist in the specified path... Creating the model and training it for the specified epochs...


In [9]:
# Actual training

train_models(generator=generator,
             discriminator=discriminator,
             dataset=data,
             dataloader=dataloader,
             bce_loss=bce_loss,
             l1_loss=l1_loss,
             l1_lambda=L1_LAMBDA,
             g_optimizer=g_optimizer,
             d_optimizer=d_optimizer,
             g_scaler=g_scaler,
             d_scaler=d_scaler,
             device=device,
             NUM_EPOCHS=NUM_EPOCHS,
             g_path=GENERATOR_SAVE_PATH,
             d_path=DISCRIMINATOR_SAVE_PATH,
             result_path=RESULT_SAVE_PATH)  

Epoch [1/512]: 100%|██████████| 123/123 [02:20<00:00,  1.14s/it, Gen batch loss=19.3, Gen loss=24.4, Disc batch loss=0.0879, Disc loss=0.309, Real=0.889, Fake=0.041] 
Epoch [2/512]: 100%|██████████| 123/123 [02:18<00:00,  1.13s/it, Gen batch loss=21, Gen loss=20.8, Disc batch loss=0.0136, Disc loss=0.0708, Real=0.978, Fake=0.00468]  
Epoch [3/512]: 100%|██████████| 123/123 [02:16<00:00,  1.11s/it, Gen batch loss=20.8, Gen loss=19.6, Disc batch loss=0.00754, Disc loss=0.0584, Real=0.988, Fake=0.00316] 
Epoch [4/512]: 100%|██████████| 123/123 [02:15<00:00,  1.10s/it, Gen batch loss=16.2, Gen loss=19.2, Disc batch loss=0.00347, Disc loss=0.0106, Real=0.996, Fake=0.00246] 
Epoch [5/512]: 100%|██████████| 123/123 [02:12<00:00,  1.07s/it, Gen batch loss=19.6, Gen loss=19.1, Disc batch loss=0.15, Disc loss=0.0161, Real=0.843, Fake=0.0677]      
Epoch [6/512]: 100%|██████████| 123/123 [02:16<00:00,  1.11s/it, Gen batch loss=13.5, Gen loss=17.8, Disc batch loss=0.00211, Disc loss=0.0232, Real=0

KeyboardInterrupt: 

In [None]:
# Saving models

torch.save(obj=generator.state_dict(), f=GENERATOR_SAVE_PATH)
torch.save(obj=discriminator.state_dict(), f=DISCRIMINATOR_SAVE_PATH)