In [None]:
import os
import random
import torch
import cv2 as cv
import torchvision
import torch.nn as nn
from pathlib import Path
from tqdm.auto import tqdm
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score

In [None]:
# local directory for video data
#DATA_DIR = r"C:\Users\knila\Few Shot video classification\Train"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# 1. Gather video paths
video_paths = list(Path(DATA_DIR).rglob("*.mp4"))
train_video_paths, test_video_paths = train_test_split(video_paths, test_size=0.2, random_state=42)

In [None]:
# 2. Labeling function
def labeling(vid_paths):
    return [path.parent.name for path in vid_paths]

train_labels = labeling(train_video_paths)
test_labels = labeling(test_video_paths)
unique_labels = sorted(set(train_labels))
label_dict = {label: i for i, label in enumerate(unique_labels)}


In [None]:
# 3.video loader
def video_to_tensor(video_path, expected_frames=32, height=112, width=112):
    capture = cv.VideoCapture(str(video_path))
    frames = []
    for _ in range(expected_frames):
        ret, frame = capture.read()
        if not ret:
            break
        frame = cv.resize(frame, (width, height), interpolation=cv.INTER_AREA)
        frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
        frame = torch.from_numpy(frame)
        frames.append(frame)
    capture.release()
    if len(frames) == 0:
        # Return a zero tensor if video is unreadable
        return torch.zeros(3, expected_frames, height, width)
    frames = torch.stack(frames)
    frames = frames.float() / 255.0
    frames = frames.permute(3, 0, 1, 2)  # (C, T, H, W)
    # Pad if not enough frames
    if frames.shape[1] < expected_frames:
        pad = expected_frames - frames.shape[1]
        frames = F.pad(frames, (0, 0, 0, 0, 0, pad))
    return frames


In [None]:
# 4. Pair dataset with deterministic sampling
class PairVideoDataset(Dataset):
    def __init__(self, video_paths, labels, pairs_per_epoch=1000):
        self.video_paths = video_paths
        self.labels = labels
        self.pairs = []
        # Precompute pairs for reproducibility and efficiency
        label_to_indices = {label: [i for i, l in enumerate(labels) if l == label] for label in set(labels)}
        for _ in range(pairs_per_epoch):
            idx1 = random.randint(0, len(video_paths) - 1)
            label1 = labels[idx1]
            # Similar pair
            idx2 = random.choice(label_to_indices[label1])
            self.pairs.append((idx1, idx2, 1))
            # Dissimilar pair
            label2 = random.choice([l for l in set(labels) if l != label1])
            idx3 = random.choice(label_to_indices[label2])
            self.pairs.append((idx1, idx3, 0))

    def __getitem__(self, idx):
        idx1, idx2, sim = self.pairs[idx]
        video1 = video_to_tensor(self.video_paths[idx1])
        video2 = video_to_tensor(self.video_paths[idx2])
        return video1, video2, torch.tensor(sim, dtype=torch.float32)

    def __len__(self):
        return len(self.pairs)


In [None]:
# 5. Siamese network (single backbone)
class SiamesePredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = torchvision.models.video.r3d_18(weights='DEFAULT')
        for param in self.backbone.parameters():
            param.requires_grad = False
        # Unfreeze last two layers
        for child in list(self.backbone.children())[-2:]:
            for param in child.parameters():
                param.requires_grad = True
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_features, 512)
        self.fc = nn.Linear(512, 128)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.backbone(x)
        x = self.fc(x)
        x = self.dropout(x)
        x = F.normalize(x, p=2, dim=1)
        return x

model = SiamesePredictor().to(device)

In [None]:
# 6. Contrastive loss (vectorized)
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        cos_sim = F.cosine_similarity(output1, output2)
        loss = torch.mean((1 - label) * (1 - cos_sim) + label * torch.clamp(cos_sim - self.margin, min=0.0))
        return loss

loss_fn = ContrastiveLoss()

In [None]:
# 7. DataLoader
train_dataset = PairVideoDataset(train_video_paths, train_labels, pairs_per_epoch=500)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_dataset = PairVideoDataset(test_video_paths, test_labels, pairs_per_epoch=100)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)


In [None]:
# 8. Optimizer
optimizer = torch.optim.Adam(
    list(model.fc.parameters()) +
    list(list(model.backbone.children())[-2].parameters()) +
    list(list(model.backbone.children())[-1].parameters()),
    lr=0.001,
    weight_decay=1e-3
)

In [None]:
# 9. Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for video1, video2, sim_label in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        video1 = video1.to(device)
        video2 = video2.to(device)
        sim_label = sim_label.to(device)
        optimizer.zero_grad()
        out1 = model(video1)
        out2 = model(video2)
        loss = loss_fn(out1, out2, sim_label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {running_loss / len(train_loader):.4f}")

In [None]:
# 10. Evaluation
model.eval()
predictions = []
true_labels = []
with torch.no_grad():
    for video1, video2, sim_label in test_loader:
        video1 = video1.to(device)
        video2 = video2.to(device)
        out1 = model(video1)
        out2 = model(video2)
        cos_sim = F.cosine_similarity(out1, out2)
        pred = (cos_sim > 0.8).int().cpu().numpy()
        predictions.extend(pred.tolist())
        true_labels.extend(sim_label.int().cpu().numpy().tolist())

accuracy = accuracy_score(true_labels, predictions)
print(f"Test Accuracy: {accuracy:.4f}")