In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.interpolate import interp1d
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

BASE_PATH = "/kaggle/input/asl-signs"
PREPROCESSED_PATH = "./preprocessed_data"

# Load metadata
train_csv = pd.read_csv(f"{BASE_PATH}/train.csv")
with open(f"{BASE_PATH}/sign_to_prediction_index_map.json", "r") as f:
    label_map = json.load(f)

NUM_CLASSES = len(label_map)
print(f"Classes: {NUM_CLASSES}")

N_JOBS = 4

In [None]:

def has_valid_hands_fast(sample_path, min_frames=5):
    """Filter samples with insufficient hand data"""
    full_path = os.path.join(BASE_PATH, sample_path)
    df = pd.read_parquet(full_path, columns=["frame", "type"], engine="pyarrow")
    
    hand_df = df[df["type"].isin(["left_hand", "right_hand"])]
    if hand_df.empty or hand_df["frame"].nunique() < min_frames:
        return False
    return True

def check_index(i):
    return i if has_valid_hands_fast(train_csv.iloc[i]["path"]) else None

# Try loading pre-computed indices, otherwise compute
try:
    valid_indices = np.load("/kaggle/input/valid-indices-npy/valid_indices.npy").tolist()
    print(f"Loaded {len(valid_indices)} valid indices from cache")
except:
    print("Computing valid indices (this may take a while)...")
    results = Parallel(n_jobs=N_JOBS, prefer="threads")(
        delayed(check_index)(i) for i in tqdm(range(len(train_csv)))
    )
    valid_indices = [i for i in results if i is not None]
    os.makedirs("/kaggle/working", exist_ok=True)
    np.save("/kaggle/working/valid_indices.npy", np.array(valid_indices))
    print(f"Valid samples: {len(valid_indices)} / {len(train_csv)}")

# Filter CSV to valid samples
train_csv = train_csv.iloc[valid_indices].reset_index(drop=True)
print(f"Filtered train_csv length: {len(train_csv)}")


class LandmarkPreprocessor:
    @staticmethod
    def filter_empty_frames(kp, threshold=0.5):
        missing_per_frame = np.sum(kp == 0, axis=(1, 2)) / (kp.shape[1] * 3)
        valid_mask = missing_per_frame < threshold
        if np.sum(valid_mask) == 0:
            return kp[:1]
        return kp[valid_mask]
    
    @staticmethod
    def interpolate_missing(kp):
        T, L, C = kp.shape
        kp_interp = kp.copy()
        
        for l in range(L):
            for c in range(C):
                track = kp[:, l, c]
                valid_mask = track != 0
                n_valid = np.sum(valid_mask)
                
                if n_valid == 0:
                    continue
                elif n_valid == 1:
                    kp_interp[:, l, c] = track[valid_mask][0]
                else:
                    valid_indices = np.where(valid_mask)[0]
                    valid_values = track[valid_mask]
                    f = interp1d(valid_indices, valid_values, kind='linear',
                               fill_value='extrapolate', bounds_error=False)
                    kp_interp[:, l, c] = f(np.arange(T))
        
        return kp_interp
    
    @staticmethod
    def anchor_normalize(kp):
        T, L, C = kp.shape
        
        # For hand landmarks, use centroid of all valid points
        valid_mask = kp != 0
        
        # Calculate centroid per frame
        anchor_pos = np.zeros((T, C))
        scale = np.zeros(T)
        
        for t in range(T):
            valid_points = kp[t][np.any(valid_mask[t], axis=1)]
            if len(valid_points) > 0:
                anchor_pos[t] = valid_points.mean(axis=0)
                # Scale based on spread of points
                distances = np.linalg.norm(valid_points - anchor_pos[t], axis=1)
                scale[t] = distances.mean() if len(distances) > 0 else 1.0
            else:
                anchor_pos[t] = 0
                scale[t] = 1.0
        
        # Apply normalization
        anchor_pos_expanded = anchor_pos[:, np.newaxis, :]
        kp_relative = kp - anchor_pos_expanded
        
        # Use mean scale to avoid division by zero
        mean_scale = scale[scale > 0].mean() if np.any(scale > 0) else 1.0
        if mean_scale > 1e-6:
            kp_normalized = kp_relative / mean_scale
        else:
            kp_normalized = kp_relative
        
        return kp_normalized, anchor_pos, mean_scale

