In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import math
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


import pandas as pd
import numpy as np


final_cols = ['anger_intensity',
'arousal_change_magnitude',
'arousal_deviation_from_neutral',
'arousal_onset_detected',
'arousal_stability',
'attention_focus_distracted',
'attention_focus_focused',
'attention_focus_moderate',
'attention_stability_index',
'avg_blink_duration_sec',
'avg_fixation_duration_sec',
'avg_saccade_amplitude',
'behavioral_complexity',
'behavioral_state_normal',
'blink_completeness_score',
'blink_rate_per_minute',
'blink_rhythm_score',
'cognitive_load_index',
'disengagement_indicator',
'disgust_intensity',
'emotion_quadrant_negative_high_arousal',
'emotion_quadrant_negative_low_arousal',
'emotion_quadrant_neutral',
'emotion_quadrant_positive_high_arousal',
'emotion_transition_frequency',
'engagement_proxy_score',
'engagement_score',
'engagement_state_low',
'expression_arousal_sync',
'expression_change_rate',
'eye_openness_score',
'eyebrow_furrow_intensity',
'eyebrow_raise_intensity',
'facial_asymmetry',
'fixation_count_per_window',
'frown_intensity',
'gaze_consistency_score',
'gaze_direction_center',
'gaze_direction_down',
'gaze_direction_down_left',
'gaze_direction_down_right',
'gaze_direction_left',
'gaze_direction_right',
'gaze_direction_up',
'gaze_direction_up_left',
'gaze_direction_up_right',
'gaze_head_coordination',
'head_Tx_velocity',
'head_Ty_velocity',
'head_Tz_velocity',
'head_gaze_alignment_score',
'head_movement_jerk',
'head_movement_stability',
'head_pitch',
'head_pitch_acceleration',
'head_pitch_velocity',
'head_roll',
'head_roll_acceleration',
'head_roll_velocity',
'head_tilt_direction_center',
'head_tilt_direction_left',
'head_yaw',
'head_yaw_acceleration',
'head_yaw_velocity',
'left_eye_aperture',
'micro_expression_frequency_per_min',
'mouth_openness',
'multimodal_consistency',
'nostril_flare_intensity',
'pupil_size_mean',
'pupil_size_std',
'right_eye_aperture',
'saccade_frequency_per_sec',
'sadness_intensity',
'smile_intensity',
'surprise_intensity',
'temporal_alignment_score',
'valence_deviation_from_neutral',
'valence_stability']


def apply_one_hot_encoding(df: pd.DataFrame, columns_to_encode: list) -> pd.DataFrame:
    """
    Applies one-hot encoding to specified categorical columns in a DataFrame.

    Args:
        df: The input pandas DataFrame containing the features.
        columns_to_encode: A list of column names to be one-hot encoded.

    Returns:
        A new DataFrame with the specified columns one-hot encoded.
        The original categorical columns will be dropped.
    """
    
    # Filter for columns that actually exist in the DataFrame
    existing_columns_to_encode = [col for col in columns_to_encode if col in df.columns]
    
    if not existing_columns_to_encode:
        print("No specified categorical columns found in the DataFrame to encode.")
        return df # Return original df if no columns exist
        

    # Use pd.get_dummies for one-hot encoding
    # prefix ensures unique column names (e.g., 'head_tilt_direction_left')
    # dtype=int ensures the new columns are integers (0 or 1)
    df_encoded = pd.get_dummies(
        df,
        columns=existing_columns_to_encode,
        prefix=existing_columns_to_encode,
        dtype=int
    )
    
    return df_encoded

