## MSKA: Multi-Stream Keypoint-based Action Recognition for Sign Language (Modified for Hands Only)

### 1. Setup and Imports

In [17]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedShuffleSplit # train_test_split not used directly if using StratifiedShuffleSplit
from collections import Counter
from tqdm import tqdm # Changed from tqdm.notebook
from torchmetrics import F1Score, Accuracy
import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange # If used by model components

In [None]:
# --- Configuration ---
# Path to the directory containing processed .npy keypoint files
# Expected structure: data_dir/className/splitName/videoName.npy
# where videoName.npy contains holistic keypoints (N_FRAMES, 543, 3)
DATA_DIR = "SignLanguage_Processed_Data/keypoints_for_model"  # <--- UPDATE THIS PATH
NUM_CLASSES = 4  # <--- UPDATE THIS for "clavier, disque dur, ordinateur, souris"
BATCH_SIZE = 16
NUM_EPOCHS = 50 # Adjust as needed, 100 might be long for a start
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 1e-4
D_MODEL = 128 # Model dimension, can be tuned
N_HEAD = 4    # Number of attention heads, can be tuned
NUM_LAYERS = 2 # Number of transformer layers, can be tuned

# Keypoint indices for hands from holistic (543 keypoints)
# Pose: 0-32 (33), Face: 33-500 (468), Left Hand: 501-521 (21), Right Hand: 522-542 (21)
LEFT_HAND_INDICES = slice(501, 522) # 21 keypoints
RIGHT_HAND_INDICES = slice(522, 543) # 21 keypoints
NUM_HAND_KEYPOINTS = 21
KEYPOINT_FEATURES = 3 # x, y, z

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

### 2. Data Loading and Preprocessing

In [19]:
# --- Custom PyTorch Dataset for Sign Language Keypoints (Hands Only) ---
class SignLanguageKeypointsDataset(Dataset):
    def __init__(self, data_samples, labels, max_frames):
        """
        Args:
            data_samples (list): List of file paths to .npy files.
            labels (list): List of corresponding labels (integers).
            max_frames (int): The maximum number of frames to pad/truncate sequences to.
        """
        self.data_samples = data_samples
        self.labels = labels
        self.max_frames = max_frames

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        file_path = self.data_samples[idx]
        label = self.labels[idx]

        try:
            # Load holistic keypoints (N_FRAMES, 543, 3)
            holistic_data = np.load(file_path)
        except Exception as e:
            print(f"Error loading {file_path}: {e}. Returning zeros.")
            holistic_data = np.zeros((self.max_frames, 543, KEYPOINT_FEATURES), dtype=np.float32)

        # Slice out hand keypoints
        # Shape: (N_FRAMES, NUM_HAND_KEYPOINTS, KEYPOINT_FEATURES)
        left_hand_kps = holistic_data[:, LEFT_HAND_INDICES, :]
        right_hand_kps = holistic_data[:, RIGHT_HAND_INDICES, :]

        # Pad or truncate each stream to max_frames
        left_hand_kps = self.pad_or_truncate_stream(left_hand_kps, NUM_HAND_KEYPOINTS)
        right_hand_kps = self.pad_or_truncate_stream(right_hand_kps, NUM_HAND_KEYPOINTS)

        return (
            torch.tensor(left_hand_kps, dtype=torch.float32),
            torch.tensor(right_hand_kps, dtype=torch.float32),
            torch.tensor(label, dtype=torch.long),
        )

    def pad_or_truncate_stream(self, stream, num_keypoints_per_stream):
        """Pads or truncates a stream (num_frames, num_keypoints, features) to self.max_frames."""
        current_frames = stream.shape[0]
        if current_frames == self.max_frames:
            return stream
        elif current_frames < self.max_frames:
            # Pad with zeros
            padding_shape = (self.max_frames - current_frames, num_keypoints_per_stream, KEYPOINT_FEATURES)
            padding = np.zeros(padding_shape, dtype=stream.dtype)
            return np.concatenate([stream, padding], axis=0)
        else:
            # Truncate (select first max_frames)
            return stream[:self.max_frames, :, :]