class LandmarkAugmentor:
    @staticmethod
    def horizontal_flip(kp, left_hand_indices, right_hand_indices, p=0.5):
        """FIXED: Proper handling of hand indices"""
        if np.random.rand() < p:
            kp = kp.copy()
            kp[:, :, 0] *= -1  # Flip x-coordinate
            
            # Swap hands if both are present
            if len(left_hand_indices) > 0 and len(right_hand_indices) > 0:
                left_hand = kp[:, left_hand_indices, :].copy()
                right_hand = kp[:, right_hand_indices, :].copy()
                kp[:, left_hand_indices, :] = right_hand
                kp[:, right_hand_indices, :] = left_hand
        return kp
    
    @staticmethod
    def rotation_3d(kp, max_angle=60, p=0.5):
        if np.random.rand() < p:
            angle = np.random.uniform(-max_angle, max_angle)
            angle_rad = np.deg2rad(angle)
            
            cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
            R = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])
            
            T, L, _ = kp.shape
            kp_flat = kp.reshape(-1, 3)
            kp_rotated = (R @ kp_flat.T).T
            kp = kp_rotated.reshape(T, L, 3)
        return kp
    
    @staticmethod
    def resize_3d(kp, scale_range=(0.8, 1.2), p=0.5):
        if np.random.rand() < p:
            scale = np.random.uniform(*scale_range)
            kp = kp * scale
        return kp
    
    @staticmethod
    def finger_dropout(kp, hand_indices, p_per_finger=0.1, p=0.3):
        if np.random.rand() < p:
            for idx in hand_indices:
                if np.random.rand() < p_per_finger:
                    kp[:, idx, :] = 0
        return kp






In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=logits.shape[1])
        p_t = (probs * targets_one_hot).sum(dim=1)
        
        focal_weight = (1 - p_t) ** self.gamma
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        focal_loss = focal_weight * ce_loss
        
        if self.alpha is not None:
            if self.alpha.device != logits.device:
                self.alpha = self.alpha.to(logits.device)
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss


In [None]:
class ASLDataset(Dataset):
    def __init__(self, dataframe, label_map, max_frames=64, split='train',
                 use_anchor_norm=True, use_interpolation=True, use_augmentation=True):
        self.df = dataframe.reset_index(drop=True)
        self.label_map = label_map
        self.max_frames = max_frames
        self.split = split
        self.use_anchor_norm = use_anchor_norm
        self.use_interpolation = use_interpolation
        self.use_augmentation = use_augmentation and (split == 'train')
        
        # Initialize helpers
        self.preprocessor = LandmarkPreprocessor()
        self.augmentor = LandmarkAugmentor()
        
        # CRITICAL: Define hand indices (42 landmarks total, 21 per hand)
        self.left_hand_indices = list(range(0, 21))
        self.right_hand_indices = list(range(21, 42))
        
        # Class balancing
        self.class_counts = self._analyze_class_distribution()
        self.class_weights = self._compute_class_weights()
        

        print(f"ASL Dataset - {split.upper()}")
        
        print(f"Samples: {len(self.df)}")
        print(f"Classes: {len(label_map)}")
        print(f"Max frames: {max_frames}")
        print(f"Anchor normalization: {use_anchor_norm}")
        print(f"Interpolation: {use_interpolation}")
        print(f"Augmentation: {self.use_augmentation}")

    
    def _analyze_class_distribution(self):
        labels = [self.label_map[row['sign']] for _, row in self.df.iterrows()]
        return Counter(labels)
    
    def _compute_class_weights(self):
        num_classes = len(self.label_map)
        weights = np.zeros(num_classes)
        total_samples = sum(self.class_counts.values())
        
        for class_idx in range(num_classes):
            count = self.class_counts.get(class_idx, 1)
            weights[class_idx] = total_samples / (num_classes * count)
        
        weights = weights / weights.sum() * num_classes
        return torch.FloatTensor(weights)
    
    def get_sample_weights(self):
        """For WeightedRandomSampler"""
        sample_weights = []
        for _, row in self.df.iterrows():
            label = self.label_map[row['sign']]
            sample_weights.append(self.class_weights[label].item())
        return torch.FloatTensor(sample_weights)
    
    def load_hand_keypoints(self, sample_path):
        """Load hand landmarks (42 landmarks: 21 left + 21 right)"""
        full_path = os.path.join(BASE_PATH, sample_path)
        df = pd.read_parquet(full_path, engine="pyarrow")
        
        hand_df = df[df["type"].isin(["left_hand", "right_hand"])]
        hand_df = hand_df.sort_values(["frame", "type", "landmark_index"])
        
        frames = hand_df["frame"].nunique()
        coords = hand_df[["x", "y", "z"]].values
        coords = coords.reshape(frames, 42, 3)
        
        return np.nan_to_num(coords, nan=0.0)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load keypoints
        kp = self.load_hand_keypoints(row["path"])  # (T, 42, 3)
        original_length = kp.shape[0]
        
        # Preprocessing
        kp = self.preprocessor.filter_empty_frames(kp, threshold=0.5)
        
        if self.use_interpolation:
            kp = self.preprocessor.interpolate_missing(kp)
        
        if self.use_anchor_norm:
            kp, _, _ = self.preprocessor.anchor_normalize(kp)
        
        # Temporal resampling
        if kp.shape[0] < self.max_frames:
            pad_len = self.max_frames - kp.shape[0]
            kp = np.concatenate([kp, np.zeros((pad_len, 42, 3))], axis=0)
        elif kp.shape[0] > self.max_frames:
            indices = np.linspace(0, kp.shape[0]-1, self.max_frames).astype(int)
            kp = kp[indices]
        
        kp = np.nan_to_num(kp, nan=0.0)
        
        # Augmentation (FIXED with correct indices)
        if self.use_augmentation:
            kp = self.augmentor.horizontal_flip(
                kp, self.left_hand_indices, self.right_hand_indices, p=0.5)
            kp = self.augmentor.rotation_3d(kp, max_angle=60, p=0.5)
            kp = self.augmentor.resize_3d(kp, scale_range=(0.8, 1.2), p=0.5)
            
            all_hand_indices = self.left_hand_indices + self.right_hand_indices
            kp = self.augmentor.finger_dropout(kp, all_hand_indices, p=0.3)
        
        label = self.label_map[row["sign"]]
        
        return torch.FloatTensor(kp), label


