In [None]:
import os
import glob
from PIL import Image, ImageFilter
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from tqdm import tqdm

In [None]:
# ===========================================
# Download LSDIR from Hugging Face into Colab
# ===========================================
!pip install datasets --quiet

import os
from datasets import load_dataset
from pathlib import Path
from PIL import Image

# 1) Define where to store the dataset locally
dataset_dir = "/content/LSDIR"
os.makedirs(dataset_dir, exist_ok=True)

# 2) Load LSDIR from Hugging Face
# Using "danjacobellis/LSDIR" which contains LR-HR pairs and ~85k images
dataset = load_dataset("danjacobellis/LSDIR")

# 3) Save LR-HR images locally for training
def save_images(split="train"):
    lr_folder = os.path.join(dataset_dir, split, "LR", "x2")
    hr_folder = os.path.join(dataset_dir, split, "HR")
    os.makedirs(lr_folder, exist_ok=True)
    os.makedirs(hr_folder, exist_ok=True)

    for i, item in enumerate(dataset[split]):
        # item has keys like 'lr' and 'hr' as bytes
        lr_img = Image.open(item['lr'])
        hr_img = Image.open(item['hr'])
        lr_img.save(os.path.join(lr_folder, f"{i:05d}.png"))
        hr_img.save(os.path.join(hr_folder, f"{i:05d}.png"))

    return lr_folder, hr_folder

lr_train_folder, hr_train_folder = save_images("train")
lr_test_folder, hr_test_folder = save_images("test")

# 4) Verify
print("LR Train Folder:", lr_train_folder)
print("HR Train Folder:", hr_train_folder)
print("LR Test Folder:", lr_test_folder)
print("HR Test Folder:", hr_test_folder)

print("Sample LR images:", os.listdir(lr_train_folder)[:5])
print("Sample HR images:", os.listdir(hr_train_folder)[:5])


In [None]:
# ===========================================
# Super-Resolution Full Pipeline for LSDIR ×2 (No Augmentation)
# ===========================================

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ----------------------------------------
# 1) Dataset class for pre-built LR-HR pairs (no augmentation)
# ----------------------------------------
class LSDIRDataset(Dataset):
    def __init__(self, lr_folder, hr_folder):
        self.lr_files = sorted(glob.glob(os.path.join(lr_folder, "*.*")))
        self.hr_files = sorted(glob.glob(os.path.join(hr_folder, "*.*")))
        if len(self.lr_files) == 0 or len(self.hr_files) == 0:
            raise ValueError("No valid images found in LR or HR folder.")
        assert len(self.lr_files) == len(self.hr_files), "LR and HR folder must have same number of images"

        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr = Image.open(self.lr_files[idx]).convert("RGB")
        hr = Image.open(self.hr_files[idx]).convert("RGB")
        lr = self.to_tensor(lr)
        hr = self.to_tensor(hr)
        return lr, hr

# ----------------------------------------
# 2) Channel Attention module
# ----------------------------------------
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b,c,_,_ = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y

# ----------------------------------------
# 3) TinyESPCN Enhanced Model
# ----------------------------------------
class TinyESPCNEnhanced(nn.Module):
    def __init__(self, scale=2, use_attention=True):
        super().__init__()
        self.scale = scale
        self.use_attention = use_attention

        self.conv1 = nn.Conv2d(3,64,7,1,3)
        self.res_blocks = nn.Sequential(*[nn.Sequential(nn.Conv2d(64,64,3,1,1), nn.ReLU()) for _ in range(10)])
        if use_attention:
            self.attention = ChannelAttention(64)
        self.conv2 = nn.Conv2d(64, 3*(scale**2), 3,1,1)
        self.pixel_shuffle = nn.PixelShuffle(scale)

    def forward(self, x):
        lr_input = x
        x1 = F.relu(self.conv1(x))
        x2 = self.res_blocks(x1)
        if self.use_attention:
            x2 = self.attention(x2)
        x = self.pixel_shuffle(self.conv2(x2+x1))
        lr_up = F.interpolate(lr_input, scale_factor=self.scale, mode='bicubic', align_corners=False)
        return torch.clamp(x+lr_up,0,1)

