In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from espcn import espcn_x4

class PairedDataset(Dataset):
    def __init__(self, lr_dir, hr_dir):
        self.lr_files = sorted([os.path.join(lr_dir, f) for f in os.listdir(lr_dir)])
        self.hr_files = sorted([os.path.join(hr_dir, f) for f in os.listdir(hr_dir)])

        # Both images to tensor, normalized [0,1]
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr = Image.open(self.lr_files[idx]).convert("RGB")
        hr = Image.open(self.hr_files[idx]).convert("RGB")

        lr = self.to_tensor(lr)  # (3, H, W)
        hr = self.to_tensor(hr)  # (3, H, W)

        return lr, hr


In [None]:
class ImageDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform

        # Only keep valid image files
        valid_exts = (".png", ".jpg", ".jpeg", ".bmp")
        self.lr_images = [f for f in os.listdir(lr_dir) if f.lower().endswith(valid_exts)]
        self.hr_images = [f for f in os.listdir(hr_dir) if f.lower().endswith(valid_exts)]

        # Sort to ensure matching order
        self.lr_images.sort()
        self.hr_images.sort()

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_images[idx])

        lr = Image.open(lr_path).convert("RGB")
        hr = Image.open(hr_path).convert("RGB")

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr

In [None]:
# Paths
lr_dir = "train_LR"
hr_dir = "train_HR"

transform = transforms.ToTensor()  # Converts to tensor in [0,1]

dataset = ImageDataset(lr_dir, hr_dir, transform=transform)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)


# Model   # use your file, e.g., espcn.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = espcn_x4(in_channels=3, out_channels=3, channels=64).to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for lr, hr in train_loader:
        lr, hr = lr.to(device), hr.to(device)

        sr = model(lr)              # super-resolved output
        loss = criterion(sr, hr)    # pixel-wise MSE loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.6f}")

    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"epoch_result/espcn_epoch{epoch+1}.pth")


In [None]:
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

model.eval()
total_psnr, total_ssim = 0, 0

with torch.no_grad():
    for lr, hr in val_loader:
        lr, hr = lr.to(device), hr.to(device)
        sr = model(lr)

        sr = torch.clamp(sr, 0, 1)  # Ensure values in range
        sr_np = sr.squeeze(0).permute(1,2,0).cpu().numpy()
        hr_np = hr.squeeze(0).permute(1,2,0).cpu().numpy()

        total_psnr += psnr(hr_np, sr_np, data_range=1.0)
        total_ssim += ssim(hr_np, sr_np, channel_axis=2, data_range=1.0)

print("PSNR:", total_psnr/len(val_loader))
print("SSIM:", total_ssim/len(val_loader))
