In [1]:
import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.interpolate import interp1d
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

BASE_PATH = "/kaggle/input/asl-signs"
PREPROCESSED_PATH = "./preprocessed_data"

# Load metadata
train_csv = pd.read_csv(f"{BASE_PATH}/train.csv")
with open(f"{BASE_PATH}/sign_to_prediction_index_map.json", "r") as f:
    label_map = json.load(f)

NUM_CLASSES = len(label_map)
print(f"Classes: {NUM_CLASSES}")

N_JOBS = 4

Classes: 250


In [2]:

def has_valid_hands_fast(sample_path, min_frames=5):
    """Filter samples with insufficient hand data"""
    full_path = os.path.join(BASE_PATH, sample_path)
    df = pd.read_parquet(full_path, columns=["frame", "type"], engine="pyarrow")
    
    hand_df = df[df["type"].isin(["left_hand", "right_hand"])]
    if hand_df.empty or hand_df["frame"].nunique() < min_frames:
        return False
    return True

def check_index(i):
    return i if has_valid_hands_fast(train_csv.iloc[i]["path"]) else None

# Try loading pre-computed indices, otherwise compute
try:
    valid_indices = np.load("/kaggle/input/valid-indices-npy/valid_indices.npy").tolist()
    print(f"Loaded {len(valid_indices)} valid indices from cache")
except:
    print("Computing valid indices (this may take a while)...")
    results = Parallel(n_jobs=N_JOBS, prefer="threads")(
        delayed(check_index)(i) for i in tqdm(range(len(train_csv)))
    )
    valid_indices = [i for i in results if i is not None]
    os.makedirs("/kaggle/working", exist_ok=True)
    np.save("/kaggle/working/valid_indices.npy", np.array(valid_indices))
    print(f"Valid samples: {len(valid_indices)} / {len(train_csv)}")

# Filter CSV to valid samples
train_csv = train_csv.iloc[valid_indices].reset_index(drop=True)
print(f"Filtered train_csv length: {len(train_csv)}")


class LandmarkPreprocessor:
    @staticmethod
    def filter_empty_frames(kp, threshold=0.5):
        """Remove frames with too many missing landmarks"""
        missing_per_frame = np.sum(kp == 0, axis=(1, 2)) / (kp.shape[1] * 3)
        valid_mask = missing_per_frame < threshold
        if np.sum(valid_mask) == 0:
            return kp[:1]
        return kp[valid_mask]
    
    @staticmethod
    def interpolate_missing(kp):
        T, L, C = kp.shape
        kp_interp = kp.copy()
        
        for l in range(L):
            for c in range(C):
                track = kp[:, l, c]
                valid_mask = track != 0
                n_valid = np.sum(valid_mask)
                
                if n_valid == 0:
                    continue
                elif n_valid == 1:
                    kp_interp[:, l, c] = track[valid_mask][0]
                else:
                    valid_indices = np.where(valid_mask)[0]
                    valid_values = track[valid_mask]
                    f = interp1d(valid_indices, valid_values, kind='linear',
                               fill_value='extrapolate', bounds_error=False)
                    kp_interp[:, l, c] = f(np.arange(T))
        
        return kp_interp
    
    @staticmethod
    def anchor_normalize(kp):
        """Normalize relative to hand centroid (FIXED for hand-only data)"""
        T, L, C = kp.shape
        
        # For hand landmarks, use centroid of all valid points
        valid_mask = kp != 0
        
        # Calculate centroid per frame
        anchor_pos = np.zeros((T, C))
        scale = np.zeros(T)
        
        for t in range(T):
            valid_points = kp[t][np.any(valid_mask[t], axis=1)]
            if len(valid_points) > 0:
                anchor_pos[t] = valid_points.mean(axis=0)
                # Scale based on spread of points
                distances = np.linalg.norm(valid_points - anchor_pos[t], axis=1)
                scale[t] = distances.mean() if len(distances) > 0 else 1.0
            else:
                anchor_pos[t] = 0
                scale[t] = 1.0
        
        # Apply normalization
        anchor_pos_expanded = anchor_pos[:, np.newaxis, :]
        kp_relative = kp - anchor_pos_expanded
        
        # Use mean scale to avoid division by zero
        mean_scale = scale[scale > 0].mean() if np.any(scale > 0) else 1.0
        if mean_scale > 1e-6:
            kp_normalized = kp_relative / mean_scale
        else:
            kp_normalized = kp_relative
        
        return kp_normalized, anchor_pos, mean_scale

