In [None]:
import json
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from model.ESRGAN import RRDBNet, Discriminator_VGG, train_esrgan
from data_utils.dataset import LocalImageDataset, get_split_indices
from model.utils import get_device, train_model_single_epoch, validate_model_single_epoch, save_checkpoint, save_samples, CombinedLoss

In [None]:
generator = RRDBNet()
vgg_extractor = Discriminator_VGG()

In [None]:
with open("data/test_indices.json", "r") as f:
    test_indices = json.load(f)
len(test_indices)

In [None]:
INPUT_DIR = "data/resolution_128"
TARGET_DIR = "data/resolution_256"
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 0.01

GRAD_CLIP = 1
EDGE_WEIGHT = 0.3
PSNR_WEIGHT = 0.3

history = {
    "train_loss": [],
    "train_psnr": [],
    "val_loss": [],
    "val_psnr": []
}

In [None]:
train_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, train_indices), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, val_indices), batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=BATCH_SIZE, shuffle=False)

In [None]:
num_images = len([f for f in os.listdir(INPUT_DIR) if f.endswith(('.jpg', '.png', '.jpeg', '.webp'))])
train_indices, val_indices = get_split_indices(num_images, test_indices)

In [None]:
mae_loss = nn.L1Loss()
criterion = CombinedLoss(mae_loss, EDGE_WEIGHT)
optimizer = optim.AdamW(generator.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=60)

In [None]:
train_esrgan(generator, discriminator, vgg_extractor, dataloader, 
                 g_optimizer, d_optimizer, num_epochs=100, device='cuda')