In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
import zipfile
import os

gopro_zip_path = '/content/drive/MyDrive/GoPro.zip'  # replace path of the dataset
gopro_extract_path = '/content/gopro_dataset/'

# Unzipping GoPro
with zipfile.ZipFile(gopro_zip_path, 'r') as zip_ref:
    zip_ref.extractall(gopro_extract_path)

print(" GoPro dataset unzipped to:", gopro_extract_path)


In [None]:
import os
import shutil
import random

def split_train_val(data_dir, val_ratio=0.1, seed=42):
    blur_dir = os.path.join(data_dir, 'train', 'blur')
    gt_dir = os.path.join(data_dir, 'train', 'GT')
    val_blur_dir = os.path.join(data_dir, 'val', 'blur')
    val_gt_dir = os.path.join(data_dir, 'val', 'GT')

    os.makedirs(val_blur_dir, exist_ok=True)
    os.makedirs(val_gt_dir, exist_ok=True)

    scenes = sorted(os.listdir(blur_dir))
    random.seed(seed)
    val_scenes = random.sample(scenes, int(len(scenes) * val_ratio))

    for scene in val_scenes:
        shutil.move(os.path.join(blur_dir, scene), os.path.join(val_blur_dir, scene))
        shutil.move(os.path.join(gt_dir, scene), os.path.join(val_gt_dir, scene))

    print(f"Moved {len(val_scenes)} scenes to validation set.")

# Example usage
split_train_val('/content/gopro_dataset/GoPro', val_ratio=0.1)


Moved 2 scenes to validation set.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleTemporalUNet(nn.Module):
    def __init__(self, in_channels=9, out_channels=3, base_filters=32):
        super(SimpleTemporalUNet, self).__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )

        # Encoder path
        self.enc1 = conv_block(in_channels, base_filters)       # 9 -> 32
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = conv_block(base_filters, base_filters*2)    # 32 -> 64
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = conv_block(base_filters*2, base_filters*4)  # 64 -> 128
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = conv_block(base_filters*4, base_filters*8)  # 128 -> 256

        # Bottleneck
        self.bottleneck = conv_block(base_filters*8, base_filters*16)  # 256 -> 512

        # Decoder path
        self.up4 = nn.ConvTranspose2d(base_filters*16, base_filters*8, kernel_size=2, stride=2)
        self.dec4 = conv_block(base_filters*16, base_filters*8)

        self.up3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2)
        self.dec3 = conv_block(base_filters*8, base_filters*4)

        self.up2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.dec2 = conv_block(base_filters*4, base_filters*2)

        self.up1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.dec1 = conv_block(base_filters*2, base_filters)

        # Final conv (to RGB output)
        self.final_conv = nn.Conv2d(base_filters, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)        # [B, 32, H, W]
        p1 = self.pool1(e1)      # Downsample

        e2 = self.enc2(p1)       # [B, 64, H/2, W/2]
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)       # [B, 128, H/4, W/4]
        p3 = self.pool3(e3)

        e4 = self.enc4(p3)       # [B, 256, H/8, W/8]

        # Bottleneck
        b = self.bottleneck(F.max_pool2d(e4, 2))  # [B, 512, H/16, W/16]

        # Decoder
        d4 = self.up4(b)                               # Upsample
        d4 = torch.cat([d4, e4], dim=1)               # Skip connection
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.final_conv(d1)
        out = torch.sigmoid(out)  # Normalize output to [0,1]

        return out

In [None]:
import os
import glob
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms

class VideoTripletDataset(Dataset):
    def __init__(self, blur_dir, gt_dir):
        self.blur_dir = blur_dir
        self.gt_dir = gt_dir
        self.samples = []
        self.transform = transforms.ToTensor()

        for scene in sorted(os.listdir(blur_dir)):
            blur_scene_path = os.path.join(blur_dir, scene)
            gt_scene_path = os.path.join(gt_dir, scene)

            if not os.path.isdir(blur_scene_path) or not os.path.isdir(gt_scene_path):
                continue

            blur_images = sorted(glob.glob(os.path.join(blur_scene_path, "*.png")))

            for i in range(1, len(blur_images) - 1):
                center_name = os.path.basename(blur_images[i])
                gt_path = os.path.join(gt_scene_path, center_name)

                if os.path.exists(gt_path):
                    triplet = [blur_images[i-1], blur_images[i], blur_images[i+1]]
                    self.samples.append((triplet, gt_path))

    def __len__(self):
        return len(self.samples)

    def _read_rgb_image(self, path):
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        if img is None:
            raise IOError(f"Failed to load image: {path}")
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def __getitem__(self, idx):
        blur_paths, gt_path = self.samples[idx]

        # Read triplet + GT
        blurs = [self._read_rgb_image(p) for p in blur_paths]
        gt = self._read_rgb_image(gt_path)

        # Minimal consistent center crop to align sizes
        h = min([img.shape[0] for img in blurs] + [gt.shape[0]])
        w = min([img.shape[1] for img in blurs] + [gt.shape[1]])
        h -= h % 16  # Optional: ensure divisible by 16 for UNet
        w -= w % 16

        def crop_center(img, h, w):
            ch, cw = img.shape[:2]
            start_y = (ch - h) // 2
            start_x = (cw - w) // 2
            return img[start_y:start_y+h, start_x:start_x+w]

        blurs = [crop_center(img, h, w) for img in blurs]
        gt = crop_center(gt, h, w)

        # Stack 3 frames into 9-channel input
        blur_tensor = torch.cat([self.transform(img) for img in blurs], dim=0)  # Shape: [9, H, W]
        gt_tensor = self.transform(gt)  # Shape: [3, H, W]

        return blur_tensor.float(), gt_tensor.float()


In [None]:
from torch.utils.data import DataLoader

def get_loaders(train_blur, train_gt, val_blur, val_gt, batch_size=4, num_workers=2):
    """
    Returns PyTorch DataLoaders for training and validation datasets.

    Args:
        train_blur (str): Path to training blurred frames.
        train_gt (str): Path to training ground truth frames.
        val_blur (str): Path to validation blurred frames.
        val_gt (str): Path to validation ground truth frames.
        batch_size (int): Batch size for training loader.
        num_workers (int): Number of subprocesses for data loading.

    Returns:
        tuple: (train_loader, val_loader)
    """
    train_dataset = VideoTripletDataset(train_blur, train_gt)
    val_dataset = VideoTripletDataset(val_blur, val_gt)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True  # Ensures batch size consistency for training
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader


In [None]:
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4, device='cuda',
                drive_path='/content/drive/MyDrive/model_checkpoints', model_name='best_model.pth',
                use_amp=True):

    os.makedirs(drive_path, exist_ok=True)
    best_model_path = os.path.join(drive_path, model_name)

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_val_loss = float('inf')
    best_epoch = -1

    # Resume training from saved model
    if os.path.exists(best_model_path):
        print(f"Loading best model from {best_model_path}...")
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        model.eval()

    for epoch in range(num_epochs):
        torch.cuda.empty_cache()
        model.train()
        train_loss = 0.0

        for inputs, targets in tqdm(train_loader, desc=f"[Epoch {epoch+1}/{num_epochs}] Training"):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            with torch.autocast(device_type=device, enabled=use_amp):
                outputs = model(inputs)
                loss = F.mse_loss(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Avg Train Loss: {avg_train_loss:.6f}")

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                with torch.autocast(device_type=device, enabled=use_amp):
                    outputs = model(inputs)
                    loss = F.mse_loss(outputs, targets)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} - Avg Val Loss: {avg_val_loss:.6f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved at epoch {best_epoch} with val loss: {best_val_loss:.6f}")

        torch.cuda.empty_cache()

    print(f"Training complete. Best epoch: {best_epoch} with val loss: {best_val_loss:.6f}")


In [None]:
import torch

# Dataset paths
train_blur = '/content/gopro_dataset/GoPro/train/blur'
train_gt = '/content/gopro_dataset/GoPro/train/GT'
val_blur = '/content/gopro_dataset/GoPro/val/blur'
val_gt = '/content/gopro_dataset/GoPro/val/GT'

# Training settings
batch_size = 4
num_epochs = 60
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_path = '/content/drive/MyDrive/model_checkpoints'
model_name = 'best_model.pth'

# Load data
train_loader, val_loader = get_loaders(
    train_blur, train_gt,
    val_blur, val_gt,
    batch_size=batch_size,
    num_workers=2
)

# Initialize and train model
model = SimpleTemporalUNet()
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=num_epochs,
    lr=learning_rate,
    device=device,
    drive_path=checkpoint_path,
    model_name=model_name,
    use_amp=True  # Optional: toggle mixed precision training
)