In [None]:
def create_dataloaders(train_df, val_df, label_map, batch_size=32, 
                       max_frames=64, num_workers=4):
    
    train_ds = ASLDataset(
        train_df, label_map, max_frames=max_frames, split='train',
        use_anchor_norm=True, use_interpolation=True, use_augmentation=True
    )
    
    val_ds = ASLDataset(
        val_df, label_map, max_frames=max_frames, split='val',
        use_anchor_norm=True, use_interpolation=True, use_augmentation=False
    )
    
    # Weighted sampler for balanced training
    sample_weights = train_ds.get_sample_weights()
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=2, persistent_workers=True, drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=2, persistent_workers=True
    )
    
    
    focal_loss = FocalLoss(alpha=train_ds.class_weights, gamma=2.0)
    
    return train_loader, val_loader, focal_loss, train_ds.class_weights


In [None]:
if __name__ == "__main__":
    set_seed(42)
    
    split_idx = int(len(train_csv) * 0.8)
    train_df = train_csv[:split_idx]
    val_df = train_csv[split_idx:]
    
    print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
    
    train_loader, val_loader, focal_loss, class_weights = create_dataloaders(
        train_df, val_df, label_map, 
        batch_size=32, max_frames=64, num_workers=4
    )
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        print(f"Batch shape: {data.shape}")  # Should be [32, 64, 42, 3]
        print(f"Labels shape: {labels.shape}")
        break

In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
import math


In [None]:
def stratified_sample_dataset(train_csv, sample_ratio=0.75, random_state=42):
    """
    Sample dataset while preserving class distribution
    
    Args:
        train_csv: DataFrame with 'path' and 'sign' columns
        sample_ratio: Fraction of data to keep (0.5 = 50%)
        random_state: Random seed for reproducibility
    
    Returns:
        sampled_df: Stratified sample of the dataset
    """
    print(f"Stratified Sampling: Keeping {sample_ratio*100:.0f}% of data")
    
    # Group by class and sample proportionally
    sampled_dfs = []
    
    for sign_class in train_csv['sign'].unique():
        class_df = train_csv[train_csv['sign'] == sign_class]
        n_samples = max(1, int(len(class_df) * sample_ratio))  # At least 1 sample per class
        
        sampled_class = class_df.sample(n=n_samples, random_state=random_state)
        sampled_dfs.append(sampled_class)
    
    sampled_df = pd.concat(sampled_dfs, ignore_index=True)
    sampled_df = sampled_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    print(f"Original dataset: {len(train_csv)} samples")
    print(f"Sampled dataset: {len(sampled_df)} samples")
    print(f"Reduction: {(1 - len(sampled_df)/len(train_csv))*100:.1f}%")
    
    orig_dist = train_csv['sign'].value_counts(normalize=True).sort_index()
    samp_dist = sampled_df['sign'].value_counts(normalize=True).sort_index()
    
    max_deviation = (orig_dist - samp_dist).abs().max()
    print(f"Max class distribution deviation: {max_deviation*100:.2f}%")
    print(f"{'='*60}\n")
    
    return sampled_df

