In [1]:
import os
import sys
import math
import argparse
from typing import List, Tuple

import cv2
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models


In [2]:
def extract_frames(video_path: str, max_frames: int = None) -> Tuple[List[np.ndarray], float]:
    """Extract frames (BGR) from video. Returns list of frames and fps."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Unable to open video {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    frames = []
    success, frame = cap.read()
    while success:
        frames.append(frame.copy())
        if max_frames and len(frames) >= max_frames:
            break
        success, frame = cap.read()
    cap.release()
    return frames, fps

In [3]:
class FrameEmbedder:
    def __init__(self, device='cpu'):
        self.device = torch.device(device)
        model = models.resnet18(pretrained=True)
        # remove final fc
        modules = list(model.children())[:-1]  # remove fc
        self.backbone = nn.Sequential(*modules).to(self.device).eval()
        self.transform = T.Compose([
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])

    @torch.no_grad()
    def embed(self, frames: List[np.ndarray], batch_size: int = 32) -> np.ndarray:
        """Return embeddings: (num_frames, embed_dim)"""
        embeddings = []
        for i in range(0, len(frames), batch_size):
            batch = frames[i:i+batch_size]
            tensors = [self.transform(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)).unsqueeze(0) for f in batch]
            x = torch.cat(tensors, dim=0).to(self.device)
            feats = self.backbone(x)  # [B, C, 1, 1]
            feats = feats.view(feats.size(0), -1)  # [B, C]
            embeddings.append(feats.cpu().numpy())
        embeddings = np.vstack(embeddings)
        return embeddings  # shape (N, C)


In [4]:
class ContextualPropagator(nn.Module):
    def __init__(self, feat_dim: int, kernel_size: int = 9, hidden_dim: int = 512):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv1d(in_channels=feat_dim, out_channels=hidden_dim,
                              kernel_size=kernel_size, padding=padding, bias=True)
        self.act = nn.ReLU()
        self.project_back = nn.Linear(hidden_dim, feat_dim)

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        """
        feats: [T, D]
        returns contextual_feats: [T, D]
        """
        x = feats.transpose(0, 1).unsqueeze(0)  # [1, D, T]
        y = self.conv(x)  # [1, H, T]
        y = self.act(y)
        y = y.squeeze(0).transpose(0, 1)  # [T, H]
        y = self.project_back(y)  # [T, D]
        # residual connection
        return feats + y

In [5]:
class FrameScorer(nn.Module):
    def __init__(self, feat_dim: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
            nn.Sigmoid()
        )

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        # feats: [T, D] -> scores: [T]
        scores = self.net(feats).squeeze(-1)
        return scores

In [6]:
def select_keyframes(scores: np.ndarray, max_keyframes: int, suppression_window: int = 15) -> List[int]:
    """
    Select top-k frames with simple temporal suppression:
    - pick highest score
    - suppress +/- suppression_window frames (set score=0)
    - repeat until k selected or no scores left
    """
    sc = scores.copy()
    selected = []
    for _ in range(min(max_keyframes, len(scores))):
        idx = int(np.argmax(sc))
        if sc[idx] <= 0:
            break
        selected.append(idx)
        left = max(0, idx - suppression_window)
        right = min(len(sc)-1, idx + suppression_window)
        sc[left:right+1] = 0.0
    selected.sort()
    return selected

In [None]:
def summarize_video(video_path: str,
                    out_path: str,
                    device: str = 'cpu',
                    max_frames: int = None,
                    max_keyframes: int = 20,
                    suppression_window: int = 15):
    # 1. Extract frames
    print("Extracting frames...")
    frames, fps = extract_frames(video_path, max_frames=max_frames)
    n = len(frames)
    print(f"Total frames: {n}, fps={fps}")

    # 2. Embeddings
    print("Computing embeddings...")
    embedder = FrameEmbedder(device=device)
    embeddings = embedder.embed(frames)  # (n, D)
    embeddings = embeddings.astype(np.float32)
    feat_dim = embeddings.shape[1]

    # 3. Contextual propagation
    print("Applying contextual propagation...")
    cp = ContextualPropagator(feat_dim=feat_dim)
    cp.to(device)
    cp.eval()
    with torch.no_grad():
        feats_t = torch.from_numpy(embeddings).to(device)  # [T, D]
        contextual_feats = cp(feats_t).cpu().numpy()

    # 4. Scoring
    print("Scoring frames...")
    scorer = FrameScorer(feat_dim=feat_dim)
    scorer.to(device)
    scorer.eval()
    with torch.no_grad():
        scores_t = scorer(torch.from_numpy(contextual_feats).to(device))
        scores = scores_t.cpu().numpy()  # [T,]

    # normalize scores to [0,1]
    scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)

    # 5. Keyframe selection
    print("Selecting keyframes...")
    selected_idxs = select_keyframes(scores, max_keyframes=max_keyframes,
                                     suppression_window=suppression_window)
    print(f"Selected {len(selected_idxs)} keyframes: {selected_idxs}")

    # 6. Summary generation: collect frames in temporal order
    summary_frames = [frames[i] for i in selected_idxs]

    # If you'd rather produce a short video with small segments around selected frames,
    # you can include a range around each selected index. This code only uses the single frame.

    # 7. Write summary video
    write_fps = min(5.0, fps)  # keep summary fps lower for quick viewing
    print(f"Writing summary to {out_path} (fps={write_fps}) ...")
    write_summary_video(summary_frames, out_path, fps=write_fps)
    print("Done.")


In [10]:
def main():
    parser = argparse.ArgumentParser(description="Simple video summarizer pipeline")
    parser.add_argument("input", help="Input video path")
    parser.add_argument("output", help="Output summary video path")
    parser.add_argument("--device", default="cpu", help="cpu or cuda")
    parser.add_argument("--max-frames", type=int, default=None, help="Max frames to read (for testing)")
    parser.add_argument("--k", type=int, default=20, help="Max keyframes to select")
    parser.add_argument("--suppress", type=int, default=15, help="Temporal suppression window (frames)")
    args = parser.parse_args()

    summarize_video(args.input, args.output,
                    device=args.device,
                    max_frames=args.max_frames,
                    max_keyframes=args.k,
                    suppression_window=args.suppress)


if __name__ == "__main__":
    main()

usage: ipykernel_launcher.py [-h] [--device DEVICE] [--max-frames MAX_FRAMES]
                             [--k K] [--suppress SUPPRESS]
                             input output
ipykernel_launcher.py: error: the following arguments are required: input, output


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [7]:
video_folder = "videos/"
output_root = "keyframes5/"
os.makedirs(output_root, exist_ok=True)

In [9]:
video_files = [f for f in os.listdir(video_folder) if f.endswith((".mp4", ".avi", ".mov"))]

for n,video_file in enumerate(video_files):
    video_path = os.path.join(video_folder, video_file)
    print(f"Processing {video_file}...")

    frames,fps = extract_frames(video_path)
    if len(frames) == 0:
        print(f"⚠️ Skipped {video_file} (too short)")
        continue

    print("Computing embeddings...")

    embedder = FrameEmbedder(device="cpu")
    embeddings = embedder.embed(frames) 
    embeddings = embeddings.astype(np.float32)
    feat_dim = embeddings.shape[1]

    print("Applying contextual propagation...")

    cp = ContextualPropagator(feat_dim=feat_dim)
    cp.to("cpu")
    cp.eval()
    with torch.no_grad():
        feats_t = torch.from_numpy(embeddings).to("cpu")  # [T, D]
        contextual_feats = cp(feats_t).cpu().numpy()

    print("Scoring frames...")

    scorer = FrameScorer(feat_dim=feat_dim)
    scorer.to("cpu")
    scorer.eval()
    with torch.no_grad():
        scores_t = scorer(torch.from_numpy(contextual_feats).to("cpu"))
        scores = scores_t.cpu().numpy()  # [T,]
    
    scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-8)

    print("Selecting keyframes...")

    selected_idxs = select_keyframes(scores, max_keyframes=int(0.15*len(frames)))

    base_name = os.path.splitext(video_file)[0]
    out_dir = os.path.join(output_root, base_name)
    os.makedirs(out_dir, exist_ok=True)
    
    for i,j in enumerate(selected_idxs):
        frame_bgr = cv2.cvtColor(frames[j], cv2.COLOR_RGB2BGR)
        out_path = os.path.join(out_dir, f"keyframe_{i+1}_frame{j}.jpg")
        cv2.imwrite(out_path, frame_bgr)
    
    print(f"✅ Saved {len(selected_idxs)} keyframes to {out_dir}")

    txt_path = os.path.join(out_dir, f"{base_name}_indices.txt")
    with open(txt_path, "w") as f:
        for idx in selected_idxs:
            f.write(str(idx) + "\n")
    print(f"🗒️ Saved indices to {txt_path}")


    

    
    

Processing Air_Force_One.mp4...
Computing embeddings...




Applying contextual propagation...
Scoring frames...
Selecting keyframes...
✅ Saved 220 keyframes to keyframes5/Air_Force_One
🗒️ Saved indices to keyframes5/Air_Force_One\Air_Force_One_indices.txt
Processing Base jumping.mp4...
Computing embeddings...


KeyboardInterrupt: 