#360 degree audio-visual content based viewport prediction
LSTM-ekf hybrid model


## 1. Setup & Dependencies

In [None]:
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple, List, Optional, Dict
from scipy.io import wavfile
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

## 2. Configuration

In [None]:
class Config:
    # Local development paths
    DATA_ROOT = Path("/home/kirill/Projects/audioVisualVRAttentionModel/data")
    HEAD_DATA_DIR = DATA_ROOT / "head"
    AUDIO_DIR = DATA_ROOT / "360videos_with_ambisonic"      
    
    DEV_MODE = True
    DEV_VIDEO_ID = "0001"
    
    # Audio settings
    USE_AUDIO = True
    AUDIO_SAMPLE_RATE = 48000  # D-SAV360 ambisonic is 48kHz
    AUDIO_FEATURE_DIM = 9  # 1 omnidirectional + 8 directional sectors
    VIDEO_FPS = 60
    
    PREDICTION_HORIZON_SEC = 2.5
    INPUT_HISTORY_SEC = 2.0
    SAMPLE_RATE_HZ = 90  # Head tracking sample rate
    PREDICTION_STEPS = int(PREDICTION_HORIZON_SEC * SAMPLE_RATE_HZ)
    INPUT_STEPS = int(INPUT_HISTORY_SEC * SAMPLE_RATE_HZ)
    
    LSTM_HIDDEN_SIZE = 128
    LSTM_NUM_LAYERS = 2
    LSTM_DROPOUT = 0.2
    
    EKF_PROCESS_NOISE = 0.01
    EKF_MEASUREMENT_NOISE = 0.001
    
    BATCH_SIZE = 64
    LEARNING_RATE = 1e-3
    NUM_EPOCHS = 50
    EARLY_STOPPING_PATIENCE = 10
    
    EVAL_HORIZONS = [0.5, 1.0, 1.5, 2.0, 2.5]
    TRAIN_RATIO = 0.7
    VAL_RATIO = 0.15

config = Config()
print(f"Prediction: {config.INPUT_STEPS} steps -> {config.PREDICTION_STEPS} steps")
print(f"Audio features: {config.USE_AUDIO} (dim={config.AUDIO_FEATURE_DIM})")

## 3. Spherical Geometry Utilities

In [None]:
class SphericalUtils:
    @staticmethod
    def uv_to_unit_vector(u, v):
        theta = u * 2 * np.pi
        phi = (v - 0.5) * np.pi
        x = np.cos(phi) * np.cos(theta)
        y = np.cos(phi) * np.sin(theta)
        z = np.sin(phi)
        return np.stack([x, y, z], axis=-1)
    
    @staticmethod
    def unit_vector_to_uv(p):
        x, y, z = p[..., 0], p[..., 1], p[..., 2]
        theta = np.arctan2(y, x)
        theta = np.where(theta < 0, theta + 2 * np.pi, theta)
        phi = np.arcsin(np.clip(z, -1, 1))
        return theta / (2 * np.pi), phi / np.pi + 0.5
    
    @staticmethod
    def tangent_velocity(p_t, p_next):
        dot = np.sum(p_t * p_next, axis=-1, keepdims=True)
        dot = np.clip(dot, -1.0, 1.0)
        tangent = p_next - dot * p_t
        tangent_norm = np.linalg.norm(tangent, axis=-1, keepdims=True) + 1e-8
        angle = np.arccos(dot)
        return (tangent / tangent_norm) * angle
    
    @staticmethod
    def exp_map(p, v):
        v_norm = np.linalg.norm(v, axis=-1, keepdims=True) + 1e-8
        result = np.cos(v_norm) * p + np.sin(v_norm) * (v / v_norm)
        return result / (np.linalg.norm(result, axis=-1, keepdims=True) + 1e-8)
    
    @staticmethod
    def normalize(p):
        return p / (np.linalg.norm(p, axis=-1, keepdims=True) + 1e-8)

In [None]:
class SphericalUtilsTorch:
    @staticmethod
    def normalize(p):
        return p / (torch.norm(p, dim=-1, keepdim=True) + 1e-8)
    
    @staticmethod
    def exp_map(p, v):
        v_norm = torch.norm(v, dim=-1, keepdim=True) + 1e-8
        result = torch.cos(v_norm) * p + torch.sin(v_norm) * (v / v_norm)
        return SphericalUtilsTorch.normalize(result)
    
    @staticmethod
    def cosine_loss(p_pred, p_target):
        return 1.0 - torch.sum(p_pred * p_target, dim=-1)
    
    @staticmethod
    def angular_error_degrees(p_pred, p_target):
        dot = torch.clamp(torch.sum(p_pred * p_target, dim=-1), -1.0, 1.0)
        return torch.acos(dot) * (180.0 / np.pi)

