<a href="https://colab.research.google.com/github/KaushalKD279/Jumbled_frame_reconstruction/blob/main/reconstuction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ========================================
# CELL 1: Install dependencies & imports
# ========================================
!pip install -q opencv-python-headless torch torchvision tqdm scikit-learn gdown

import os
import time
import cv2
import numpy as np
from tqdm import tqdm
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import gdown
import warnings
from concurrent.futures import ThreadPoolExecutor
warnings.filterwarnings('ignore')
print("Imports done.")


[31mERROR: Operation cancelled by user[0m[31m
[0mImports done.


In [None]:
# ========================================
# CELL 2: Download jumbled video (Google Drive)
# ========================================
file_id = "1DbR9yap-vgUaPiI3hCEKUnniXr-TrdOt"
output = "jumbled_video.mp4"
print("Downloading jumbled video...")
gdown.download(f"https://drive.google.com/uc?id={file_id}", output, quiet=False)
print("Download complete:", output)


Downloading jumbled video...


Downloading...
From: https://drive.google.com/uc?id=1DbR9yap-vgUaPiI3hCEKUnniXr-TrdOt
To: /content/jumbled_video.mp4
100%|██████████| 90.6M/90.6M [00:01<00:00, 77.9MB/s]

Download complete: jumbled_video.mp4





In [None]:
# ========================================
# CELL 3: Optimized FrameReconstructor (batched features, GPU similarity, in-memory frames, limited 2-opt)
# ========================================
import math
from typing import List, Tuple

class FrameReconstructorOptimized:
    def __init__(self, device=None, use_deep_features=True, model_name='resnet50'):
        # pick device
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
        self.use_deep_features = use_deep_features
        self.model_name = model_name
        self.model = None
        self.transform = None
        if self.use_deep_features:
            self._init_feature_extractor()
        print(f"Initialized reconstructor on device={self.device}, deep_features={self.use_deep_features}")

    def _init_feature_extractor(self):
        # Use torchvision ResNet backbone without final FC: outputs (B, 2048, 1, 1)
        print("Initializing feature extractor:", self.model_name)
        if self.model_name == 'resnet50':
            base = models.resnet50(pretrained=True)
        elif self.model_name == 'resnet18':
            base = models.resnet18(pretrained=True)
        else:
            base = models.resnet50(pretrained=True)

        # strip fc layer
        self.model = torch.nn.Sequential(*list(base.children())[:-1]).to(self.device)
        self.model.eval()

        # transform pipeline
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        # warmup
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224).to(self.device)
            self.model(dummy)
        print("Feature extractor ready on", self.device)

    def extract_frames_in_memory(self, video_path: str) -> Tuple[List[np.ndarray], float, float]:
        """Decode video fully into a list of BGR numpy arrays (in-memory). Returns frames, fps, frame_count"""
        start = time.time()
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        frames = []
        frame_count = 0
        pbar = tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0), desc="Decoding frames", unit="f")
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)  # BGR numpy array
            frame_count += 1
            pbar.update(1)
        pbar.close()
        cap.release()
        elapsed = time.time() - start
        print(f"Decoded {frame_count} frames at {fps:.2f} FPS in {elapsed:.2f}s")
        return frames, fps, elapsed

    def _batchify_and_extract(self, frames: List[np.ndarray], batch_size: int = 32) -> np.ndarray:
        """
        Convert frames (BGR numpy arrays) into batched tensors and run through model.
        Returns numpy array shape (N, feat_dim).
        """
        N = len(frames)
        feats_list = []
        to_tensor = self.transform
        device = self.device

        # Preconvert to tensors to avoid redundant CPU->GPU ops
        # We'll build batches of tensors on CPU then stack and send to GPU
        tensors_cpu = []
        for img in frames:
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            t = to_tensor(img_rgb)
            tensors_cpu.append(t)

        with torch.no_grad():
            for i in tqdm(range(0, N, batch_size), desc="Extracting features (batched)"):
                batch = torch.stack(tensors_cpu[i:i+batch_size]).to(device)  # (B,3,224,224)
                out = self.model(batch)  # (B, feat, 1, 1)
                out = out.view(out.size(0), -1)  # (B, feat_dim)
                feats_list.append(out.cpu().numpy())

        feats = np.concatenate(feats_list, axis=0)
        return feats  # (N, feat_dim)

    def compute_features(self, frames: List[np.ndarray], batch_size: int = 32) -> Tuple[np.ndarray, float]:
        """Public: compute features for all frames and return (features, elapsed_time)"""
        t0 = time.time()
        if self.use_deep_features:
            feats = self._batchify_and_extract(frames, batch_size=batch_size)
        else:
            # fallback: color histograms (fast, CPU)
            feats = []
            for img in tqdm(frames, desc="Histogram features"):
                small = cv2.resize(img, (64, 64))
                hist = cv2.calcHist([small], [0,1,2], None, [8,8,8], [0,256,0,256,0,256])
                feats.append(hist.flatten())
            feats = np.array(feats)
        elapsed = time.time() - t0
        print(f"Feature extraction done: shape={feats.shape} time={elapsed:.2f}s")
        return feats, elapsed

    def build_similarity_matrix(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
        """
        Build cosine similarity matrix using torch on device (fast).
        Returns S as numpy array on CPU.
        """
        t0 = time.time()
        device = self.device
        with torch.no_grad():
            feats_t = torch.from_numpy(features).float().to(device)  # (N, dim)
            # normalize rows
            norms = feats_t.norm(dim=1, keepdim=True).clamp(min=1e-8)
            feats_t = feats_t / norms
            # compute cosine similarity via matrix multiplication
            S_t = torch.mm(feats_t, feats_t.t())  # (N,N) on device
            # clamp numeric noise
            S_t = S_t.clamp(-1.0, 1.0)
            S = S_t.cpu().numpy()
        elapsed = time.time() - t0
        print(f"Built similarity matrix (N={S.shape[0]}) in {elapsed:.2f}s")
        return S, elapsed

    def greedy_path_from_sim(self, S: np.ndarray, start_idx: int = None) -> Tuple[List[int], float]:
        t0 = time.time()
        N = S.shape[0]
        if start_idx is None:
            start_idx = int(np.argmax(np.sum(S, axis=1)))
        path = [start_idx]
        remaining = set(range(N)) - {start_idx}
        # greedy: always pick argmax similarity to current
        while remaining:
            cur = path[-1]
            rem_list = np.array(list(remaining))
            sims = S[cur, rem_list]
            # argmax
            arg = int(np.argmax(sims))
            nxt = int(rem_list[arg])
            path.append(nxt)
            remaining.remove(nxt)
        elapsed = time.time() - t0
        print(f"Greedy path computed in {elapsed:.2f}s")
        return path, elapsed

    def two_opt_limited(self, path: List[int], S: np.ndarray, max_iter: int = 200, top_k: int = 15) -> Tuple[List[int], float]:
        """
        2-opt with neighborhood limitation:
        - Precompute top_k neighbors for each node.
        - When considering swapping edges, only attempt k values where improvement is likely.
        Uses fast delta-cost computation for swaps.
        """
        t0 = time.time()
        N = len(path)
        # Precompute top_k neighbors (sorted descending)
        topk = np.argsort(-S, axis=1)[:, :top_k]  # (N, top_k)
        # position lookup
        pos = np.empty(N, dtype=int)
        for idx, node in enumerate(path):
            pos[node] = idx

        improved = True
        it = 0
        # Helper to compute delta for reversing segment (i..k)
        def delta_cost(a_idx, b_idx, c_idx, d_idx):
            # change when replacing edges (a->b) & (c->d) with (a->c) & (b->d)
            return (S[a_idx, c_idx] + S[b_idx, d_idx]) - (S[a_idx, b_idx] + S[c_idx, d_idx])

        while improved and it < max_iter:
            improved = False
            it += 1
            # iterate through nodes by position
            for i in range(1, N-2):
                a = path[i-1]
                b = path[i]
                # consider only c candidates that are among top neighbors of 'a' or 'b'
                candidate_nodes = set()
                candidate_nodes.update(topk[a])
                candidate_nodes.update(topk[b])
                # filter positions > i (k >= i+1)
                for c in candidate_nodes:
                    kpos = pos[c]
                    # we need kpos > i (we'll reverse b..c)
                    if not (kpos > i):
                        continue
                    c_node = path[kpos]
                    d_node = path[kpos+1] if (kpos+1 < N) else None
                    if d_node is None:
                        continue
                    # compute improvement delta
                    dlt = delta_cost(a, b, c_node, d_node)
                    if dlt > 1e-6:  # positive dlt means similarity increases (good)
                        # perform reversal between i and kpos
                        path[i:kpos+1] = path[i:kpos+1][::-1]
                        # update positions
                        for j in range(i, kpos+1):
                            pos[path[j]] = j
                        improved = True
                        break
                if improved:
                    break
        elapsed = time.time() - t0
        print(f"2-opt limited finished in {it} iterations, time={elapsed:.2f}s")
        return path, elapsed

    def write_ordered_video_in_memory(self, path_order: List[int], frames: List[np.ndarray], output_path: str, fps: float):
        """Write the reordered frames to an mp4 file using VideoWriter (frames are in memory)"""
        t0 = time.time()
        if len(path_order) == 0:
            raise ValueError("Empty path order.")
        first_frame = frames[path_order[0]]
        h, w = first_frame.shape[:2]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
        for idx in tqdm(path_order, desc="Writing video"):
            frame = frames[idx]
            out.write(frame)
        out.release()
        elapsed = time.time() - t0
        print(f"Wrote video to {output_path} in {elapsed:.2f}s")
        return elapsed

    def reconstruct(self, video_path: str, output_path: str = 'reconstructed.mp4', batch_size:int = 32, two_opt_iters:int = 200, top_k:int = 15):
        """Complete pipeline returning path and timings dict"""
        timings = {}
        total_t0 = time.time()

        # 1. Decode frames into memory
        frames, fps, t_decode = self.extract_frames_in_memory(video_path)
        timings['decode'] = t_decode
        N = len(frames)
        if N == 0:
            raise ValueError("No frames decoded.")

        # 2. Feature extraction (batched)
        feats, t_feat = self.compute_features(frames, batch_size=batch_size)
        timings['feature_extraction'] = t_feat

        # 3. Similarity matrix (GPU then to CPU)
        S, t_sim = self.build_similarity_matrix(feats)
        timings['similarity_matrix'] = t_sim

        # 4. Greedy initial path
        path, t_greedy = self.greedy_path_from_sim(S)
        timings['greedy_path'] = t_greedy

        # 5. 2-opt limited
        path, t_twoopt = self.two_opt_limited(path, S, max_iter=two_opt_iters, top_k=top_k)
        timings['two_opt'] = t_twoopt

        # 6. Write reordered video
        t_write = self.write_ordered_video_in_memory(path, frames, output_path, fps=fps)
        timings['video_writing'] = t_write

        timings['total'] = time.time() - total_t0
        print("Timings:", timings)
        return path, timings


In [None]:
# ========================================
# CELL 4: Run the optimized reconstruction
# ========================================
# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)
if device == 'cuda':
    try:
        print("GPU:", torch.cuda.get_device_name(0))
    except Exception:
        pass

# initialize reconstructor (choose resnet18 for speed if you want)
reconstructor = FrameReconstructorOptimized(device=device, use_deep_features=True, model_name='resnet50')

# run reconstruction (tune batch_size, two_opt_iters, top_k for speed/accuracy tradeoff)
path, timings = reconstructor.reconstruct(
    video_path='jumbled_video.mp4',
    output_path='reconstructed.mp4',
    batch_size=32,        # larger batch_size -> better GPU utilization
    two_opt_iters=200,    # reduce iterations for faster runs
    top_k=12              # smaller top_k -> faster but possibly slightly worse
)

print("\n✅ Reconstruction finished. Output:", 'reconstructed.mp4')
print("Timings summary:")
for k,v in timings.items():
    print(f"  {k:20s} : {v:.2f}s")


Using device: cpu
Initializing feature extractor: resnet50
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 190MB/s]


