In [None]:
# Python Packages
from torch.utils.data import DataLoader

import numpy as np

import warnings
import random
import torch
import os

# Local Modules
from utilities.utils import train_model_progressive, plot_loss
from utilities.datasets import SID_dataset
from models.QStormernew import PASTormer
from models.QStormer import QStormer
from models.Restormer import Restormer 

In [None]:
warnings.filterwarnings('ignore')

torch.backends.cudnn.benchmark = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

DEVICE = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')

IM_SIZE = 256

In [None]:
input_path = "data/training/x/"
label_path = "data/training/y/"

valid_input_path = "data/validation/x/"
valid_label_path = "data/validation/y/"

batch_size = 8

dataset_train = SID_dataset(input_path, label_path, IM_SIZE)
dataset_valid = SID_dataset(valid_input_path, valid_label_path, IM_SIZE)
train_loader = DataLoader(
    dataset_train, batch_size=batch_size, num_workers=4, shuffle=True, pin_memory=True)
valid_loader = DataLoader(
    dataset_valid, batch_size=1, num_workers=4, shuffle=True, pin_memory=True)

In [None]:
qnet = PASTormer().to(DEVICE)
param_net = sum(p.numel() for p in qnet.parameters())
print(f'QStormer:\t\t{param_net}')

In [None]:
models = [
    [qnet, 'QStormer'],
]

train_dict = {
    'train_loaders': train_loader,
    'valid_loaders': valid_loader,
    'device': DEVICE,
    'epoch': 25,
    'lr': 5e-4,
    'lr_min': 5e-8,
    'im_sizes': IM_SIZE
}

In [None]:
results = []
for i, model in enumerate(models):
    if not os.path.isdir(f'trained_models'):
        os.mkdir(f'trained_models')
    if not os.path.isdir(f'trained_models/{model[1]}'):
        os.mkdir(f'trained_models/{model[1]}')
    print(f'Training {model[1]}')
    result = train_model_progressive(model, **train_dict)
    result = [result, model[1]]
    results.append(result)

In [None]:
for result in results:
    print(result[1])
    plot_loss(**result[0])