## 3.5 Audio Feature Extraction

In [None]:
class AudioFeatureExtractor:
    """Extract spatial audio features from D-SAV360 ambisonic WAV files.
    
    D-SAV360 uses first-order ambisonics (FOA) with 4 channels:
    - W: Omnidirectional (pressure)
    - X: Front-back axis
    - Y: Left-right axis  
    - Z: Up-down axis
    """
    
    def __init__(self, audio_dir, video_id, num_sectors=8, 
                 audio_sample_rate=48000, head_sample_rate=90):
        self.audio_dir = Path(audio_dir)
        self.video_id = video_id
        self.num_sectors = num_sectors
        self.audio_sample_rate = audio_sample_rate
        self.head_sample_rate = head_sample_rate
        self.feature_dim = num_sectors + 1  # omnidirectional + sectors
        
        # Load ambisonic audio
        self.audio_data = None
        self._load_audio()
        
    def _load_audio(self):
        """Load ambisonic WAV file for the video."""
        # D-SAV360 structure: audio_dir/video_id/video_id.wav
        wav_path = self.audio_dir / self.video_id / f"{self.video_id}.wav"
        
        if not wav_path.exists():
            print(f"Warning: Audio file not found at {wav_path}")
            return
        
        try:
            sample_rate, audio = wavfile.read(wav_path)
            
            # Normalize to float [-1, 1]
            if audio.dtype == np.int16:
                audio = audio.astype(np.float32) / 32768.0
            elif audio.dtype == np.int32:
                audio = audio.astype(np.float32) / 2147483648.0
            
            # Expect 4 channels for FOA (W, X, Y, Z)
            if len(audio.shape) == 1:
                print(f"Warning: Mono audio, expected 4-channel ambisonics")
                audio = audio.reshape(-1, 1)
            
            self.audio_data = audio
            self.actual_sample_rate = sample_rate
            print(f"Loaded audio: {wav_path.name}, {audio.shape}, {sample_rate}Hz")
            
        except Exception as e:
            print(f"Error loading audio: {e}")
            self.audio_data = None
    
    def _time_to_audio_sample(self, t_seconds):
        """Convert head tracking timestamp to audio sample index."""
        return int(t_seconds * self.actual_sample_rate)
    
    def _get_directional_energy(self, w, x, y, z, theta, phi):
        """Compute energy in a specific direction from ambisonics.
        
        Args:
            w, x, y, z: Ambisonic channels (scalars or arrays)
            theta: Azimuth angle (0 = front, pi/2 = left)
            phi: Elevation angle (0 = horizon, pi/2 = up)
        
        Returns:
            Directional energy estimate
        """
        # Spherical harmonic decoding for first-order ambisonics
        # Direction vector
        dx = np.cos(phi) * np.cos(theta)
        dy = np.cos(phi) * np.sin(theta)
        dz = np.sin(phi)
        
        # Decode: pressure + velocity components weighted by direction
        decoded = w + dx * x + dy * y + dz * z
        
        return np.mean(decoded ** 2)  # Energy
    
    def extract_features(self, u, v, t_seconds, window_ms=50):
        """Extract audio features at viewport position and time.
        
        Args:
            u, v: Normalized viewport coordinates (0-1)
            t_seconds: Timestamp in seconds
            window_ms: Analysis window in milliseconds
        
        Returns:
            Feature vector: [omnidirectional_energy, sector_0, ..., sector_7]
        """
        if self.audio_data is None:
            return np.zeros(self.feature_dim, dtype=np.float32)
        
        # Get audio window
        center_sample = self._time_to_audio_sample(t_seconds)
        window_samples = int(window_ms / 1000 * self.actual_sample_rate)
        start = max(0, center_sample - window_samples // 2)
        end = min(len(self.audio_data), center_sample + window_samples // 2)
        
        if start >= end or end > len(self.audio_data):
            return np.zeros(self.feature_dim, dtype=np.float32)
        
        audio_window = self.audio_data[start:end]
        
        # Handle different channel counts
        if audio_window.shape[1] >= 4:
            w, x, y, z = audio_window[:, 0], audio_window[:, 1], audio_window[:, 2], audio_window[:, 3]
        elif audio_window.shape[1] == 1:
            # Mono fallback
            w = audio_window[:, 0]
            x = y = z = np.zeros_like(w)
        else:
            return np.zeros(self.feature_dim, dtype=np.float32)
        
        # Omnidirectional energy (W channel)
        omni_energy = np.mean(w ** 2)
        
        # Convert viewport (u, v) to spherical angles
        viewport_theta = u * 2 * np.pi  # Azimuth
        viewport_phi = (v - 0.5) * np.pi  # Elevation
        
        # Compute energy in sectors around viewport
        sector_features = []
        for i in range(self.num_sectors):
            # Sector direction relative to viewport
            sector_angle = 2 * np.pi * i / self.num_sectors
            sector_theta = viewport_theta + 0.3 * np.cos(sector_angle)  # ~17° offset
            sector_phi = viewport_phi + 0.3 * np.sin(sector_angle)
            sector_phi = np.clip(sector_phi, -np.pi/2, np.pi/2)
            
            energy = self._get_directional_energy(w, x, y, z, sector_theta, sector_phi)
            sector_features.append(energy)
        
        # Normalize features
        features = np.array([omni_energy] + sector_features, dtype=np.float32)
        max_val = features.max()
        if max_val > 0:
            features = features / max_val  # Normalize to [0, 1]
        
        return features
    
    def extract_sequence_features(self, u_seq, v_seq, t_seq):
        """Extract features for a sequence of positions."""
        features = np.zeros((len(t_seq), self.feature_dim), dtype=np.float32)
        for i, (u, v, t) in enumerate(zip(u_seq, v_seq, t_seq)):
            features[i] = self.extract_features(u, v, t)
        return features

## 4. Data Loading

In [None]:
def load_head_tracking_data(video_id, data_dir):
    file_path = data_dir / f"head_video_{video_id}.csv"
    df = pd.read_csv(file_path)
    unit_vecs = SphericalUtils.uv_to_unit_vector(df['u'].values, df['v'].values)
    df['px'], df['py'], df['pz'] = unit_vecs[:, 0], unit_vecs[:, 1], unit_vecs[:, 2]
    return df

def split_by_participant(participant_ids, train_ratio=0.7, val_ratio=0.15, seed=42):
    np.random.seed(seed)
    ids = np.array(participant_ids)
    np.random.shuffle(ids)
    n = len(ids)
    n_train, n_val = int(n * train_ratio), int(n * val_ratio)
    return ids[:n_train].tolist(), ids[n_train:n_train+n_val].tolist(), ids[n_train+n_val:].tolist()

## 5. Dataset

In [None]:
class ViewportDataset(Dataset):
    def __init__(self, df, participant_ids, input_steps, prediction_steps, 
                 eval_horizons_steps, audio_extractor=None, use_audio=False):
        self.input_steps = input_steps
        self.prediction_steps = prediction_steps
        self.eval_horizons_steps = eval_horizons_steps
        self.use_audio = use_audio and audio_extractor is not None
        self.audio_dim = audio_extractor.feature_dim if self.use_audio else 0
        self.sequences = []
        
        df_filtered = df[df['id'].isin(participant_ids)].copy()
        
        for pid in participant_ids:
            pdf = df_filtered[df_filtered['id'] == pid].sort_values('t')
            if len(pdf) < input_steps + prediction_steps:
                continue
                
            positions = pdf[['px', 'py', 'pz']].values
            timestamps = pdf['t'].values
            u_vals = pdf['u'].values
            v_vals = pdf['v'].values
            
            # Compute velocities
            velocities = np.zeros_like(positions)
            velocities[:-1] = SphericalUtils.tangent_velocity(positions[:-1], positions[1:])
            velocities[-1] = velocities[-2]
            
            # Extract audio features if enabled
            if self.use_audio:
                audio_features = audio_extractor.extract_sequence_features(
                    u_vals, v_vals, timestamps
                )
            else:
                audio_features = None
            
            total_len = input_steps + prediction_steps
            for i in range(len(positions) - total_len + 1):
                seq_data = {
                    'positions': positions[i:i+total_len],
                    'velocities': velocities[i:i+total_len],
                    'pid': pid
                }
                if self.use_audio:
                    seq_data['audio'] = audio_features[i:i+total_len]
                self.sequences.append(seq_data)
        
        print(f"Created {len(self.sequences)} sequences (audio={self.use_audio})")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        targets = [seq['positions'][self.input_steps - 1 + h] for h in self.eval_horizons_steps]
        
        item = {
            'input_positions': torch.FloatTensor(seq['positions'][:self.input_steps]),
            'input_velocities': torch.FloatTensor(seq['velocities'][:self.input_steps]),
            'targets': torch.FloatTensor(np.array(targets))
        }
        
        if self.use_audio:
            item['input_audio'] = torch.FloatTensor(seq['audio'][:self.input_steps])
        
        return item

## 6. Extended Kalman Filter (Frozen)

In [None]:
class SphericalEKF:
    def __init__(self, process_noise=0.01, measurement_noise=0.001, dt=1.0/90):
        self.dt = dt
        self.Q = np.eye(6) * process_noise
        self.Q[:3, :3] *= 0.1
        self.R = np.eye(3) * measurement_noise
        self.F = np.eye(6)
        self.F[:3, 3:] = np.eye(3) * dt
        self.H = np.zeros((3, 6))
        self.H[:3, :3] = np.eye(3)
        self.reset()
    
    def reset(self):
        self.x = np.zeros(6)
        self.x[2] = 1.0
        self.P = np.eye(6) * 0.1
        self.initialized = False
    
    def predict_trajectory(self, steps):
        trajectory = np.zeros((steps, 3))
        x_pred = self.x.copy()
        for i in range(steps):
            x_pred = self.F @ x_pred
            x_pred[:3] = SphericalUtils.normalize(x_pred[:3])
            trajectory[i] = x_pred[:3]
        return trajectory
    
    def update(self, measurement):
        if not self.initialized:
            self.x[:3] = measurement
            self.initialized = True
            return measurement, 0.0
        
        x_pred = self.F @ self.x
        x_pred[:3] = SphericalUtils.normalize(x_pred[:3])
        P_pred = self.F @ self.P @ self.F.T + self.Q
        y = measurement - self.H @ x_pred
        innovation_mag = np.linalg.norm(y)
        S = self.H @ P_pred @ self.H.T + self.R
        K = P_pred @ self.H.T @ np.linalg.inv(S)
        self.x = x_pred + K @ y
        self.x[:3] = SphericalUtils.normalize(self.x[:3])
        self.P = (np.eye(6) - K @ self.H) @ P_pred
        return self.x[:3], innovation_mag

In [None]:
class BatchEKF:
    def __init__(self, config):
        self.ekf = SphericalEKF(config.EKF_PROCESS_NOISE, config.EKF_MEASUREMENT_NOISE)
    
    def process_batch(self, input_positions, eval_horizons_steps):
        batch_size, seq_len = input_positions.shape[:2]
        positions_np = input_positions.cpu().numpy()
        ekf_preds = np.zeros((batch_size, len(eval_horizons_steps), 3))
        innovations = np.zeros((batch_size, seq_len))
        
        for b in range(batch_size):
            self.ekf.reset()
            for t in range(seq_len):
                _, innov = self.ekf.update(positions_np[b, t])
                innovations[b, t] = innov
            traj = self.ekf.predict_trajectory(max(eval_horizons_steps))
            for i, h in enumerate(eval_horizons_steps):
                ekf_preds[b, i] = traj[h - 1]
        
        return torch.FloatTensor(ekf_preds).to(input_positions.device), torch.FloatTensor(innovations).to(input_positions.device)

## 7. LSTM Model with Gating

In [None]:
class SphericalLSTM(nn.Module):
    def __init__(self, input_dim=6, hidden_size=128, num_layers=2, 
                 dropout=0.2, num_horizons=5, audio_dim=0):
        super().__init__()
        
        total_input_dim = input_dim + audio_dim
        self.audio_dim = audio_dim
        
        self.lstm = nn.LSTM(total_input_dim, hidden_size, num_layers, 
                           batch_first=True, dropout=dropout if num_layers > 1 else 0)
        
        self.prediction_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size // 2, 3)
            ) for _ in range(num_horizons)
        ])
        
        self.gate_head = nn.Sequential(
            nn.Linear(hidden_size + 1, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, num_horizons),
            nn.Sigmoid()
        )
    
    def forward(self, input_positions, input_velocities, innovation_magnitude, input_audio=None):
        # Concatenate trajectory features
        lstm_input = torch.cat([input_positions, input_velocities], dim=-1)
        
        # Add audio if available
        if input_audio is not None and self.audio_dim > 0:
            lstm_input = torch.cat([lstm_input, input_audio], dim=-1)
        
        lstm_out, _ = self.lstm(lstm_input)
        hidden = lstm_out[:, -1, :]
        
        corrections = torch.stack([head(hidden) for head in self.prediction_heads], dim=1)
        gate_input = torch.cat([hidden, innovation_magnitude.mean(dim=-1, keepdim=True)], dim=-1)
        gates = self.gate_head(gate_input)
        
        return corrections, gates

