In [1]:
import sys
import warnings
import logging
import re

# Suppress warnings and logging
warnings.filterwarnings('ignore')
logging.disable(logging.WARNING)

# Core libraries
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

# Typing
from typing import List, Dict, Optional, Any

# External modules
from LRU_pytorch import LRU
from temporaldata import Data
from torch_brain.nn import InfiniteVocabEmbedding

# Hydra/OmegaConf
from omegaconf import DictConfig, OmegaConf

# Utils (project-specific)
sys.path.append("../")
from utils.data import get_dataset_config, get_train_val_loaders
from utils.preprocessing import bin_spikes
from utils.loss import r2_score, move_to_gpu
from utils.plotting import plot_training_curves

# Device and data root
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_root = "/nfs/ghome/live/mlaimon/data/foundational_ssm/motor/processed/"


In [None]:
import wandb
from omegaconf import DictConfig, OmegaConf

config = OmegaConf.create({
    "wandb": {
        "project": "foundational_ssm",
        "run_name": "ssm_neural_behavior",
        "tags": ["neural", "behavior", "masking"],
        "log_freq": 1  # Log validation metrics every 10 epochs
    },
    "dataset": {
        "name": "perich_miller_population_2018",
        "subjects": ["j"],
        "batch_size": 64
    },
    "model": {
        "num_neural_features": 192,
        "num_behavior_features": 2,
        "num_context_features": 32,
        "embedding_dim": 64,
        "ssm_projection_dim": 64,
        "ssm_hidden_dim": 64,
        "ssm_num_layers": 1,
        "ssm_dropout": 0.1,
        "pred_neural_dim": 192,
        "pred_behavior_dim": 2,
        "sequence_length": 1.0,
        "sampling_rate": 100,
        "lin_dropout": 0.1,
        "activation_fn": "relu"
    },
    "training": {
        "learning_rate": 1e-3,
        "mask_prob": 0.5,
        "num_epochs": 100,
        "neural_weight": 1.0,
        "behavior_weight": 1.0
    }
})

# Print the configuration
print(OmegaConf.to_yaml(config))

wandb:
  project: foundational_ssm
  run_name: ssm_neural_behavior
  tags:
  - neural
  - behavior
  - masking
  log_freq: 10
dataset:
  name: perich_miller_population_2018
  subjects:
  - j
  batch_size: 64
model:
  num_neural_features: 192
  num_behavior_features: 2
  num_context_features: 32
  embedding_dim: 64
  ssm_projection_dim: 64
  ssm_hidden_dim: 64
  ssm_num_layers: 1
  ssm_dropout: 0.1
  pred_neural_dim: 192
  pred_behavior_dim: 2
  sequence_length: 1.0
  sampling_rate: 100
  lin_dropout: 0.1
  activation_fn: relu
training:
  learning_rate: 0.001
  mask_prob: 0.5
  num_epochs: 100
  neural_weight: 1.0
  behavior_weight: 1.0



