In [3]:
import os
import time
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from math import inf

# ---------- Parameters ----------
VIDEO_PATH = "jumbled_video.mp4"
FRAMES_DIR = "frames_improved"
OUTPUT_VIDEO = "reconstructed_improved.mp4"
FPS = 30
IMG_SIZE = 224             # for CNN embedding
BEAM_WIDTH = 6             # larger -> slower but more robust (try 4..12)
DIRECTION_WEIGHT = 0.35    # how strongly direction influences cost (0..1)
# -------------------------------

def extract_frames(video_path, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    idx = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        cv2.imwrite(os.path.join(out_dir, f"frame_{idx:04d}.jpg"), frame)
        idx += 1
    cap.release()
    return idx

def get_resnet_feature_extractor(device):
    
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    modules = list(model.children())[:-1]  
    model = nn.Sequential(*modules)
    model.to(device).eval()
    return model

def extract_features_gpu(frames_dir, device):
    transform = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std =[0.229, 0.224, 0.225])
    ])
    files = sorted(os.listdir(frames_dir))
    feats = []
    model = get_resnet_feature_extractor(device)
    with torch.no_grad():
        for f in tqdm(files, desc="Extracting features"):
            img = Image.open(os.path.join(frames_dir, f)).convert("RGB")
            t = transform(img).unsqueeze(0).to(device)
            out = model(t)            # shape: (1, 512, 1, 1)
            out = out.squeeze().cpu().numpy().reshape(-1)
            feats.append(out)
    feats = np.vstack(feats)

    norms = np.linalg.norm(feats, axis=1, keepdims=True)
    feats = feats / (norms + 1e-10)
    return files, feats

def compute_pairwise_cosine_distance(feats):
   
    sims = feats @ feats.T
    dists = 1.0 - sims
    np.fill_diagonal(dists, inf)
    return dists

def phase_correlation_shift(grayA, grayB):
   
    a = np.float32(grayA)
    b = np.float32(grayB)
   
    hann = cv2.createHanningWindow(a.shape[::-1], cv2.CV_32F)
    a_win = a * hann
    b_win = b * hann
    shift, response = cv2.phaseCorrelate(a_win, b_win)
   
    return np.array(shift), response

def compute_direction_matrix(frames_dir, frame_files, use_downscale=(640,360)):
    n = len(frame_files)
    shifts = np.zeros((n, n, 2), dtype=np.float32)
    responses = np.zeros((n,n), dtype=np.float32)
  
    gray_list = []
    for f in frame_files:
        img = cv2.imread(os.path.join(frames_dir, f))
        h,w = img.shape[:2]
        if use_downscale:
            img = cv2.resize(img, use_downscale, interpolation=cv2.INTER_LINEAR)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray_list.append(gray)
    for i in tqdm(range(n), desc="Computing phase-correlation shifts"):
        a = gray_list[i]
        for j in range(n):
            if i == j:
                shifts[i,j] = (0.0,0.0)
                responses[i,j] = 0.0
            else:
                s, r = phase_correlation_shift(a, gray_list[j])
                shifts[i,j] = s
                responses[i,j] = r
    return shifts, responses

def infer_global_forward(shifts, responses, mask_threshold=0.0001):
   
    n = shifts.shape[0]
    flat_shifts = shifts.reshape(-1,2)
    flat_resp = responses.reshape(-1)
    
    sel = flat_resp > np.percentile(flat_resp, 60) 
    if sel.sum() == 0:
        sel = flat_resp > np.mean(flat_resp)
    if sel.sum() == 0:
       
        median = np.median(flat_shifts, axis=0)
    else:
        median = np.median(flat_shifts[sel], axis=0)
    
    if np.linalg.norm(median) < 0.3:
        median = np.array([1.0, 0.0])
    return median

def build_directed_cost_matrix(sim_dists, shifts, median_forward, responses, direction_weight):
   
    n = sim_dists.shape[0]
    cost = sim_dists.copy()
    
    med = median_forward / (np.linalg.norm(median_forward) + 1e-9)
    for i in range(n):
        for j in range(n):
            if i==j:
                cost[i,j] = inf
                continue
            shift = shifts[i,j]
            proj = float(np.dot(shift, med)) 
            resp = responses[i,j]
            
            dir_score = np.tanh(proj / (np.linalg.norm(shift)+1e-9)) * (resp / (resp + 1.0))
            
            cost[i,j] = sim_dists[i,j] * (1.0 - direction_weight * dir_score)
   
    np.fill_diagonal(cost, inf)
    return cost