In [None]:
class DifficultyAnalyzer:
    """
    Multi-dimensional difficulty scoring for intelligent cascading
    Combines: prediction uncertainty, spatial complexity, temporal complexity, motion patterns
    """
    
    @staticmethod
    def compute_prediction_uncertainty(logits):
        """Entropy-based uncertainty from model predictions"""
        probs = F.softmax(logits, dim=1)
        entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=1)
        max_entropy = math.log(logits.shape[1])
        return entropy / max_entropy  # [0, 1]
    
    @staticmethod
    def compute_spatial_complexity(x):
        """
        Measure spatial spread and hand coordination complexity
        Args: x shape (batch, max_len, 42, 3)
        Returns: complexity score [0, 1]
        """
        batch_size = x.shape[0]
        complexity_scores = []
        
        for i in range(batch_size):
            sample = x[i]  # (max_len, 42, 3)
            
            # Get valid frames (non-zero)
            valid_mask = sample.abs().sum(dim=(1, 2)) > 0
            if valid_mask.sum() == 0:
                complexity_scores.append(0.0)
                continue
            
            valid_frames = sample[valid_mask]  # (T, 42, 3)
            
            # 1. Hand spread (variance of landmark positions)
            spread = valid_frames.std(dim=1).mean().item()  # Higher = more spread out
            
            # 2. Two-hand coordination (difference between left and right hand)
            left_hand = valid_frames[:, :21, :]  # First 21 landmarks
            right_hand = valid_frames[:, 21:, :]  # Last 21 landmarks
            
            left_center = left_hand.mean(dim=1)
            right_center = right_hand.mean(dim=1)
            hand_distance = torch.norm(left_center - right_center, dim=1).mean().item()
            
            # 3. Finger articulation (variance within each hand)
            left_articulation = left_hand.std(dim=1).mean().item()
            right_articulation = right_hand.std(dim=1).mean().item()
            articulation = (left_articulation + right_articulation) / 2
            
            # Normalize and combine (higher values = more complex)
            spatial_complexity = (spread * 0.3 + hand_distance * 0.4 + articulation * 0.3)
            # Clip to [0, 1] range (empirically tuned)
            spatial_complexity = min(spatial_complexity / 2.0, 1.0)
            
            complexity_scores.append(spatial_complexity)
        
        return torch.tensor(complexity_scores, device=x.device)
    
    @staticmethod
    def compute_temporal_complexity(x):
        """
        Measure temporal variation and motion smoothness
        Args: x shape (batch, max_len, 42, 3)
        Returns: complexity score [0, 1]
        """
        batch_size = x.shape[0]
        complexity_scores = []
        
        for i in range(batch_size):
            sample = x[i]  # (max_len, 42, 3)
            
            # Get valid frames
            valid_mask = sample.abs().sum(dim=(1, 2)) > 0
            if valid_mask.sum() <= 1:
                complexity_scores.append(0.0)
                continue
            
            valid_frames = sample[valid_mask]  # (T, 42, 3)
            T = valid_frames.shape[0]
            
            # 1. Frame-to-frame velocity (first derivative)
            velocity = torch.diff(valid_frames, dim=0)  # (T-1, 42, 3)
            avg_velocity = velocity.abs().mean().item()
            
            # 2. Acceleration (second derivative) - jerkiness
            if T > 2:
                acceleration = torch.diff(velocity, dim=0)  # (T-2, 42, 3)
                avg_acceleration = acceleration.abs().mean().item()
            else:
                avg_acceleration = 0.0
            
            # 3. Temporal variance (how much the sign changes over time)
            temporal_variance = valid_frames.std(dim=0).mean().item()
            
            # 4. Direction changes (non-smooth motion)
            if T > 2:
                velocity_norm = F.normalize(velocity.reshape(T-1, -1), dim=1)
                direction_changes = (1 - (velocity_norm[:-1] * velocity_norm[1:]).sum(dim=1)).mean().item()
            else:
                direction_changes = 0.0
            
            # Combine metrics (higher = more complex temporal pattern)
            temporal_complexity = (
                avg_velocity * 0.3 + 
                avg_acceleration * 0.3 + 
                temporal_variance * 0.2 + 
                direction_changes * 0.2
            )
            # Normalize to [0, 1]
            temporal_complexity = min(temporal_complexity / 1.5, 1.0)
            
            complexity_scores.append(temporal_complexity)
        
        return torch.tensor(complexity_scores, device=x.device)
    
    @staticmethod
    def compute_motion_pattern_complexity(x):
        """
        Analyze motion patterns (circular, linear, static, etc.)
        Args: x shape (batch, max_len, 42, 3)
        Returns: complexity score [0, 1]
        """
        batch_size = x.shape[0]
        complexity_scores = []
        
        for i in range(batch_size):
            sample = x[i]
            
            valid_mask = sample.abs().sum(dim=(1, 2)) > 0
            if valid_mask.sum() <= 2:
                complexity_scores.append(0.0)
                continue
            
            valid_frames = sample[valid_mask]
            T = valid_frames.shape[0]
            
            # Use hand centroids for motion pattern
            hand_centroid = valid_frames.mean(dim=1)  # (T, 3)
            
            # 1. Path length vs straight-line distance (tortuosity)
            path_length = torch.norm(torch.diff(hand_centroid, dim=0), dim=1).sum().item()
            straight_distance = torch.norm(hand_centroid[-1] - hand_centroid[0]).item()
            
            if straight_distance > 1e-6:
                tortuosity = path_length / straight_distance
            else:
                tortuosity = 1.0
            
            # 2. 3D motion (z-axis usage)
            z_variance = hand_centroid[:, 2].std().item()
            
            # 3. Circular vs linear motion (variance in different axes)
            xy_variance = hand_centroid[:, :2].std(dim=0).mean().item()
            
            # Combine (complex motions have high tortuosity and 3D usage)
            motion_complexity = (
                min(tortuosity / 3.0, 1.0) * 0.5 +
                min(z_variance / 0.3, 1.0) * 0.3 +
                min(xy_variance / 0.5, 1.0) * 0.2
            )
            
            complexity_scores.append(motion_complexity)
        
        return torch.tensor(complexity_scores, device=x.device)
    
    @classmethod
    def compute_difficulty_score(cls, x, logits, weights=None):
        """
        Compute comprehensive difficulty score combining all factors
        
        Args:
            x: input data (batch, max_len, 42, 3)
            logits: model predictions (batch, num_classes)
            weights: dict of weights for each component
        
        Returns:
            difficulty_score: [0, 1] where higher = more difficult
            components: dict with individual scores for analysis
        """
        if weights is None:
            weights = {
                'prediction_uncertainty': 0.35,
                'spatial_complexity': 0.25,
                'temporal_complexity': 0.25,
                'motion_complexity': 0.15
            }
        
        # Compute all components
        pred_uncertainty = cls.compute_prediction_uncertainty(logits)
        spatial_comp = cls.compute_spatial_complexity(x)
        temporal_comp = cls.compute_temporal_complexity(x)
        motion_comp = cls.compute_motion_pattern_complexity(x)
        
        # Weighted combination
        difficulty = (
            weights['prediction_uncertainty'] * pred_uncertainty +
            weights['spatial_complexity'] * spatial_comp +
            weights['temporal_complexity'] * temporal_comp +
            weights['motion_complexity'] * motion_comp
        )
        
        components = {
            'prediction_uncertainty': pred_uncertainty,
            'spatial_complexity': spatial_comp,
            'temporal_complexity': temporal_comp,
            'motion_complexity': motion_comp,
            'overall_difficulty': difficulty
        }
        
        return difficulty, components


