In [1]:
import json
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from model.ESRGAN import RRDBNet, Discriminator_VGG, train_esrgan, VGGFeatureExtractor
from model.ResNetSR import ResNetSR
from data_utils.dataset import LocalImageDataset, get_split_indices
from model.utils import get_device, train_model_single_epoch, validate_model_single_epoch, save_checkpoint, save_samples, CombinedLoss

In [2]:
generator = ResNetSR(upscale_factor=2, num_res_blocks=2, num_channels=1, num_features=32)
discriminator = Discriminator_VGG()
vgg_extractor = VGGFeatureExtractor()



In [3]:
with open("data/test_indices.json", "r") as f:
    test_indices = json.load(f)
len(test_indices)

594

In [4]:
INPUT_DIR = "data/resolution_128"
TARGET_DIR = "data/resolution_256"
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 0.01

GRAD_CLIP = 1
EDGE_WEIGHT = 0.3
PSNR_WEIGHT = 0.3

history = {
    "train_loss": [],
    "train_psnr": [],
    "val_loss": [],
    "val_psnr": []
}

In [5]:
num_images = len([f for f in os.listdir(INPUT_DIR) if f.endswith(('.jpg', '.png', '.jpeg', '.webp'))])
train_indices, val_indices = get_split_indices(num_images, test_indices)

In [6]:
train_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, train_indices), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, val_indices), batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=BATCH_SIZE, shuffle=False)

In [7]:
g_optimizer = optim.AdamW(generator.parameters(), lr=LEARNING_RATE)
d_optimizer = optim.AdamW(vgg_extractor.parameters(), lr=LEARNING_RATE)

In [8]:
train_esrgan(
    generator,
    discriminator,
    vgg_extractor,
    train_loader,
    val_loader,
    g_optimizer,
    d_optimizer,
    num_epochs=100,
    device="cpu",
)

1111


RuntimeError: Given groups=1, weight of size [32, 1, 3, 3], expected input[64, 3, 128, 128] to have 1 channels, but got 3 channels instead