In [None]:
import os
import glob
from PIL import Image, ImageFilter
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from tqdm import tqdm

In [None]:
import os
import zipfile
import shutil
import itertools

# -------------------------
# 1️⃣ Configuration
# -------------------------
zip_path = "/content/drive/MyDrive/img_align_celeba.zip"  # path to your uploaded ZIP
extract_root = "/content/celeba_images"
subset_folder = "/content/celeba_images/subset_25k"

os.makedirs(extract_root, exist_ok=True)
os.makedirs(subset_folder, exist_ok=True)

# -------------------------
# 2️⃣ Extract ZIP
# -------------------------
print("📂 Extracting CelebA zip...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_root)

# -------------------------
# 3️⃣ Identify full image folder
# -------------------------
image_folder = os.path.join(extract_root, "img_align_celeba")
if not os.path.exists(image_folder):
    raise FileNotFoundError("Extracted image folder not found!")

all_images = sorted(os.listdir(image_folder))
print(f"Total images in full dataset: {len(all_images)}")

# -------------------------
# 4️⃣ Copy first 25k images to subset
# -------------------------
subset_images = itertools.islice(all_images, 25000)

for img in subset_images:
    shutil.copy(os.path.join(image_folder, img), subset_folder)

print(f"Subset created with {len(os.listdir(subset_folder))} images")
print(f"Subset folder path: {subset_folder}")


📂 Extracting CelebA zip...
Total images in full dataset: 202599
Subset created with 25000 images
Subset folder path: /content/celeba_images/subset_25k


In [None]:
# ===========================================
# Super-Resolution Full Pipeline (CelebA 2x, Local Images)
# ===========================================

import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ----------------------------------------
# 1) Use local CelebA folder
# ----------------------------------------
dataset_folder = "/content/celeba_images/subset_25k"  # your local folder with images
all_images = [f for f in os.listdir(dataset_folder) if f.lower().endswith(".jpg")]
print(f"Found {len(all_images)} images in: {dataset_folder}")

# ----------------------------------------
# 2) Dataset class for LR-HR pairs
# ----------------------------------------
class CelebADataset(Dataset):
    def __init__(self, root, scale=2, crop_size=64, max_images=None):
        self.files = sorted([os.path.join(root,f) for f in os.listdir(root) if f.lower().endswith(".jpg")])
        if max_images is not None:
            self.files = self.files[:max_images]
        self.scale = scale
        self.crop_size = crop_size
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        w,h = img.size
        if w < self.crop_size or h < self.crop_size:
            img = img.resize((max(w,self.crop_size), max(h,self.crop_size)), Image.BICUBIC)

        # Random crop
        x = torch.randint(0, img.width - self.crop_size + 1, (1,)).item()
        y = torch.randint(0, img.height - self.crop_size + 1, (1,)).item()
        hr = img.crop((x, y, x+self.crop_size, y+self.crop_size))
        lr = hr.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)

        return self.to_tensor(lr), self.to_tensor(hr)

# ----------------------------------------
# 3) DataLoader setup
# ----------------------------------------
scale = 2
crop_size = 64
batch_size = 16
epochs = 50

dataset = CelebADataset(dataset_folder, scale=scale, crop_size=crop_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
print(f"DataLoader ready with {len(dataset)} samples")

# ----------------------------------------
# 4) Channel Attention module
# ----------------------------------------
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b,c,_,_ = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y

# ----------------------------------------
# 5) TinyESPCN Enhanced Model
# ----------------------------------------
class TinyESPCNEnhanced(nn.Module):
    def __init__(self, scale=2, use_attention=True):
        super().__init__()
        self.scale = scale
        self.use_attention = use_attention

        self.conv1 = nn.Conv2d(3,64,7,1,3)
        self.res_blocks = nn.Sequential(*[nn.Sequential(nn.Conv2d(64,64,3,1,1), nn.ReLU()) for _ in range(10)])
        if use_attention:
            self.attention = ChannelAttention(64)
        self.conv2 = nn.Conv2d(64, 3*(scale**2), 3,1,1)
        self.pixel_shuffle = nn.PixelShuffle(scale)

    def forward(self, x):
        lr_input = x
        x1 = F.relu(self.conv1(x))
        x2 = self.res_blocks(x1)
        if self.use_attention:
            x2 = self.attention(x2)
        x = self.pixel_shuffle(self.conv2(x2+x1))
        lr_up = F.interpolate(lr_input, scale_factor=self.scale, mode='bicubic', align_corners=False)
        return torch.clamp(x+lr_up,0,1)

# ----------------------------------------
# 6) Enhanced Loss (Perceptual + Edge + Lab)
# ----------------------------------------
class EnhancedLoss(nn.Module):
    def __init__(self, device='cuda'):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.eval()
        for p in vgg.parameters(): p.requires_grad=False
        self.vgg = vgg.to(device)
        self.device = device
        self.layers = [2,7,12]

        sobel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32)
        laplacian = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=torch.float32)

        self.sobel_x = sobel_x.view(1,1,3,3).repeat(3,1,1,1).to(device)
        self.sobel_y = sobel_y.view(1,1,3,3).repeat(3,1,1,1).to(device)
        self.laplacian = laplacian.view(1,1,3,3).repeat(3,1,1,1).to(device)

    def forward(self,sr,hr):
        sr = torch.clamp(sr,0,1)
        hr = torch.clamp(hr,0,1)

        mean = torch.tensor([0.485,0.456,0.406],device=self.device).view(1,3,1,1)
        std = torch.tensor([0.229,0.224,0.225],device=self.device).view(1,3,1,1)
        sr_vgg = (sr-mean)/std
        hr_vgg = (hr-mean)/std

        loss=0
        sr_f, hr_f = sr_vgg, hr_vgg
        for i,layer in enumerate(self.vgg):
            sr_f = layer(sr_f)
            hr_f = layer(hr_f)
            if i in self.layers:
                loss += F.l1_loss(sr_f, hr_f)

        grad_x_sr = F.conv2d(sr, self.sobel_x, padding=1, groups=3)
        grad_y_sr = F.conv2d(sr, self.sobel_y, padding=1, groups=3)
        grad_x_hr = F.conv2d(hr, self.sobel_x, padding=1, groups=3)
        grad_y_hr = F.conv2d(hr, self.sobel_y, padding=1, groups=3)
        edge_loss = F.l1_loss(grad_x_sr,grad_x_hr)+F.l1_loss(grad_y_sr,grad_y_hr)

        lap_sr = F.conv2d(sr,self.laplacian,padding=1,groups=3)
        lap_hr = F.conv2d(hr,self.laplacian,padding=1,groups=3)
        edge_loss += F.l1_loss(lap_sr, lap_hr)

        loss += 0.2 * edge_loss

        sr_lab = rgb_to_lab(sr)
        hr_lab = rgb_to_lab(hr)
        loss += 0.1 * F.l1_loss(sr_lab, hr_lab)

        return loss