In [None]:

class Conv1DBlock(nn.Module):
    """1D Convolutional block with depthwise separable convolution"""
    def __init__(self, dim, kernel_size=17, drop_rate=0.2):
        super().__init__()
        padding = kernel_size // 2
        
        self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=padding, groups=dim)
        self.pwconv = nn.Conv1d(dim, dim, 1)
        self.bn = nn.BatchNorm1d(dim, momentum=0.05)
        self.dropout = nn.Dropout(drop_rate)
        
    def forward(self, x):
        residual = x
        x = self.dwconv(x)
        x = self.pwconv(x)
        x = self.bn(x)
        x = F.gelu(x)
        x = self.dropout(x)
        return x + residual


class LightweightTransformerBlock(nn.Module):
    """Lightweight transformer block for temporal modeling"""
    def __init__(self, dim, num_heads=4, expand=2, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * expand),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * expand, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = x.transpose(1, 2)
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x.transpose(1, 2)


class LateDropout(nn.Module):
    """Dropout that only activates after a certain training step"""
    def __init__(self, p=0.8, start_step=0):
        super().__init__()
        self.p = p
        self.start_step = start_step
        self.current_step = 0
    
    def forward(self, x):
        if self.training and self.current_step >= self.start_step:
            return F.dropout(x, p=self.p, training=True)
        return x
    
    def step(self):
        self.current_step += 1




In [None]:
class ASLFilterModel(nn.Module):
    """
    Lightweight first-pass filter with multi-dimensional difficulty assessment
    """
    def __init__(self, num_classes, max_len=64, channels=126, dim=96, dropout_step=0):
        super().__init__()
        self.channels = channels
        self.pad_value = 0.0
        
        # Stem
        self.stem_conv = nn.Linear(channels, dim, bias=False)
        self.stem_bn = nn.BatchNorm1d(dim, momentum=0.05)
        
        # First block: 3 Conv1D + 1 Transformer
        self.conv_block1 = nn.Sequential(
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
        )
        self.transformer1 = LightweightTransformerBlock(dim, num_heads=4, expand=2)
        
        # Second block: 3 Conv1D + 1 Transformer
        self.conv_block2 = nn.Sequential(
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
            Conv1DBlock(dim, kernel_size=17, drop_rate=0.2),
        )
        self.transformer2 = LightweightTransformerBlock(dim, num_heads=4, expand=2)
        
        # Top conv and classification
        self.top_conv = nn.Linear(dim, dim * 2)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.late_dropout = LateDropout(0.8, start_step=dropout_step)
        self.classifier = nn.Linear(dim * 2, num_classes)
        
        # Difficulty analyzer
        self.difficulty_analyzer = DifficultyAnalyzer()
        
        # Better initialization for deep network
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Proper weight initialization for better convergence"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x, return_difficulty=False, difficulty_weights=None):
        """
        Args:
            x: (batch, max_len, 42, 3) - raw landmark data
            return_difficulty: If True, returns (logits, difficulty_score, components)
            difficulty_weights: Custom weights for difficulty components
        """
        batch_size = x.shape[0]
        x_original = x.clone()  # Keep for difficulty analysis
        
        # Reshape: (batch, max_len, 42, 3) -> (batch, max_len, 126)
        x = x.reshape(batch_size, x.shape[1], -1)
        
        # Stem
        x = self.stem_conv(x)
        x = x.transpose(1, 2)
        x = self.stem_bn(x)
        
        # First stage
        x = self.conv_block1(x)
        x = self.transformer1(x)
        
        # Second stage
        x = self.conv_block2(x)
        x = self.transformer2(x)
        
        # Top conv and pooling
        x = x.transpose(1, 2)
        x = self.top_conv(x)
        x = F.gelu(x)
        x = x.transpose(1, 2)
        x = self.global_pool(x).squeeze(-1)
        
        # Classification
        x = self.late_dropout(x)
        logits = self.classifier(x)
        
        if return_difficulty:
            difficulty, components = self.difficulty_analyzer.compute_difficulty_score(
                x_original, logits, difficulty_weights
            )
            return logits, difficulty, components
        
        return logits
    
    def get_model_size_mb(self):
        """Calculate model size in MB"""
        param_size = sum(p.numel() * p.element_size() for p in self.parameters())
        buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
        size_mb = (param_size + buffer_size) / (1024 ** 2)
        return size_mb