class LandmarkAugmentor:
    @staticmethod
    def horizontal_flip(kp, left_hand_indices, right_hand_indices, p=0.5):
        """FIXED: Proper handling of hand indices"""
        if np.random.rand() < p:
            kp = kp.copy()
            kp[:, :, 0] *= -1  # Flip x-coordinate
            
            # Swap hands if both are present
            if len(left_hand_indices) > 0 and len(right_hand_indices) > 0:
                left_hand = kp[:, left_hand_indices, :].copy()
                right_hand = kp[:, right_hand_indices, :].copy()
                kp[:, left_hand_indices, :] = right_hand
                kp[:, right_hand_indices, :] = left_hand
        return kp
    
    @staticmethod
    def rotation_3d(kp, max_angle=60, p=0.5):
        if np.random.rand() < p:
            angle = np.random.uniform(-max_angle, max_angle)
            angle_rad = np.deg2rad(angle)
            
            cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
            R = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]])
            
            T, L, _ = kp.shape
            kp_flat = kp.reshape(-1, 3)
            kp_rotated = (R @ kp_flat.T).T
            kp = kp_rotated.reshape(T, L, 3)
        return kp
    
    @staticmethod
    def resize_3d(kp, scale_range=(0.8, 1.2), p=0.5):
        if np.random.rand() < p:
            scale = np.random.uniform(*scale_range)
            kp = kp * scale
        return kp
    
    @staticmethod
    def finger_dropout(kp, hand_indices, p_per_finger=0.1, p=0.3):
        if np.random.rand() < p:
            for idx in hand_indices:
                if np.random.rand() < p_per_finger:
                    kp[:, idx, :] = 0
        return kp






Computing valid indices (this may take a while)...


100%|██████████| 94477/94477 [08:00<00:00, 196.62it/s]

Valid samples: 94198 / 94477
Filtered train_csv length: 94198





In [3]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=logits.shape[1])
        p_t = (probs * targets_one_hot).sum(dim=1)
        
        focal_weight = (1 - p_t) ** self.gamma
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        focal_loss = focal_weight * ce_loss
        
        if self.alpha is not None:
            if self.alpha.device != logits.device:
                self.alpha = self.alpha.to(logits.device)
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss


In [4]:
class ASLDataset(Dataset):
    def __init__(self, dataframe, label_map, max_frames=64, split='train',
                 use_anchor_norm=True, use_interpolation=True, use_augmentation=True):
        self.df = dataframe.reset_index(drop=True)
        self.label_map = label_map
        self.max_frames = max_frames
        self.split = split
        self.use_anchor_norm = use_anchor_norm
        self.use_interpolation = use_interpolation
        self.use_augmentation = use_augmentation and (split == 'train')
        
        # Initialize helpers
        self.preprocessor = LandmarkPreprocessor()
        self.augmentor = LandmarkAugmentor()
        
        # CRITICAL: Define hand indices (42 landmarks total, 21 per hand)
        self.left_hand_indices = list(range(0, 21))
        self.right_hand_indices = list(range(21, 42))
        
        # Class balancing
        self.class_counts = self._analyze_class_distribution()
        self.class_weights = self._compute_class_weights()
        

        print(f"ASL Dataset - {split.upper()}")
        
        print(f"Samples: {len(self.df)}")
        print(f"Classes: {len(label_map)}")
        print(f"Max frames: {max_frames}")
        print(f"Anchor normalization: {use_anchor_norm}")
        print(f"Interpolation: {use_interpolation}")
        print(f"Augmentation: {self.use_augmentation}")

    
    def _analyze_class_distribution(self):
        labels = [self.label_map[row['sign']] for _, row in self.df.iterrows()]
        return Counter(labels)
    
    def _compute_class_weights(self):
        num_classes = len(self.label_map)
        weights = np.zeros(num_classes)
        total_samples = sum(self.class_counts.values())
        
        for class_idx in range(num_classes):
            count = self.class_counts.get(class_idx, 1)
            weights[class_idx] = total_samples / (num_classes * count)
        
        weights = weights / weights.sum() * num_classes
        return torch.FloatTensor(weights)
    
    def get_sample_weights(self):
        """For WeightedRandomSampler"""
        sample_weights = []
        for _, row in self.df.iterrows():
            label = self.label_map[row['sign']]
            sample_weights.append(self.class_weights[label].item())
        return torch.FloatTensor(sample_weights)
    
    def load_hand_keypoints(self, sample_path):
        """Load hand landmarks (42 landmarks: 21 left + 21 right)"""
        full_path = os.path.join(BASE_PATH, sample_path)
        df = pd.read_parquet(full_path, engine="pyarrow")
        
        hand_df = df[df["type"].isin(["left_hand", "right_hand"])]
        hand_df = hand_df.sort_values(["frame", "type", "landmark_index"])
        
        frames = hand_df["frame"].nunique()
        coords = hand_df[["x", "y", "z"]].values
        coords = coords.reshape(frames, 42, 3)
        
        return np.nan_to_num(coords, nan=0.0)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load keypoints
        kp = self.load_hand_keypoints(row["path"])  # (T, 42, 3)
        original_length = kp.shape[0]
        
        # Preprocessing
        kp = self.preprocessor.filter_empty_frames(kp, threshold=0.5)
        
        if self.use_interpolation:
            kp = self.preprocessor.interpolate_missing(kp)
        
        if self.use_anchor_norm:
            kp, _, _ = self.preprocessor.anchor_normalize(kp)
        
        # Temporal resampling
        if kp.shape[0] < self.max_frames:
            pad_len = self.max_frames - kp.shape[0]
            kp = np.concatenate([kp, np.zeros((pad_len, 42, 3))], axis=0)
        elif kp.shape[0] > self.max_frames:
            indices = np.linspace(0, kp.shape[0]-1, self.max_frames).astype(int)
            kp = kp[indices]
        
        kp = np.nan_to_num(kp, nan=0.0)
        
        # Augmentation (FIXED with correct indices)
        if self.use_augmentation:
            kp = self.augmentor.horizontal_flip(
                kp, self.left_hand_indices, self.right_hand_indices, p=0.5)
            kp = self.augmentor.rotation_3d(kp, max_angle=60, p=0.5)
            kp = self.augmentor.resize_3d(kp, scale_range=(0.8, 1.2), p=0.5)
            
            all_hand_indices = self.left_hand_indices + self.right_hand_indices
            kp = self.augmentor.finger_dropout(kp, all_hand_indices, p=0.3)
        
        label = self.label_map[row["sign"]]
        
        return torch.FloatTensor(kp), label


In [5]:
def create_dataloaders(train_df, val_df, label_map, batch_size=32, 
                       max_frames=64, num_workers=4):
    
    train_ds = ASLDataset(
        train_df, label_map, max_frames=max_frames, split='train',
        use_anchor_norm=True, use_interpolation=True, use_augmentation=True
    )
    
    val_ds = ASLDataset(
        val_df, label_map, max_frames=max_frames, split='val',
        use_anchor_norm=True, use_interpolation=True, use_augmentation=False
    )
    
    # Weighted sampler for balanced training
    sample_weights = train_ds.get_sample_weights()
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=2, persistent_workers=True, drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=2, persistent_workers=True
    )
    
    
    focal_loss = FocalLoss(alpha=train_ds.class_weights, gamma=2.0)
    
    return train_loader, val_loader, focal_loss, train_ds.class_weights


In [6]:
if __name__ == "__main__":
    set_seed(42)
    
    split_idx = int(len(train_csv) * 0.8)
    train_df = train_csv[:split_idx]
    val_df = train_csv[split_idx:]
    
    print(f"Train samples: {len(train_df)}, Val samples: {len(val_df)}")
    
    train_loader, val_loader, focal_loss, class_weights = create_dataloaders(
        train_df, val_df, label_map, 
        batch_size=32, max_frames=64, num_workers=4
    )
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        print(f"Batch shape: {data.shape}")  # Should be [32, 64, 42, 3]
        print(f"Labels shape: {labels.shape}")
        break

Train samples: 75358, Val samples: 18840
ASL Dataset - TRAIN
Samples: 75358
Classes: 250
Max frames: 64
Anchor normalization: True
Interpolation: True
Augmentation: True
ASL Dataset - VAL
Samples: 18840
Classes: 250
Max frames: 64
Anchor normalization: True
Interpolation: True
Augmentation: False
Batch shape: torch.Size([32, 64, 42, 3])
Labels shape: torch.Size([32])