Feature extractor ready on cpu
Initialized reconstructor on device=cpu, deep_features=True


Decoding frames: 100%|██████████| 300/300 [00:03<00:00, 93.08f/s]


Decoded 300 frames at 30.00 FPS in 3.29s


Extracting features (batched): 100%|██████████| 10/10 [01:14<00:00,  7.50s/it]


Feature extraction done: shape=(300, 2048) time=80.02s
Built similarity matrix (N=300) in 0.01s
Greedy path computed in 0.01s
2-opt limited finished in 37 iterations, time=0.11s


Writing video: 100%|██████████| 300/300 [00:07<00:00, 38.21it/s]

Wrote video to reconstructed.mp4 in 7.86s
Timings: {'decode': 3.2944602966308594, 'feature_extraction': 80.02252984046936, 'similarity_matrix': 0.011424064636230469, 'greedy_path': 0.006487131118774414, 'two_opt': 0.11117029190063477, 'video_writing': 7.860515356063843, 'total': 91.30696988105774}

✅ Reconstruction finished. Output: reconstructed.mp4
Timings summary:
  decode               : 3.29s
  feature_extraction   : 80.02s
  similarity_matrix    : 0.01s
  greedy_path          : 0.01s
  two_opt              : 0.11s
  video_writing        : 7.86s
  total                : 91.31s





In [None]:
# ========================================
# CELL 5: Download result (Colab helper)
# ========================================
from google.colab import files
print("Preparing download...")
files.download('reconstructed.mp4')
print("Download started.")


Preparing download...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download started.