In [None]:
class KalmanLSTMHybrid(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.use_audio = config.USE_AUDIO
        self.eval_horizons_steps = [int(h * config.SAMPLE_RATE_HZ) for h in config.EVAL_HORIZONS]
        self.batch_ekf = BatchEKF(config)
        
        audio_dim = config.AUDIO_FEATURE_DIM if config.USE_AUDIO else 0
        self.lstm = SphericalLSTM(
            input_dim=6,
            hidden_size=config.LSTM_HIDDEN_SIZE,
            num_layers=config.LSTM_NUM_LAYERS,
            dropout=config.LSTM_DROPOUT,
            num_horizons=len(config.EVAL_HORIZONS),
            audio_dim=audio_dim
        )
    
    def forward(self, input_positions, input_velocities, input_audio=None):
        ekf_preds, innovations = self.batch_ekf.process_batch(
            input_positions, self.eval_horizons_steps
        )
        
        corrections, gates = self.lstm(
            input_positions, input_velocities, innovations, input_audio
        )
        
        gated_corrections = gates.unsqueeze(-1) * corrections
        predictions = SphericalUtilsTorch.exp_map(ekf_preds, gated_corrections)
        
        return predictions, {
            'ekf_predictions': ekf_preds,
            'gates': gates,
            'innovations': innovations
        }

## 8. Loss & Metrics

In [None]:
class MultiHorizonLoss(nn.Module):
    def __init__(self, horizon_weights=None):
        super().__init__()
        self.horizon_weights = horizon_weights
    
    def forward(self, predictions, targets):
        cosine_losses = SphericalUtilsTorch.cosine_loss(predictions, targets)
        if self.horizon_weights:
            weights = torch.tensor(self.horizon_weights, device=predictions.device)
            loss = (cosine_losses * weights.unsqueeze(0)).mean()
        else:
            loss = cosine_losses.mean()
        return loss, {f'loss_h{i}': cosine_losses[:, i].mean().item() for i in range(cosine_losses.shape[1])}

def evaluate_model(model, dataloader, config, device):
    model.eval()
    all_errors = {h: [] for h in config.EVAL_HORIZONS}
    with torch.no_grad():
        for batch in dataloader:
            input_audio = batch.get('input_audio', None)
            if input_audio is not None:
                input_audio = input_audio.to(device)
            
            preds, _ = model(
                batch['input_positions'].to(device),
                batch['input_velocities'].to(device),
                input_audio
            )
            targets = batch['targets'].to(device)
            for i, h in enumerate(config.EVAL_HORIZONS):
                errors = SphericalUtilsTorch.angular_error_degrees(preds[:, i], targets[:, i])
                all_errors[h].extend(errors.cpu().numpy().tolist())
    return {f'MAE_{h}s': np.mean(all_errors[h]) for h in config.EVAL_HORIZONS}

## 9. Training Loop

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, num_batches = 0.0, 0
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        input_audio = batch.get('input_audio', None)
        if input_audio is not None:
            input_audio = input_audio.to(device)
        
        preds, _ = model(
            batch['input_positions'].to(device),
            batch['input_velocities'].to(device),
            input_audio
        )
        
        loss, _ = criterion(preds, batch['targets'].to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return {'train_loss': total_loss / num_batches}

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss, num_batches = 0.0, 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_audio = batch.get('input_audio', None)
            if input_audio is not None:
                input_audio = input_audio.to(device)
            
            preds, _ = model(
                batch['input_positions'].to(device),
                batch['input_velocities'].to(device),
                input_audio
            )
            loss, _ = criterion(preds, batch['targets'].to(device))
            total_loss += loss.item()
            num_batches += 1
    
    return {'val_loss': total_loss / num_batches}

In [None]:
def train_model(model, train_loader, val_loader, config, device):
    model = model.to(device)
    criterion = MultiHorizonLoss([0.5, 0.7, 0.85, 1.0, 1.0])
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    
    history = {'train_loss': [], 'val_loss': []}
    best_val_loss, patience_counter, best_state = float('inf'), 0, None
    
    for epoch in range(config.NUM_EPOCHS):
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, device)
        val_metrics = validate(model, val_loader, criterion, device)
        scheduler.step(val_metrics['val_loss'])
        
        history['train_loss'].append(train_metrics['train_loss'])
        history['val_loss'].append(val_metrics['val_loss'])
        
        if val_metrics['val_loss'] < best_val_loss:
            best_val_loss = val_metrics['val_loss']
            patience_counter = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch+1}: train={train_metrics['train_loss']:.4f}, val={val_metrics['val_loss']:.4f}")
        
        if patience_counter >= config.EARLY_STOPPING_PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    if best_state:
        model.load_state_dict(best_state)
    return model, history