def train_filter_model(model, train_loader, val_loader, criterion, 
                       num_epochs=40, lr=1e-3, device='cuda',
                       difficulty_threshold=0.4, warmup_epochs=10,
                       difficulty_start_epoch=20,  # NEW PARAMETER
                       difficulty_weights=None, label_smoothing=0.1):
    """Train with multi-dimensional difficulty assessment and annealing schedule"""
    model = model.to(device)
    
    # Use label smoothing for better generalization
    ce_criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # Longer warmup for stability
    warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.01, total_iters=warmup_epochs
    )
    cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs - warmup_epochs, eta_min=lr * 0.01
    )
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer, 
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[warmup_epochs]
    )
    
    best_val_acc = 0
    history = {
        'train_loss': [], 'train_acc': [], 
        'val_loss': [], 'val_acc': [],
        'difficult_pct': [],
        'avg_pred_uncertainty': [],
        'avg_spatial_complexity': [],
        'avg_temporal_complexity': [],
        'avg_motion_complexity': [],
        'difficulty_threshold_used': []  # NEW: track threshold over time
    }
    
    # NEW: Print training schedule
    print(f"\n{'='*60}")
    print(f"Training Schedule:")
    print(f"  Epochs 1-{difficulty_start_epoch}: Pure classification (no difficulty)")
    print(f"  Epochs {difficulty_start_epoch+1}-{num_epochs}: Difficulty-aware training")
    print(f"{'='*60}\n")
    
    for epoch in range(num_epochs):
        # NEW: Compute current difficulty threshold (annealing from 0.8 to target)
        if epoch < difficulty_start_epoch:
            use_difficulty = False
            current_threshold = 1.0  # Don't cascade anything
        else:
            use_difficulty = True
            # Anneal from 0.8 to difficulty_threshold over 10 epochs
            anneal_progress = min(1.0, (epoch - difficulty_start_epoch) / 10)
            current_threshold = 0.8 - (0.8 - difficulty_threshold) * anneal_progress
        
        # Training
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        
        # NEW: Updated progress bar description
        phase_name = "Classification" if not use_difficulty else f"Difficulty (thresh={current_threshold:.2f})"
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [{phase_name}]")
        
        for data, labels in pbar:
            data, labels = data.to(device), labels.to(device)
            
            optimizer.zero_grad()
            logits = model(data)
            
            # Use CrossEntropy with label smoothing instead of Focal Loss during warmup
            if epoch < warmup_epochs:
                loss = ce_criterion(logits, labels)
            else:
                loss = criterion(logits, labels)
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            if hasattr(model, 'late_dropout'):
                model.late_dropout.step()
            
            train_loss += loss.item()
            preds = logits.argmax(dim=1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100*train_correct/train_total:.2f}%'
            })
        
        train_loss /= len(train_loader)
        train_acc = train_correct / train_total
        
        # Validation with difficulty analysis
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        difficult_count = 0
        
        # Track difficulty components
        all_pred_unc, all_spatial, all_temporal, all_motion = [], [], [], []
        
        with torch.no_grad():
            for data, labels in val_loader:
                data, labels = data.to(device), labels.to(device)
                
                # NEW: Always compute difficulty for logging, but only use if active
                logits, difficulty, components = model(
                    data, return_difficulty=True, difficulty_weights=difficulty_weights
                )
                loss = criterion(logits, labels)
                
                val_loss += loss.item()
                preds = logits.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
                
                # Track difficult predictions using current threshold
                difficult_count += (difficulty > current_threshold).sum().item()
                
                # Collect components for analysis
                all_pred_unc.extend(components['prediction_uncertainty'].cpu().numpy())
                all_spatial.extend(components['spatial_complexity'].cpu().numpy())
                all_temporal.extend(components['temporal_complexity'].cpu().numpy())
                all_motion.extend(components['motion_complexity'].cpu().numpy())
        
        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        difficult_pct = 100 * difficult_count / val_total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['difficult_pct'].append(difficult_pct)
        history['avg_pred_uncertainty'].append(np.mean(all_pred_unc))
        history['avg_spatial_complexity'].append(np.mean(all_spatial))
        history['avg_temporal_complexity'].append(np.mean(all_temporal))
        history['avg_motion_complexity'].append(np.mean(all_motion))
        history['difficulty_threshold_used'].append(current_threshold)  # NEW
        
        scheduler.step()
        
        # NEW: Print more frequently and at key transitions
        if (epoch + 1) % 5 == 0 or epoch == 0 or epoch == num_epochs - 1 or epoch == difficulty_start_epoch:
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {100*val_acc:.2f}%")
            if use_difficulty:
                print(f"  Difficult samples (cascade to next): {difficult_pct:.2f}% (threshold={current_threshold:.2f})")
                print(f"  Difficulty breakdown:")
                print(f"    - Pred uncertainty: {np.mean(all_pred_unc):.3f}")
                print(f"    - Spatial complex:  {np.mean(all_spatial):.3f}")
                print(f"    - Temporal complex: {np.mean(all_temporal):.3f}")
                print(f"    - Motion complex:   {np.mean(all_motion):.3f}")
            else:
                print(f"  (Difficulty assessment inactive - pure classification phase)")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'history': history,
                'difficulty_weights': difficulty_weights,
                'difficulty_threshold': difficulty_threshold  # NEW
            }, 'best_filter_model.pth')
            if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
                print(f"  ✓ Saved best model")
    
    print(f"\n{'='*60}")
    print(f"Training Complete! Best Val Acc: {100*best_val_acc:.2f}%")
    print(f"{'='*60}\n")
    
    return history


