In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

In [3]:
!pip install torch torchvision eva-decord tqdm numpy mamba-ssm triton

Collecting eva-decord
  Downloading eva_decord-0.6.1-py3-none-manylinux2010_x86_64.whl.metadata (449 bytes)
Collecting mamba-ssm
  Downloading mamba_ssm-2.2.4.tar.gz (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.8/91.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting triton
  Using cached triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading eva_decord-0.6.1-py3-none-manylinux2010_x86_64.whl (13.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.6/13.6 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.1 MB)
Building wheels for collected packages: mamba-ssm
  Building wheel for mamba-ssm (pyproject.toml) 

In [4]:
import os
import json
import logging
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, ChainedScheduler
import decord
import numpy as np
from mamba_ssm import Mamba
from einops import rearrange
import torchvision.transforms as T
from tqdm.notebook import tqdm
from typing import Dict, List, Tuple
from collections import defaultdict
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

In [5]:

# Setting up environment variables for distributed training for kaggle
def setup_distributed():
    # Set environment variables if not already set
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ['WORLD_SIZE'] = '2'  # For 2 GPUs
    if 'RANK' not in os.environ:
        os.environ['RANK'] = '0'
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = '0'
    
    # Initialize process group
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=int(os.environ['WORLD_SIZE']),
        rank=int(os.environ['RANK'])
    )
    torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

# Enhanced logging configuration
def setup_logging():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s - %(message)s", 
        datefmt="%Y-%m-%d %H:%M:%S"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)

