In [7]:
import numpy as np 
import torch
from transnetv2_pytorch import TransNetV2
from tqdm import tqdm
import cv2

In [35]:
import warnings 
warnings.filterwarnings('ignore')

#### Step 1: Extract the keyframes

In [8]:
# Load model 
ckpt_file_path = 'transnetv2-pytorch-weights.pth'
state_dict = torch.load(ckpt_file_path)
model = TransNetV2() 
model.load_state_dict(state_dict)
model.eval().cuda() 

TransNetV2(
  (SDDCNN): ModuleList(
    (0): StackedDDCNNV2(
      (DDCNN): ModuleList(
        (0): DilatedDCNNV2(
          (Conv3D_1): Conv3DConfigurable(
            (layers): ModuleList(
              (0): Conv3d(3, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
              (1): Conv3d(32, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
            )
          )
          (Conv3D_2): Conv3DConfigurable(
            (layers): ModuleList(
              (0): Conv3d(3, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
              (1): Conv3d(32, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), dilation=(2, 1, 1), bias=False)
            )
          )
          (Conv3D_4): Conv3DConfigurable(
            (layers): ModuleList(
              (0): Conv3d(3, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
              (1): Conv3d(32, 16, kernel_size=(3, 1, 1), 

In [41]:
def detect_scene_boundary(model, video_path, threshold= 0.5, CHUNK_SIZE= 512, FRAME_SIZE=(48,27)): 
    frames = [] 
    scene_scores = [] 
    boundaries = [] 

    # Resize the frames 
    cap = cv2.VideoCapture(video_path)
    while True: 
        ret, frame = cap.read() 
        if not ret: 
            break 
        frame = cv2.resize(frame, FRAME_SIZE)
        frames.append(frame)
    total_frames = len(frames)
    cap.release() 

    # Predict 
    for i in range(0, len(frames), CHUNK_SIZE): 
        chunk = frames[i : i + CHUNK_SIZE]
        chunk_np = np.stack(chunk, axis= 0)
        chunk_tensor = torch.from_numpy(chunk_np).unsqueeze(0).cuda() 

        with torch.no_grad(): 
            output = model(chunk_tensor)[0]
            output = torch.sigmoid(output)
            scene_scores.append(output[0, :, 0].cpu().numpy())
    scene_scores = np.concatenate(scene_scores, axis=0)

    # Find boudaries 
    boundaries = [
        i for i in range(1, len(scene_scores))
        if scene_scores[i - 1] < threshold and scene_scores[i] >= threshold
    ]

    scene_ranges = [0] + boundaries + [total_frames]
    scenes = [
        (scene_ranges[i], scene_ranges[i + 1])
        for i in range(len(scene_ranges) - 1)
    ]

    return scenes

In [None]:
video_path = 'videos/L01_V030.mp4'
scenes = detect_scene_boundary(model, video_path)

[(0, 47),
 (47, 436),
 (436, 497),
 (497, 535),
 (535, 580),
 (580, 624),
 (624, 663),
 (663, 704),
 (704, 748),
 (748, 835),
 (835, 871),
 (871, 1284),
 (1284, 1373),
 (1373, 1439),
 (1439, 1518),
 (1518, 1690),
 (1690, 1764),
 (1764, 1834),
 (1834, 1901),
 (1901, 1966),
 (1966, 2018),
 (2018, 2086),
 (2086, 2152),
 (2152, 2205),
 (2205, 2262),
 (2262, 2330),
 (2330, 2387),
 (2387, 2427),
 (2427, 2983),
 (2983, 3141),
 (3141, 3266),
 (3266, 3940),
 (3940, 4162),
 (4162, 4333),
 (4333, 4500),
 (4500, 4815),
 (4815, 5023),
 (5023, 5106),
 (5106, 5225),
 (5225, 5266),
 (5266, 5355),
 (5355, 5422),
 (5422, 5487),
 (5487, 5559),
 (5559, 5603),
 (5603, 5672),
 (5672, 6168),
 (6168, 6255),
 (6255, 6828),
 (6828, 6896),
 (6896, 7013),
 (7013, 7042),
 (7042, 7079),
 (7079, 7136),
 (7136, 7174),
 (7174, 7205),
 (7205, 7237),
 (7237, 7278),
 (7278, 7308),
 (7308, 7481),
 (7481, 7508),
 (7508, 7606),
 (7606, 7698),
 (7698, 7822),
 (7822, 7874),
 (7874, 7919),
 (7919, 7963),
 (7963, 8026),
 (8026,

#### Visualization

In [None]:
import matplotlib.pyplot as plt
import cv2

cap = cv2.VideoCapture(video_path)
original_frames = []

while True:
    ret, frame = cap.read()
    if not ret:
        break
    original_frames.append(frame)

cap.release()
plt.figure(figsize=(16, 2 * len(scenes) // 5))

for i in range(min(len(scenes), 10)):  # display first 10 scenes
    start, end = scenes[i]
    mid = (start + end) // 2
    frame = original_frames[mid][:, :, ::-1]  # Convert BGR to RGB

    plt.subplot(2, 5, i + 1)
    plt.imshow(frame)
    plt.title(f"Scene {i} ({start}-{end})")
    plt.axis("off")

plt.tight_layout()
plt.show()