In [3]:
class ValidationMetrics:
    def __init__(self, device):
        self.device = device
        
    def compute_metrics(self, dataloader, model):
        model.eval()
        
        # Initialize overall metrics
        metrics = {
            "encoding_loss": 0.0,  
            "decoding_loss": 0.0,  
            "combined_loss": 0.0,
            "behavior_r2": 0.0
        }
        
        # Initialize per-subject metrics
        subject_metrics = {}
        for subj_id in model.subject_ids:
            subject_metrics[subj_id] = {
                "encoding_loss": 0.0,
                "decoding_loss": 0.0,
                "combined_loss": 0.0,
                "behavior_r2": 0.0,
                "sample_count": 0
            }
        
        # Collect predictions for R² calculation
        all_behavior_targets = []
        all_behavior_preds = []
        
        # Per-subject predictions and targets
        subject_behavior_targets = {subj_id: [] for subj_id in model.subject_ids}
        subject_behavior_preds = {subj_id: [] for subj_id in model.subject_ids}
        
        num_batches = 0
        
        with torch.no_grad():
            for batch in dataloader:
                batch = move_to_gpu(batch, self.device)
                batch_size = batch["behavior_input"].shape[0]
                
                # Forward pass for encoding (neural prediction)
                encoding_predictions = model(
                    **batch,
                    neural_mask=torch.zeros(batch_size, device=self.device),
                    behavior_mask=torch.ones(batch_size, device=self.device)
                )
                
                # Forward pass for decoding (behavior prediction)
                decoding_predictions = model(
                    **batch,
                    neural_mask=torch.ones(batch_size, device=self.device),
                    behavior_mask=torch.zeros(batch_size, device=self.device)
                )
                
                # Overall metrics
                encoding_loss = F.poisson_nll_loss(
                    input=encoding_predictions["pred_neural"],
                    target=batch["neural_input"],
                    log_input=False,
                    reduction='mean'
                )
                
                decoding_loss = F.mse_loss(
                    input=decoding_predictions["pred_behavior"],
                    target=batch["behavior_input"],
                    reduction='mean'
                )
                
                combined_loss = encoding_loss + decoding_loss
                
                metrics["encoding_loss"] += encoding_loss.item()
                metrics["decoding_loss"] += decoding_loss.item()
                metrics["combined_loss"] += combined_loss.item()
                
                # Collect overall predictions
                all_behavior_targets.append(batch["behavior_input"])
                all_behavior_preds.append(decoding_predictions["pred_behavior"])
                
                # Per-subject metrics
                for i, subj_id in enumerate(batch["subject_id"]):
                    # Calculate per-subject losses
                    subj_encoding_loss = F.poisson_nll_loss(
                        input=encoding_predictions["pred_neural"][i:i+1],
                        target=batch["neural_input"][i:i+1],
                        log_input=False,
                        reduction='mean'
                    )
                    
                    subj_decoding_loss = F.mse_loss(
                        input=decoding_predictions["pred_behavior"][i:i+1],
                        target=batch["behavior_input"][i:i+1],
                        reduction='mean'
                    )
                    
                    subject_metrics[subj_id]["encoding_loss"] += subj_encoding_loss.item()
                    subject_metrics[subj_id]["decoding_loss"] += subj_decoding_loss.item()
                    subject_metrics[subj_id]["combined_loss"] += (subj_encoding_loss + subj_decoding_loss).item()
                    subject_metrics[subj_id]["sample_count"] += 1
                    
                    # Store per-subject predictions for R2 calculation
                    subject_behavior_targets[subj_id].append(batch["behavior_input"][i:i+1])
                    subject_behavior_preds[subj_id].append(decoding_predictions["pred_behavior"][i:i+1])
                
                num_batches += 1
        
        # Average overall metrics
        for key in ["encoding_loss", "decoding_loss", "combined_loss"]:
            metrics[key] /= num_batches if num_batches > 0 else 1
            
        # Calculate overall behavior R²
        behavior_targets = torch.cat(all_behavior_targets)
        behavior_preds = torch.cat(all_behavior_preds)
        metrics["behavior_r2"] = r2_score(behavior_preds.cpu(), behavior_targets.cpu())
        
        # Calculate per-subject R² and average per-subject metrics
        for subj_id in model.subject_ids:
            if subject_metrics[subj_id]["sample_count"] > 0:
                # Average per-subject losses
                count = subject_metrics[subj_id]["sample_count"]
                for key in ["encoding_loss", "decoding_loss", "combined_loss"]:
                    subject_metrics[subj_id][key] /= count
                
                # Calculate per-subject R²
                if len(subject_behavior_targets[subj_id]) > 0:
                    subj_targets = torch.cat(subject_behavior_targets[subj_id])
                    subj_preds = torch.cat(subject_behavior_preds[subj_id])
                    subject_metrics[subj_id]["behavior_r2"] = r2_score(
                        subj_preds.cpu(), subj_targets.cpu()
                    )
        
        # Add per-subject metrics to the overall metrics
        metrics["per_subject"] = subject_metrics
        
        return metrics