def beam_search_order(cost_matrix, beam_width=6):
    n = cost_matrix.shape[0]
   
    row_sums = np.nan_to_num(cost_matrix.sum(axis=1), posinf=1e9)
    start_candidates = np.argsort(row_sums)[:min(6, n)]
    
    beams = []
    for s in start_candidates:
        beams.append((0.0, [s], set(range(n)) - {s}))
    for step in tqdm(range(n-1), desc="Beam search ordering"):
        new_beams = []
        for total, path, rem in beams:
            last = path[-1]
            
            options = sorted(list(rem), key=lambda j: cost_matrix[last,j])
          
            k = min(beam_width, len(options))
            for next_idx in options[:k]:
                new_total = total + float(cost_matrix[last, next_idx])
                new_path = path + [next_idx]
                new_rem = rem - {next_idx}
                new_beams.append((new_total, new_path, new_rem))
      
        new_beams.sort(key=lambda x: x[0])
        beams = [(c, p, r) for (c,p,r) in new_beams[:beam_width]]

    best = min(beams, key=lambda x: x[0])
    return best[1]

def write_video_from_order(frames_dir, frame_files, order, output_path, fps=30):
    first = cv2.imread(os.path.join(frames_dir, frame_files[order[0]]))
    h,w = first.shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (w,h))
    for idx in tqdm(order, desc="Writing output video"):
        frame = cv2.imread(os.path.join(frames_dir, frame_files[idx]))
        out.write(frame)
    out.release()

def main():
    t0 = time.time()
    print("1) Extracting frames...")
    n = extract_frames(VIDEO_PATH, FRAMES_DIR)
    print(f"  extracted {n} frames to {FRAMES_DIR}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"2) Extracting CNN features on device: {device}")
    frame_files, feats = extract_features_gpu(FRAMES_DIR, device)
    print("  features shape:", feats.shape)

    print("3) Pairwise similarity (cosine distance)...")
    sim_dists = compute_pairwise_cosine_distance(feats)

    print("4) Phase-correlation shifts (direction cues)...")
    shifts, responses = compute_direction_matrix(FRAMES_DIR, frame_files, use_downscale=(512,288))
    median_forward = infer_global_forward(shifts, responses)
    print("  inferred median forward shift:", median_forward)

    print("5) Building directed cost matrix...")
    cost = build_directed_cost_matrix(sim_dists, shifts, median_forward, responses, DIRECTION_WEIGHT)

    print("6) Reconstructing order with beam search...")
    order_indices = beam_search_order(cost, beam_width=BEAM_WIDTH)

    print("7) Writing reconstructed video...")
    write_video_from_order(FRAMES_DIR, frame_files, order_indices, OUTPUT_VIDEO, fps=FPS)

    elapsed = time.time() - t0
    print(f"\nDone — reconstructed saved to {OUTPUT_VIDEO} (time {elapsed:.2f}s)")
    with open("execution_time_log.txt", "w") as f:
        f.write(f"Execution Time: {elapsed:.2f} seconds\n")

if __name__ == "__main__":
    main()


1) Extracting frames...
  extracted 300 frames to frames_improved
2) Extracting CNN features on device: cuda


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\RITANKAR/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:03<00:00, 12.2MB/s]
Extracting features: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:16<00:00, 18.32it/s]


  features shape: (300, 512)
3) Pairwise similarity (cosine distance)...
4) Phase-correlation shifts (direction cues)...


Computing phase-correlation shifts: 100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [03:01<00:00,  1.65it/s]


  inferred median forward shift: [1. 0.]
5) Building directed cost matrix...
6) Reconstructing order with beam search...


Beam search ordering: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 299/299 [00:00<00:00, 4059.54it/s]


7) Writing reconstructed video...


Writing output video: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:11<00:00, 26.89it/s]


Done — reconstructed saved to reconstructed_improved.mp4 (time 236.82s)



