In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleTemporalUNet(nn.Module):
    def __init__(self, in_channels=9, out_channels=3, base_filters=32):
        super(SimpleTemporalUNet, self).__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )

        # Encoder path
        self.enc1 = conv_block(in_channels, base_filters)       # 9 -> 32
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = conv_block(base_filters, base_filters*2)    # 32 -> 64
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = conv_block(base_filters*2, base_filters*4)  # 64 -> 128
        self.pool3 = nn.MaxPool2d(2)

        self.enc4 = conv_block(base_filters*4, base_filters*8)  # 128 -> 256

        # Bottleneck
        self.bottleneck = conv_block(base_filters*8, base_filters*16)  # 256 -> 512

        # Decoder path
        self.up4 = nn.ConvTranspose2d(base_filters*16, base_filters*8, kernel_size=2, stride=2)
        self.dec4 = conv_block(base_filters*16, base_filters*8)

        self.up3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2)
        self.dec3 = conv_block(base_filters*8, base_filters*4)

        self.up2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.dec2 = conv_block(base_filters*4, base_filters*2)

        self.up1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.dec1 = conv_block(base_filters*2, base_filters)

        # Final conv (to RGB output)
        self.final_conv = nn.Conv2d(base_filters, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)        # [B, 32, H, W]
        p1 = self.pool1(e1)      # Downsample

        e2 = self.enc2(p1)       # [B, 64, H/2, W/2]
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)       # [B, 128, H/4, W/4]
        p3 = self.pool3(e3)

        e4 = self.enc4(p3)       # [B, 256, H/8, W/8]

        # Bottleneck
        b = self.bottleneck(F.max_pool2d(e4, 2))  # [B, 512, H/16, W/16]

        # Decoder
        d4 = self.up4(b)                               # Upsample
        d4 = torch.cat([d4, e4], dim=1)               # Skip connection
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        out = self.final_conv(d1)
        out = torch.sigmoid(out)  # Normalize output to [0,1]

        return out

In [3]:
import os
import glob
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms

class VideoTripletDataset(Dataset):
    def __init__(self, blur_dir, gt_dir):
        self.blur_dir = blur_dir
        self.gt_dir = gt_dir
        self.samples = []
        self.transform = transforms.ToTensor()

        for scene in sorted(os.listdir(blur_dir)):
            blur_scene_path = os.path.join(blur_dir, scene)
            gt_scene_path = os.path.join(gt_dir, scene)

            if not os.path.isdir(blur_scene_path) or not os.path.isdir(gt_scene_path):
                continue

            blur_images = sorted(glob.glob(os.path.join(blur_scene_path, "*.png")))

            for i in range(1, len(blur_images) - 1):
                center_name = os.path.basename(blur_images[i])
                gt_path = os.path.join(gt_scene_path, center_name)

                if os.path.exists(gt_path):
                    triplet = [blur_images[i-1], blur_images[i], blur_images[i+1]]
                    self.samples.append((triplet, gt_path))

    def __len__(self):
        return len(self.samples)

    def _read_rgb_image(self, path):
        img = cv2.imread(path, cv2.IMREAD_COLOR)
        if img is None:
            raise IOError(f"Failed to load image: {path}")
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def __getitem__(self, idx):
        blur_paths, gt_path = self.samples[idx]

        # Read triplet + GT
        blurs = [self._read_rgb_image(p) for p in blur_paths]
        gt = self._read_rgb_image(gt_path)

        # Minimal consistent center crop to align sizes
        h = min([img.shape[0] for img in blurs] + [gt.shape[0]])
        w = min([img.shape[1] for img in blurs] + [gt.shape[1]])
        h -= h % 16  # Optional: ensure divisible by 16 for UNet
        w -= w % 16

        def crop_center(img, h, w):
            ch, cw = img.shape[:2]
            start_y = (ch - h) // 2
            start_x = (cw - w) // 2
            return img[start_y:start_y+h, start_x:start_x+w]

        blurs = [crop_center(img, h, w) for img in blurs]
        gt = crop_center(gt, h, w)

        # Stack 3 frames into 9-channel input
        blur_tensor = torch.cat([self.transform(img) for img in blurs], dim=0)  # Shape: [9, H, W]
        gt_tensor = self.transform(gt)  # Shape: [3, H, W]

        return blur_tensor.float(), gt_tensor.float()


In [4]:
import cv2
import numpy as np
import torch
from torchvision.transforms.functional import to_tensor
from google.colab.patches import cv2_imshow
from IPython.display import HTML, display
from base64 import b64encode
from tqdm import tqdm
from google.colab import files

# Load your model and weights
model = SimpleTemporalUNet()  # replace with your model initialization
model.load_state_dict(torch.load('/content/drive/MyDrive/model_checkpoints/best_model.pth', map_location='cpu'))  # adjust path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

def video_to_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()
    return frames

def frames_to_video(frames, output_path, fps=30):
    height, width, _ = frames[0].shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    for f in frames:
        # Convert RGB to BGR for saving
        f_bgr = cv2.cvtColor(f, cv2.COLOR_RGB2BGR)
        out.write(f_bgr)
    out.release()

def pad_frames(frames, pad=1):
    # Replicate first and last frames for temporal padding
    return [frames[0]]*pad + frames + [frames[-1]]*pad

def prepare_input(frames, idx):
    # Stack frames idx-1, idx, idx+1 along channel dimension
    prev_frame = to_tensor(frames[idx-1])
    curr_frame = to_tensor(frames[idx])
    next_frame = to_tensor(frames[idx+1])
    input_tensor = torch.cat([prev_frame, curr_frame, next_frame], dim=0)
    return input_tensor.unsqueeze(0).to(device)  # add batch dimension

def deblur_video(video_path, output_path):
    frames = video_to_frames(video_path)
    if len(frames) < 3:
        raise ValueError("Video must have at least 3 frames for temporal input")

    fps = cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FPS)
    padded_frames = pad_frames(frames, pad=1)

    restored_frames = []
    with torch.no_grad():
        for i in tqdm(range(1, len(padded_frames) - 1), desc="Deblurring frames"):
            input_tensor = prepare_input(padded_frames, i)
            output = model(input_tensor)
            output = output.squeeze(0).cpu().clamp(0, 1)  # (3, H, W)
            output_np = output.permute(1, 2, 0).numpy()  # H, W, 3
            output_img = (output_np * 255).astype(np.uint8)
            restored_frames.append(output_img)

    frames_to_video(restored_frames, output_path, fps=int(fps))
    print(f"Restored video saved to: {output_path}")

    # Display output video inline in Colab
    mp4 = open(output_path, 'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML(f"""
    <video width=640 controls>
        <source src="{data_url}" type="video/mp4">
    </video>
    """)

# --- Usage in Colab ---

uploaded = files.upload()
for fn in uploaded.keys():
    print(f'Processing {fn} ...')
    output_html = deblur_video(fn, 'restored_output.mp4')

display(output_html)


Saving GOPR0384_11_00.mp4 to GOPR0384_11_00.mp4
Processing GOPR0384_11_00.mp4 ...


Deblurring frames: 100%|██████████| 100/100 [00:15<00:00,  6.57it/s]


Restored video saved to: restored_output.mp4


In [5]:
from google.colab import files
restored="/content/restored_output.mp4"
files.download(restored)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>