In [None]:

def train(model, optimizer, train_loader, val_loader, loss_fn, config):
    
    # Initialize wandb and validation metrics
    wandb.init(
        project=config.wandb.project,
        name=config.wandb.run_name,
        tags=config.wandb.tags,
        config=OmegaConf.to_container(config, resolve=True)
    )
    wandb.watch(model, log="all", log_freq=config.wandb.log_freq)
    validator = ValidationMetrics(device)
    
    # Tracking metrics
    train_losses = []
    val_metrics_history = []

    # Training loop
    for epoch in range(config.training.num_epochs):
        model.train()
        epoch_loss = 0
        num_batches = 0
               

        # Training steps
        for batch_idx, batch in enumerate(train_loader):
            batch = move_to_gpu(batch, device)
            loss = training_step(batch, model, optimizer, loss_fn, config.training.mask_prob)
            train_losses.append(loss.item())
            epoch_loss += loss.item()
            num_batches += 1
            
            wandb.log({
                "batch/loss": loss.item(),
                "batch/step": epoch * len(train_loader) + batch_idx
            })
            
        avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
        wandb.log({
            "train/loss": avg_epoch_loss,
            "train/epoch": epoch + 1
        })

        # Evaluate on validation set periodically
        if (epoch + 1) % config.wandb.log_freq == 0 or epoch == config.training.num_epochs - 1:
            val_metrics = validator.compute_metrics(val_loader, model)
            val_metrics_history.append(val_metrics)
            
            # Log overall validation metrics
            wandb.log({
                "val/encoding_loss": val_metrics["encoding_loss"],
                "val/decoding_loss": val_metrics["decoding_loss"],
                "val/combined_loss": val_metrics["combined_loss"],
                "val/behavior_r2": val_metrics["behavior_r2"],
                "val/epoch": epoch + 1
            })
            
            # Log per-subject metrics
            for subj_id, subj_metrics in val_metrics["per_subject"].items():
                wandb.log({
                    f"val/subject/{subj_id}/encoding_loss": subj_metrics["encoding_loss"],
                    f"val/subject/{subj_id}/decoding_loss": subj_metrics["decoding_loss"],
                    f"val/subject/{subj_id}/combined_loss": subj_metrics["combined_loss"],
                    f"val/subject/{subj_id}/behavior_r2": subj_metrics["behavior_r2"],
                    "epoch": epoch + 1
                })
            
            # Print validation summary
            print(f"Epoch {epoch+1}/{config.training.num_epochs} | " +
                  f"Train Loss: {avg_epoch_loss:.4f} | " +
                  f"Val Decoding Loss: {val_metrics['decoding_loss']:.4f} | " +
                  f"Val Behavior R²: {val_metrics['behavior_r2']:.4f} | " +
                  f"Val Encoding Loss: {val_metrics['encoding_loss']:.4f}")
            
    # Close wandb run
    wandb.finish()
    return train_losses, val_metrics_history


def training_step(batch, model, optimizer, loss_fn, mask_prob=0.5):
    
    # 1. Prepare the masks
    batch_size = batch["neural_input"].shape[0]
    device = batch["neural_input"].device
    neural_mask = torch.ones(batch_size, device=device)
    behavior_mask = torch.ones(batch_size, device=device)
    for i in range(batch_size):
        mask_type = np.random.choice(['neural', 'behavior', 'none'], p=[mask_prob/2, mask_prob/2, 1-mask_prob])
        if mask_type == 'neural':
            neural_mask[i] = 0.0
        elif mask_type == 'behavior':
            behavior_mask[i] = 0.0
    
    # 2. Forward pass
    optimizer.zero_grad()                  
    pred = model(
        **batch,
        neural_mask=neural_mask,
        behavior_mask=behavior_mask
    )  
    target = batch
    
    # 3. Compute loss
    loss = loss_fn(pred, target)         
    loss.backward()                     
    optimizer.step()                       
    return loss