### 3. Model Architecture (Hands Only)

In [20]:
# --- Model Components (largely from your original MSKA, adapted) ---
class SpatialAttention(nn.Module):
    """Attention across keypoints within a single frame for one stream."""
    def __init__(self, d_model): # d_model is the feature dimension *after* initial projection
        super().__init__()
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        # x: [batch, frames, keypoints, d_model_features]
        batch, frames, num_kps, d_model_feat = x.shape
        
        # Process per-frame: Reshape to treat (batch*frames) as batch for attention
        x_reshaped = x.reshape(batch * frames, num_kps, d_model_feat)
        
        Q = self.query(x_reshaped)
        K = self.key(x_reshaped)
        V = self.value(x_reshaped)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_model_feat**0.5)
        attn_probs = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_probs, V)
        
        # Reshape back to original batch and frames dimensions
        return out.reshape(batch, frames, num_kps, d_model_feat)

class TemporalTransformer(nn.Module):
    """Processes a temporal sequence of frame-level features using Transformer Encoder."""
    def __init__(self, d_model, nhead, num_layers, dim_feedforward_factor=4):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model * dim_feedforward_factor,
            batch_first=True # Important: expects (batch, seq, feature)
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x):
        # x: [batch, frames, features (d_model)]
        x = self.transformer_encoder(x) # batch_first=True handles this directly
        return self.norm(x)

class ConvTransformerBlock(nn.Module):
    """Combines 1D Convolutions with a Temporal Transformer for sequence processing."""
    def __init__(self, d_model, nhead, num_transformer_layers):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv1d(d_model, d_model * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(d_model * 2, d_model, kernel_size=3, padding=1),
        )
        self.norm_after_conv = nn.LayerNorm(d_model)
        self.transformer = TemporalTransformer(d_model, nhead, num_transformer_layers)

    def forward(self, x):
        # Input x: [batch, frames, features (d_model)]
        x_permuted = x.permute(0, 2, 1)  # [batch, features (d_model), frames]
        
        conv_out = self.conv_block(x_permuted)
        conv_out_permuted = conv_out.permute(0, 2, 1) # [batch, frames, features (d_model)]
        
        # Residual connection + Norm (Pre-norm for Transformer)
        x_residual = x + conv_out_permuted 
        x_normed = self.norm_after_conv(x_residual) 
        
        # Pass to transformer
        transformer_out = self.transformer(x_normed)
        return transformer_out