## 10. Model Save/Load & Inference

In [None]:
def save_model(model, config, filepath='viewport_model.pth'):
    """Save model weights and config for later inference."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'config': {
            'USE_AUDIO': config.USE_AUDIO,
            'AUDIO_FEATURE_DIM': config.AUDIO_FEATURE_DIM,
            'LSTM_HIDDEN_SIZE': config.LSTM_HIDDEN_SIZE,
            'LSTM_NUM_LAYERS': config.LSTM_NUM_LAYERS,
            'LSTM_DROPOUT': config.LSTM_DROPOUT,
            'EVAL_HORIZONS': config.EVAL_HORIZONS,
            'SAMPLE_RATE_HZ': config.SAMPLE_RATE_HZ,
            'INPUT_STEPS': config.INPUT_STEPS,
            'EKF_PROCESS_NOISE': config.EKF_PROCESS_NOISE,
            'EKF_MEASUREMENT_NOISE': config.EKF_MEASUREMENT_NOISE,
        }
    }
    torch.save(checkpoint, filepath)
    print(f"Model saved to {filepath}")

def load_model(filepath='viewport_model.pth', device='cpu'):
    """Load model from checkpoint for inference."""
    checkpoint = torch.load(filepath, map_location=device)
    cfg = checkpoint['config']
    
    # Recreate config object
    class LoadedConfig:
        pass
    loaded_config = LoadedConfig()
    for k, v in cfg.items():
        setattr(loaded_config, k, v)
    
    # Recreate model
    model = KalmanLSTMHybrid(loaded_config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded from {filepath}")
    return model, loaded_config

def predict_viewport(model, positions, velocities, audio_features=None, device='cpu'):
    """
    Run inference on a single trajectory sequence.
    
    Args:
        model: Trained KalmanLSTMHybrid model
        positions: np.array of shape (seq_len, 3) - unit vectors
        velocities: np.array of shape (seq_len, 3) - tangent velocities
        audio_features: Optional np.array of shape (seq_len, audio_dim)
        device: torch device
    
    Returns:
        predictions: np.array of shape (num_horizons, 3) - predicted unit vectors
        metadata: dict with EKF predictions, gates, etc.
    """
    model.eval()
    
    with torch.no_grad():
        pos_tensor = torch.FloatTensor(positions).unsqueeze(0).to(device)
        vel_tensor = torch.FloatTensor(velocities).unsqueeze(0).to(device)
        
        audio_tensor = None
        if audio_features is not None:
            audio_tensor = torch.FloatTensor(audio_features).unsqueeze(0).to(device)
        
        preds, metadata = model(pos_tensor, vel_tensor, audio_tensor)
        
        return preds[0].cpu().numpy(), {
            'ekf_predictions': metadata['ekf_predictions'][0].cpu().numpy(),
            'gates': metadata['gates'][0].cpu().numpy(),
        }

def predict_to_uv(model, positions, velocities, audio_features=None, device='cpu'):
    """Run inference and convert predictions to (u, v) coordinates."""
    preds, metadata = predict_viewport(model, positions, velocities, audio_features, device)
    
    # Convert unit vectors to UV
    u_preds, v_preds = SphericalUtils.unit_vector_to_uv(preds)
    
    return {
        'u': u_preds,
        'v': v_preds,
        'unit_vectors': preds,
        'ekf_predictions': metadata['ekf_predictions'],
        'gates': metadata['gates']
    }

## 11. Data Export for Graphs & Plots

In [None]:
def export_training_history(history, filepath='training_history.csv'):
    """Export training history to CSV for external plotting."""
    df = pd.DataFrame({
        'epoch': range(1, len(history['train_loss']) + 1),
        'train_loss': history['train_loss'],
        'val_loss': history['val_loss']
    })
    df.to_csv(filepath, index=False)
    print(f"Training history exported to {filepath}")
    return df

def export_error_metrics(metrics, config, filepath='error_metrics.csv'):
    """Export error metrics per horizon to CSV."""
    data = {
        'horizon_sec': config.EVAL_HORIZONS,
        'mae_degrees': [metrics[f'MAE_{h}s'] for h in config.EVAL_HORIZONS]
    }
    df = pd.DataFrame(data)
    df.to_csv(filepath, index=False)
    print(f"Error metrics exported to {filepath}")
    return df

def export_detailed_predictions(model, dataloader, config, device, filepath='detailed_predictions.csv'):
    """Export detailed predictions vs targets for analysis."""
    model.eval()
    results = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= 10:  # Limit to first 10 batches for file size
                break
                
            input_audio = batch.get('input_audio', None)
            if input_audio is not None:
                input_audio = input_audio.to(device)
            
            preds, metadata = model(
                batch['input_positions'].to(device),
                batch['input_velocities'].to(device),
                input_audio
            )
            targets = batch['targets'].to(device)
            
            preds_np = preds.cpu().numpy()
            targets_np = targets.cpu().numpy()
            ekf_preds_np = metadata['ekf_predictions'].cpu().numpy()
            gates_np = metadata['gates'].cpu().numpy()
            
            for i in range(preds_np.shape[0]):
                for h_idx, h in enumerate(config.EVAL_HORIZONS):
                    # Compute errors
                    pred = preds_np[i, h_idx]
                    target = targets_np[i, h_idx]
                    ekf_pred = ekf_preds_np[i, h_idx]
                    
                    dot_hybrid = np.clip(np.dot(pred, target), -1, 1)
                    dot_ekf = np.clip(np.dot(ekf_pred, target), -1, 1)
                    
                    error_hybrid = np.degrees(np.arccos(dot_hybrid))
                    error_ekf = np.degrees(np.arccos(dot_ekf))
                    
                    # Convert to UV
                    pred_u, pred_v = SphericalUtils.unit_vector_to_uv(pred)
                    target_u, target_v = SphericalUtils.unit_vector_to_uv(target)
                    
                    results.append({
                        'batch': batch_idx,
                        'sample': i,
                        'horizon_sec': h,
                        'pred_u': float(pred_u),
                        'pred_v': float(pred_v),
                        'target_u': float(target_u),
                        'target_v': float(target_v),
                        'error_hybrid_deg': error_hybrid,
                        'error_ekf_deg': error_ekf,
                        'gate_value': gates_np[i, h_idx]
                    })
    
    df = pd.DataFrame(results)
    df.to_csv(filepath, index=False)
    print(f"Detailed predictions exported to {filepath} ({len(df)} rows)")
    return df

def export_gate_analysis(model, dataloader, config, device, filepath='gate_analysis.csv'):
    """Export gate values and innovation magnitudes for analysis."""
    model.eval()
    results = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= 20:
                break
                
            input_audio = batch.get('input_audio', None)
            if input_audio is not None:
                input_audio = input_audio.to(device)
            
            _, metadata = model(
                batch['input_positions'].to(device),
                batch['input_velocities'].to(device),
                input_audio
            )
            
            gates = metadata['gates'].cpu().numpy()
            innovations = metadata['innovations'].cpu().numpy()
            
            for i in range(gates.shape[0]):
                for h_idx, h in enumerate(config.EVAL_HORIZONS):
                    results.append({
                        'batch': batch_idx,
                        'sample': i,
                        'horizon_sec': h,
                        'gate_value': gates[i, h_idx],
                        'mean_innovation': innovations[i].mean(),
                        'max_innovation': innovations[i].max()
                    })
    
    df = pd.DataFrame(results)
    df.to_csv(filepath, index=False)
    print(f"Gate analysis exported to {filepath} ({len(df)} rows)")
    return df

## 12. Visualization

In [None]:
def plot_training_history(history):
    """Plot training and validation loss curves."""
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(history['train_loss'], label='Train', linewidth=2)
    ax.plot(history['val_loss'], label='Validation', linewidth=2)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss (Cosine)', fontsize=12)
    ax.set_title('Training History', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: training_history.png")

def plot_error_vs_horizon(metrics, config):
    """Plot angular error vs prediction horizon."""
    horizons = config.EVAL_HORIZONS
    errors = [metrics[f'MAE_{h}s'] for h in horizons]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(horizons, errors, 'o-', lw=2.5, ms=10, color='#2196F3')
    ax.fill_between(horizons, 0, errors, alpha=0.2, color='#2196F3')
    ax.set_xlabel('Prediction Horizon (seconds)', fontsize=12)
    ax.set_ylabel('Mean Angular Error (degrees)', fontsize=12)
    ax.set_title('Prediction Accuracy vs Horizon', fontsize=14)
    ax.grid(True, alpha=0.3)
    
    for h, e in zip(horizons, errors):
        ax.annotate(f'{e:.1f}°', (h, e), textcoords='offset points', 
                   xytext=(0, 12), ha='center', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('error_vs_horizon.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: error_vs_horizon.png")

def plot_gate_distribution(gate_df):
    """Plot gate value distribution across horizons."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Gate values by horizon
    ax1 = axes[0]
    gate_df.boxplot(column='gate_value', by='horizon_sec', ax=ax1)
    ax1.set_xlabel('Prediction Horizon (s)', fontsize=12)
    ax1.set_ylabel('Gate Value', fontsize=12)
    ax1.set_title('Gate Values by Horizon', fontsize=14)
    plt.suptitle('')
    
    # Gate vs Innovation
    ax2 = axes[1]
    scatter = ax2.scatter(gate_df['mean_innovation'], gate_df['gate_value'], 
                         c=gate_df['horizon_sec'], cmap='viridis', alpha=0.5, s=20)
    ax2.set_xlabel('Mean Innovation Magnitude', fontsize=12)
    ax2.set_ylabel('Gate Value', fontsize=12)
    ax2.set_title('Gate vs EKF Innovation', fontsize=14)
    plt.colorbar(scatter, ax=ax2, label='Horizon (s)')
    
    plt.tight_layout()
    plt.savefig('gate_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: gate_analysis.png")

def plot_prediction_scatter(pred_df, horizon=2.5):
    """Plot predicted vs target positions for a specific horizon."""
    df_h = pred_df[pred_df['horizon_sec'] == horizon]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # U coordinate
    ax1 = axes[0]
    ax1.scatter(df_h['target_u'], df_h['pred_u'], alpha=0.5, s=20)
    ax1.plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect')
    ax1.set_xlabel('Target U', fontsize=12)
    ax1.set_ylabel('Predicted U', fontsize=12)
    ax1.set_title(f'U Coordinate @ {horizon}s', fontsize=14)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # V coordinate
    ax2 = axes[1]
    ax2.scatter(df_h['target_v'], df_h['pred_v'], alpha=0.5, s=20)
    ax2.plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect')
    ax2.set_xlabel('Target V', fontsize=12)
    ax2.set_ylabel('Predicted V', fontsize=12)
    ax2.set_title(f'V Coordinate @ {horizon}s', fontsize=14)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'prediction_scatter_{horizon}s.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved: prediction_scatter_{horizon}s.png")

def plot_ekf_vs_hybrid(pred_df):
    """Compare EKF-only vs Hybrid model errors."""
    summary = pred_df.groupby('horizon_sec').agg({
        'error_hybrid_deg': 'mean',
        'error_ekf_deg': 'mean'
    }).reset_index()
    
    fig, ax = plt.subplots(figsize=(10, 6))
    x = summary['horizon_sec']
    width = 0.15
    
    ax.bar(x - width/2, summary['error_ekf_deg'], width, label='EKF Only', color='#FF9800')
    ax.bar(x + width/2, summary['error_hybrid_deg'], width, label='Hybrid (EKF+LSTM)', color='#4CAF50')
    
    ax.set_xlabel('Prediction Horizon (seconds)', fontsize=12)
    ax.set_ylabel('Mean Angular Error (degrees)', fontsize=12)
    ax.set_title('EKF vs Hybrid Model Comparison', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('ekf_vs_hybrid.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved: ekf_vs_hybrid.png")

## 13. Main Execution

In [None]:
# Load data (single video for development)
print("Loading data...")
df = load_head_tracking_data(config.DEV_VIDEO_ID, config.HEAD_DATA_DIR)
print(f"Loaded {len(df)} samples from video {config.DEV_VIDEO_ID}")

# Initialize audio extractor with D-SAV360 ambisonic WAV
audio_extractor = None
if config.USE_AUDIO:
    audio_extractor = AudioFeatureExtractor(
        config.AUDIO_DIR, 
        config.DEV_VIDEO_ID,
        audio_sample_rate=config.AUDIO_SAMPLE_RATE,
        head_sample_rate=config.SAMPLE_RATE_HZ
    )
    print(f"Audio feature dim: {audio_extractor.feature_dim}")

# Split by participant
participant_ids = sorted(df['id'].unique().tolist())
train_ids, val_ids, test_ids = split_by_participant(participant_ids)
print(f"Participants - Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")

In [None]:
# Create datasets with audio
eval_horizons_steps = [int(h * config.SAMPLE_RATE_HZ) for h in config.EVAL_HORIZONS]

train_dataset = ViewportDataset(
    df, train_ids, config.INPUT_STEPS, config.PREDICTION_STEPS,
    eval_horizons_steps, audio_extractor, config.USE_AUDIO
)
val_dataset = ViewportDataset(
    df, val_ids, config.INPUT_STEPS, config.PREDICTION_STEPS,
    eval_horizons_steps, audio_extractor, config.USE_AUDIO
)
test_dataset = ViewportDataset(
    df, test_ids, config.INPUT_STEPS, config.PREDICTION_STEPS,
    eval_horizons_steps, audio_extractor, config.USE_AUDIO
)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE)

In [None]:
# Initialize and train model
model = KalmanLSTMHybrid(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

model, history = train_model(model, train_loader, val_loader, config, device)

In [None]:
# Save model weights for later inference
save_model(model, config, 'viewport_model.pth')

# Export training history
history_df = export_training_history(history, 'training_history.csv')

In [None]:
# Evaluate model and export metrics
test_metrics = evaluate_model(model, test_loader, config, device)
print("\nTest Results:")
for h in config.EVAL_HORIZONS:
    print(f"  MAE @ {h}s: {test_metrics[f'MAE_{h}s']:.2f}°")

# Export error metrics
metrics_df = export_error_metrics(test_metrics, config, 'error_metrics.csv')

In [None]:
# Export detailed predictions and gate analysis
pred_df = export_detailed_predictions(model, test_loader, config, device, 'detailed_predictions.csv')
gate_df = export_gate_analysis(model, test_loader, config, device, 'gate_analysis.csv')

In [None]:
# Generate all plots
plot_training_history(history)
plot_error_vs_horizon(test_metrics, config)
plot_ekf_vs_hybrid(pred_df)
plot_gate_distribution(gate_df)
plot_prediction_scatter(pred_df, horizon=2.5)