In [16]:
import torch.nn.functional as F

class CombinedLoss(nn.Module):
    def __init__(self, neural_weight=1.0, behavior_weight=1.0):
        super().__init__()
        self.neural_weight = neural_weight
        self.behavior_weight = behavior_weight
        # For Poisson NLL, predictions should be rates (>=0).
        # Targets are counts.
        # log_input=False means inputs are rates, not log-rates.
        # full=True includes Stirling's approximation for log(target!).
        # reduction='mean' averages loss over all elements.

    def forward(self, predictions: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]):
        # Neural Loss: Poisson NLL
        # Ensure pred_neural is non-negative if it's not already (e.g., via model's final activation)
        # For Poisson NLL, target should be counts (neural_input before normalization if any)
        # and prediction should be expected counts (rates).
        loss_neural = F.poisson_nll_loss(
            input=predictions["pred_neural"],
            target=targets["neural_input"]
        )

        # Behavior Loss: MSE
        loss_behavior = F.mse_loss(
            input=predictions["pred_behavior"],
            target=targets["behavior_input"],
            reduction='mean'
        )
        # print(f"Loss Neural: {loss_neural.item()}, Loss Behavior: {loss_behavior.item()}")
        combined_loss = (self.neural_weight * loss_neural) + \
                        (self.behavior_weight * loss_behavior)
        return combined_loss

In [17]:
def map_binned_features_to_global(
    session_binned_features,
    session_unit_id_strings,
    max_global_units=192
):
    """
    Map session-specific binned neural features to a global, padded array.

    Args:
        session_binned_features (np.ndarray): (num_bins, num_session_units)
        session_unit_id_strings (list/np.ndarray): Unit ID strings for the session (len = num_session_units)
        max_global_units (int): Output array's second dimension size

    Returns:
        np.ndarray: (num_bins, max_global_units)
    """
    if not (isinstance(session_binned_features, np.ndarray) and session_binned_features.ndim == 2):
        raise ValueError("session_binned_features must be 2D np.ndarray")
    if len(session_unit_id_strings) != session_binned_features.shape[1]:
        raise ValueError("session_unit_id_strings length must match num_session_units")

    global_binned = np.zeros((session_binned_features.shape[0], max_global_units), dtype=session_binned_features.dtype)
    for i, unit_str in enumerate(session_unit_id_strings):
        m = re.search(r'elec(\d+)', unit_str)
        if m:
            idx = int(m.group(1)) - 1
            if 0 <= idx < max_global_units:
                global_binned[:, idx] = session_binned_features[:, i]
    return global_binned

