In [None]:
BASE_LR = "/kaggle/input/reds-data-sample/DATA/val_sharp_bicubic"
BASE_HR = "/kaggle/input/reds-data-sample/DATA/val_sharp"

In [None]:
import os

from glob import glob

sample_lr = sorted(glob(os.path.join(BASE_LR, "000", "*.png")))[0]
sample_hr = sorted(glob(os.path.join(BASE_HR, "000", "*.png")))[0]

print("Sample LR:", sample_lr)
print("Sample HR:", sample_hr)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

# Pick a specific scene and frame index
scene = "000"
frame_idx = 10

# Build file paths
lr_path = os.path.join(BASE_LR, scene, f"{frame_idx:08d}.png")
hr_path = os.path.join(BASE_HR, scene, f"{frame_idx:08d}.png")

# Load images
lr_img = Image.open(lr_path).convert("RGB")
hr_img = Image.open(hr_path).convert("RGB")

# Plot side-by-side
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(lr_img)
plt.title("Low-Resolution (bicubic)")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(hr_img)
plt.title("High-Resolution (ground truth)")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image
from torchvision.utils import save_image

In [None]:
# === Dataset ===
class PairedImageDataset(torch.utils.data.Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform
        self.filenames = sorted(os.listdir(lr_dir))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.filenames[idx])
        hr_path = os.path.join(self.hr_dir, self.filenames[idx])

        lr = Image.open(lr_path).convert("RGB")
        hr = Image.open(hr_path).convert("RGB")

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr

In [None]:
# === Generator (your model) ===
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return x + self.block(x)

class SimpleSRResNet(nn.Module):
    def __init__(self, upscale_factor=4, num_blocks=5, in_channels=3):
        super().__init__()
        self.entry = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, padding=4),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_blocks)])
        self.mid = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.upscale = nn.Sequential(
            nn.Conv2d(64, 64 * (upscale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor),
            nn.ReLU(inplace=True)
        )
        self.out = nn.Conv2d(64, 3, kernel_size=9, padding=4)

    def forward(self, x):
        x = self.entry(x)
        res = self.res_blocks(x)
        x = self.mid(res) + x
        x = self.upscale(x)
        return self.out(x)

In [None]:
# === Discriminator ===
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool2d((4, 4)),  # Ensures output is always 4x4
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

lr_path = "/kaggle/input/reds-data-sample/DATA/val_sharp_bicubic/000"
hr_path = "/kaggle/input/reds-data-sample/DATA/val_sharp/000"

full_dataset = PairedImageDataset(lr_path, hr_path, transform)
train_len = int(0.8 * len(full_dataset))
val_len = len(full_dataset) - train_len
train_dataset, val_dataset = random_split(full_dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
# === Models ===
generator = SimpleSRResNet().to(device)
discriminator = Discriminator().to(device)

In [None]:
# === Loss and Optimizers ===
criterion_G = nn.MSELoss()
criterion_D = nn.BCELoss()

In [None]:
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

In [None]:
# === Training ===
for epoch in range(1, 3):  # just 2 epochs for test
    generator.train()
    discriminator.train()

    for lr_imgs, hr_imgs in train_loader:
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

        # --- Train discriminator ---
        sr_imgs = generator(lr_imgs)
        real_labels = torch.ones(hr_imgs.size(0), 1).to(device)
        fake_labels = torch.zeros(hr_imgs.size(0), 1).to(device)

        real_output = discriminator(hr_imgs)
        fake_output = discriminator(sr_imgs.detach())

        d_loss_real = criterion_D(real_output, real_labels)
        d_loss_fake = criterion_D(fake_output, fake_labels)
        d_loss = d_loss_real + d_loss_fake

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # --- Train generator ---
        fake_output = discriminator(sr_imgs)
        adv_loss = criterion_D(fake_output, real_labels)
        mse_loss = criterion_G(sr_imgs, hr_imgs)
        g_loss = mse_loss + 1e-3 * adv_loss  # small weight for GAN

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch}: G_Loss={g_loss.item():.4f}, D_Loss={d_loss.item():.4f}")

    # Save example
    generator.eval()
    with torch.no_grad():
        for lr_imgs, _ in val_loader:
            lr_imgs = lr_imgs.to(device)
            sr_imgs = generator(lr_imgs)
            save_image(sr_imgs, f"output_epoch{epoch}.png")
            break  # just one image

In [None]:
def __getitem__(self, idx):
    hr_img = Image.open(self.hr_paths[idx])
    lr_img = Image.open(self.lr_paths[idx])
    
    # Convert to tensors
    hr_tensor = self.transform(hr_img)
    lr_tensor = self.transform(lr_img)

    print("HR:", hr_tensor.shape, "LR:", lr_tensor.shape)  # Add this line to debug

    return lr_tensor, hr_tensor

In [None]:
lr, hr = next(iter(train_loader))
print("LR batch shape:", lr.shape)  # e.g. torch.Size([16, 3, 64, 64])
print("HR batch shape:", hr.shape)  # e.g. torch.Size([16, 3, 256, 256])

In [None]:
def visualize_results(lr_imgs, sr_imgs, hr_imgs, num_images=4):
    lr_imgs = lr_imgs[:num_images]
    sr_imgs = sr_imgs[:num_images]
    hr_imgs = hr_imgs[:num_images]

    # Denormalize if your images are normalized (-1,1) or (0,1)
    def denorm(img):
        return img.clamp(0, 1)

    fig, axs = plt.subplots(num_images, 3, figsize=(10, num_images * 3))
    titles = ['Low-Res Input', 'Super-Resolved', 'High-Res Target']

    for i in range(num_images):
        images = [lr_imgs[i], sr_imgs[i], hr_imgs[i]]
        for j in range(3):
            axs[i, j].imshow(denorm(images[j].cpu().permute(1, 2, 0)).numpy())
            axs[i, j].set_title(titles[j])
            axs[i, j].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class SuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform_lr=None, transform_hr=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_images = sorted(os.listdir(lr_dir))
        self.hr_images = sorted(os.listdir(hr_dir))
        self.transform_lr = transform_lr
        self.transform_hr = transform_hr

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_images[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_images[idx])
        
        lr_image = Image.open(lr_path).convert('RGB')
        hr_image = Image.open(hr_path).convert('RGB')

        if self.transform_lr:
            lr_image = self.transform_lr(lr_image)
        if self.transform_hr:
            hr_image = self.transform_hr(hr_image)

        return lr_image, hr_image

# Paths to your dataset
lr_path = "/kaggle/input/reds-data-sample/DATA/val_sharp_bicubic/000"
hr_path = "/kaggle/input/reds-data-sample/DATA/val_sharp/000"

# Image transforms
transform_lr = transforms.Compose([
    transforms.Resize((64, 64)),   # Adjust as needed
    transforms.ToTensor()
])

transform_hr = transforms.Compose([
    transforms.Resize((256, 256)),  # Adjust as needed
    transforms.ToTensor()
])

# Create dataset and dataloader
train_dataset = SuperResolutionDataset(lr_dir, hr_dir, transform_lr, transform_hr)
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as T

# Dummy data loader
batch_size = 4
lr_imgs = torch.randn(20, 3, 64, 64)  # Low-resolution input
hr_imgs = torch.randn(20, 3, 64, 64)  # Ground truth HR (can be resized if needed)
dataset = TensorDataset(lr_imgs, hr_imgs)
loader = DataLoader(dataset, batch_size=batch_size)

# Generator model (upsampling to 256x256)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),  # 64 -> 128
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),  # 128 -> 256
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 256 -> 128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),  # 128 -> 64
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 64 * 64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize models and optimizers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)

