In [21]:
import pandas as pd
from PIL import Image
import os
import cv2
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np

In [2]:
# Define paths
train_csv_path = '../data/processed/MMHS150K_processed/training_demo/train_demo_data.csv'
val_csv_path = '../data/processed/MMHS150K_processed/validation_demo/val_demo_data.csv'
train_img_dir = '../data/processed/MMHS150K_processed/training_demo/train_demo_images_updated/'
val_img_dir = '../data/processed/MMHS150K_processed/validation_demo/val_demo_images_updated/'

# Read CSV files
train_df = pd.read_csv(train_csv_path)
val_df = pd.read_csv(val_csv_path)

def load_images(df, img_folder):
    images = []
    labels = []
    for _, row in df.iterrows():
        img_path = os.path.join(img_folder, f"{row['img_id']}.jpg")
        if os.path.exists(img_path):
            images.append(Image.open(img_path).convert('RGB'))
            labels.append(row['label'])
    return images, labels

# Load data without any missing image notifications
train_images, train_labels = load_images(train_df, train_img_dir)
val_images, val_labels = load_images(val_df, val_img_dir)

# Basic verification
print(f"Training: {len(train_images)} images")
print(f"Validation: {len(val_images)} images")

Training: 180 images
Validation: 14 images


In [None]:


# 1. Video Frame Extraction
def extract_frames(video_path, sample_rate=1):
    cap = cv2.VideoCapture(video_path)
    frames = []
    frame_count = 0
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if frame_count % sample_rate == 0:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        frame_count += 1
    
    cap.release()
    return frames

# 2. ViT-based Spatial Feature Extraction
class ViTFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = models.vit_b_16(pretrained=True)
        
        # Remove classification head
        self.vit.heads = nn.Identity()
        
        # Freeze all parameters
        for param in self.vit.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        # x shape: [batch, 3, 224, 224]
        x = self.vit.conv_proj(x)  # Convert to patches
        x = x.flatten(2).transpose(1, 2)  # [batch, num_patches, embed_dim]
        
        # Add class token and position embeddings
        cls_token = self.vit.class_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.vit.encoder.pos_embedding
        
        # Process through transformer encoder
        x = self.vit.encoder(x)
        return x[:, 0]  # Return class token embedding (768-dim)

# 3. Temporal Transformer for Context
class TemporalTransformer(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, num_layers=2):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
    def forward(self, x):
        # x shape: [seq_len, embed_dim]
        return self.transformer(x.unsqueeze(0)).squeeze(0)  # [seq_len, embed_dim]

# 4. Full Pipeline
def process_video_with_vit(video_path, sample_rate=1):
    # Extract frames
    frames = extract_frames(video_path, sample_rate)
    
    # Preprocessing transforms (same as ViT training)
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Extract spatial features with ViT
    vit_extractor = ViTFeatureExtractor().eval()
    frame_tensors = torch.stack([preprocess(frame) for frame in frames])
    
    with torch.no_grad():
        spatial_features = vit_extractor(frame_tensors)  # [num_frames, 768]
    
    # Add temporal context with transformer
    temporal_model = TemporalTransformer()
    contextual_embeddings = temporal_model(spatial_features)  # [num_frames, 768]
    
    return contextual_embeddings.numpy()

# 5. Usage Example
if __name__ == "__main__":
    video_path = "your_video.mp4"
    embeddings = process_video_with_vit(video_path, sample_rate=1)
    print(f"ViT Contextual Embeddings Shape: {embeddings.shape}")
    # Output: (num_frames, 768)