class SSMNeuroModel(nn.Module):
    """
    SSM-based model for neural decoding, inspired by the provided diagram.
    It processes multiple input modalities (Neural, Behavior, Stimuli, Context),
    embeds them, passes them through an SSM core, and then decodes them into
    multiple output predictions.
    """
    def __init__(self,
                 # Input feature dimensions for each modality (before embedding)
                 num_neural_features: int,
                 num_behavior_features: int,
                 num_context_features: int,
                 # Embedding dimension (D in the diagram)
                 subject_ids: List[str], # List of subject IDs for which the model is initialized
                 embedding_dim: int,
                 # SSM core dimensions
                 ssm_projection_dim: int, # Dimension after initial projection, M in diagram
                 ssm_hidden_dim: int,     # Hidden dimension of SSM blocks
                 ssm_num_layers: int,
                 ssm_dropout: float,
                 # Output feature dimensions for each prediction head
                 pred_neural_dim: int,
                 pred_behavior_dim: int,
                 # General params
                 sequence_length: float, # Max duration of input sequence (seconds)
                 sampling_rate: float = 100, # Hz, e.g., for converting sec to samples
                 lin_dropout: float = 0.1,
                 activation_fn: str = "relu", # Type of activation: "relu", "gelu", "tanh", etc.
                 embed_init_scale: float = 0.02,
                 bin_size: float = 1e-3, # Size of bins for neural features (seconds)
                 ):
        super().__init__()

        self.num_neural_features = num_neural_features
        self.num_behavior_features = num_behavior_features
        self.num_context_features = num_context_features
        self.embedding_dim = embedding_dim
        self.ssm_projection_dim = ssm_projection_dim
        self.ssm_hidden_dim = ssm_hidden_dim
        self.sequence_length_sec = sequence_length
        self.sampling_rate = sampling_rate
        # Calculate sequence length in samples (timesteps)
        self.num_timesteps = int(sequence_length * sampling_rate)
        self.lin_dropout_rate = lin_dropout
        self.embed_init_scale = embed_init_scale
        self.subject_ids = subject_ids
        
        # Activation function
        if activation_fn == "relu": self.activation = nn.ReLU()
        elif activation_fn == "gelu": self.activation = nn.GELU()
        elif activation_fn == "tanh": self.activation = nn.Tanh()
        else: raise ValueError(f"Unsupported activation_fn: {activation_fn}")

        # 1. Tokenization Embeddings
        self.session_emb = InfiniteVocabEmbedding(embedding_dim=self.num_context_features, 
                                                  init_scale=self.embed_init_scale)     
        
        self.unit_emb = InfiniteVocabEmbedding(embedding_dim=self.embedding_dim, init_scale=self.embed_init_scale)


        # 2. Subject-Specific Embedders (map modality features to embedding_dim)
        self.neural_embedders = nn.ModuleDict({
            subj_id: nn.Linear(self.num_neural_features, self.embedding_dim)
            for subj_id in self.subject_ids
        })
        self.behavior_embedders = nn.ModuleDict({
            subj_id: nn.Linear(self.num_behavior_features, self.embedding_dim)
            for subj_id in self.subject_ids
        })
        
        self.context_embedder = nn.Linear(self.num_context_features, self.embedding_dim)

        # Total dimension after concatenating D-dimensional embeddings from 4 modalities
        num_active_modalities = 3 # Neural, Behavior, Context (No Stimuli in this version)
        concatenated_dim = self.embedding_dim * num_active_modalities
        
        # 3. Foundational SSM Core
        # self.projection_to_ssm_input = nn.Linear(concatenated_dim, ssm_projection_dim)
        self.ssm_blocks = nn.ModuleList()
        for i in range(ssm_num_layers):
            self.ssm_blocks.append(
                LRU(
                    in_features=concatenated_dim,
                    out_features=concatenated_dim,
                    state_features=ssm_hidden_dim
                )
            )
        self.ssm_output_dim = concatenated_dim 

        # 4. Subject-Specific Decoders
        self.decoder_neural_modules = nn.ModuleDict({
            subj_id: nn.Linear(self.ssm_output_dim, pred_neural_dim)
            for subj_id in self.subject_ids
        })
        self.decoder_behavior_modules = nn.ModuleDict({
            subj_id: nn.Linear(self.ssm_output_dim, pred_behavior_dim)
            for subj_id in self.subject_ids
        })

        self.dropout = nn.Dropout(self.lin_dropout_rate)

    def forward(self,
                neural_input: torch.Tensor, # Shape: (batch, seq_len, num_neural_features)
                behavior_input: torch.Tensor, # Shape: (batch, seq_len, num_behavior_features)
                session_id: torch.Tensor,  # Shape: (batch)
                subject_id: List[str], # Shape: (batch)
                neural_mask: Optional[torch.Tensor] = None,
                behavior_mask: Optional[torch.Tensor] = None,
                ) -> Dict[str, torch.Tensor]:

        if self.unit_emb.is_lazy():
            raise ValueError(
                "Unit vocabulary has not been initialized, please use "
                "`model.unit_emb.initialize_vocab(unit_ids)`"
            )
        if self.session_emb.is_lazy():
            raise ValueError(
                "Session vocabulary has not been initialized, please use "
                "`model.session_emb.initialize_vocab(session_ids)`"
            )
            
        batch_size, seq_len, _ = neural_input.shape
        
        session_tokens = [self.session_emb.tokenizer(sid) for sid in session_id]
        session_embs = torch.stack([
            self.session_emb(torch.tensor(tokens).unsqueeze(0)).squeeze(0) 
            for tokens in session_tokens
        ])  
        context_input = session_embs.unsqueeze(1).expand(-1, seq_len, -1) # Shape: (batch, seq, embed_dim)
        

        embedded_neural = []
        embedded_behavior = []
        
        # Apply masks if provided 
        if neural_mask is not None:
            neural_input = neural_input * neural_mask.unsqueeze(1).unsqueeze(2)
        if behavior_mask is not None:
            behavior_input = behavior_input * behavior_mask.unsqueeze(1).unsqueeze(2)

        # 1. Embed inputs
        for i in range(batch_size):
            subj_id = subject_id[i]
            if subj_id not in self.subject_ids:
                raise ValueError(f"Unknown subject_id '{subj_id}' in batch. Model not initialized for this subject.")

            emb_n = self.neural_embedders[subj_id](neural_input[i]) # Shape: (seq, num_neural_features)
            embedded_neural.append(self.dropout(self.activation(emb_n)))

            emb_b = self.behavior_embedders[subj_id](behavior_input[i]) # Shape: (seq, num_behavior_features)
            embedded_behavior.append(self.dropout(self.activation(emb_b)))
        
        embedded_neural = torch.stack(embedded_neural, dim=0)  # Shape: (batch, seq_len, embedding_dim)
        embedded_behavior = torch.stack(embedded_behavior, dim=0)  # Shape: (batch, seq_len, embedding_dim)
        embedded_context = self.dropout(self.activation(self.context_embedder(context_input))) # Shape: (batch, seq_len, embedding_dim)

    
        # 2. Concatenate embeddings
        # Shape: (batch, seq_len, embedding_dim * 3)
        concatenated_embeddings = torch.cat([
            embedded_neural, embedded_behavior, embedded_context
        ], dim=-1)

        # 3. Pass through Foundational SSM Core
        # ssm_core_input = self.dropout(self.activation(self.projection_to_ssm_input(concatenated_embeddings)))
        ssm_core_input = concatenated_embeddings # Shape: (batch, seq_len, concatenated_dim)
        ssm_layer_output = ssm_core_input
        for ssm_block in self.ssm_blocks:
            ssm_layer_output = ssm_block(ssm_layer_output) # Adjust if it returns state tuple
            if isinstance(ssm_layer_output, tuple): # e.g. LSTM output, hidden
                ssm_layer_output = ssm_layer_output[0]
            ssm_layer_output = self.dropout(ssm_layer_output) # General dropout after each block processing
        final_ssm_output = ssm_layer_output # Shape: (batch, seq_len, self.ssm_output_dim)

        # 4. Decode
        all_pred_neural = []
        all_pred_behavior = []

        for i in range(batch_size):
            subj_id = subject_id[i]
            ssm_out_sample = final_ssm_output[i] # (seq_len, ssm_output_dim)

            pred_n = self.decoder_neural_modules[subj_id](ssm_out_sample)
            all_pred_neural.append(pred_n)

            pred_b = self.decoder_behavior_modules[subj_id](ssm_out_sample)
            all_pred_behavior.append(pred_b)
        
        predictions = {
            "pred_neural": torch.stack(all_pred_neural),
            "pred_behavior": torch.stack(all_pred_behavior),
        }

        return predictions


    def tokenize(self, data: Data) -> Dict:
        r"""Tokenizer used to tokenize Data for the POYO model.

        This tokenizer can be called as a transform. If you are applying multiple
        transforms, make sure to apply this one last.

        This code runs on CPU. Do not access GPU tensors inside this function.
        """
        
        unit_ids = data.units.id
        spikes = data.spikes        
        binned_spikes = bin_spikes(
            spikes=spikes,
            num_units=len(unit_ids),
            bin_size= 1 / self.sampling_rate,
            num_bins=self.num_timesteps  
        ).T
        neural_input = map_binned_features_to_global(
            session_binned_features=binned_spikes,
            session_unit_id_strings=unit_ids,
            max_global_units=self.num_neural_features
        ) # (N_timesteps, N_global_units)
        
        behavior_input = data.cursor.vel # (N_timesteps, N_behavior_features)        

        data_dict = {
            "neural_input": torch.tensor(neural_input, dtype=torch.float32),
            "behavior_input": torch.tensor(behavior_input, dtype=torch.float32),
            "session_id": data.session.id,
            "subject_id": data.subject.id
        }

        return data_dict
            


