In [1]:
import torch 
from torch import nn 
from torch.utils.data import DataLoader
from torchvision import transforms

from generator import Generator
from discriminator import Discriminator
from dataset import AppleOrangeData
from trainer import train_models

from pathlib import Path

In [2]:
# Device agnostic code
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

In [3]:
# Hyperparameters 

TRAIN_DIR = 'apple_orange_data/train'
BATCH_SIZE = 1
LEARNING_RATE = 2e-4
CYCLE_LAMBDA = 10
NUM_EPOCHS = 64

GENERATOR_G_SAVE_PATH = 'Models/generator_g.pth.tar'
GENERATOR_H_SAVE_PATH = 'Models/generator_h.pth.tar'
DISCRIMINATOR_X_SAVE_PATH = 'Models/discriminator_x.pth.tar'
DISCRIMINATOR_Y_SAVE_PATH = 'Models/discriminator_y.pth.tar'
RESULT_SAVE_PATH = 'Results/Train 1'

In [4]:
input_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(256, 256)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

train_data = AppleOrangeData(root_dir=TRAIN_DIR, transform=input_transform)
train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
len(train_data), len(train_dataloader)

(1019, 1019)

In [5]:
# Initialising models 

generator_G = Generator(in_channels=3, out_channels=3, num_features=64, num_residuals=9).to(device)
generator_H = Generator(in_channels=3, out_channels=3, num_features=64, num_residuals=9).to(device)
discriminator_X = Discriminator(in_channels=3).to(device)
discriminator_Y = Discriminator(in_channels=3).to(device)

In [6]:
# Loading model if exists

model_file = Path(GENERATOR_G_SAVE_PATH)
if model_file.is_file():
    generator_G.load_state_dict(torch.load(f=GENERATOR_G_SAVE_PATH))
    print("1) Exists")
else:
    print("1) Created")

model_file = Path(GENERATOR_H_SAVE_PATH)
if model_file.is_file():
    generator_H.load_state_dict(torch.load(f=GENERATOR_H_SAVE_PATH))
    print("2) Exists")
else:
    print("2) Created")
    
model_file = Path(DISCRIMINATOR_X_SAVE_PATH)
if model_file.is_file():
    discriminator_X.load_state_dict(torch.load(f=DISCRIMINATOR_X_SAVE_PATH))
    print("3) Exists")
else:
    print("3) Created")
    
model_file = Path(DISCRIMINATOR_Y_SAVE_PATH)
if model_file.is_file():
    discriminator_Y.load_state_dict(torch.load(f=DISCRIMINATOR_Y_SAVE_PATH))
    print("4) Exists")
else:
    print("4) Created")

1) Exists
2) Exists
3) Exists
4) Exists


In [7]:
# Loss functions and optimizers 

mse_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
gen_optimizer = torch.optim.Adam(params=list(generator_G.parameters())+list(generator_H.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_optimizer = torch.optim.Adam(params=list(discriminator_X.parameters())+list(discriminator_Y.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
gen_scaler = torch.cuda.amp.GradScaler()
disc_scaler = torch.cuda.amp.GradScaler()

In [8]:
train_models(generator_G=generator_G,
             generator_H=generator_H,
             discriminator_X=discriminator_X,
             discriminator_Y=discriminator_Y,
             dataloader=train_dataloader,
             mse_loss=mse_loss,
             l1_loss=l1_loss,
             cycle_lambda=CYCLE_LAMBDA,
             gen_optimizer=gen_optimizer,
             disc_optimizer=disc_optimizer,
             gen_scaler=gen_scaler,
             disc_scaler=disc_scaler,
             device=device,
             NUM_EPOCHS=NUM_EPOCHS,
             generator_G_path=GENERATOR_G_SAVE_PATH,
             generator_H_path=GENERATOR_H_SAVE_PATH,
             discriminator_X_path=DISCRIMINATOR_X_SAVE_PATH,
             discriminator_Y_path=DISCRIMINATOR_Y_SAVE_PATH,
             result_path=RESULT_SAVE_PATH)

Epoch [24/64]: 100%|██████████| 1019/1019 [19:54<00:00,  1.17s/it, Gen batch loss=3.51, Gen train loss=3.88, Disc batch loss=0.621, Disc train loss=0.638]
Epoch [25/64]: 100%|██████████| 1019/1019 [19:54<00:00,  1.17s/it, Gen batch loss=4.13, Gen train loss=3.89, Disc batch loss=1.05, Disc train loss=0.644] 
Epoch [26/64]: 100%|██████████| 1019/1019 [19:55<00:00,  1.17s/it, Gen batch loss=4.07, Gen train loss=3.81, Disc batch loss=0.764, Disc train loss=0.653]
Epoch [27/64]: 100%|██████████| 1019/1019 [19:51<00:00,  1.17s/it, Gen batch loss=3.17, Gen train loss=3.88, Disc batch loss=1.03, Disc train loss=0.615] 
Epoch [28/64]: 100%|██████████| 1019/1019 [19:51<00:00,  1.17s/it, Gen batch loss=3.33, Gen train loss=3.82, Disc batch loss=0.361, Disc train loss=0.619]
Epoch [29/64]: 100%|██████████| 1019/1019 [19:53<00:00,  1.17s/it, Gen batch loss=3.06, Gen train loss=3.77, Disc batch loss=1.32, Disc train loss=0.628] 
Epoch [30/64]: 100%|██████████| 1019/1019 [19:52<00:00,  1.17s/it, Gen

KeyboardInterrupt: 

In [None]:
# saving models

torch.save(obj=generator_G.state_dict(), f=GENERATOR_G_SAVE_PATH)
torch.save(obj=generator_H.state_dict(), f=GENERATOR_H_SAVE_PATH)
torch.save(obj=discriminator_X.state_dict(), f=DISCRIMINATOR_X_SAVE_PATH)
torch.save(obj=discriminator_Y.state_dict(), f=DISCRIMINATOR_Y_SAVE_PATH)