In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import torchmetrics

# Dataset
class SuperResolutionDataset(Dataset):
    def __init__(self, low_res_dir, high_res_dir, subset_size=None):
        self.low_res_dir = low_res_dir
        self.high_res_dir = high_res_dir
        self.low_res_images = sorted(os.listdir(low_res_dir))
        self.high_res_images = sorted(os.listdir(high_res_dir))
        if subset_size:
            self.low_res_images = self.low_res_images[:subset_size]
            self.high_res_images = self.high_res_images[:subset_size]
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.low_res_images)

    def __getitem__(self, index):
        low_res = Image.open(os.path.join(self.low_res_dir, self.low_res_images[index])).convert("RGB")
        high_res = Image.open(os.path.join(self.high_res_dir, self.high_res_images[index])).convert("RGB")
        return self.transform(low_res), self.transform(high_res)


# Shallow Feature Extractor
class ShallowFeatureExtractor(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super().__init__()
        self.head = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.head(x)

# CRRB
class CRRB(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv_high = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
        self.downsample = nn.Conv2d(channels, channels, 4, 2, 1)
        self.upsample = nn.ConvTranspose2d(channels, channels, 4, 2, 1)
        self.conv_low = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.ReLU(inplace=True)
        )

    def forward(self, f_high, f_low):
        high_feat = self.conv_high(f_high)

        # Save original size of f_high
        target_size = high_feat.shape[2:]

        # Pad f_low so it's divisible by 2 before downsampling
        h, w = f_low.shape[2], f_low.shape[3]
        pad_h = (2 - h % 2) % 2
        pad_w = (2 - w % 2) % 2
        f_low_padded = F.pad(f_low, (0, pad_w, 0, pad_h), mode='reflect')

        # Downsample and upsample
        low_down = self.downsample(f_low_padded)
        low_feat = self.upsample(self.conv_low(low_down))

        # Final: match exactly to f_high size
        low_feat = F.interpolate(low_feat, size=target_size, mode='bilinear', align_corners=False)

        return high_feat + low_feat



# RAB
class RAB(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels * 2, channels // 2, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 2, channels * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, f1, f2): 
        f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=False)
        fused = torch.cat([f1, f2], dim=1)
        weights = self.fc(self.pool(fused))
        w1, w2 = torch.chunk(weights, 2, dim=1)
        return f1 * w1 + f2 * w2

# CRFAN Model
class CRFAN(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_features=64):
        super().__init__()
        self.shallow_feat = ShallowFeatureExtractor(in_channels, num_features)

        self.up2 = nn.ConvTranspose2d(num_features, num_features, 4, 2, 1)
        self.crrb2 = nn.Sequential(*[CRRB(num_features) for _ in range(4)])

        self.up4 = nn.ConvTranspose2d(num_features, num_features, 4, 2, 1)
        self.crrb4 = nn.Sequential(*[CRRB(num_features) for _ in range(2)])

        self.up8 = nn.ConvTranspose2d(num_features, num_features, 4, 2, 1)
        self.crrb8 = CRRB(num_features)

        self.rab1 = RAB(num_features)
        self.rab2 = RAB(num_features)
        self.reconstruct = nn.Conv2d(num_features, out_channels, 3, 1, 1)

    def forward(self, x):
        x1 = self.shallow_feat(x)

        x2 = self.up2(x1)
        for block in self.crrb2:
            x2 = block(x2, x1)

        x4 = self.up4(x2)
        for block in self.crrb4:
            x4 = block(x4, x2)

        x8 = self.up8(x4)
        x8 = self.crrb8(x8, x4)

        fuse1 = self.rab1(x4, x2)
        fuse2 = self.rab2(x8, fuse1)

        out = self.reconstruct(fuse2)

        # Residual connection with input upsampled Ã—8
        input_upsampled = F.interpolate(x, size=out.shape[2:], mode='bicubic', align_corners=False)
        return out + input_upsampled

# Training
def train_model(model, dataloader, num_epochs=10, lr=1e-5, device='cuda', accumulation_steps=4):
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scaler = torch.cuda.amp.GradScaler()
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    
    criterion = nn.MSELoss().to(device)

    model.to(device)
    model.train()
    best_psnr = 0.0
    best_epoch = 0

    for epoch in range(num_epochs):
        epoch_loss, epoch_psnr, epoch_ssim = 0, 0, 0
        optimizer.zero_grad()
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)

        for step, (lr_img, hr_img) in enumerate(pbar):
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)

            with torch.cuda.amp.autocast():
                out = model(lr_img)
                out_resized = F.interpolate(out, size=hr_img.shape[2:], mode='bicubic')
                loss = criterion(out_resized, hr_img)
                mse = F.mse_loss(out_resized, hr_img)
                psnr = 10 * torch.log10(1.0 / mse)
                ssim_val = ssim_metric(out_resized, hr_img)

            scaler.scale(loss).backward()

            if (step + 1) % accumulation_steps == 0 or (step + 1) == len(dataloader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            epoch_loss += loss.item()
            epoch_psnr += psnr.item()
            epoch_ssim += ssim_val.item()

            pbar.set_postfix({
                "Loss": f"{epoch_loss / (step + 1):.4f}",
                "PSNR": f"{epoch_psnr / (step + 1):.2f}",
                "SSIM": f"{epoch_ssim / (step + 1):.4f}"
            })

        avg_psnr = epoch_psnr / len(dataloader)
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            best_epoch = epoch + 1
            torch.save(model.state_dict(), "crfan_best_model.pth")
            print(f"\nðŸ“¦ Best model saved at epoch {best_epoch} with PSNR: {best_psnr:.2f} dB")

    torch.save(model.state_dict(), "crfan_model_last_epoch.pth")
    print("\nâœ… Final model saved as crfan_model_last_epoch.pth")

# Main
if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    torch.backends.cudnn.benchmark = True

    low_res_path = "/home/serenag/super_resolution/CalebA/lr_downsampled_images"
    high_res_path = "/home/serenag/super_resolution/CalebA/hr_upsampled_images"

    dataset = SuperResolutionDataset(low_res_path, high_res_path)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)

    model = CRFAN()
    train_model(model, dataloader, num_epochs=10, lr=1e-5)