def evaluate_filter_efficiency(model, val_loader, device='cuda', 
                               difficulty_threshold=0.4, difficulty_weights=None):
    """Evaluate with detailed difficulty breakdown"""
    model.eval()
    
    total = 0
    easy_count = 0
    difficult_count = 0
    
    easy_correct = 0
    overall_correct = 0
    
    # Track which difficulty component triggers most cascades
    cascade_reasons = {
        'high_pred_uncertainty': 0,
        'high_spatial': 0,
        'high_temporal': 0,
        'high_motion': 0
    }
    
    with torch.no_grad():
        for data, labels in val_loader:
            data = data.to(device)
            labels = labels.to(device)
            
            logits, difficulty, components = model(
                data, return_difficulty=True, difficulty_weights=difficulty_weights
            )
            preds = logits.argmax(dim=1)
            
            batch_size = data.size(0)
            total += batch_size
            
            # Track filtering
            easy_mask = difficulty <= difficulty_threshold
            easy_count += easy_mask.sum().item()
            difficult_count += (~easy_mask).sum().item()
            
            # Track accuracy
            easy_correct += ((preds == labels) & easy_mask).sum().item()
            overall_correct += (preds == labels).sum().item()
            
            # Analyze why samples are difficult
            for i in range(batch_size):
                if difficulty[i] > difficulty_threshold:
                    if components['prediction_uncertainty'][i] > 0.5:
                        cascade_reasons['high_pred_uncertainty'] += 1
                    if components['spatial_complexity'][i] > 0.5:
                        cascade_reasons['high_spatial'] += 1
                    if components['temporal_complexity'][i] > 0.5:
                        cascade_reasons['high_temporal'] += 1
                    if components['motion_complexity'][i] > 0.5:
                        cascade_reasons['high_motion'] += 1
    
    easy_acc = easy_correct / easy_count if easy_count > 0 else 0
    overall_acc = overall_correct / total
    
    print(f"\n{'='*60}")
    print(f"Multi-Dimensional Difficulty Cascade Analysis")
    print(f"{'='*60}")
    print(f"Total samples: {total}")
    print(f"Easy (handled by filter): {easy_count} ({100*easy_count/total:.1f}%)")
    print(f"  → Accuracy on easy: {100*easy_acc:.2f}%")
    print(f"Difficult (cascade to main): {difficult_count} ({100*difficult_count/total:.1f}%)")
    print(f"\nCascade reasons (samples can have multiple):")
    print(f"  - High prediction uncertainty: {cascade_reasons['high_pred_uncertainty']}")
    print(f"  - High spatial complexity: {cascade_reasons['high_spatial']}")
    print(f"  - High temporal complexity: {cascade_reasons['high_temporal']}")
    print(f"  - High motion complexity: {cascade_reasons['high_motion']}")
    print(f"\nOverall filter accuracy: {100*overall_acc:.2f}%")
    print(f"{'='*60}\n")
    
    return {
        'total': total,
        'easy': easy_count,
        'difficult': difficult_count,
        'easy_acc': easy_acc,
        'overall_acc': overall_acc,
        'cascade_reasons': cascade_reasons
    }




In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



In [None]:


