In [1]:
import os
import json

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from data_utils.dataset import LocalImageDataset, get_split_indices
from model.ResNetSR import ResNetSR
from model.utils import get_device, train_model_single_epoch, validate_model_single_epoch, save_checkpoint, save_samples, CombinedLoss

In [2]:
with open("data/test_indices.json", "r") as f:
    test_indices = json.load(f)

In [3]:
INPUT_DIR = "data/resolution_128"
TARGET_DIR = "data/resolution_256"
NUM_EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.01

GRAD_CLIP = 1
EDGE_WEIGHT = 0.3

history = {
    "train_loss": [],
    "train_psnr": [],
    "val_loss": [],
    "val_psnr": []
}

In [4]:
num_images = len([f for f in os.listdir(INPUT_DIR) if f.endswith(('.jpg', '.png', '.jpeg', '.webp'))])
train_indices, val_indices = get_split_indices(num_images, test_indices)

train_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, train_indices), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, val_indices), batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=BATCH_SIZE, shuffle=False)

In [5]:
best_val_loss = float('inf')

save_ckpt = 2
checkpoint_dir = "ckpt"

n_samples = 5
samples_to_visualize = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=n_samples, shuffle=False)

sample_dir = "samples"

train_losses = []
val_losses = []

In [6]:
device = get_device()
model = ResNetSR(upscale_factor=2, num_res_blocks=2, num_channels=1, num_features=32)
model = model.to(device)
mae_loss = nn.L1Loss()
criterion = CombinedLoss(mae_loss, EDGE_WEIGHT)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=60)

In [7]:
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_psnr = train_model_single_epoch(model, train_loader, criterion, optimizer, device, scaler=None, grad_clip=GRAD_CLIP)
    val_loss, val_psnr = validate_model_single_epoch(model, val_loader, criterion, device)
    current_lr = scheduler.get_last_lr()[0]
    print(f"[Epoch {epoch}/{NUM_EPOCHS}]",
            f"Train Loss: {train_loss:.4f}, PSNR: {train_psnr:.2f} | "
            f"Val Loss: {val_loss:.4f}, PSNR: {val_psnr:.2f} | "
            f"lr: {round(current_lr, 5)}")

    if epoch % save_ckpt == 0:
        save_checkpoint(epoch+1, model, optimizer, history, checkpoint_dir)

        save_samples(epoch+1, model, val_loader, device, sample_dir, samples_to_visualize)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(0, model, optimizer, history, checkpoint_dir)

    scheduler.step(val_loss)

    history["train_loss"].append(train_loss)
    history["train_psnr"].append(train_psnr)
    history["val_loss"].append(val_loss)
    history["val_psnr"].append(val_psnr)

Training:   0%|          | 0/67 [00:00<?, ?it/s]

Training: 100%|██████████| 67/67 [00:22<00:00,  3.01it/s, loss=0.296, psnr=12.8]   


[Epoch 1/20] Train Loss: 28.0889, PSNR: -4.19 | Val Loss: 0.2720, PSNR: 14.26 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.24it/s, loss=0.202, psnr=16.4]


[Epoch 2/20] Train Loss: 0.2268, PSNR: 15.72 | Val Loss: 0.2053, PSNR: 16.22 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_3
Saved 5 samples at epoch 3 to samples/epoch_3_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.22it/s, loss=0.172, psnr=16.7]


[Epoch 3/20] Train Loss: 0.1887, PSNR: 16.40 | Val Loss: 0.1680, PSNR: 16.76 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.17it/s, loss=0.142, psnr=17.4]


[Epoch 4/20] Train Loss: 0.1521, PSNR: 17.13 | Val Loss: 0.1478, PSNR: 17.23 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_5
Saved 5 samples at epoch 5 to samples/epoch_5_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.19it/s, loss=0.138, psnr=17.5]


[Epoch 5/20] Train Loss: 0.1391, PSNR: 17.62 | Val Loss: 0.1304, PSNR: 17.92 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.18it/s, loss=0.118, psnr=18.3]


