<a href="https://colab.research.google.com/github/NajibaTagougui/HPE/blob/main/DOATA_test_UC11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_video
from torchvision.transforms import Compose, Resize, Normalize
from transformers import ViTModel
import cv2
import warnings
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple

# Configuration
@dataclass
class Config:
    # Dataset
    root_dir: str = "/content/drive/MyDrive/UCF11"  # Changed to UCF11
    frame_size: int = 224
    num_frames: int = 32
    batch_size: int = 8

    # Model
    embed_dim: int = 768
    num_heads: int = 8
    num_layers: int = 4

    # Training
    lr: float = 1e-4
    epochs: int = 50
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Occlusion
    occlusion_radius: int = 25
    occlusion_speed: float = 2.0
    occlusion_enabled: bool = True

# Dataset
class UCF11Dataset(Dataset):
    def __init__(self, file_list: List[Tuple[str, int]], transform=None, config=None):
        self.file_list = file_list
        self.transform = transform or Compose([
            Resize((config.frame_size, config.frame_size)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.config = config

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        try:
            video, mask = self._read_video(path)
            video = self.transform(video)
            return {
                'video': video.float(),  # Ensure float type
                'mask': mask.float(),    # Ensure float type
                'label': torch.tensor(label, dtype=torch.long)  # Ensure long type
            }
        except Exception as e:
            warnings.warn(f"Error loading video {path}: {str(e)}")
            # Return empty tensors of correct shape
            dummy_video = torch.zeros((self.config.num_frames, 3, self.config.frame_size, self.config.frame_size))
            dummy_mask = torch.zeros((self.config.num_frames, 1, self.config.frame_size, self.config.frame_size))
            return {
                'video': dummy_video,
                'mask': dummy_mask,
                'label': torch.tensor(-1, dtype=torch.long)  # Invalid label
            }

    def _read_video(self, path):
        try:
            # Try PyAV first
            video, _, _ = read_video(path, pts_unit='sec')
            video = video.permute(0, 3, 1, 2).float() / 255.0
        except Exception as e:
            # Fallback to OpenCV
            warnings.warn(f"PyAV failed for {path}: {e}, using OpenCV")
            video = self._read_with_opencv(path)

        # Ensure we have the correct number of frames
        if len(video) > self.config.num_frames:
            indices = torch.linspace(0, len(video)-1, self.config.num_frames).long()
            video = video[indices]
        elif len(video) < self.config.num_frames:
            # Pad with last frame if needed
            last_frame = video[-1].unsqueeze(0)
            padding = last_frame.repeat(self.config.num_frames - len(video), 1, 1, 1)
            video = torch.cat([video, padding], dim=0)

        # Generate occlusion mask
        mask = self._generate_occlusion(video)
        return video, mask

    def _read_with_opencv(self, path):
        cap = cv2.VideoCapture(str(path))
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret: break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(torch.from_numpy(frame))
        cap.release()
        if not frames:
            return torch.zeros((self.config.num_frames, 3, 240, 320))  # Default resolution if empty
        video = torch.stack(frames).permute(0, 3, 1, 2).float() / 255.0
        return video

    def _generate_occlusion(self, video):
        T, C, H, W = video.shape
        mask = torch.zeros(T, 1, H, W)

        if self.config.occlusion_enabled:
            x = torch.randint(0, W, (1,))
            y = torch.randint(0, H, (1,))
            for t in range(T):
                x = x + self.config.occlusion_speed
                y = y + torch.randn(1) * 0.5
                radius = self.config.occlusion_radius

                Y, X = torch.meshgrid(torch.arange(H), torch.arange(W))
                dist = ((X - x)/radius)**2 + ((Y - y)/radius)**2
                mask[t] = torch.exp(-dist)
        return mask



# Model Components
class OcclusionModulatedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.alpha = nn.Parameter(torch.tensor(0.5))

    def forward(self, x, mask):
        B, T, N, C = x.shape
        qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, self.head_dim).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        mask = mask.view(B, T, N).unsqueeze(2).unsqueeze(3)
        attn = attn * (self.alpha * mask + (1 - self.alpha))
        attn = F.softmax(attn, dim=-1)

        x = (attn @ v).transpose(2, 3).reshape(B, T, N, C)
        return self.proj(x)

class OcclusionEstimator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.local_cnn = nn.Sequential(
            nn.Conv2d(config.embed_dim, config.embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(config.embed_dim, config.embed_dim, kernel_size=3, padding=1),
            nn.ReLU()
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.embed_dim,
            nhead=config.num_heads,
            dim_feedforward=config.embed_dim*4
        )
        self.global_transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.pred_head = nn.Sequential(
            nn.Conv2d(config.embed_dim, 64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        B, T, N, C = x.shape
        P = int(N ** 0.5)
        x = x.permute(0, 1, 3, 2).reshape(B*T, C, P, P)

        local_feats = self.local_cnn(x)
        global_feats = local_feats.flatten(2).permute(2, 0, 1)
        global_feats = self.global_transformer(global_feats)
        global_feats = global_feats.permute(1, 2, 0).reshape(B*T, C, P, P)

        combined = local_feats + global_feats
        occlusion_maps = self.pred_head(combined)
        occlusion_maps = F.interpolate(
            occlusion_maps,
            size=(self.config.frame_size, self.config.frame_size),
            mode='bilinear',
            align_corners=False
        )
        return occlusion_maps.reshape(B, T, self.config.frame_size, self.config.frame_size)

class TemporalAggregator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([TemporalBlock(config) for _ in range(config.num_layers)])
        self.pos_embed = nn.Parameter(torch.randn(1, config.num_frames, 1, config.embed_dim) * 0.02)
        self.norm = nn.LayerNorm(config.embed_dim)

    def forward(self, x, occlusion_maps):
        B, T, N, C = x.shape
        H = W = int(N ** 0.5)

        x = x + self.pos_embed[:, :T]
        attn_masks = F.interpolate(
            occlusion_maps.flatten(0, 1).unsqueeze(1),
            size=(H, W),
            mode='bilinear',
            align_corners=False
        ).view(B, T, H * W)

        for layer in self.layers:
            x = layer(x, attn_masks)
        return self.norm(x)

class TemporalBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = OcclusionModulatedAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.embed_dim, config.embed_dim * 4),
            nn.GELU(),
            nn.Linear(config.embed_dim * 4, config.embed_dim)
        )
        self.norm1 = nn.LayerNorm(config.embed_dim)
        self.norm2 = nn.LayerNorm(config.embed_dim)

    def forward(self, x, attn_masks):
        attn_out = self.attention(self.norm1(x), attn_masks)
        x = x + attn_out
        mlp_out = self.mlp(self.norm2(x))
        x = x + mlp_out
        return x

# Main DOATA Model
class DOATA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Load ViT backbone with correct configuration
        self.backbone = ViTModel.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            add_pooling_layer=False,
            ignore_mismatched_sizes=True
        )

        # Freeze backbone initially
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.occlusion_estimator = OcclusionEstimator(config)
        self.temporal_aggregator = TemporalAggregator(config)

        # Classifier with proper initialization
        self.classifier = nn.Sequential(
            nn.LayerNorm(config.embed_dim),
            nn.Linear(config.embed_dim, 512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, 11)  # 11 classes for UCF11
        )

        # Initialize weights correctly
        self._init_weights()

    def _init_weights(self):
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:  # Fixed typo: was 'biases'
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LayerNorm):
                nn.init.constant_(module.weight, 1.0)
                nn.init.constant_(module.bias, 0)  # Fixed typo: was 'biases'

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)

        # Get features from backbone
        with torch.no_grad():
            features = self.backbone(x).last_hidden_state[:, 1:]  # Remove CLS token

        features = features.view(B, T, -1, self.config.embed_dim)
        occlusion_maps = self.occlusion_estimator(features)
        temporal_features = self.temporal_aggregator(features, occlusion_maps)
        logits = self.classifier(temporal_features.mean(dim=1))

        return logits, occlusion_maps