class HandStreamProcessor(nn.Module):
    """Processes a single hand keypoint stream (e.g., left or right hand)."""
    def __init__(self, in_keypoint_features, num_keypoints, d_model, nhead, num_layers):
        super().__init__()
        # Projects raw keypoint features (x,y,z) to d_model for each keypoint
        self.projection = nn.Sequential(
            nn.Linear(in_keypoint_features, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )
        self.spatial_attn = SpatialAttention(d_model)
        # After spatial attention and mean pooling over keypoints, input to ConvTransformer is d_model
        self.temporal_conv_transformer = ConvTransformerBlock(d_model, nhead, num_layers)
        # Temporal attention pooling to get a single vector per stream
        self.temporal_attn_pool = nn.Linear(d_model, 1) 
        
    def forward(self, x):
        # x: [batch, frames, num_keypoints, in_keypoint_features (e.g., 3 for x,y,z)]
        x = self.projection(x)  # [batch, frames, num_keypoints, d_model]
        x = self.spatial_attn(x) # [batch, frames, num_keypoints, d_model], features are refined
        
        # Mean pool across keypoints to get a per-frame feature vector
        x = x.mean(dim=2)  # [batch, frames, d_model]
        
        x = self.temporal_conv_transformer(x) # [batch, frames, d_model]
        
        # Temporal attention pooling
        # x is [batch, frames, d_model]
        attn_weights = self.temporal_attn_pool(x).squeeze(-1) # [batch, frames]
        attn_weights = F.softmax(attn_weights, dim=-1) # [batch, frames]
        # Weighted sum: (batch, frames, d_model) * (batch, frames, 1) -> sum over frames
        stream_embedding = torch.sum(x * attn_weights.unsqueeze(-1), dim=1) # [batch, d_model]
        return stream_embedding

class SignLanguageHandsModel(nn.Module):
    """Model for sign language recognition using only left and right hand keypoints."""
    def __init__(self, num_classes, in_keypoint_features, num_hand_keypoints, d_model, nhead, num_layers):
        super().__init__()
        self.left_hand_processor = HandStreamProcessor(
            in_keypoint_features, num_hand_keypoints, d_model, nhead, num_layers
        )
        self.right_hand_processor = HandStreamProcessor(
            in_keypoint_features, num_hand_keypoints, d_model, nhead, num_layers
        )
        
        # Classifier: Takes concatenated features from both hand streams
        # Each stream outputs [batch, d_model], so concatenated is [batch, d_model * 2]
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 2, d_model), # Feature fusion layer
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(d_model, num_classes)
        )
       
    def forward(self, left_hand_input, right_hand_input):
        # left_hand_input: [batch, frames, num_keypoints, features_per_keypoint]
        # right_hand_input: [batch, frames, num_keypoints, features_per_keypoint]
        
        left_out = self.left_hand_processor(left_hand_input)   # [batch, d_model]
        right_out = self.right_hand_processor(right_hand_input) # [batch, d_model]
        
        # Concatenate the features from both hand streams
        combined_features = torch.cat([left_out, right_out], dim=-1) # [batch, d_model * 2]
        
        return self.classifier(combined_features)

# --- Stream dimensions (features per keypoint, e.g., x, y, z) ---
IN_KEYPOINT_FEATURES = KEYPOINT_FEATURES # Should be 3 (x,y,z)
NUM_KP_PER_HAND = NUM_HAND_KEYPOINTS   # Should be 21

### 4. Loss Function (Balanced Focal Loss)

In [21]:
class BalancedFocalLoss(nn.Module):
    """
    Balanced Focal Loss, using effective number of samples for class weighting.
    """
    def __init__(self, class_counts, gamma=2.0, beta=0.9999):
        super().__init__()
        self.gamma = gamma
        
        if not isinstance(class_counts, torch.Tensor):
            class_counts = torch.tensor(class_counts, dtype=torch.float)
        
        # Effective number of samples calculation
        effective_num = 1.0 - torch.pow(beta, class_counts)
        # Clamp to avoid division by zero for classes with no samples (should not happen with good data)
        weights = (1.0 - beta) / torch.clamp(effective_num, min=1e-8) 
        weights = weights / weights.sum() * len(class_counts) # Normalize so average weight is 1
        self.register_buffer('weights', weights) # Use register_buffer for proper device handling

    def forward(self, inputs, targets):
        # inputs: [batch_size, num_classes] (logits)
        # targets: [batch_size] (long tensor of class indices)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none') # [batch_size]
        pt = torch.exp(-ce_loss) # Probabilities of the true class
        focal_term = (1 - pt) ** self.gamma
        
        # Apply class weights based on targets
        alpha = self.weights[targets].to(inputs.device) 
        
        balanced_focal_loss = alpha * focal_term * ce_loss
        return balanced_focal_loss.mean()


### 5. Training and Validation Loop