criterion_G = nn.MSELoss()
criterion_D = nn.BCELoss()

import matplotlib.pyplot as plt
import torchvision.utils as vutils

def visualize_results(lr_imgs, sr_imgs, hr_imgs, num_images=4):
    lr_imgs = lr_imgs[:num_images]
    sr_imgs = sr_imgs[:num_images]
    hr_imgs = hr_imgs[:num_images]

    # Denormalize if your images are normalized (-1,1) or (0,1)
    def denorm(img):
        return img.clamp(0, 1)

    fig, axs = plt.subplots(num_images, 3, figsize=(10, num_images * 3))
    titles = ['Low-Res Input', 'Super-Resolved', 'High-Res Target']

    for i in range(num_images):
        images = [lr_imgs[i], sr_imgs[i], hr_imgs[i]]
        for j in range(3):
            axs[i, j].imshow(denorm(images[j].cpu().permute(1, 2, 0)).numpy())
            axs[i, j].set_title(titles[j])
            axs[i, j].axis('off')
    plt.tight_layout()
    plt.show()

lr_imgs, hr_imgs = next(iter(dataloader))
with torch.no_grad():
    sr_imgs = generator(lr_imgs.to(device))
    visualize_results(lr_imgs, sr_imgs.cpu(), hr_imgs)


# Training loop
for epoch in range(1):  # For quick testing
    for lr, hr in loader:
        lr, hr = lr.to(device), hr.to(device)

        # Generate fake high-res images
        sr = generator(lr)

        # Resize HR images to match SR output (256x256)
        hr_resized = F.interpolate(hr, size=sr.shape[-2:], mode='bilinear', align_corners=False)

        # Train Discriminator
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        real_output = discriminator(hr_resized)
        fake_output = discriminator(sr.detach())

        d_loss_real = criterion_D(real_output, real_labels)
        d_loss_fake = criterion_D(fake_output, fake_labels)
        d_loss = d_loss_real + d_loss_fake

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        fake_output = discriminator(sr)
        adv_loss = criterion_D(fake_output, real_labels)
        mse_loss = criterion_G(sr, hr_resized)
        g_loss = mse_loss + 1e-3 * adv_loss

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        print(f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")


# Visualize AFTER training
lr_imgs, hr_imgs = next(iter(dataloader))
with torch.no_grad():
    sr_imgs = generator(lr_imgs.to(device))
    visualize_results(lr_imgs, sr_imgs.cpu(), hr_imgs)