def rgb_to_lab(tensor):
    from skimage import color
    B,C,H,W = tensor.shape
    lab=[]
    for i in range(B):
        img = tensor[i].detach().permute(1,2,0).cpu().numpy()
        lab_img = color.rgb2lab(img)
        lab.append(torch.tensor(lab_img, device=tensor.device).permute(2,0,1))
    return torch.stack(lab)

# ----------------------------------------
# 7) Training function
# ----------------------------------------
def train_model(model, dataloader, epochs=50, lr=1e-3, device='cuda'):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = EnhancedLoss(device=device)
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(dataloader)
        for lr_img, hr_img in pbar:
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            optimizer.zero_grad()
            sr = model(lr_img)
            loss = criterion(sr, hr_img)
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Epoch {epoch+1}/{epochs} Loss:{loss.item():.6f}")
    return model

# ----------------------------------------
# 8) Initialize and train model
# ----------------------------------------
model = TinyESPCNEnhanced(scale=scale)
print("Model initialized, starting training...")
model = train_model(model, loader, epochs=50, lr=1e-3, device=device)
print("Training complete!")


Using device: cuda
Found 25000 images in: /content/celeba_images/subset_25k
DataLoader ready with 25000 samples
Model initialized, starting training...
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:07<00:00, 72.9MB/s]
Epoch 1/50 Loss:2.950243: 100%|██████████| 1563/1563 [01:44<00:00, 14.99it/s]
Epoch 2/50 Loss:2.959672: 100%|██████████| 1563/1563 [01:44<00:00, 14.93it/s]
Epoch 3/50 Loss:3.012973: 100%|██████████| 1563/1563 [01:42<00:00, 15.21it/s]
Epoch 4/50 Loss:2.712124: 100%|██████████| 1563/1563 [01:41<00:00, 15.41it/s]
Epoch 5/50 Loss:2.699619: 100%|██████████| 1563/1563 [01:42<00:00, 15.28it/s]
Epoch 6/50 Loss:2.529857: 100%|██████████| 1563/1563 [01:42<00:00, 15.32it/s]
Epoch 7/50 Loss:2.502214: 100%|██████████| 1563/1563 [01:40<00:00, 15.49it/s]
Epoch 8/50 Loss:2.545166: 100%|██████████| 1563/1563 [01:40<00:00, 15.51it/s]
Epoch 9/50 Loss:2.199218: 100%|██████████| 1563/1563 [01:40<00:00, 15.59it/s]
Epoch 10/50 Loss:3.131749: 100%|██████████| 1563/1563 [01:40<00:00, 15.51it/s]
Epoch 11/50 Loss:2.349300: 100%|██████████| 1563/1563 [01:40<00:00, 15.50it/s]
Epoch 12/50 Loss:2.250749: 100%|██████████| 1563/1563 [01:40<00:00, 15.57it/s]
Epoch 13/5

Training complete!





In [None]:
torch.save(model.state_dict(), "tiny_espcn_celeba.pth")
print("Model weights saved successfully.")


Model weights saved successfully.


In [None]:
torch.save(model, "tiny_espcn_celeba_full.pth")
print("Model saved successfully.")

Model saved successfully.