In [6]:
# This function is used to create a dataloader for the dataset
# It is used to load the data in batches and shuffle the data
class TALDataset(Dataset):
    """Temporal Action Localization Dataset"""
    def __init__(
        self,
        json_path: str,
        video_root: str,
        clip_len: int = 16,
        frame_size: int = 224,
        sample_rate: int = 4,
        temporal_stride: int = 4,
        num_clips: int = 32,
        subset: str = 'validation'  # Default to validation since that's what we have
    ):
        self.video_root = video_root
        self.clip_len = clip_len
        self.frame_size = frame_size
        self.sample_rate = sample_rate
        self.temporal_stride = temporal_stride
        self.num_clips = num_clips
        
        logging.info(f"Initializing TALDataset with json_path: {json_path}")
        logging.info(f"Video root directory: {video_root}")
        
        # Load and validate JSON data
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            logging.info(f"Successfully loaded JSON data with {len(data['database'])} total videos")
        except Exception as e:
            logging.error(f"Error loading JSON file: {str(e)}")
            raise
        
        self.database = data['database']
        
        # Get all unique classes
        self.classes = sorted({
            ann['label'] for vid in self.database.values() 
            for ann in vid['annotations']
        })
        logging.info(f"Found {len(self.classes)} unique action classes")
        
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}
        
        # Filter videos by subset
        self.video_list = []
        for vid_id, vid_data in self.database.items():
            # Check if video file exists - using correct path structure
            video_path = os.path.join(self.video_root, vid_id, f"{vid_id}.mp4")
            if not os.path.exists(video_path):
                logging.warning(f"Video file not found: {video_path}")
                continue
                
            if vid_data['subset'] == subset:
                self.video_list.append(vid_id)
        
        logging.info(f"Found {len(self.video_list)} videos in {subset} subset")
        
        if len(self.video_list) == 0:
            logging.error(f"No videos found for subset '{subset}'")
            logging.error(f"Available subsets in data: {set(v['subset'] for v in self.database.values())}")
            raise ValueError("No valid videos found in the dataset")
        
        self.transform = T.Compose([
            T.Resize((frame_size, frame_size)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _load_video(self, video_id: str) -> torch.Tensor:
        try:
            video_path = os.path.join(self.video_root, video_id, f"{video_id}.mp4")
            # logging.info(f"Loading video from: {video_path}")
            
            vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
            
            total_frames = len(vr)
            indices = np.linspace(
                0, total_frames-1, 
                num=self.num_clips*self.temporal_stride,
                dtype=int
            )
            
            frames = vr.get_batch(indices).asnumpy()
            frames = torch.from_numpy(frames).float() / 255.0
            return frames.permute(0, 3, 1, 2)  # (T, C, H, W)
            
        except Exception as e:
            logging.error(f"Error loading video {video_id}: {str(e)}")
            raise

    def __getitem__(self, idx: int) -> Dict:
        video_id = self.video_list[idx]
        video_data = self.database[video_id]
        
        try:
            # Load and process video
            frames = self._load_video(video_id)
            frames = self.transform(frames)
            
            # Convert to (C, T, H, W)
            frames = frames.permute(1, 0, 2, 3)
            
            # Create temporal targets
            duration = video_data['duration']
            num_frames = frames.shape[1]
            time_unit = duration / num_frames
            
            # Initialize targets
            action_map = torch.zeros((num_frames, len(self.classes)))
            reg_targets = torch.zeros((num_frames, 2))
            
            for ann in video_data['annotations']:
                label = self.class_to_idx[ann['label']]
                start = max(0, int(ann['segment'][0] / time_unit))
                end = min(num_frames-1, int(ann['segment'][1] / time_unit))
                
                # Skip invalid segments
                if start >= end:
                    continue
                
                # Classification target
                action_map[start:end+1, label] = 1.0
                
                # Calculate center and length, ensuring no division by zero
                length = end - start
                if length > 0:  # Only calculate regression targets for valid segments
                    center = (start + end) / 2
                    
                    for t in range(start, end+1):
                        reg_targets[t, 0] = (t - center) / length  # Center offset
                        reg_targets[t, 1] = (t - start) / length  # Length ratio
            
            return {
                'video_id': video_id,
                'frames': frames,
                'action_map': action_map,
                'reg_targets': reg_targets,
                'duration': duration
            }
            
        except Exception as e:
            logging.error(f"Error processing video {video_id}: {str(e)}")
            raise

    def __len__(self):
        return len(self.video_list)

## Model Architecture:

In [7]:
class FeatureAggregatedBiS6(nn.Module):
    """Enhanced Bi-directional S6 Block"""
    def __init__(self, dim: int, kernel_sizes: List[int] = [3,5,7], expansion: int = 2):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.ConstantPad1d((k//2, (k-1)//2), 0),
                nn.Conv1d(dim, dim, k),
                nn.GELU()
            ) for k in kernel_sizes
        ])
        
        self.s6_fwd = Mamba(
            d_model=dim,
            d_state=16,
            d_conv=4,
            expand=expansion
        )
        self.s6_bwd = Mamba(
            d_model=dim,
            d_state=16,
            d_conv=4,
            expand=expansion
        )
        self.norm = nn.LayerNorm(dim)
        self.gate = nn.Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, T = x.shape
        residual = x
        
        # Multi-scale temporal aggregation
        conv_outs = [conv(x) for conv in self.convs]
        x = sum(conv_outs) * self.gate
        
        # Bi-directional processing
        x = x.permute(0, 2, 1)  # [B, T, C]
        x_fwd = self.s6_fwd(x)
        x_bwd = self.s6_bwd(x.flip(1)).flip(1)
        x = x_fwd + x_bwd
        
        return self.norm(x.permute(0, 2, 1) + residual)

class DualBiS6TAL(nn.Module):
    """Dual-path S6 Architecture for Temporal Action Localization"""
    def __init__(self, num_classes: int, dim: int = 128, recur_steps: int = 4):
        super().__init__()
        # Feature extractor
        self.encoder = nn.Sequential(
            nn.Conv3d(3, dim, kernel_size=(3,7,7), stride=(1,2,2), padding=(1,3,3)),
            nn.GELU(),
            nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
            nn.BatchNorm3d(dim)
        )
        
        # Temporal modeling
        self.temporal_blocks = nn.ModuleList([
            FeatureAggregatedBiS6(dim)
            for _ in range(recur_steps)
        ])
        
        # Pyramid branches
        self.pyramid = nn.ModuleList([
            nn.Sequential(
                FeatureAggregatedBiS6(dim),
                nn.MaxPool1d(2, stride=2)
            ) for _ in range(4)
        ])
        
        # Prediction heads
        self.cls_head = nn.Sequential(
            nn.Conv1d(dim, dim//2, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(dim//2, num_classes, 1)
        )
        
        self.reg_head = nn.Sequential(
            nn.Conv1d(dim, dim//2, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(dim//2, 2, 1),
            nn.Tanh()  # Bounded regression outputs
        )

    def forward(self, x: torch.Tensor) -> Dict:
        # Initial features
        x = self.encoder(x)  # [B, C, T, H, W]
        x = x.flatten(3).mean(-1)  # [B, C, T]
        
        # Temporal modeling
        for block in self.temporal_blocks:
            x = block(x)
        
        # Multi-scale pyramid
        pyramid_features = [x]
        for branch in self.pyramid:
            x = branch(x)
            pyramid_features.append(x)
        
        # Merge pyramid features
        merged = torch.cat([
            nn.functional.interpolate(f, size=pyramid_features[0].shape[-1])
            for f in pyramid_features
        ], dim=1)
        
        return {
            'cls_logits': self.cls_head(merged),
            'reg_outputs': self.reg_head(merged)
        }

In [8]:
class TemporalAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, T = x.shape
        x = x.permute(0, 2, 1)
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        x = (attn.softmax(dim=-1) @ v).transpose(1, 2).reshape(B, T, C)
        return self.proj(x).permute(0, 2, 1)

class RefinedFeatureAggregatedBiS6(nn.Module):
    """Enhanced Bi-directional S6 Block with temporal attention"""
    def __init__(self, dim: int, kernel_sizes: List[int] = [3,5,7], expansion: int = 2):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.ConstantPad1d((k//2, (k-1)//2), 0),
                nn.Conv1d(dim, dim, k),
                nn.GELU()
            ) for k in kernel_sizes
        ])
        
        self.temporal_attn = TemporalAttention(dim)
        
        self.s6_fwd = Mamba(
            d_model=dim,
            d_state=16,
            d_conv=4,
            expand=expansion
        )
        self.s6_bwd = Mamba(
            d_model=dim,
            d_state=16,
            d_conv=4,
            expand=expansion
        )
        
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.gate = nn.Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Multi-scale temporal aggregation
        conv_outs = [conv(x) for conv in self.convs]
        x = sum(conv_outs) * self.gate
        
        # Temporal attention
        x = x + self.temporal_attn(x)
        x = self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1)
        
        # Bi-directional processing
        x_p = x.permute(0, 2, 1)  # [B, T, C]
        x_fwd = self.s6_fwd(x_p)
        x_bwd = self.s6_bwd(x_p.flip(1)).flip(1)
        x = x_fwd + x_bwd
        x = self.norm2(x).permute(0, 2, 1)
        
        return x

class RefinedDualBiS6TAL(nn.Module):
    """Improved Dual-path S6 Architecture with enhanced temporal modeling"""
    def __init__(self, num_classes: int, dim: int = 128, recur_steps: int = 4):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(3, dim, kernel_size=(3,7,7), stride=(1,2,2), padding=(1,3,3)),
            nn.GELU(),
            nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
            nn.BatchNorm3d(dim)
        )
        
        # Temporal modeling with attention
        self.temporal_blocks = nn.ModuleList([
            RefinedFeatureAggregatedBiS6(dim)
            for _ in range(recur_steps)
        ])
        
        # Multi-scale pyramid with attention
        self.pyramid = nn.ModuleList([
            nn.Sequential(
                RefinedFeatureAggregatedBiS6(dim),
                nn.MaxPool1d(2, stride=2)
            ) for _ in range(4)
        ])
        
        # Prediction heads with confidence modeling
        self.cls_head = nn.Sequential(
            nn.Conv1d(dim * 5, dim, 1),
            nn.GELU(),
            nn.Conv1d(dim, dim//2, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(dim//2, num_classes, 1)
        )
        
        self.reg_head = nn.Sequential(
            nn.Conv1d(dim * 5, dim, 1),
            nn.GELU(),
            nn.Conv1d(dim, dim//2, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(dim//2, 2, 1),
            nn.Tanh()
        )
        
        self.conf_head = nn.Sequential(
            nn.Conv1d(dim * 5, dim//2, 3, padding=1),
            nn.GELU(),
            nn.Conv1d(dim//2, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> Dict:
        # Initial features
        x = self.encoder(x)
        x = x.flatten(3).mean(-1)
        
        # Temporal modeling
        for block in self.temporal_blocks:
            x = block(x)
        
        # Multi-scale pyramid
        pyramid_features = [x]
        curr_feat = x
        for branch in self.pyramid:
            curr_feat = branch(curr_feat)
            pyramid_features.append(
                nn.functional.interpolate(
                    curr_feat, 
                    size=x.shape[-1],
                    mode='linear'
                )
            )
        
        # Merge pyramid features
        merged = torch.cat(pyramid_features, dim=1)
        
        return {
            'cls_logits': self.cls_head(merged),
            'reg_outputs': self.reg_head(merged),
            'confidence': self.conf_head(merged)
        }

## Training Code for TAL

In [9]:
class TALTrainer:
    """Distributed Training Framework"""
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device(f'cuda:{dist.get_rank()}')
        
        # Initialize dataset
        self.dataset = TALDataset(
            json_path=config['json_path'],
            video_root=config['video_root'],
            clip_len=config['clip_len'],
            frame_size=config['frame_size']
        )
        
        # Split dataset
        train_size = int(0.9 * len(self.dataset))
        val_size = len(self.dataset) - train_size
        self.train_set, self.val_set = torch.utils.data.random_split(
            self.dataset, [train_size, val_size]
        )
        
        # Create dataloaders
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        
        self.val_loader = DataLoader(
            self.val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        # Initialize model
        self.model = DualBiS6TAL(
            num_classes=len(self.dataset.classes),
            dim=config['dim'],
            recur_steps=config['recur_steps']
        ).to(self.device)
        self.model = DistributedDataParallel(self.model, device_ids=[self.device])
        
        # Optimizer and scheduler
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config['lr'],
            weight_decay=0.01
        )
        
        self.scheduler = ChainedScheduler([
            LinearLR(
                self.optimizer, 
                start_factor=0.1,
                total_iters=config['warmup_steps']
            ),
            CosineAnnealingLR(
                self.optimizer,
                T_max=config['epochs']-config['warmup_steps']
            )
        ])
        
        # Loss functions
        self.cls_criterion = nn.BCEWithLogitsLoss()
        self.reg_criterion = nn.SmoothL1Loss()

    def train_epoch(self, epoch: int):
        self.model.train()
        total_loss = 0.0
        
        for batch in tqdm(self.train_loader, desc=f"Epoch {epoch+1}"):
            frames = batch['frames'].to(self.device)
            action_map = batch['action_map'].to(self.device)
            reg_targets = batch['reg_targets'].to(self.device)
            
            # Forward pass
            outputs = self.model(frames)
            cls_loss = self.cls_criterion(outputs['cls_logits'], action_map)
            reg_loss = self.reg_criterion(outputs['reg_outputs'], reg_targets)
            loss = cls_loss + 0.5*reg_loss
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            
        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        total_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validating"):
                frames = batch['frames'].to(self.device)
                action_map = batch['action_map'].to(self.device)
                reg_targets = batch['reg_targets'].to(self.device)
                
                outputs = self.model(frames)
                
                # Calculate loss
                cls_loss = self.cls_criterion(outputs['cls_logits'], action_map)
                reg_loss = self.reg_criterion(outputs['reg_outputs'], reg_targets)
                loss = cls_loss + 0.5*reg_loss
                total_loss += loss.item()
                
                # Collect predictions
                all_preds.append({
                    'cls': outputs['cls_logits'].sigmoid().cpu(),
                    'reg': outputs['reg_outputs'].cpu(),
                    'video_ids': batch['video_id']
                })
                all_targets.append({
                    'cls': action_map.cpu(),
                    'reg': reg_targets.cpu()
                })
        
        # Calculate metrics
        metrics = self.calculate_metrics(all_preds, all_targets)
        metrics['val_loss'] = total_loss / len(self.val_loader)
        return metrics

    def calculate_metrics(self, preds: List, targets: List) -> Dict:
        """Calculate temporal action localization metrics"""
        # Implement proper TAL metrics here
        cls_preds = torch.cat([p['cls'] for p in preds])
        cls_targets = torch.cat([t['cls'] for t in targets])
        
        # Simple classification metrics
        cls_ap = []
        for c in range(cls_preds.shape[1]):
            ap = average_precision_score(
                cls_targets[:, c].numpy(),
                cls_preds[:, c].numpy()
            )
            cls_ap.append(ap)
        
        return {
            'mAP': np.mean(cls_ap),
            'mAP@0.5': self.calculate_temporal_map(preds, targets, iou_thresh=0.5)
        }

    def calculate_temporal_map(predictions: List[Dict], targets: List[Dict], iou_thresholds: List[float] = None) -> float:
        """Calculate temporal mean Average Precision
        
        Args:
            predictions: List of prediction dicts containing 'cls', 'reg', 'confidence' 
            targets: List of target dicts containing 'cls', 'reg'
            iou_thresholds: List of IoU thresholds for evaluation
        
        Returns:
            mean Average Precision averaged across IoU thresholds
        """
        if iou_thresholds is None:
            iou_thresholds = np.linspace(0.5, 0.95, 10)
        
        def compute_temporal_iou(pred_seg, target_seg):
            """Compute IoU between predicted and target segments"""
            intersection = max(0, min(pred_seg[1], target_seg[1]) - max(pred_seg[0], target_seg[0]))
            union = max(pred_seg[1], target_seg[1]) - min(pred_seg[0], target_seg[0])
            return intersection / (union + 1e-8)
        
        # Convert predictions to segments
        all_segments = defaultdict(list)
        for pred_dict in predictions:
            cls_scores = pred_dict['cls'].sigmoid()
            reg_outputs = pred_dict['reg']
            conf_scores = pred_dict['confidence']
            
            for t in range(cls_scores.shape[1]):
                for c in range(cls_scores.shape[2]):
                    if cls_scores[0, t, c] > 0.05:  # Detection threshold
                        center = t + reg_outputs[0, t, 0]
                        length = reg_outputs[0, t, 1] * 100  # Scale factor
                        start = max(0, center - length/2)
                        end = min(cls_scores.shape[1], center + length/2)
                        
                        score = cls_scores[0, t, c] * conf_scores[0, t, 0]
                        all_segments[c].append({
                            'segment': [start, end],
                            'score': score.item()
                        })
        
        # Convert targets to segments
        target_segments = defaultdict(list)
        for target_dict in targets:
            cls_targets = target_dict['cls']
            for t in range(cls_targets.shape[1]):
                for c in range(cls_targets.shape[2]):
                    if cls_targets[0, t, c] > 0.5:
                        target_segments[c].append([t, t+1])  # Unit segments
        
        # Calculate AP for each class and IoU threshold
        ap_scores = []
        for iou_thresh in iou_thresholds:
            class_aps = []
            
            for c in range(len(all_segments)):
                # Sort predictions by score
                predictions = sorted(all_segments[c], key=lambda x: x['score'], reverse=True)
                
                # Initialize precision calculation
                num_positive = len(target_segments[c])
                if num_positive == 0:
                    continue
                    
                tp = np.zeros(len(predictions))
                fp = np.zeros(len(predictions))
                
                # Match predictions to targets
                used_targets = set()
                
                for i, pred in enumerate(predictions):
                    max_iou = 0
                    max_idx = -1
                    
                    for j, target in enumerate(target_segments[c]):
                        if j in used_targets:
                            continue
                            
                        iou = compute_temporal_iou(pred['segment'], target)
                        if iou > max_iou:
                            max_iou = iou
                            max_idx = j
                    
                    if max_iou >= iou_thresh:
                        tp[i] = 1
                        used_targets.add(max_idx)
                    else:
                        fp[i] = 1
                
                # Calculate precision and recall
                tp_cumsum = np.cumsum(tp)
                fp_cumsum = np.cumsum(fp)
                recalls = tp_cumsum / num_positive
                precisions = tp_cumsum / (tp_cumsum + fp_cumsum)
                
                # Calculate AP using 11-point interpolation
                ap = 0
                for t in np.linspace(0, 1, 11):
                    if np.sum(recalls >= t) == 0:
                        p = 0
                    else:
                        p = np.max(precisions[recalls >= t])
                    ap += p / 11
                
                class_aps.append(ap)
            
            if class_aps:
                ap_scores.append(np.mean(class_aps))
        
        return np.mean(ap_scores)
    def save_checkpoint(self, epoch: int, is_best: bool = False):
        checkpoint = {
            'epoch': epoch,
            'model_state': self.model.module.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'classes': self.dataset.classes
        }
        
        if is_best:
            torch.save(checkpoint, os.path.join(self.config['save_dir'], 'best_model.pth'))
        else:
            torch.save(checkpoint, os.path.join(self.config['save_dir'], f'checkpoint_{epoch}.pth'))

    def train(self):
        best_map = 0.0
        for epoch in range(self.config['epochs']):
            train_loss = self.train_epoch(epoch)
            val_metrics = self.validate()
            self.scheduler.step()
            
            # Log metrics
            logging.info(
                f"Epoch {epoch+1}/{self.config['epochs']} | "
                f"Train Loss: {train_loss:.4f} | "
                f"Val Loss: {val_metrics['val_loss']:.4f} | "
                f"mAP: {val_metrics['mAP']:.4f} | "
                f"mAP@0.5: {val_metrics['mAP@0.5']:.4f}"
            )
            
            # Save checkpoints
            if (epoch+1) % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)
                
            if val_metrics['mAP@0.5'] > best_map:
                best_map = val_metrics['mAP@0.5']
                self.save_checkpoint(epoch, is_best=True)

In [10]:
# Modified training configuration without distributed setup
def train_single_gpu(config):
    # Set up logging
    setup_logging()
    
    # Initialize dataset
    dataset = TALDataset(
        json_path=config['json_path'],
        video_root=config['video_root'],
        clip_len=config['clip_len'],
        frame_size=config['frame_size']
    )

    # Save class labels properly using JSON
    with open("class_labels.txt", 'w') as f:
        json.dump(dataset.class_to_idx, f) 


    
    # Split dataset
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_set,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_set,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = RefinedDualBiS6TAL(
        num_classes=len(dataset.classes),
        dim=config['dim'],
        recur_steps=config['recur_steps']
    ).to(device)
    
    # Optimizer and scheduler setup
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=0.01
    )
    
    scheduler = ChainedScheduler([
        LinearLR(
            optimizer, 
            start_factor=0.1,
            total_iters=config['warmup_steps']
        ),
        CosineAnnealingLR(
            optimizer,
            T_max=config['epochs']-config['warmup_steps']
        )
    ])
    
    # Training loop
    best_map = 0.0
    for epoch in range(config['epochs']):
        # Train
        model.train()
        train_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            frames = batch['frames'].to(device)
            action_map = batch['action_map'].to(device)
            reg_targets = batch['reg_targets'].to(device)
            
            outputs = model(frames)
            cls_loss = nn.BCEWithLogitsLoss()(outputs['cls_logits'].permute(0, 2, 1), action_map)
            reg_loss = nn.SmoothL1Loss()(outputs['reg_outputs'].permute(0, 2, 1), reg_targets)
            loss = cls_loss + 0.5 * reg_loss
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validate
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                frames = batch['frames'].to(device)
                action_map = batch['action_map'].to(device)
                reg_targets = batch['reg_targets'].to(device)
                
                outputs = model(frames)
                cls_loss = nn.BCEWithLogitsLoss()(outputs['cls_logits'].permute(0, 2, 1), action_map)
                reg_loss = nn.SmoothL1Loss()(outputs['reg_outputs'].permute(0, 2, 1), reg_targets)
                loss = cls_loss + 0.5 * reg_loss
                
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        scheduler.step()
        
        # Log metrics
        logging.info(
            f"Epoch {epoch+1}/{config['epochs']} | "
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_loss:.4f}"
        )
        
        # Save checkpoint
        if (epoch + 1) % config['save_interval'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss
            }, f"{config['save_dir']}/checkpoint_epoch_{epoch+1}.pth")

# Configuration for Kaggle T4
config = {
        'json_path': '/kaggle/working/validation_data_cleaned.json',
        'video_root': '/kaggle/input/activitynet-100-validation/kinetics_dataset/videos',
        'clip_len': 8,
        'frame_size': 128,
        'batch_size': 8,  # Reduced batch size
        'epochs': 20,
        'lr': 1e-4,
        'warmup_steps': 5,
        'dim': 64,  # Reduced dimension
        'recur_steps': 2,  # Reduced recurrent steps
        'save_interval': 5,
        'save_dir': '/kaggle/working/checkpoints'
    }


# Create save directory
os.makedirs(config['save_dir'], exist_ok=True)

# Start training
train_single_gpu(config)

2025-02-17 03:39:37 - INFO - Initializing TALDataset with json_path: /kaggle/working/validation_data_cleaned.json
2025-02-17 03:39:37 - INFO - Video root directory: /kaggle/input/activitynet-100-validation/kinetics_dataset/videos
2025-02-17 03:39:37 - INFO - Successfully loaded JSON data with 170 total videos
2025-02-17 03:39:37 - INFO - Found 10 unique action classes
2025-02-17 03:39:38 - INFO - Found 170 videos in validation subset


Epoch 1:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 03:44:10 - INFO - Epoch 1/20 | Train Loss: 0.7216 | Val Loss: 0.7051


Epoch 2:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 03:48:38 - INFO - Epoch 2/20 | Train Loss: 0.6909 | Val Loss: 0.6576


Epoch 3:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 03:53:10 - INFO - Epoch 3/20 | Train Loss: 0.6153 | Val Loss: 0.5644


Epoch 4:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 03:57:34 - INFO - Epoch 4/20 | Train Loss: 0.5107 | Val Loss: 0.4573


Epoch 5:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:02:06 - INFO - Epoch 5/20 | Train Loss: 0.4158 | Val Loss: 0.3746


Epoch 6:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:06:43 - INFO - Epoch 6/20 | Train Loss: 0.3398 | Val Loss: 0.3174


Epoch 7:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:11:09 - INFO - Epoch 7/20 | Train Loss: 0.3021 | Val Loss: 0.2873


Epoch 8:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:15:32 - INFO - Epoch 8/20 | Train Loss: 0.2792 | Val Loss: 0.2722


Epoch 9:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:20:03 - INFO - Epoch 9/20 | Train Loss: 0.2725 | Val Loss: 0.2649


Epoch 10:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:24:29 - INFO - Epoch 10/20 | Train Loss: 0.2664 | Val Loss: 0.2597


Epoch 11:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:28:58 - INFO - Epoch 11/20 | Train Loss: 0.2602 | Val Loss: 0.2577


Epoch 12:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:33:27 - INFO - Epoch 12/20 | Train Loss: 0.2644 | Val Loss: 0.2568


Epoch 13:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:37:57 - INFO - Epoch 13/20 | Train Loss: 0.2577 | Val Loss: 0.2550


Epoch 14:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:42:25 - INFO - Epoch 14/20 | Train Loss: 0.2553 | Val Loss: 0.2548


Epoch 15:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:46:55 - INFO - Epoch 15/20 | Train Loss: 0.2624 | Val Loss: 0.2547


Epoch 16:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:51:19 - INFO - Epoch 16/20 | Train Loss: 0.2578 | Val Loss: 0.2546


Epoch 17:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 04:55:46 - INFO - Epoch 17/20 | Train Loss: 0.2591 | Val Loss: 0.2546


Epoch 18:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 05:00:15 - INFO - Epoch 18/20 | Train Loss: 0.2630 | Val Loss: 0.2545


Epoch 19:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 05:04:47 - INFO - Epoch 19/20 | Train Loss: 0.2625 | Val Loss: 0.2538


Epoch 20:   0%|          | 0/20 [00:00<?, ?it/s]

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

2025-02-17 05:09:18 - INFO - Epoch 20/20 | Train Loss: 0.2581 | Val Loss: 0.2524