[Epoch 6/20] Train Loss: 0.1230, PSNR: 18.19 | Val Loss: 0.1189, PSNR: 18.39 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_7
Saved 5 samples at epoch 7 to samples/epoch_7_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.21it/s, loss=0.107, psnr=18.8] 


[Epoch 7/20] Train Loss: 0.1110, PSNR: 18.71 | Val Loss: 0.1053, PSNR: 18.94 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.17it/s, loss=0.109, psnr=18.9] 


[Epoch 8/20] Train Loss: 0.1042, PSNR: 19.08 | Val Loss: 0.0959, PSNR: 19.41 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_9
Saved 5 samples at epoch 9 to samples/epoch_9_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.22it/s, loss=0.0853, psnr=20.1]


[Epoch 9/20] Train Loss: 0.0934, PSNR: 19.60 | Val Loss: 0.0910, PSNR: 19.78 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.06it/s, loss=0.0768, psnr=20.5]


[Epoch 10/20] Train Loss: 0.0850, PSNR: 20.06 | Val Loss: 0.0816, PSNR: 20.23 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_11
Saved 5 samples at epoch 11 to samples/epoch_11_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.24it/s, loss=0.0966, psnr=19.5]


[Epoch 11/20] Train Loss: 0.0853, PSNR: 20.12 | Val Loss: 0.0982, PSNR: 19.47 | lr: 0.01


Training: 100%|██████████| 67/67 [00:21<00:00,  3.12it/s, loss=0.0784, psnr=20.6]


[Epoch 12/20] Train Loss: 0.0802, PSNR: 20.41 | Val Loss: 0.0752, PSNR: 20.73 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_13
Saved 5 samples at epoch 13 to samples/epoch_13_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.12it/s, loss=0.0838, psnr=20.4]


[Epoch 13/20] Train Loss: 0.0742, PSNR: 20.79 | Val Loss: 0.0857, PSNR: 20.10 | lr: 0.01


Training: 100%|██████████| 67/67 [00:20<00:00,  3.21it/s, loss=0.0692, psnr=21.1]


[Epoch 14/20] Train Loss: 0.0765, PSNR: 20.70 | Val Loss: 0.0677, PSNR: 21.20 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_15
Saved 5 samples at epoch 15 to samples/epoch_15_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.20it/s, loss=0.0649, psnr=21.4]


[Epoch 15/20] Train Loss: 0.0680, PSNR: 21.22 | Val Loss: 0.0668, PSNR: 21.33 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.21it/s, loss=0.0674, psnr=21.3]


[Epoch 16/20] Train Loss: 0.0726, PSNR: 20.99 | Val Loss: 0.0652, PSNR: 21.40 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_17
Saved 5 samples at epoch 17 to samples/epoch_17_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.24it/s, loss=0.0569, psnr=22.1]


[Epoch 17/20] Train Loss: 0.0651, PSNR: 21.45 | Val Loss: 0.0637, PSNR: 21.56 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:20<00:00,  3.20it/s, loss=0.0591, psnr=21.9]


[Epoch 18/20] Train Loss: 0.0638, PSNR: 21.58 | Val Loss: 0.0611, PSNR: 21.79 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_19
Saved 5 samples at epoch 19 to samples/epoch_19_samples.png
Model checkpoint saved at ckpt/ckpt_0


Training: 100%|██████████| 67/67 [00:21<00:00,  3.14it/s, loss=32.1, psnr=-25.1] 


[Epoch 19/20] Train Loss: 1.6615, PSNR: 17.34 | Val Loss: 12.1742, PSNR: -20.29 | lr: 0.01


Training: 100%|██████████| 67/67 [00:21<00:00,  3.10it/s, loss=27.5, psnr=-20.1]   


[Epoch 20/20] Train Loss: 299586.6519, PSNR: -26.73 | Val Loss: 7.6428, PSNR: -17.63 | lr: 0.01
Model checkpoint saved at ckpt/ckpt_21
Saved 5 samples at epoch 21 to samples/epoch_21_samples.png
