In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os

from BinomDataset_Colorization import BinomDataset 
from CGAP_UNET_Colorization import UN

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
import torch.utils.data as dt
from torch.utils.data import random_split
import random

if not torch.cuda.is_available():
    raise ValueError("GPU not found, code will run on CPU and can be extremely slow!")
else:
    device = torch.device("cuda:0")

print(f'Device in use: {device}')

In [None]:
def psnrToString(inp):
    if inp < 0:
        return 'm'+str(-inp)
    else:
        return str(inp)

minpsnr = -40
maxpsnr = 30

name = psnrToString(minpsnr)+"to"+psnrToString(maxpsnr)+"-256x256-ffhq-colorization-full"

CHECKPOINT_PATH = ''
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
CHECKPOINT_PATH , name

In [None]:
maxepochs = 20
dataset_path = ''
dataset = BinomDataset(root = dataset_path, windowSize = 256, minPSNR = minpsnr, maxPSNR = maxpsnr, virtSize = 1)

seed = 42
torch.manual_seed(seed)
random.seed(seed)

# Set the sizes for your train and test sets
total_size = len(dataset)
train_size = int(0.8 * total_size)  # 80% for training
val_size = total_size - train_size  # Remaining 20% for testing

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
train_loader = dt.DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True, pin_memory=False, num_workers=4)
val_loader = dt.DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=True,  pin_memory=False, num_workers=4)

trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, name), gradient_clip_val=0.5,
                     accelerator="gpu",
                     max_epochs=maxepochs, 
                     callbacks=[ModelCheckpoint(save_weights_only=False, mode="min", monitor="val_loss", every_n_epochs= 1),
                                LearningRateMonitor("epoch"),
                                EarlyStopping('val_loss', patience=2000)])

model = UN(channels = 3, levels=10, depth=7, start_filts=32, 
           up_mode = 'upsample', merge_mode = 'concat').to(device)

In [None]:
trainer.fit(model, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(CHECKPOINT_PATH, name)+'.ckpt')