In [22]:
def train_model(model, train_loader, val_loader, num_classes, device, num_epochs, learning_rate, weight_decay, y_train_labels_for_loss):
    # Initialize metrics
    train_f1_metric = F1Score(task='multiclass', num_classes=num_classes, average='macro').to(device)
    val_f1_metric = F1Score(task='multiclass', num_classes=num_classes, average='macro').to(device)
    train_accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes).to(device)
    val_accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, factor=0.5, verbose=True)

    # Calculate class counts for BalancedFocalLoss
    if len(y_train_labels_for_loss) == 0:
        raise ValueError("y_train_labels_for_loss is empty. Ensure data loading provides training labels.")
    class_counts_np = np.bincount(y_train_labels_for_loss, minlength=num_classes)
    class_counts_tensor = torch.tensor(class_counts_np, dtype=torch.float).to(device)
    criterion = BalancedFocalLoss(class_counts=class_counts_tensor, gamma=2.0)

    best_val_f1 = 0.0
    history = {'train_loss': [], 'val_loss': [], 'train_f1': [], 'val_f1': [], 'train_acc': [], 'val_acc': []}

    print("\n--- Starting Training ---")
    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
        
        for left_kps, right_kps, labels in train_pbar:
            left_kps = left_kps.to(device)
            right_kps = right_kps.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(left_kps, right_kps)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
            optimizer.step()

            epoch_train_loss += loss.item()
            train_f1_metric.update(outputs, labels)
            train_accuracy_metric.update(outputs, labels)
            train_pbar.set_postfix({'loss': f"{loss.item():.4f}"}) 

        avg_train_loss = epoch_train_loss / len(train_loader)
        epoch_train_f1 = train_f1_metric.compute().item()
        epoch_train_acc = train_accuracy_metric.compute().item()
        train_f1_metric.reset()
        train_accuracy_metric.reset()

        # Validation Phase
        model.eval()
        epoch_val_loss = 0.0
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False)
        with torch.no_grad():
            for left_kps, right_kps, labels in val_pbar:
                left_kps = left_kps.to(device)
                right_kps = right_kps.to(device)
                labels = labels.to(device)

                outputs = model(left_kps, right_kps)
                loss = criterion(outputs, labels)
                
                epoch_val_loss += loss.item()
                val_f1_metric.update(outputs, labels)
                val_accuracy_metric.update(outputs, labels)
                val_pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_val_loss = epoch_val_loss / len(val_loader)
        epoch_val_f1 = val_f1_metric.compute().item()
        epoch_val_acc = val_accuracy_metric.compute().item()
        val_f1_metric.reset()
        val_accuracy_metric.reset()

        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_f1'].append(epoch_train_f1)
        history['val_f1'].append(epoch_val_f1)
        history['train_acc'].append(epoch_train_acc)
        history['val_acc'].append(epoch_val_acc)

        print(f"Epoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train F1: {epoch_train_f1:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f}, Val F1:   {epoch_val_f1:.4f}, Val Acc:   {epoch_val_acc:.4f}")

        scheduler.step(epoch_val_f1) 

        if epoch_val_f1 > best_val_f1:
            best_val_f1 = epoch_val_f1
            torch.save(model.state_dict(), 'best_sign_model_hands_only.pth')
            print(f"  🔥 New best model saved with Val F1: {best_val_f1:.4f}")

    print("--- Training Finished ---")
    # Plot training history
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Loss Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history['train_f1'], label='Train F1')
    plt.plot(history['val_f1'], label='Val F1')
    plt.title('F1 Score Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    
    plt.subplot(1, 3, 3)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Val Accuracy')
    plt.title('Accuracy Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('sign_language_training_hands_only.png')
    
    model.load_state_dict(torch.load('best_sign_model_hands_only.pth')) 
    return model, history


### 7. Evaluation on Test Set 

In [23]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score

def evaluate_model(model, test_loader, device, label_mapping_dict, num_classes_eval):
    model.eval()
    all_preds = []
    all_labels = []
    
    test_pbar = tqdm(test_loader, desc='Evaluating on Test Set', leave=False)
    with torch.no_grad():
        for left_kps, right_kps, labels in test_pbar:
            left_kps = left_kps.to(device)
            right_kps = right_kps.to(device)
            
            outputs = model(left_kps, right_kps)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    unique_labels_in_test = sorted(list(set(all_labels)))
    target_names_report = [name for name, idx in sorted(label_mapping_dict.items(), key=lambda item: item[1]) if idx in unique_labels_in_test]
    
    cm_target_names = [name for name, idx in sorted(label_mapping_dict.items(), key=lambda item: item[1])]
    cm_labels = list(range(num_classes_eval))

    print("\n--- Test Set Evaluation ---")
    print("\n📊 Classification Report:")
    print(classification_report(all_labels, all_preds, labels=unique_labels_in_test, target_names=target_names_report, digits=4, zero_division=0))
    
    test_accuracy = accuracy_score(all_labels, all_preds)
    test_f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    test_f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    print(f"\n🔍 Test Accuracy: {test_accuracy:.4f}")
    print(f"🔍 Test F1-Score (Macro): {test_f1_macro:.4f}")
    print(f"🔍 Test F1-Score (Weighted): {test_f1_weighted:.4f}")

    cm = confusion_matrix(all_labels, all_preds, labels=cm_labels)
    plt.figure(figsize=(max(10, num_classes_eval // 1.5), max(8, num_classes_eval // 2))) 
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=cm_target_names,
                yticklabels=cm_target_names)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig('confusion_matrix_hands_only.png')
    
    return {'accuracy': test_accuracy, 'f1_macro': test_f1_macro, 'f1_weighted': test_f1_weighted}


### 8. Model Instantiation and Training Execution

In [None]:
if __name__ == '__main__':
    # --- Load data paths and labels ---
    all_file_paths = []
    all_labels = []
    label_mapping = {} # Will be populated here
    current_label_idx = 0
    frame_counts = []

    print(f"Loading data from: {DATA_DIR}")
    if not os.path.exists(DATA_DIR):
        raise FileNotFoundError(f"DATA_DIR not found: {DATA_DIR}. Please run preprocessing first.")
        
    for class_name in sorted(os.listdir(DATA_DIR)):
        class_path = os.path.join(DATA_DIR, class_name)
        if not os.path.isdir(class_path):
            continue

        if class_name not in label_mapping:
            label_mapping[class_name] = current_label_idx
            current_label_idx += 1
        class_label = label_mapping[class_name]

        for split_name in sorted(os.listdir(class_path)): # e.g., 'train', 'val', 'test'
            split_path = os.path.join(class_path, split_name)
            if not os.path.isdir(split_path):
                continue

            for file_name in os.listdir(split_path):
                if file_name.endswith(".npy"):
                    file_path = os.path.join(split_path, file_name)
                    all_file_paths.append(file_path)
                    all_labels.append(class_label)
                    try:
                        data = np.load(file_path)
                        frame_counts.append(data.shape[0])
                    except Exception as e:
                        print(f"Warning: Could not load {file_path} to get frame count: {e}")
                        frame_counts.append(0) 

    if not all_file_paths:
        raise ValueError(f"No .npy files found in {DATA_DIR}. Please check the path and data structure.")

    if frame_counts:
        MAX_FRAMES = int(np.percentile(frame_counts, 95)) 
        print(f"Using MAX_FRAMES = {MAX_FRAMES} (95th percentile of frame counts)")
    else:
        MAX_FRAMES = 50 
        print(f"Warning: No frame counts. Using default MAX_FRAMES = {MAX_FRAMES}")

    X = np.array(all_file_paths)
    y = np.array(all_labels)

    # --- Stratified Train-Validation-Test Split ---
    sss_test = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_val_indices, test_indices = next(sss_test.split(X, y))
    X_train_val, X_test_paths = X[train_val_indices], X[test_indices]
    y_train_val, y_test_labels = y[train_val_indices], y[test_indices]

    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42) 
    train_indices, val_indices = next(sss_val.split(X_train_val, y_train_val))
    X_train_paths, X_val_paths = X_train_val[train_indices], X_train_val[val_indices]
    y_train_labels, y_val_labels = y_train_val[train_indices], y_train_val[val_indices] # y_train_labels is now defined

    print(f"\nDataset splits:")
    print(f"Train samples: {len(X_train_paths)}")
    print(f"Validation samples: {len(X_val_paths)}")
    print(f"Test samples: {len(X_test_paths)}")

    train_dataset = SignLanguageKeypointsDataset(X_train_paths, y_train_labels, MAX_FRAMES)
    val_dataset = SignLanguageKeypointsDataset(X_val_paths, y_val_labels, MAX_FRAMES)
    test_dataset = SignLanguageKeypointsDataset(X_test_paths, y_test_labels, MAX_FRAMES)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True if DEVICE.type == 'cuda' else False)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True if DEVICE.type == 'cuda' else False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True if DEVICE.type == 'cuda' else False)
    
    print("\nClass distribution in splits:")
    for split_name, labels_in_split_arg in [('Train', y_train_labels), ('Validation', y_val_labels), ('Test', y_test_labels)]:
        counts = Counter(labels_in_split_arg)
        print(f"{split_name}:")
        for label_idx, count in sorted(counts.items()):
            class_name_found = [name for name, idx in label_mapping.items() if idx == label_idx]
            if class_name_found:
                class_name = class_name_found[0]
                print(f"  {class_name} (ID {label_idx}): {count}")
            else:
                print(f"  Unknown Label ID {label_idx}: {count}")

    # Initialize the model
    model_instance = SignLanguageHandsModel(
        num_classes=NUM_CLASSES, 
        in_keypoint_features=IN_KEYPOINT_FEATURES, 
        num_hand_keypoints=NUM_KP_PER_HAND, 
        d_model=D_MODEL, 
        nhead=N_HEAD, 
        num_layers=NUM_LAYERS
    ).to(DEVICE)

    print(f"Model instantiated with {sum(p.numel() for p in model_instance.parameters() if p.requires_grad)} trainable parameters.")

    trained_model_instance = None
    if train_loader and val_loader:
        trained_model_instance, history_data = train_model(
            model=model_instance,
            train_loader=train_loader,
            val_loader=val_loader,
            num_classes=NUM_CLASSES,
            device=DEVICE,
            num_epochs=NUM_EPOCHS,
            learning_rate=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY,
            y_train_labels_for_loss=y_train_labels # Pass y_train_labels here
        )
    else:
        print("Dataloaders not initialized. Skipping training.")

    # Evaluation
    if trained_model_instance and test_loader:
        print("\nEvaluating the model trained in this session.")
        evaluation_results = evaluate_model(trained_model_instance, test_loader, DEVICE, label_mapping, NUM_CLASSES)
    elif os.path.exists('best_sign_model_hands_only.pth') and test_loader:
        print("\nLoading best model from 'best_sign_model_hands_only.pth' for evaluation.")
        model_to_evaluate_loaded = SignLanguageHandsModel(
            num_classes=NUM_CLASSES, 
            in_keypoint_features=IN_KEYPOINT_FEATURES, 
            num_hand_keypoints=NUM_KP_PER_HAND, 
            d_model=D_MODEL, 
            nhead=N_HEAD, 
            num_layers=NUM_LAYERS
        ).to(DEVICE)
        try:
            model_to_evaluate_loaded.load_state_dict(torch.load('best_sign_model_hands_only.pth', map_location=DEVICE))
            evaluation_results = evaluate_model(model_to_evaluate_loaded, test_loader, DEVICE, label_mapping, NUM_CLASSES)
        except FileNotFoundError:
            print("Error: 'best_sign_model_hands_only.pth' not found during evaluation phase.")
    else:
        print("No trained model or test_loader available. Skipping evaluation.")

    if 'plt' in globals():
        plt.show()