In [None]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet50
from torchvision.io import read_image

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

In [None]:
def frozen_resnet50():
    model = resnet50(weights="IMAGENET1K_V2")
    model.fc = torch.nn.Identity()          # 2048-d
    for p in model.parameters():
        p.requires_grad_(False)
    return model.to(DEVICE).eval()

PREPROC = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ConvertImageDtype(torch.float32),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225])
])

def load_frames(frames_dir, n_frames=32, pad_to_len=None):
    paths = sorted(Path(frames_dir).glob("*.jpg"))
    if not paths:
        return torch.empty(0, 3, 224, 224, device=DEVICE)

    # samples up to n_frames, if the video has more frames will downsample uniformly
    idx = np.linspace(0, len(paths) - 1,
                      num=min(n_frames, len(paths)),
                      dtype=int)
    batch = torch.stack([PREPROC(read_image(str(paths[i])))
                         for i in idx]).to(DEVICE)

    # Optional: repeat last frame until pad_to_len
    if pad_to_len and batch.shape[0] < pad_to_len:
        reps = pad_to_len - batch.shape[0]
        batch = torch.cat([batch,
                           batch[-1:].expand(reps, -1, -1, -1)], dim=0)
    return batch

In [None]:
class FrameCNNBiGRUVectorizer:
    def __init__(self, n_frames=32, hidden=384):
        self.n_frames = n_frames
        self.backbone = frozen_resnet50()
        self.gru = nn.GRU(
            input_size=2048,
            hidden_size=hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        ).to(DEVICE)

    @torch.inference_mode()
    def vectorize_vid(self, frames_dir):
        frames = load_frames(frames_dir,
                             n_frames=self.n_frames,
                             pad_to_len=None)
        if frames.nelement() == 0:
            return np.zeros(2 * self.gru.hidden_size, np.float32)

        feats = self.backbone(frames)                        # (T,2048)
        feats = feats.unsqueeze(0)                           # (1,T,2048)
        _, h = self.gru(feats)                               # h: (2,1,H)
        h = h.transpose(0, 1).reshape(1, -1)                 # (1, 2H)
        return h.squeeze(0).cpu().numpy().astype(np.float32)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import tqdm.notebook as tq

In [None]:
def iter_frame_dirs(frames_root: Path):
    """
    Yields (video_id, frames_dir) for every sub-directory whose name ends with
    '_frames'.
    """
    for p in sorted(frames_root.iterdir()):
        if p.is_dir() and p.name.endswith("_frames"):
            yield p.name.rsplit("_frames", 1)[0], p


def save_vector(vec: np.ndarray, video_id: str, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    np.save(out_dir / f"{video_id}_bigru.npy", vec.astype(np.float32))


In [None]:
def run_bigru_extraction(frames_root: str,
                         output_dir: str,
                         n_frames: int = 32,
                         hidden: int = 384):

    frames_root = Path(frames_root)
    output_dir  = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    frame_dirs = [p for p in frames_root.iterdir()
                  if p.is_dir() and p.name.endswith('_frames')]
    vectoriser = FrameCNNBiGRUVectorizer(n_frames=n_frames, hidden=hidden)

    for fdir in tq.tqdm(frame_dirs, desc='Extracting', unit='vid'):
        video_id = fdir.name.rsplit('_frames', 1)[0]
        vec = vectoriser.vectorize_vid(fdir)
        np.save(output_dir / f'{video_id}_bigru.npy', vec.astype(np.float32))

    print(f' Done. Wrote {len(frame_dirs)} vectors to {output_dir}')


In [None]:
frames_dir = "/content/drive/MyDrive/all_videos_frames"
output_dir = "/content/drive/MyDrive/video_vectorizer/cnn_output"
run_bigru_extraction(frames_dir, output_dir)