In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio
import numpy as np
from MobileSR import MobileSR

class SuperResolutionDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.transform = transform
        self.hr_images = sorted(os.listdir(hr_dir))
        self.lr_images = sorted(os.listdir(lr_dir))

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_image_path = os.path.join(self.hr_dir, self.hr_images[idx])
        lr_image_path = os.path.join(self.lr_dir, self.lr_images[idx])

        hr_image = Image.open(hr_image_path).convert("RGB")
        lr_image = Image.open(lr_image_path).convert("RGB")

        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)

        return lr_image, hr_image


def compute_psnr(outputs, hr_images):
    """Compute the PSNR for the outputs and ground truth high-resolution images."""
    outputs_np = outputs.cpu().detach().numpy()
    hr_images_np = hr_images.cpu().detach().numpy()
    psnr = peak_signal_noise_ratio(hr_images_np, outputs_np, data_range=255.0)
    return psnr


# Validation function
def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0
    val_psnr = 0

    with torch.no_grad():
        for lr_images, hr_images in tqdm(dataloader, desc="Validation"):
            lr_images, hr_images = lr_images.to(device), hr_images.to(device)
            outputs = model(lr_images)
            loss = criterion(outputs, hr_images)
            val_loss += loss.item()
            val_psnr += compute_psnr(outputs, hr_images)

    avg_loss = val_loss / len(dataloader)
    avg_psnr = val_psnr / len(dataloader)
    return avg_loss, avg_psnr


# Training function
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    epoch_psnr = 0

    for lr_images, hr_images in tqdm(dataloader, desc="Training"):
        lr_images, hr_images = lr_images.to(device), hr_images.to(device)
        optimizer.zero_grad()
        outputs = model(lr_images)
        loss = criterion(outputs, hr_images)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_psnr += compute_psnr(outputs, hr_images)

    avg_loss = epoch_loss / len(dataloader)
    avg_psnr = epoch_psnr / len(dataloader)
    return avg_loss, avg_psnr


# Training and validation loop
def train_and_validate(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, save_folder, device):
    os.makedirs(save_folder, exist_ok=True)

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        # Train
        train_loss, train_psnr = train_one_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Train Loss: {train_loss:.4f}, Train PSNR: {train_psnr:.2f}")

        # Validate
        val_loss, val_psnr = validate_one_epoch(model, val_loader, criterion, device)
        print(f"Val Loss: {val_loss:.4f}, Val PSNR: {val_psnr:.2f}")

        # Step Scheduler
        scheduler.step()

        # Save Model Checkpoint
        checkpoint_path = os.path.join(save_folder, f"model_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Saved model checkpoint to {checkpoint_path}")


# Define augmentation pipelines
transform_train_aug1 = transforms.Compose([
    transforms.RandomHorizontalFlip(p=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_train_aug2 = transforms.Compose([
    transforms.RandomVerticalFlip(p=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_train_aug3 = transforms.Compose([
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Validation transform (no augmentation)
transform_val = transforms.Compose([
    transforms.ToTensor(),
])

# Create datasets with different augmentation pipelines
train_dataset1 = SuperResolutionDataset(
    hr_dir="../BCCD_mask/ssr_small/train_HR",
    lr_dir="../BCCD_mask/ssr_small/train_LR_bicubic/X4",
    transform=transform_train_aug1,
)

train_dataset2 = SuperResolutionDataset(
    hr_dir="../BCCD_mask/ssr_small/train_HR",
    lr_dir="../BCCD_mask/ssr_small/train_LR_bicubic/X4",
    transform=transform_train_aug2,
)

train_dataset3 = SuperResolutionDataset(
    hr_dir="../BCCD_mask/ssr_small/train_HR",
    lr_dir="../BCCD_mask/ssr_small/train_LR_bicubic/X4",
    transform=transform_train_aug3,
)

# Combine training datasets into one
combined_train_dataset = ConcatDataset([train_dataset1, train_dataset2, train_dataset3])

# Create DataLoaders
train_loader = DataLoader(combined_train_dataset, batch_size=4, shuffle=True)

val_dataset = SuperResolutionDataset(
    hr_dir="../BCCD_mask/ssr_small/val_HR",
    lr_dir="../BCCD_mask/ssr_small/val_LR_bicubic/X4",
    transform=transform_val,
)

val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# Model, optimizer, and training configuration
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# Training configuration
model = MobileSR(n_feats=40, n_heads=8, ratios=[4, 2, 2, 2, 4], upscaling_factor=4).to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200000, gamma=0.5)

# Train and validate the model
save_folder = "20epochs_combineddataset_nopretrain_Normalization"
train_and_validate(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=100, save_folder=save_folder, device=device)