# Training Utilities
def get_ucf11_splits(root_dir, split='train'):
    """Load UCF11 dataset with nested subdirectory structure"""
    base_dir = os.path.join(root_dir, 'UCF11_updated_mpg')
    if not os.path.exists(base_dir):
        raise FileNotFoundError(f"UCF11 directory not found at {base_dir}")

    # Get all class directories
    classes = sorted([d for d in os.listdir(base_dir)
                     if os.path.isdir(os.path.join(base_dir, d))])

    if not classes:
        raise FileNotFoundError(f"No classes found in {base_dir}")

    class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
    file_list = []
    video_extensions = ('.mpg', '.avi', '.mp4', '.mpeg')

    for class_name in classes:
        class_dir = os.path.join(base_dir, class_name)

        # Recursively find all video files in subdirectories
        videos = []
        for root, _, files in os.walk(class_dir):
            for file in files:
                if file.lower().endswith(video_extensions):
                    videos.append(os.path.join(root, file))

        if not videos:
            warnings.warn(f"No videos found in class: {class_name}")
            continue

        # Sort videos for consistent train/test split
        videos = sorted(videos)
        split_idx = int(0.8 * len(videos))  # 80/20 split

        if split == 'train':
            selected_videos = videos[:split_idx]
        else:
            selected_videos = videos[split_idx:]

        for video_path in selected_videos:
            file_list.append((video_path, class_to_idx[class_name]))

    if not file_list:
        raise RuntimeError(f"No videos found for {split} split. Check dataset structure.")

    print(f"Found {len(file_list)} videos for {split} split")
    return file_list