In [10]:
class Config:
    # Data Configuration - 30 second windows for production
    SEQUENCE_LENGTH = 750  # 30 seconds at 25 FPS
    FEATURE_DIM = 456  # Your total features
    FRAME_RATE = 25
    
    # Model Architecture - Optimized for 2x16GB
    D_MODEL = 768
    N_HEADS = 12
    N_ENCODER_LAYERS = 8
    DROPOUT = 0.1
    
    # SSL Configuration
    SSL_TASKS = ['temporal_prediction', 'behavioral_consistency', 'attention_flow']
    PREDICTION_HORIZON = 50  # Predict next 2 seconds
    CONSISTENCY_WINDOW = 125  # 5 seconds for consistency check
    
    # Training - Full utilization of 2x16GB setup
    BATCH_SIZE = 16  # Maximize GPU utilization
    SSL_EPOCHS = 250  # Testing epochs
    SSL_LR = 2e-4
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Data paths for Kaggle
    TRAIN_PATH = '/kaggle/input/daisee-feature-processed/Train-final'
    TEST_PATH = '/kaggle/input/daisee-feature-processed/Test-final'

class BehavioralDataset(Dataset):
    """Dataset for temporal behavioral data with SSL objectives"""
    
    def __init__(self, data_folder, sequence_length=750, is_train=True, max_files=None):
        self.data_folder = Path(data_folder)
        self.sequence_length = sequence_length
        self.is_train = is_train
        
        # Collect all CSV files
        self.csv_files = list(self.data_folder.glob("*.csv"))
        
        # Limit files for testing if specified
        if max_files:
            self.csv_files = self.csv_files[:max_files]
        
        print(f"Found {len(self.csv_files)} videos in {data_folder}")
        
        # Get feature dimension from first file
        if len(self.csv_files) > 0:
            sample_df = pd.read_csv(self.csv_files[0])
            self.feature_columns = final_cols
            self.actual_feature_dim = len(final_cols)
        else:
            raise ValueError("No CSV files found in the directory")
    

    
    def __len__(self):
        return len(self.csv_files)
    
    def __getitem__(self, idx):
        try:
            csv_path = self.csv_files[idx]
            df = pd.read_csv(csv_path)
            categorical_features_for_encoding = [
                    "head_tilt_direction",
                    "emotion_quadrant",
                    "engagement_state",
                    "attention_focus",
                    "behavioral_state"
                ]

            df = apply_one_hot_encoding(df, categorical_features_for_encoding)
            
            if len(df) == 0:
                print(f"Warning: Empty dataframe in {csv_path}")
                # Return dummy data
                features = np.zeros((self.sequence_length, self.actual_feature_dim), dtype=np.float32)
            else:
                # Use the pre-determined feature columns for consistency
                for col in self.feature_columns:
                    if col not in df.columns:
                        df[col] = 0

                feature_data = df[self.feature_columns].copy()
                
                # Convert all columns to numeric, forcing errors to NaN
                for col in feature_data.columns:
                    feature_data[col] = pd.to_numeric(feature_data[col], errors='coerce')

                feature_data = feature_data.fillna(feature_data.mean()).fillna(0)
                
                # Handle variable length videos
                if len(feature_data) < self.sequence_length:
                    # Pad shorter videos by repeating last frame
                    padding_needed = self.sequence_length - len(feature_data)
                    if len(feature_data) > 0:
                        last_values = feature_data.iloc[-1:].values
                        padding = np.repeat(last_values, padding_needed, axis=0)
                        features = np.vstack([feature_data.values, padding])
                    else:
                        features = np.zeros((self.sequence_length, len(feature_data.columns)), dtype=np.float32)
                        
                elif len(feature_data) > self.sequence_length:
                    # For longer videos, sample random window during training
                    if self.is_train:
                        start_idx = np.random.randint(0, len(feature_data) - self.sequence_length + 1)
                        features = feature_data.iloc[start_idx:start_idx + self.sequence_length].values
                    else:
                        # Use first window for validation
                        features = feature_data.iloc[:self.sequence_length].values
                else:
                    features = feature_data.values
                
                # Ensure we have the right shape and type
                features = features.astype(np.float32)
                
                # Handle any remaining NaN or inf values
                features = np.nan_to_num(features, nan=0.0, posinf=1.0, neginf=-1.0)
                
                # Normalize features to prevent extreme values
                features = np.clip(features, -10, 10)
            
            # Create SSL targets
            ssl_targets = self._create_ssl_targets(features)
            
            return {
                'features': torch.tensor(features),
                'ssl_targets': ssl_targets,
                'video_name': csv_path.stem
            }
            
        except Exception as e:
            print(f"Error processing {csv_path}: {str(e)}")
            dummy_features = np.zeros((self.sequence_length, self.actual_feature_dim), dtype=np.float32)
            ssl_targets = self._create_ssl_targets(dummy_features)
            return {
                'features': torch.tensor(dummy_features),
                'ssl_targets': ssl_targets,
                'video_name': f'error_{idx}'
            }
    
    def _create_ssl_targets(self, features):
        """Create multiple SSL objectives with fixed tensor dimensions"""
        targets = {}
        seq_len, feat_dim = features.shape
        
        # 1. Temporal Prediction: Fixed to always produce same sequence length
        horizon = Config.PREDICTION_HORIZON
        
        if seq_len > horizon:
            # Take input context from beginning, predict future frames
            context_frames = features[:-horizon]  # Remove last 'horizon' frames
            future_frames = features[horizon:]    # Remove first 'horizon' frames
            
            # Ensure both have same length (should be seq_len - horizon)
            min_len = min(len(context_frames), len(future_frames))
            targets['temporal_context'] = torch.tensor(context_frames[:min_len], dtype=torch.float32)
            targets['temporal_future'] = torch.tensor(future_frames[:min_len], dtype=torch.float32)
        else:
            # For short sequences, predict next frame
            targets['temporal_context'] = torch.tensor(features[:-1] if seq_len > 1 else features, dtype=torch.float32)
            targets['temporal_future'] = torch.tensor(features[1:] if seq_len > 1 else features, dtype=torch.float32)
        
        # 2. Behavioral Consistency: Use different feature subsets
        # Split features into different behavioral modalities
        third = feat_dim // 3
        attention_features = features[:, :third] if third > 0 else features
        engagement_features = features[:, third:2*third] if third > 0 else features
        emotion_features = features[:, 2*third:] if third > 0 else features
        
        targets['attention_trajectory'] = torch.tensor(attention_features, dtype=torch.float32)
        targets['engagement_trajectory'] = torch.tensor(engagement_features, dtype=torch.float32)
        targets['emotion_trajectory'] = torch.tensor(emotion_features, dtype=torch.float32)
        
        # 3. Cross-modal alignment
        targets['cross_modal_pairs'] = (
            torch.tensor(attention_features, dtype=torch.float32),
            torch.tensor(engagement_features, dtype=torch.float32)
        )
        
        return targets

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):  # Increased for longer sequences
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class BehavioralTransformer(nn.Module):
    """Transformer designed for behavioral temporal sequences - Full scale"""
    
    def __init__(self, config, actual_feature_dim):
        super().__init__()
        self.config = config
        self.actual_feature_dim = actual_feature_dim
        
        # Input projection with layer norm
        self.input_projection = nn.Sequential(
            nn.Linear(actual_feature_dim, config.D_MODEL),
            nn.LayerNorm(config.D_MODEL),
            nn.Dropout(config.DROPOUT)
        )
        
        self.pos_encoding = PositionalEncoding(config.D_MODEL, max_len=config.SEQUENCE_LENGTH)
        
        # Transformer encoder - Full scale
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.D_MODEL,
            nhead=config.N_HEADS,
            dim_feedforward=config.D_MODEL * 4,
            dropout=config.DROPOUT,
            batch_first=True,
            activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, config.N_ENCODER_LAYERS)
        
        # SSL heads with better architectures
        self.temporal_predictor = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL),
            nn.GELU(),
            nn.LayerNorm(config.D_MODEL),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.GELU(),
            nn.Linear(config.D_MODEL // 2, actual_feature_dim)
        )
        
        self.consistency_projector = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.GELU(),
            nn.LayerNorm(config.D_MODEL // 2),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, 256)  # Consistency embedding
        )
        
        self.attention_flow_predictor = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.GELU(),
            nn.LayerNorm(config.D_MODEL // 2),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, max(actual_feature_dim // 3, 64))  # Attention flow
        )
        
        # Cross-modal predictors
        self.cross_modal_predictor = nn.Sequential(
            nn.Linear(config.D_MODEL, config.D_MODEL // 2),
            nn.GELU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(config.D_MODEL // 2, max(actual_feature_dim // 3, 64))
        )
    
    def forward(self, x, return_embeddings=False):
        # x shape: (batch, sequence, features)
        batch_size, seq_len, _ = x.shape
        
        # Project to model dimension
        x = self.input_projection(x)  # (batch, seq, d_model)
        
        # Add positional encoding
        x = x.transpose(0, 1)  # (seq, batch, d_model)
        x = self.pos_encoding(x)
        x = x.transpose(0, 1)  # (batch, seq, d_model)
        
        # Transformer encoding
        embeddings = self.transformer(x)  # (batch, seq, d_model)
        
        if return_embeddings:
            return embeddings
        
        # SSL predictions
        ssl_outputs = {}
        
        # Temporal prediction - predict for reduced sequence length
        prediction_length = seq_len - Config.PREDICTION_HORIZON
        if prediction_length > 0:
            temporal_embeddings = embeddings[:, :prediction_length]  # Match context length
            ssl_outputs['temporal'] = self.temporal_predictor(temporal_embeddings)
        else:
            # Fallback for short sequences
            ssl_outputs['temporal'] = self.temporal_predictor(embeddings[:, :-1] if seq_len > 1 else embeddings)
        
        # Behavioral consistency (use mean pooling)
        ssl_outputs['consistency'] = self.consistency_projector(embeddings.mean(dim=1))
        
        # Attention flow
        ssl_outputs['attention_flow'] = self.attention_flow_predictor(embeddings)
        
        # Cross-modal prediction
        ssl_outputs['cross_modal'] = self.cross_modal_predictor(embeddings)
        
        return ssl_outputs

class BehavioralSSLLoss(nn.Module):
    """Multi-objective SSL loss for behavioral data with fixed tensor handling"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.mse_loss = nn.MSELoss()
        self.cosine_loss = nn.CosineEmbeddingLoss()
        self.huber_loss = nn.HuberLoss(delta=1.0)
        
    def forward(self, predictions, targets):
        total_loss = 0
        loss_components = {}
        batch_size = None
        device = None
        
        # Get batch size and device from any available tensor
        for key, value in predictions.items():
            if isinstance(value, torch.Tensor):
                batch_size = value.size(0)
                device = value.device
                break
        
        if batch_size is None or device is None:
            # Create a default tensor on CPU if no predictions available
            device = torch.device('cpu')
            return torch.tensor(0.0, device=device), {}
        
        # 1. Temporal Prediction Loss (main objective)
        if 'temporal_future' in targets and 'temporal' in predictions:
            future_frames = targets['temporal_future']
            pred_frames = predictions['temporal']
            
            # Ensure both tensors are on the same device
            if future_frames.device != device:
                future_frames = future_frames.to(device)
            
            # Ensure batch dimension matches
            if future_frames.size(0) == batch_size and pred_frames.size(0) == batch_size:
                # Match sequence lengths
                min_seq_len = min(future_frames.size(1), pred_frames.size(1))
                min_feat_dim = min(future_frames.size(2), pred_frames.size(2))
                
                if min_seq_len > 0 and min_feat_dim > 0:
                    temporal_loss = self.huber_loss(
                        pred_frames[:, :min_seq_len, :min_feat_dim], 
                        future_frames[:, :min_seq_len, :min_feat_dim]
                    )
                    loss_components['temporal'] = temporal_loss
                    total_loss += 2.0 * temporal_loss  # Higher weight for main task
        
        # 2. Behavioral Consistency Loss
        if 'attention_trajectory' in targets and 'consistency' in predictions:
            try:
                consistency_loss = self._compute_consistency_loss(
                    predictions['consistency'], 
                    targets['attention_trajectory'],
                    device
                )
                loss_components['consistency'] = consistency_loss
                total_loss += 0.5 * consistency_loss
            except Exception as e:
                print(f"Consistency loss error: {e}")
        
        # 3. Attention Flow Smoothness Loss
        if 'attention_flow' in predictions:
            try:
                flow_loss = self._compute_flow_smoothness_loss(
                    predictions['attention_flow']
                )
                loss_components['flow'] = flow_loss
                total_loss += 0.3 * flow_loss
            except Exception as e:
                print(f"Flow loss error: {e}")
        
        # 4. Cross-modal alignment loss
        if 'cross_modal_pairs' in targets and 'cross_modal' in predictions:
            try:
                cross_modal_loss = self._compute_cross_modal_loss(
                    predictions['cross_modal'],
                    targets['cross_modal_pairs'],
                    device
                )
                loss_components['cross_modal'] = cross_modal_loss
                total_loss += 0.4 * cross_modal_loss
            except Exception as e:
                print(f"Cross-modal loss error: {e}")
        
        # Ensure we have at least some loss
        if total_loss == 0:
            total_loss = torch.tensor(0.001, device=device, requires_grad=True)
        
        return total_loss, loss_components
    
    def _compute_consistency_loss(self, embeddings, attention_trajectories, device):
        """Encourage similar attention patterns to have similar embeddings"""
        batch_size = embeddings.size(0)
        
        if batch_size < 2:
            return torch.tensor(0.0, device=device)
        
        # Ensure attention_trajectories is on the correct device
        if attention_trajectories.device != device:
            attention_trajectories = attention_trajectories.to(device)
        
        # Compute attention similarity matrix
        attention_flat = attention_trajectories.view(batch_size, -1)
        attention_sim = F.cosine_similarity(
            attention_flat.unsqueeze(1), 
            attention_flat.unsqueeze(0), 
            dim=2
        )
        
        # Compute embedding similarity matrix
        embedding_sim = F.cosine_similarity(
            embeddings.unsqueeze(1), 
            embeddings.unsqueeze(0), 
            dim=2
        )
        
        # Loss: embedding similarity should match attention similarity
        consistency_loss = self.mse_loss(embedding_sim, attention_sim)
        return consistency_loss
    
    def _compute_flow_smoothness_loss(self, attention_flow):
        """Encourage smooth attention transitions"""
        if attention_flow.size(1) <= 1:
            return torch.tensor(0.0, device=attention_flow.device)
            
        # Compute temporal differences
        flow_diff = attention_flow[:, 1:] - attention_flow[:, :-1]
        
        # Penalize large jumps in attention flow
        smoothness_loss = torch.mean(torch.abs(flow_diff))
        return smoothness_loss
    
    def _compute_cross_modal_loss(self, predictions, target_pairs, device):
        """Cross-modal alignment loss"""
        attention_targets, engagement_targets = target_pairs
        
        # Ensure all tensors are on the correct device
        if attention_targets.device != device:
            attention_targets = attention_targets.to(device)
        if engagement_targets.device != device:
            engagement_targets = engagement_targets.to(device)
        
        # Predict engagement from attention features
        pred_engagement = predictions
        target_engagement = engagement_targets.mean(dim=1)  # Pool over sequence: [batch, features]
        
        # Pool predictions over sequence to match target dimensions
        pred_engagement_pooled = pred_engagement.mean(dim=1)  # [batch, seq, features] -> [batch, features]
        
        # Match dimensions
        min_dim = min(pred_engagement_pooled.size(-1), target_engagement.size(-1))
        
        if min_dim > 0:
            cross_modal_loss = self.mse_loss(
                pred_engagement_pooled[:, :min_dim], 
                target_engagement[:, :min_dim]
            )
            return cross_modal_loss
        else:
            # Return zero loss tensor on the correct device
            return torch.tensor(0.0, device=device, requires_grad=True)

class BehavioralSSLTrainer:
    """SSL Trainer for behavioral data - Full GPU utilization"""
    
    def __init__(self, config, actual_feature_dim):
        self.config = config
        self.device = config.DEVICE
        
        # Model
        self.model = BehavioralTransformer(config, actual_feature_dim).to(self.device)
        
        # Data parallel if multiple GPUs - maximize utilization
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
            self.model = nn.DataParallel(self.model)
        
        # Loss
        self.criterion = BehavioralSSLLoss(config)
        
        # Optimizer - higher learning rate for larger model
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=config.SSL_LR,
            weight_decay=1e-4,
            betas=(0.9, 0.95),  # Better for transformers
            eps=1e-8
        )
        
        # Scheduler with warmup
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=config.SSL_LR,
            epochs=config.SSL_EPOCHS,
            steps_per_epoch=100,  # Approximate
            pct_start=0.1,
            anneal_strategy='cos'
        )
        
        # Gradient scaler for mixed precision
        self.scaler = torch.cuda.amp.GradScaler() if self.device == 'cuda' else None
    
    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0
        loss_components_sum = {}
        num_batches = 0
        
        pbar = tqdm(dataloader, desc='Training')
        for batch_idx, batch in enumerate(pbar):
            try:
                features = batch['features'].to(self.device, non_blocking=True)
                ssl_targets = batch['ssl_targets']
                
                # Move targets to device - Fixed to handle all tensor types properly
                for key, value in ssl_targets.items():
                    if isinstance(value, torch.Tensor):
                        ssl_targets[key] = value.to(self.device, non_blocking=True)
                    elif isinstance(value, tuple):
                        ssl_targets[key] = tuple(
                            v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                            for v in value
                        )
                
                # Forward pass with mixed precision
                if self.scaler is not None:
                    with torch.cuda.amp.autocast():
                        predictions = self.model(features)
                        loss, loss_components = self.criterion(predictions, ssl_targets)
                    
                    # Backward pass
                    self.optimizer.zero_grad()
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    predictions = self.model(features)
                    loss, loss_components = self.criterion(predictions, ssl_targets)
                    
                    # Backward pass
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.optimizer.step()
                
                self.scheduler.step()
                
                total_loss += loss.item()
                num_batches += 1
                
                # Accumulate loss components
                for key, value in loss_components.items():
                    if key not in loss_components_sum:
                        loss_components_sum[key] = 0
                    loss_components_sum[key] += value.item()
                
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}',
                    'gpu_mem': f'{torch.cuda.memory_allocated()/1e9:.1f}GB' if torch.cuda.is_available() else 'N/A'
                })
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {str(e)}")
                continue
        
        return total_loss / max(num_batches, 1), loss_components_sum
    
    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            pbar = tqdm(dataloader, desc='Validation')
            for batch in pbar:
                try:
                    features = batch['features'].to(self.device, non_blocking=True)
                    ssl_targets = batch['ssl_targets']
                    
                    # Move targets to device - Fixed to handle all tensor types properly
                    for key, value in ssl_targets.items():
                        if isinstance(value, torch.Tensor):
                            ssl_targets[key] = value.to(self.device, non_blocking=True)
                        elif isinstance(value, tuple):
                            ssl_targets[key] = tuple(
                                v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                                for v in value
                            )
                    
                    if self.scaler is not None:
                        with torch.cuda.amp.autocast():
                            predictions = self.model(features)
                            loss, _ = self.criterion(predictions, ssl_targets)
                    else:
                        predictions = self.model(features)
                        loss, _ = self.criterion(predictions, ssl_targets)
                    
                    total_loss += loss.item()
                    num_batches += 1
                    
                    pbar.set_postfix({
                        'val_loss': f'{loss.item():.4f}',
                        'gpu_mem': f'{torch.cuda.memory_allocated()/1e9:.1f}GB' if torch.cuda.is_available() else 'N/A'
                    })
                    
                except Exception as e:
                    print(f"Error in validation batch: {str(e)}")
                    continue
        
        return total_loss / max(num_batches, 1)
    
    def save_model(self, path):
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        torch.save({
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'actual_feature_dim': model_to_save.actual_feature_dim
        }, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        if hasattr(self.model, 'module'):
            self.model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])



def test_dataset_loading(data_path, max_files=5):
    """Test dataset loading with detailed diagnostics"""
    print(f"Testing dataset loading from: {data_path}")
    
    if not os.path.exists(data_path):
        print(f"Path does not exist: {data_path}")
        return False
    
    try:
        test_dataset = BehavioralDataset(data_path, sequence_length=750, is_train=True, max_files=max_files)
        
        if len(test_dataset) == 0:
            print("No data found in dataset")
            return False
        
        # Test loading multiple samples
        print(f"\nTesting {min(3, len(test_dataset))} samples:")
        for i in range(min(3, len(test_dataset))):
            sample = test_dataset[i]
            print(f"Sample {i}:")
            print(f"  Features shape: {sample['features'].shape}")
            print(f"  Video name: {sample['video_name']}")
            print(f"  SSL targets keys: {list(sample['ssl_targets'].keys())}")
            
            # Check temporal target shapes
            if 'temporal_future' in sample['ssl_targets']:
                print(f"  Temporal future shape: {sample['ssl_targets']['temporal_future'].shape}")
            if 'temporal_context' in sample['ssl_targets']:
                print(f"  Temporal context shape: {sample['ssl_targets']['temporal_context'].shape}")
            
            # Check for actual data (not all zeros)
            if torch.sum(torch.abs(sample['features'])) > 0:
                print(f"  ✓ Contains non-zero data")
            else:
                print(f"  ⚠ Warning: All zeros detected")
        
        return True
        
    except Exception as e:
        print(f"Error testing dataset: {str(e)}")
        import traceback
        traceback.print_exc()
        return False

In [11]:




def main():
    config = Config()
    
    print("=== Behavioral SSL Training - Full Scale ===")
    print(f"Device: {config.DEVICE}")
    print(f"Available GPUs: {torch.cuda.device_count()}")     
    print(f"Sequence Length: {config.SEQUENCE_LENGTH} frames (30 seconds)")
    print(f"Batch Size: {config.BATCH_SIZE}")
    print(f"Model Size: D_MODEL={config.D_MODEL}, Layers={config.N_ENCODER_LAYERS}")
    
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_properties(i).name}")
            print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f}GB")
    
    # Test dataset loading first
    print("\n=== Testing Dataset Loading ===")
    if not test_dataset_loading(config.TRAIN_PATH, max_files=3):
        print("Dataset loading test failed. Please check your data paths and format.")
        return
    
    print("\n=== Creating Datasets ===")
    try:
 
        train_dataset = BehavioralDataset(
            config.TRAIN_PATH, 
            sequence_length=config.SEQUENCE_LENGTH, 
            is_train=True
        )
        
        val_dataset = BehavioralDataset(
            config.TEST_PATH, 
            sequence_length=config.SEQUENCE_LENGTH, 
            is_train=False
        )
        
        print(f"Train samples: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")
        
        # Get actual feature dimension
        sample = train_dataset[0]
        actual_feature_dim = sample['features'].shape[-1]
        print(f"Actual feature dimension: {actual_feature_dim}")
        
    except Exception as e:
        print(f"Error creating datasets: {str(e)}")
        return
    
    # Create data loaders with optimized settings
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=2,  # Reduced for stability
        pin_memory=True if config.DEVICE == 'cuda' else False,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True if config.DEVICE == 'cuda' else False,
        drop_last=False
    )
    
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    
    # Initialize trainer
    print("\n=== Initializing Model ===")
    trainer = BehavioralSSLTrainer(config, actual_feature_dim)
    
    # Model info
    total_params = sum(p.numel() for p in trainer.model.parameters())
    trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: ~{total_params * 4 / 1e9:.2f}GB (FP32)")
    
    # Training loop
    print("\n=== Starting SSL Training ===")
    
    # --- EARLY STOPPING PARAMETERS ---
    best_val_loss = float('inf')
    epochs_no_improve = 0
    early_stopping_patience = 5  # Number of epochs to wait for improvement
    min_delta = 0.005            # Minimum change in validation loss to count as an improvement
    # --- END EARLY STOPPING PARAMETERS ---

    for epoch in range(config.SSL_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{config.SSL_EPOCHS}")
        print("-" * 50)
        
        # Training
        train_loss, train_components = trainer.train_epoch(train_loader)
        
        # Validation
        val_loss = trainer.validate(val_loader)
        
        # Print epoch results
        print(f"\nEpoch {epoch + 1} Results:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        
        if train_components:
            print("Training Loss Components:")
            for component, value in train_components.items():
                avg_value = value / len(train_loader)
                print(f"  {component}: {avg_value:.4f}")
        
        # --- EARLY STOPPING LOGIC ---
        # Check if validation loss has improved by at least min_delta
        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            epochs_no_improve = 0 # Reset patience
            print(f"New best validation loss: {val_loss:.4f}. Saving model...")
            trainer.save_model('best_behavioral_ssl_model.pth') # Save the best model
            print("Model saved!")
        else:
            epochs_no_improve += 1
            print(f"Validation loss did not improve by {min_delta:.4f}. Patience: {epochs_no_improve}/{early_stopping_patience}")

        if epochs_no_improve >= early_stopping_patience:
            print(f"\nEarly stopping triggered after {early_stopping_patience} epochs without significant improvement.")
            break # Exit the training loop
        # --- END EARLY STOPPING LOGIC ---
        
        # Memory cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.1f}GB allocated, "
                  f"{torch.cuda.memory_reserved()/1e9:.1f}GB reserved")
    
    print("\n=== Training Complete ===")
    print(f"Final best validation loss: {best_val_loss:.4f}")
    print("Best model checkpoint saved as 'best_behavioral_ssl_model.pth'")
    
    # Test model inference
    print("\n=== Testing Model Inference ===")
    try:
        # Load the best model if it was saved, before inference
        print("Loading best model for inference test...")
        trainer.model.load_state_dict(torch.load('best_behavioral_ssl_model.pth'))
        trainer.model.eval()
        
        test_batch = next(iter(val_loader))
        test_features = test_batch['features'].to(config.DEVICE)
        
        with torch.no_grad():
            # Test embeddings extraction
            embeddings = trainer.model(test_features, return_embeddings=True)
            print(f"Embeddings shape: {embeddings.shape}")
            
            # Test SSL predictions
            predictions = trainer.model(test_features)
            print("SSL prediction shapes:")
            for key, value in predictions.items():
                print(f"  {key}: {value.shape}")
        
        print("✓ Model inference test successful!")
        
    except Exception as e:
        print(f"Model inference test failed: {str(e)}")


if __name__ == "__main__":
    main()

=== Behavioral SSL Training - Full Scale ===
Device: cuda
Available GPUs: 2
Sequence Length: 750 frames (30 seconds)
Batch Size: 16
Model Size: D_MODEL=768, Layers=8
GPU 0: Tesla T4
  Memory: 15.8GB
GPU 1: Tesla T4
  Memory: 15.8GB

=== Testing Dataset Loading ===
Testing dataset loading from: /kaggle/input/daisee-feature-processed/Train-final
Found 3 videos in /kaggle/input/daisee-feature-processed/Train-final

Testing 3 samples:
Sample 0:
  Features shape: torch.Size([750, 79])
  Video name: 4100241051
  SSL targets keys: ['temporal_context', 'temporal_future', 'attention_trajectory', 'engagement_trajectory', 'emotion_trajectory', 'cross_modal_pairs']
  Temporal future shape: torch.Size([700, 79])
  Temporal context shape: torch.Size([700, 79])
  ✓ Contains non-zero data
Sample 1:
  Features shape: torch.Size([750, 79])
  Video name: 1100072077
  SSL targets keys: ['temporal_context', 'temporal_future', 'attention_trajectory', 'engagement_trajectory', 'emotion_trajectory', 'cross_mod

Training:   8%|▊         | 32/406 [00:17<03:25,  1.82it/s, loss=5.1867, lr=8.08e-06, gpu_mem=3.9GB]


KeyboardInterrupt: 