In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Preprocessing the videos
# This code snippet will extract flow and delta vectors from the video. Which will be used further by the model.

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import os
import yaml
from tqdm.notebook import tqdm  # Use notebook version for Kaggle
from dotmap import DotMap
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
import torchvision.transforms.functional as TF
from torchvision import transforms as T
from einops import rearrange

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_video_frames(video_path, num_frames, resize_shape=(224, 224)):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames < 2:
        return None
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    frames = []
    for i in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame_tensor = TF.to_tensor(cv2.resize(frame, resize_shape))
            frames.append(frame_tensor)
        else:
            cap.release()
            return None
    cap.release()
    return torch.stack(frames) if len(frames) == num_frames else None

def compute_optical_flow(frames, raft_model, device, config):
    T_frames, C, H, W = frames.shape
    patch_size = config.model.video_patch_size
    num_patches_per_frame = (H // patch_size) * (W // patch_size)

    frames = frames.to(device) * 255.0
    frame_pairs_1 = frames[:-1]
    frame_pairs_2 = frames[1:]

    with torch.no_grad():
        flow_preds = raft_model(frame_pairs_1, frame_pairs_2)[-1]

    avg_flow_map = torch.mean(flow_preds, dim=0)
    avg_flow_per_patch = rearrange(avg_flow_map, 'c (ph p1) (pw p2) -> (ph pw) (c p1 p2)', p1=patch_size, p2=patch_size).mean(dim=-1)
    flow_dim = config.model.moe.experts.motion.flow_dim
    return F.pad(avg_flow_per_patch, (0, flow_dim - 2)).cpu()

def compute_frame_deltas(frames, config):
    num_patches_per_frame = (224 // config.model.video_patch_size) ** 2
    deltas = (frames[1:] - frames[:-1]).abs().mean()
    return torch.full((num_patches_per_frame, config.model.moe.experts.fast_change.delta_dim), deltas.item()).cpu()

if __name__ == '__main__':
    # Path adjustments for Kaggle
    CONFIG_PATH = '/kaggle/input/training-msvd/training_msvd.yaml'
    VIDEO_DIR = '/kaggle/input/youtube-clips-qna/YouTubeClips/YouTubeClips'
    OUTPUT_DIR = '/kaggle/working/features'

    # Create output folders
    FLOW_DIR = os.path.join(OUTPUT_DIR, "flow")
    DELTAS_DIR = os.path.join(OUTPUT_DIR, "deltas")
    os.makedirs(FLOW_DIR, exist_ok=True)
    os.makedirs(DELTAS_DIR, exist_ok=True)

    # Load config
    with open(CONFIG_PATH) as f:
        config = DotMap(yaml.safe_load(f))

    device = get_device()
    print(f"Using device: {device}")

    weights = Raft_Small_Weights.DEFAULT
    raft_model = raft_small(weights=weights).to(device)
    raft_model.eval()

    video_filenames = [f for f in os.listdir(VIDEO_DIR) if f.endswith('.avi')]
    print(f"Found {len(video_filenames)} AVI files to process.")

    for video_filename in tqdm(video_filenames, desc="Processing videos"):
        video_id_no_ext = os.path.splitext(video_filename)[0]
        flow_path = os.path.join(FLOW_DIR, f"{video_id_no_ext}.pt")
        delta_path = os.path.join(DELTAS_DIR, f"{video_id_no_ext}.pt")

        video_path = os.path.join(VIDEO_DIR, video_filename)
        frames = load_video_frames(video_path, config.model.frames_per_video)
        if frames is None:
            print(f"Warning: Could not load frames for {video_filename}. Skipping.")
            continue

        flow_vectors = compute_optical_flow(frames, raft_model, device, config)
        delta_vectors = compute_frame_deltas(frames, config)

        torch.save(flow_vectors, flow_path)
        torch.save(delta_vectors, delta_path)

    print("\nPre-computation of feature files complete.")