In [7]:
# Load dataset with config parameters
train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(
    train_config=get_dataset_config(
        config.dataset.name,
        subjects=config.dataset.subjects
    ),
    batch_size=config.dataset.batch_size
)

In [23]:
# Print dataset info
num_units = len(train_dataset.get_unit_ids())
print(f"Num Units in Session: {num_units}")

# Initialize model with config
model = SSMNeuroModel(
    num_neural_features=config.model.num_neural_features,
    num_behavior_features=config.model.num_behavior_features,
    num_context_features=config.model.num_context_features,
    embedding_dim=config.model.embedding_dim,
    ssm_projection_dim=config.model.ssm_projection_dim,
    ssm_hidden_dim=config.model.ssm_hidden_dim,
    ssm_num_layers=config.model.ssm_num_layers,
    ssm_dropout=config.model.ssm_dropout,
    pred_neural_dim=config.model.pred_neural_dim,
    pred_behavior_dim=config.model.pred_behavior_dim,
    sequence_length=config.model.sequence_length,
    sampling_rate=config.model.sampling_rate,
    lin_dropout=config.model.lin_dropout,
    activation_fn=config.model.activation_fn,
    subject_ids=train_dataset.get_subject_ids()
)

model = model.to(device)

# Initialize vocabularies
model.session_emb.initialize_vocab(train_dataset.get_session_ids())
model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())

# Connect tokenizer to datasets
transform = model.tokenize
train_dataset.transform = transform
val_dataset.transform = transform

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config.training.learning_rate)

# Setup loss function
loss_fn = CombinedLoss(
    neural_weight=config.training.neural_weight,
    behavior_weight=config.training.behavior_weight
)

# Train with wandb logging
train_losses, val_metrics_history = train(
    model, 
    optimizer, 
    train_loader, 
    val_loader, 
    loss_fn, 
    config
)

# Plot training curves (original visualization)
plot_training_curves(val_metrics_history, train_losses)

Num Units in Session: 75


[34m[1mwandb[0m: Currently logged in as: [33mmelinajingting[0m ([33mmelinajingting-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


: 

In [21]:

for i, idx in enumerate(train_loader.sampler):
    # print(f"Index {i}: {idx}")
    item = train_dataset[idx]
    # print(item["neural_input"].shape, item["behavior_input"].shape, item["context_input"].shape)
    model_output = model(
        neural_input=item["neural_input"].unsqueeze(0).to(device),
        behavior_input=item["behavior_input"].unsqueeze(0).to(device),
        session_id=[item["session_id"]],
        subject_id=[item["subject_id"]]
    )
    if i == 0:  # Only print first 11 indices
        break