# Updated training function with Hugging Face token handling
def train():
    config = Config()

    # Handle Hugging Face token
    try:
        from google.colab import userdata
        HF_TOKEN = userdata.get('HF_TOKEN')
    except:
        HF_TOKEN = None
        warnings.warn("No HF_TOKEN found, using anonymous access")

    # Initialize with deterministic behavior
    torch.manual_seed(42)
    if config.device == 'cuda':
        torch.cuda.manual_seed_all(42)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    try:
        # Load datasets
        train_files = get_ucf11_splits(config.root_dir, 'train')
        test_files = get_ucf11_splits(config.root_dir, 'test')

        train_dataset = UCF11Dataset(train_files, config=config)
        test_dataset = UCF11Dataset(test_files, config=config)

        # Filter out invalid samples
        train_dataset.file_list = [f for i, f in enumerate(train_dataset.file_list)
                                if train_dataset[i]['label'] != -1]
        test_dataset.file_list = [f for i, f in enumerate(test_dataset.file_list)
                               if test_dataset[i]['label'] != -1]

        print(f"\nTraining samples: {len(train_dataset)}")
        print(f"Test samples: {len(test_dataset)}\n")

        # DataLoaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=min(2, os.cpu_count()-1),
            pin_memory=True,
            persistent_workers=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=config.batch_size,
            num_workers=min(2, os.cpu_count()-1),
            pin_memory=True,
            persistent_workers=True
        )

        # Initialize model
        model = DOATA(config).to(config.device)

        # Optimizer
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.lr,
            weight_decay=0.01
        )

        # Training loop
        best_acc = 0.0
        for epoch in range(config.epochs):
            model.train()
            train_loss = 0.0
            correct = 0
            total = 0

            for batch_idx, batch in enumerate(train_loader):
                videos = batch['video'].to(config.device, non_blocking=True)
                labels = batch['label'].to(config.device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                logits, _ = model(videos)
                loss = F.cross_entropy(logits, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                train_loss += loss.item()
                _, predicted = logits.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                if batch_idx % 20 == 0:
                    print(f"Epoch {epoch+1}/{config.epochs} | "
                          f"Batch {batch_idx}/{len(train_loader)} | "
                          f"Loss: {loss.item():.4f} | "
                          f"Acc: {100.*correct/total:.2f}%")

            # Validation
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for batch in test_loader:
                    videos = batch['video'].to(config.device, non_blocking=True)
                    labels = batch['label'].to(config.device, non_blocking=True)

                    logits, _ = model(videos)
                    loss = F.cross_entropy(logits, labels)

                    val_loss += loss.item()
                    _, predicted = logits.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()

            # Print epoch summary
            train_acc = 100. * correct / total
            val_acc = 100. * val_correct / val_total
            print(f"\nEpoch {epoch+1} Summary:")
            print(f"Train Loss: {train_loss/len(train_loader):.4f} | Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss/len(test_loader):.4f} | Acc: {val_acc:.2f}%")

            # Save best model
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save({
                    'epoch': epoch+1,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }, 'best_model.pth')
                print(f"Saved best model with val acc: {best_acc:.2f}%")

            print("-" * 50)

        print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")

    except Exception as e:
        print(f"\nError during training: {str(e)}")
        print("\nDebugging steps:")
        print("1. Try running with num_workers=0")
        print("2. Verify your HF_TOKEN is set if using private models")
        print("3. Check CUDA memory with nvidia-smi")
        print("4. Test with smaller batch size")
        raise

if __name__ == "__main__":
    train()






Found 1283 videos for train split
Found 328 videos for test split

Training samples: 1283
Test samples: 328



Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [1]:
pip install av

Collecting av
  Downloading av-14.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
Downloading av-14.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.2/35.2 MB[0m [31m53.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: av
Successfully installed av-14.3.0
