In [1]:
NUM_PARTS     = 9
PART_ID       = 2

In [2]:
!pip install faiss-cpu torchscale av

Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Collecting torchscale
  Downloading torchscale-0.3.0-py3-none-any.whl.metadata (11 kB)
Collecting av
  Downloading av-16.0.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting fairscale==0.4.0 (from torchscale)
  Downloading fairscale-0.4.0.tar.gz (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting timm==0.6.13 (from torchscale)
  Downloading timm-0.6.13-py3-none-any.whl.metadata (38 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->torchscale)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none

In [3]:
import os, json, glob
from queue import Queue
from threading import Thread, Lock

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import av
import faiss
from tqdm import tqdm
from PIL import Image
from natsort import natsorted

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_ as trunc_normal_fn
from huggingface_hub import hf_hub_download

INPUT_JSON_DIR = "/kaggle/input/irscene/SceneIR"
OUTPUT_DIR     = "/kaggle/working/keyframes_out"

SAVE_FORMAT    = "webp"
SAVE_QUALITY   = 90
OUTPUT_SIZE    = (640, 360)

IMG_SIZE       = 384
BATCH_SIZE     = 128

THRESH_KF      = 0.95    
THRESH_PREV    = 0.95   
STEP_CAND      = 8       
MIN_SCENE_GAP  = 40    
BEIT3_REPO    = "Quintu/beit3"
BEIT3_CKPT    = "beit3_large_patch16_384_coco_retrieval.pth"

DATA_STORE = []
lock = Lock()

def ensure(p): os.makedirs(p, exist_ok=True)
class BEiT3ForRetrieval(nn.Module):
    def __init__(self, args):
        super().__init__()
        from torchscale.model.BEiT3 import BEiT3
        self.beit3 = BEiT3(args)
        d = args.encoder_embed_dim
        self.vision_head = nn.Linear(d, d, bias=False)
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_fn(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    @torch.no_grad()
    def forward(self, x):
        out = self.beit3(textual_tokens=None, visual_tokens=x, text_padding_position=None)
        v = self.vision_head(out["encoder_out"][:, 0])
        return F.normalize(v, dim=-1)

def load_model(device):
    from torchscale.architecture.config import EncoderConfig
    
    args = EncoderConfig(
        img_size=IMG_SIZE,
        patch_size=16,
        vocab_size=64010,
        multiway=True,
        layernorm_embedding=False,
        normalize_output=True,
        no_output_layer=True,
        drop_path_rate=0.0,
        encoder_embed_dim=1024,
        encoder_attention_heads=16,
        encoder_ffn_embed_dim=4096,
        encoder_layers=24,
        checkpoint_activations=None,
    )
    
    ckpt = hf_hub_download(BEIT3_REPO, BEIT3_CKPT)
    model = BEiT3ForRetrieval(args).to(device)
    state = torch.load(ckpt, map_location="cpu")
    state_dict = state.get("model", state)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    mean = torch.tensor(IMAGENET_DEFAULT_MEAN, device=device).view(1,3,1,1)
    std  = torch.tensor(IMAGENET_DEFAULT_STD,  device=device).view(1,3,1,1)
    return model,mean,std

def grab_frames(video, frame_list):
    out = {}
    container = av.open(video, options={"hwaccel":"cuda","hwaccel_output_format":"cuda"})
    stream = container.streams.video[0]

    fps = float(stream.average_rate) if stream.average_rate else float(stream.rate)
    tb = stream.time_base
    
    max_fr = max(frame_list)
    min_fr = min(frame_list)
    want = set(frame_list)
    
    if min_fr > 0:
        seek_pts = int((min_fr / fps) / float(tb))
        container.seek(seek_pts, any_frame=False, backward=True, stream=stream)
    
    for frame in container.decode(stream):
        if frame.pts is None:
            continue
        actual_frame = int(round(frame.pts * float(tb) * fps))
        
        if actual_frame > max_fr:
            break
        
        if actual_frame in want:
            out[actual_frame] = frame.to_ndarray(format="rgb24")
            if len(out) == len(want):
                break

    container.close()
    return out



def sim(a,b):return float(np.dot(a,b))
def process_scene(job, state):
    video,s,e,out_dir = job
    ensure(out_dir)

    frame_order = list(range(s, e+1, STEP_CAND))
    frames = grab_frames(video, frame_order)
    frame_order = [f for f in frame_order if f in frames]
    if len(frame_order) == 0:
        return

    model,mean,std,device = state["model"],state["mean"],state["std"],state["device"]

    imgs = [frames[f] for f in frame_order]
    embs=[]
    for i in range(0,len(imgs),BATCH_SIZE):
        x = torch.from_numpy(np.stack(imgs[i:i+BATCH_SIZE])).permute(0,3,1,2).float().to(device)/255
        x = F.interpolate(x, size=(IMG_SIZE, IMG_SIZE), mode='bicubic', align_corners=False, antialias=True)
        x = (x-mean)/std
        with torch.no_grad():
            v = model(x).cpu().numpy()
        embs.append(v)
    embs = np.concatenate(embs,0)

    kf = [0]  
    for i in range(1,len(frame_order)):
        f_i = frame_order[i]
        last_frame = frame_order[kf[-1]]

        if f_i - last_frame >= MIN_SCENE_GAP:
            kf.append(i)
            continue

        if sim(embs[i], embs[kf[-1]]) < THRESH_KF and sim(embs[i], embs[i-1]) < THRESH_PREV:
            kf.append(i)

    video_name = os.path.basename(out_dir) 
    for idx in kf:
        fr = frame_order[idx]
        pil = Image.fromarray(frames[fr]).resize(OUTPUT_SIZE, Image.BICUBIC)
        path = os.path.join(out_dir, f"{fr}.{SAVE_FORMAT}")
        pil.save(path, SAVE_FORMAT.upper(), quality=SAVE_QUALITY)
        with lock:
            DATA_STORE.append((video_name, s, fr, path, embs[idx].astype("float32")))

def make_worker(device,pbar):
    def _work():
        model,mean,std = load_model(device)
        state={"model":model,"mean":mean,"std":std,"device":device}
        while True:
            try: job = task_queue.get_nowait()
            except: break
            try: process_scene(job,state)
            except Exception as e: print("⚠️",e)
            task_queue.task_done()
            pbar.update(1)
    return _work

def main():
    ensure(OUTPUT_DIR)
    all_json = natsorted(glob.glob(os.path.join(INPUT_JSON_DIR,"**/*.json"),recursive=True))

    sz=len(all_json)//NUM_PARTS
    rem=len(all_json)%NUM_PARTS
    pid=max(1,min(PART_ID,NUM_PARTS))
    start=(sz+1)*(pid-1) if pid<=rem else rem*(sz+1)+(pid-rem-1)*sz
    end=start+(sz+1 if pid<=rem else sz)
    selected = all_json[start:end]

    global task_queue, DATA_STORE
    task_queue = Queue()
    DATA_STORE = []  

    for jp in selected:
        d=json.load(open(jp))
        video=d["video_path"]
        name=os.path.splitext(os.path.basename(video))[0]
        out_dir=os.path.join(OUTPUT_DIR,name)
        scenes = natsorted(d["scenes"], key=lambda x: x[0])
        for s,e in scenes:
            task_queue.put((video,s,e,out_dir))

    total=task_queue.qsize()
    pbar=tqdm(total=total,desc=f"Part {PART_ID}",dynamic_ncols=True)

    t0=Thread(target=make_worker("cuda:0",pbar))
    t1=Thread(target=make_worker("cuda:1",pbar))
    t0.start(); t1.start()
    task_queue.join()
    t0.join(); t1.join()
    pbar.close()

    if not DATA_STORE:
        print("⚠️ No embeddings, skip index.")
        return

    DATA_STORE = natsorted(DATA_STORE, key=lambda x: (x[0], x[1], x[2]))
    
    PATH_STORE = [item[3] for item in DATA_STORE]
    EMB_STORE = [item[4] for item in DATA_STORE]
    
    embs=np.stack(EMB_STORE).astype("float32")
    embs/=np.linalg.norm(embs,axis=1,keepdims=True)
    index=faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    faiss.write_index(index,f"/kaggle/working/keyframes_part_{PART_ID:03d}.bin")
    open(f"/kaggle/working/keyframes_paths_part_{PART_ID:03d}.txt","w").write("\n".join(PATH_STORE))
    print("✅ DONE")

if __name__=="__main__":
    main()


Part 2:   0%|          | 0/26445 [00:00<?, ?it/s]

beit3_large_patch16_384_coco_retrieval.p(…):   0%|          | 0.00/1.35G [00:00<?, ?B/s]

Part 2:   0%|          | 2/26445 [00:30<92:55:41, 12.65s/it] mmco: unref short failure
Part 2: 100%|██████████| 26445/26445 [8:30:25<00:00,  1.16s/it]


✅ DONE


In [4]:
os.system(f"cd /kaggle/working && zip -0 -r full_keyframes_part_{PART_ID}.zip "
          f"keyframes_out keyframes_part_{PART_ID:03d}.bin keyframes_paths_part_{PART_ID:03d}.txt")

  adding: keyframes_out/ (stored 0%)
  adding: keyframes_out/L05_V011/ (stored 0%)
  adding: keyframes_out/L05_V011/15330.webp (stored 0%)
  adding: keyframes_out/L05_V011/20855.webp (stored 0%)
  adding: keyframes_out/L05_V011/21052.webp (stored 0%)
  adding: keyframes_out/L05_V011/7837.webp (stored 0%)
  adding: keyframes_out/L05_V011/16267.webp (stored 0%)
  adding: keyframes_out/L05_V011/26081.webp (stored 0%)
  adding: keyframes_out/L05_V011/20428.webp (stored 0%)
  adding: keyframes_out/L05_V011/22864.webp (stored 0%)
  adding: keyframes_out/L05_V011/16620.webp (stored 0%)
  adding: keyframes_out/L05_V011/8387.webp (stored 0%)
  adding: keyframes_out/L05_V011/713.webp (stored 0%)
  adding: keyframes_out/L05_V011/9879.webp (stored 0%)
  adding: keyframes_out/L05_V011/25433.webp (stored 0%)
  adding: keyframes_out/L05_V011/12705.webp (stored 0%)
  adding: keyframes_out/L05_V011/15382.webp (stored 0%)
  adding: keyframes_out/L05_V011/18968.webp (stored 0%)
  adding: keyframes_out/L0

0