if __name__ == "__main__":
    
    set_seed(42)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # NEW: Sample dataset to 50% while preserving class distribution
    sampled_train_csv = stratified_sample_dataset(train_csv, sample_ratio=0.5, random_state=42)
    
    # Create train/val split on SAMPLED data
    split_idx = int(len(sampled_train_csv) * 0.8)
    train_df = sampled_train_csv[:split_idx]
    val_df = sampled_train_csv[split_idx:]
    
    print(f"\nDataset: {len(train_df)} train, {len(val_df)} val")
    
    # Create data loaders
    train_loader, val_loader, focal_loss, class_weights = create_dataloaders(
        train_df, val_df, label_map, 
        batch_size=32, max_frames=64, num_workers=4
    )
    
    # Initialize model with PROPER size for 250 classes
    num_classes = len(label_map)
    
    # Find optimal dim that fits under 25MB
    # For 250 classes, we need bigger model than 96 dim
    for test_dim in [192, 160, 128, 96]:
        test_model = ASLFilterModel(
            num_classes=num_classes,
            max_len=64,
            channels=126,
            dim=test_dim,
            dropout_step=5000
        )
        test_size = test_model.get_model_size_mb()
        print(f"Testing dim={test_dim}: {test_size:.2f} MB")
        if test_size < 24:  # Leave 1MB margin
            optimal_dim = test_dim
            break
        del test_model
    
    print(f"Selected optimal dim: {optimal_dim}")
    
    # Create final model
    filter_model = ASLFilterModel(
        num_classes=num_classes,
        max_len=64,
        channels=126,
        dim=optimal_dim,
        dropout_step=5000
    )
    
    model_size = filter_model.get_model_size_mb()
    print(f"\n{'='*60}")
    print(f"ASL Filter Model (Multi-Dimensional Difficulty)")
    print(f"Size: {model_size:.2f} MB / 25 MB limit")
    print(f"Parameters: {sum(p.numel() for p in filter_model.parameters()):,}")
    print(f"Dimension: {optimal_dim}")
    print(f"{'='*60}\n")
    assert model_size < 25, f"Model too large: {model_size:.2f} MB"
    
    # Custom difficulty weights (tune these based on your data)
    difficulty_weights = {
        'prediction_uncertainty': 0.35,  # Model confidence
        'spatial_complexity': 0.25,      # Hand coordination
        'temporal_complexity': 0.25,     # Motion smoothness
        'motion_complexity': 0.15        # 3D motion patterns
    }
    
    # Train with adjusted learning rate for larger model
    initial_lr = 5e-4 if optimal_dim >= 160 else 1e-3
    
    print(f"Starting training with lr={initial_lr}...\n")
    
    # NEW: Train with REDUCED epochs and difficulty annealing
    history = train_filter_model(
        model=filter_model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=focal_loss,
        num_epochs=40,  # CHANGED from 200
        lr=initial_lr,
        device=device,
        difficulty_threshold=0.4,
        warmup_epochs=10,
        difficulty_start_epoch=20,  # NEW PARAMETER
        difficulty_weights=difficulty_weights
    )
    
    # Load best and evaluate
    checkpoint = torch.load('best_filter_model.pth')
    filter_model.load_state_dict(checkpoint['model_state_dict'])
    
    results = evaluate_filter_efficiency(
        filter_model, val_loader, device=device,
        difficulty_threshold=0.4,
        difficulty_weights=difficulty_weights
    )
    
    # Enhanced visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss
    axes[0, 0].plot(epochs, history['train_loss'], label='Train')
    axes[0, 0].plot(epochs, history['val_loss'], label='Val')
    axes[0, 0].set_title('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy
    axes[0, 1].plot(epochs, [100*x for x in history['train_acc']], label='Train')
    axes[0, 1].plot(epochs, [100*x for x in history['val_acc']], label='Val')
    axes[0, 1].set_title('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Cascade percentage
    axes[0, 2].plot(epochs, history['difficult_pct'], color='red')
    axes[0, 2].set_title('% Samples Cascaded to Next Layer')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].grid(True)
    
    # Difficulty components over time
    axes[1, 0].plot(epochs, history['avg_pred_uncertainty'], label='Pred Uncertainty')
    axes[1, 0].plot(epochs, history['avg_spatial_complexity'], label='Spatial')
    axes[1, 0].plot(epochs, history['avg_temporal_complexity'], label='Temporal')
    axes[1, 0].plot(epochs, history['avg_motion_complexity'], label='Motion')
    axes[1, 0].set_title('Difficulty Components')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # NEW: Difficulty threshold annealing
    axes[1, 1].plot(epochs, history['difficulty_threshold_used'], color='purple')
    axes[1, 1].set_title('Difficulty Threshold (Annealing)')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Threshold')
    axes[1, 1].grid(True)
    
    # Cascade breakdown
    cascade_data = results['cascade_reasons']
    axes[1, 2].bar(cascade_data.keys(), cascade_data.values())
    axes[1, 2].set_title('Cascade Reasons (Final Model)')
    axes[1, 2].tick_params(axis='x', rotation=45)
    axes[1, 2].grid(True, axis='y')
    
    plt.tight_layout()
    plt.savefig('filter_training_analysis.png', dpi=150)
    print("Saved training analysis plot to 'filter_training_analysis.png'")