# ----------------------------------------
# 4) Enhanced Loss (Perceptual + Edge + Lab)
# ----------------------------------------
class EnhancedLoss(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.eval()
        for p in vgg.parameters(): p.requires_grad=False
        self.vgg = vgg.to(device)
        self.device = device
        self.layers = [2,7,12]

        sobel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32)
        laplacian = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=torch.float32)

        self.sobel_x = sobel_x.view(1,1,3,3).repeat(3,1,1,1).to(device)
        self.sobel_y = sobel_y.view(1,1,3,3).repeat(3,1,1,1).to(device)
        self.laplacian = laplacian.view(1,1,3,3).repeat(3,1,1,1).to(device)

    def forward(self,sr,hr):
        sr = torch.clamp(sr,0,1)
        hr = torch.clamp(hr,0,1)

        mean = torch.tensor([0.485,0.456,0.406],device=self.device).view(1,3,1,1)
        std = torch.tensor([0.229,0.224,0.225],device=self.device).view(1,3,1,1)
        sr_vgg = (sr-mean)/std
        hr_vgg = (hr-mean)/std

        loss=0
        sr_f, hr_f = sr_vgg, hr_vgg
        for i,layer in enumerate(self.vgg):
            sr_f = layer(sr_f)
            hr_f = layer(hr_f)
            if i in self.layers:
                loss += F.l1_loss(sr_f, hr_f)

        grad_x_sr = F.conv2d(sr, self.sobel_x, padding=1, groups=3)
        grad_y_sr = F.conv2d(sr, self.sobel_y, padding=1, groups=3)
        grad_x_hr = F.conv2d(hr, self.sobel_x, padding=1, groups=3)
        grad_y_hr = F.conv2d(hr, self.sobel_y, padding=1, groups=3)
        edge_loss = F.l1_loss(grad_x_sr,grad_x_hr)+F.l1_loss(grad_y_sr,grad_y_hr)

        lap_sr = F.conv2d(sr,self.laplacian,padding=1,groups=3)
        lap_hr = F.conv2d(hr,self.laplacian,padding=1,groups=3)
        edge_loss += F.l1_loss(lap_sr, lap_hr)

        loss += 0.2 * edge_loss

        sr_lab = rgb_to_lab(sr)
        hr_lab = rgb_to_lab(hr)
        loss += 0.1 * F.l1_loss(sr_lab, hr_lab)

        return loss

def rgb_to_lab(tensor):
    from skimage import color
    B,C,H,W = tensor.shape
    lab=[]
    for i in range(B):
        img = tensor[i].detach().permute(1,2,0).cpu().numpy()
        lab_img = color.rgb2lab(img)
        lab.append(torch.tensor(lab_img, device=tensor.device).permute(2,0,1))
    return torch.stack(lab)

# ----------------------------------------
# 5) Training function
# ----------------------------------------
def train_model(model, dataloader, epochs=50, lr=1e-3, device='cuda'):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = EnhancedLoss(device=device)
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(dataloader)
        for lr_img, hr_img in pbar:
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            optimizer.zero_grad()
            sr = model(lr_img)
            loss = criterion(sr, hr_img)
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Epoch {epoch+1}/{epochs} Loss:{loss.item():.6f}")
    return model

# ----------------------------------------
# 6) Upscale, sharpen, and MSAA post-processing
# ----------------------------------------
def upscale_images(model, files, out_folder="upscaled", device='cuda', sharpen=True):
    os.makedirs(out_folder, exist_ok=True)
    model.eval()
    to_tensor = transforms.ToTensor()
    to_pil = transforms.ToPILImage()
    with torch.no_grad():
        for path in tqdm(files):
            img = Image.open(path).convert("RGB")
            lr = to_tensor(img).unsqueeze(0).to(device)
            sr = model(lr).clamp(0,1).cpu().squeeze(0)
            out_img = to_pil(sr)
            if sharpen:
                out_img = out_img.filter(ImageFilter.UnsharpMask(radius=1.2, percent=100, threshold=1))
            out_img.save(os.path.join(out_folder, os.path.basename(path)))

def msaa_postprocess(input_folder="upscaled", output_folder="upscaled_aa", supersample=2):
    os.makedirs(output_folder, exist_ok=True)
    files = glob.glob(os.path.join(input_folder,"*.*"))
    files = [f for f in files if f.lower().endswith((".png",".jpg",".jpeg"))]
    for path in tqdm(files):
        img = Image.open(path).convert("RGB")
        w,h = img.size
        img = img.resize((w*supersample,h*supersample), Image.LANCZOS)
        img = img.filter(ImageFilter.GaussianBlur(0.5))
        img = img.resize((w,h), Image.LANCZOS)
        img.save(os.path.join(output_folder, os.path.basename(path)))

# ----------------------------------------
# 7) Initialize dataset, DataLoader, and model
# ----------------------------------------
scale = 2
batch_size = 16
epochs = 50

lr_folder = "lr_train_folder"
hr_folder = "hr_train_folder"

dataset = LSDIRDataset(lr_folder, hr_folder)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
print(f"Dataset created with {len(dataset)} samples")

model = TinyESPCNEnhanced(scale=scale)
print("Model initialized and ready for training")

# ----------------------------------------
# 8) Start training
# ----------------------------------------
trained_model = train_model(model, loader, epochs=epochs, lr=1e-4, device=device)
print("Training complete")

In [None]:
torch.save(model.state_dict(), "tiny_espcn_lsdir.pth")
print("Model weights saved successfully.")


Model weights saved successfully.


In [None]:
torch.save(model, "tiny_espcn_celeba_full.pth")
print("Model saved successfully.")

Model saved successfully.
