In [34]:
import av
import numpy as np
import torch
import os
from transformers import AutoImageProcessor, TimesformerModel
import faiss
import pickle
from tqdm import tqdm
import traceback

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [36]:
model = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400").to(device).eval()
processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")


In [37]:
video_folder = "UCF101/train"
dim = model.config.hidden_size

In [38]:
faiss_indices = {
    "mean": faiss.IndexFlatL2(dim),
    "max": faiss.IndexFlatL2(dim),
    "cls": faiss.IndexFlatL2(dim)
}
id_maps = {"mean": {}, "max": {}, "cls": {}}

In [39]:
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    converted_len = int(clip_len * frame_sample_rate)
    if seg_len < converted_len:
        start_idx = 0
        end_idx = seg_len - 1
    else:
        end_idx = np.random.randint(converted_len, seg_len)
        start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


In [40]:
def read_video_pyav(video_path, indices):
    container = av.open(video_path)
    container.seek(0)
    frames = []
    start_index, end_index = indices[0], indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame.to_ndarray(format="rgb24"))
    container.close()
    return frames

In [41]:
def read_video_pyav(video_path, indices):
    container = av.open(video_path)
    container.seek(0)
    frames = []
    start_index, end_index = indices[0], indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame.to_ndarray(format="rgb24"))
    container.close()
    return frames

In [42]:
def get_video_embeddings_all_pools(frames):
    inputs = processor(images=frames, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)

    features = outputs.last_hidden_state  # (num_frames, seq_len, dim)

    mean_pool = features.mean(dim=1).mean(dim=0).cpu().numpy()
    max_pool = features.mean(dim=1).max(dim=0).values.cpu().numpy()
    cls_token = features[:, 0, :].mean(dim=0).cpu().numpy()

    return {
        "mean": mean_pool,
        "max": max_pool,
        "cls": cls_token
    }

In [43]:
video_paths = []
for root, _, files in os.walk(video_folder):
    for f in files:
        if f.endswith((".mp4", ".avi")):
            video_paths.append(os.path.join(root, f))


In [44]:
len(video_paths)

10055

In [45]:
success_count = 0  # Count of videos successfully processed
for i, video_path in tqdm(enumerate(video_paths), total=len(video_paths)):
    try:
        container = av.open(video_path)
        total_frames = container.streams.video[0].frames
        container.close()

        indices = sample_frame_indices(clip_len=8, frame_sample_rate=4, seg_len=total_frames)
        frames = read_video_pyav(video_path, indices)
        if len(frames) < 8:
            continue  # skip videos with insufficient frames

        embs = get_video_embeddings_all_pools(frames)

        for method in ["mean", "max", "cls"]:

            faiss_indices[method].add(np.expand_dims(embs[method], axis=0))
            id_maps[method][faiss_indices[method].ntotal - 1] = video_path
    except Exception as e:
            print(f"Error processing {video_path}: {repr(e)}")
            traceback.print_exc()
print(f"\n✅ Successfully processed {success_count} out of {len(video_paths)} videos.")

  return self.preprocess(images, **kwargs)
100%|██████████| 10055/10055 [1:01:51<00:00,  2.71it/s]


✅ Successfully processed 0 out of 10055 videos.





In [47]:
for method in ["mean", "max", "cls"]:
    faiss.write_index(faiss_indices[method], f"embeddings/video_embeddings_{method}.index")
    with open(f"embeddings/id_map_{method}.pkl", "wb") as f:
        pickle.dump(id_maps[method], f)

print("All indices saved.")

All indices saved.


FAISS index dimension: 768
Number of vectors in index: 10052
