In [3]:
# Environment Setup Cell
# If you're seeing import errors, run this cell first to install required packages

import subprocess
import sys

def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Required packages
packages = [
    "numpy>=1.21.0",
    "pandas>=1.3.0", 
    "scipy>=1.7.0",
    "jax>=0.4.0",
    "jaxlib>=0.4.0",
    "dm-haiku>=0.0.10",
    "optax>=0.1.4",
    "matplotlib>=3.5.0",
    "seaborn>=0.11.0",
    "plotnine>=0.8.0",
    "requests>=2.25.0",
    "gdown>=4.0.0",
    "tqdm>=4.62.0"
]

print("Installing required packages...")
for package in packages:
    try:
        __import__(package.split('>=')[0].split('[')[0])
        print(f"✓ {package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        install_package(package)

# Install local package
try:
    from CogModelingRNNsTutorial import hybrnn
    print("✓ Local CogModelingRNNsTutorial package already installed")
except ImportError:
    print("Installing local CogModelingRNNsTutorial package...")
    install_package("-e ./CogModelingRNNsTutorial/")

print("\nSetup complete! Please restart the kernel after installation.")

Installing required packages...
Installing numpy>=1.21.0...
Collecting numpy>=1.21.0
  Downloading numpy-2.4.0-cp312-cp312-macosx_14_0_x86_64.whl.metadata (6.6 kB)
Downloading numpy-2.4.0-cp312-cp312-macosx_14_0_x86_64.whl (6.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m20.7 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: numpy
Successfully installed numpy-2.4.0
Installing pandas>=1.3.0...
Collecting pandas>=1.3.0
  Downloading pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl.metadata (91 kB)
Collecting pytz>=2020.1 (from pandas>=1.3.0)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas>=1.3.0)
  Downloading tzdata-2025.3-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m23.6 MB/s[0m  [33m0:00:00[0m eta [36m0:

# Test Notebook for BiControlRNN Model

This notebook tests the new learnable gated hybrid architecture of the BiControlRNN model.

## 1. Import and Setup

In [4]:
# Basic imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# JAX and Haiku imports
import jax
import jax.numpy as jnp
import haiku as hk
import optax

# Local imports
from CogModelingRNNsTutorial import bandits
from CogModelingRNNsTutorial import hybrnn
from CogModelingRNNsTutorial import rnn_utils
from CogModelingRNNsTutorial import plotting

# Set up random seed
rng_seq = hk.PRNGSequence(np.random.randint(2**32))

print(f"JAX devices: {jax.devices()}")

JAX devices: [CpuDevice(id=0)]


## 2. Load and Prepare Data

In [5]:
# Load one of the datasets (e.g., Qasim dataset)
import requests
from io import StringIO

# Download Qasim dataset
osf_url = 'https://osf.io/xe6yu/download?direct=1'
response = requests.get(osf_url)

if response.status_code == 200:
    qasim_data = pd.read_csv(StringIO(response.text))
    print('Dataset downloaded successfully!')
else:
    print('Failed to download dataset')

# Prepare data
selected_columns = ['participant', 'trials_gamble', 'gamble', 'prob', 'reward']
qasim = qasim_data[selected_columns]
qasim_filtered = qasim[qasim['trials_gamble'].notna()]
qasim_sorted = qasim_filtered.groupby('participant', group_keys=False).apply(lambda x: x.sort_values('trials_gamble'))
qasim_sorted = qasim_sorted.reset_index(drop=True)
qasim_sorted['participant'] = qasim_sorted.groupby(['participant']).ngroup() + 1
qasim_sorted['action'] = qasim_sorted['gamble']

# Fill missing actions
qasim_sorted['action'] = qasim_sorted['action'].fillna(-1).astype(int)

# Generate next action
qasim_sorted['action_n'] = (
    qasim_sorted
    .groupby('participant')['action']
    .shift(-1)
)
last_idxs = qasim_sorted.groupby('participant').tail(1).index
qasim_sorted.loc[last_idxs, 'action_n'] = -1

# Create sequences
xs_list, ys_list = [], []
for pid, grp in qasim_sorted.groupby('participant'):
    grp = grp.sort_values('trials_gamble')
    x = grp[['prob', 'reward']].to_numpy().astype(float)
    y = grp[['action_n']].to_numpy().astype(int)
    xs_list.append(x)
    ys_list.append(y)

# Stack into arrays
xs = np.stack(xs_list, axis=1)  # (n_trials, n_sessions, 2)
ys = np.stack(ys_list, axis=1)  # (n_trials, n_sessions, 1)

print(f"xs.shape: {xs.shape}")
print(f"ys.shape: {ys.shape}")

Dataset downloaded successfully!
xs.shape: (60, 206, 2)
ys.shape: (60, 206, 1)


# Define model parameters
rl_params = {
    's': True,        # Use state feedback
    'o': True,        # Use output feedback
    'w_h': 0.5,       # Habit weight
    'w_v': 0.5,       # Value weight
    'forget': 0.1,    # Forget rate
    'fit_forget': False  # Don't fit forget parameter
}

network_params = {
    'n_actions': 2,           # Number of actions
    'hidden_size': 64         # Hidden layer size
}

print("Model parameters defined:")
print(f"RL params: {rl_params}")
print(f"Network params: {network_params}")

In [6]:
def BiControlRNN_model(rl_params, network_params):
    """BiControlRNN model with learnable gated hybrid architecture."""
    model = hybrnn.BiControlRNN(rl_params, network_params)
    
    def forward(xs):
        # xs shape: (batch_size, sequence_length, input_dim)
        batch_size = xs.shape[0]
        
        # Initialize hidden state
        initial_state = model.initial_state(batch_size)
        
        # Unroll the RNN
        outputs, final_state = hk.dynamic_unroll(
            model, xs, initial_state
        )
        
        return outputs
    
    return forward

## 4. Model Parameters

In [7]:
# Define model parameters
rl_params = {
    's': True,        # Use state feedback
    'o': True,        # Use output feedback
    'w_h': 0.5,       # Habit weight
    'w_v': 0.5,       # Value weight
    'forget': 0.1,    # Forget rate
    'fit_forget': False  # Don't fit forget parameter
}

network_params = {
    'n_actions': 2,           # Number of actions
    'hidden_size': 64,        # Hidden layer size
    'gate_hidden_size': 32    # Gate MLP hidden size
}

print("Model parameters defined:")
print(f"RL params: {rl_params}")
print(f"Network params: {network_params}")

Model parameters defined:
RL params: {'s': True, 'o': True, 'w_h': 0.5, 'w_v': 0.5, 'forget': 0.1, 'fit_forget': False}
Network params: {'n_actions': 2, 'hidden_size': 64, 'gate_hidden_size': 32}


## 5. Prepare Dataset

In [8]:
# Split data into train/test
n_sessions = xs.shape[1]
train_sessions = int(0.8 * n_sessions)
test_sessions = n_sessions - train_sessions

# Randomly shuffle sessions
session_indices = np.random.permutation(n_sessions)
train_idx = session_indices[:train_sessions]
test_idx = session_indices[train_sessions:]

# Create train/test datasets
xs_train = xs[:, train_idx, :]
ys_train = ys[:, train_idx, :]
xs_test = xs[:, test_idx, :]
ys_test = ys[:, test_idx, :]

# Create DatasetRNN objects
dataset_train = rnn_utils.DatasetRNN(xs_train, ys_train, batch_size=32)
dataset_test = rnn_utils.DatasetRNN(xs_test, ys_test, batch_size=32)

print(f"Train sessions: {train_sessions}")
print(f"Test sessions: {test_sessions}")
print(f"Train dataset shape: {xs_train.shape}")
print(f"Test dataset shape: {xs_test.shape}")

Train sessions: 164
Test sessions: 42
Train dataset shape: (60, 164, 2)
Test dataset shape: (60, 42, 2)


## 6. Train the Model

In [9]:
# Transform model function
model_fun = hk.transform(lambda xs: BiControlRNN_model(rl_params, network_params)(xs))

# Initialize parameters
xs_batch, ys_batch = next(dataset_train)
rng_key = next(rng_seq)
params = model_fun.init(rng_key, xs_batch)

print(f"Model initialized with {sum(p.size for p in jax.tree_util.tree_leaves(params))} parameters")

TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 1 for shapes (32, 1), (32, 2), (60, 2), (60, 64).

In [None]:
# Train the model
print("Training BiControlRNN model...")

trained_params, train_losses = rnn_utils.fit_model(
    model_fun=model_fun,
    dataset_train=dataset_train,
    dataset_test=dataset_test,
    loss_fun='categorical',
    optimizer=optax.adam(learning_rate=1e-3),
    n_steps_per_call=100,
    n_steps_max=5000,
    early_stop_step=100,
    if_early_stop=True
)

print(f"\nTraining completed. Final loss: {train_losses[-1]:.4f}")

## 8. Visualize Results

In [None]:
# Compute log likelihood on test set
print("\nEvaluating model on test set...")
mean_ll, std_ll = compute_log_likelihood(dataset_test, model_fun, trained_params)

print(f"Test set normalized likelihood: {100 * mean_ll:.1f}% ± {100 * std_ll:.1f}%")

## 8. Analyze Gate Behavior

In [None]:
# Function to extract gate signals during model execution
def BiControlRNN_with_gate_analysis(rl_params, network_params):
    """Modified BiControlRNN that returns gate signals for analysis."""
    
    class BiControlRNNWithGates(hybrnn.BiControlRNN):
        def __call__(self, inputs, prev_state):
            # Call parent method
            logits, next_state = super().__call__(inputs, prev_state)
            
            # Extract gate signal from value module
            h_state, v_state, habit, value = prev_state
            action_onehot = jax.nn.one_hot(inputs[:, 0].astype(int), self._n_actions)
            reward = inputs[:, -1]
            
            # Calculate gate inputs
            gate_inputs = jnp.concatenate([
                reward[:, jnp.newaxis],
                action_onehot,
                value,
                v_state
            ], axis=-1)
            
            # Get gate signal
            gate_hidden = jax.nn.relu(self._gate_mlp_layer1(gate_inputs))
            gate_signal = jax.nn.sigmoid(self._gate_mlp_layer2(gate_hidden))
            
            return logits, next_state, gate_signal
    
    model = BiControlRNNWithGates(rl_params, network_params)
    
    def forward(xs):
        batch_size = xs.shape[0]
        initial_state = model.initial_state(batch_size)
        
        # Custom unroll to capture gate signals
        def step(carry, x):
            state = carry
            logits, next_state, gate_signal = model(x, state)
            return next_state, (logits, gate_signal)
        
        _, (outputs, gate_signals) = hk.scan(
            step, initial_state, xs
        )
        
        return outputs, gate_signals
    
    return forward

In [None]:
# Analyze gate behavior on a test batch
model_with_gates = hk.transform(lambda xs: BiControlRNN_with_gate_analysis(rl_params, network_params)(xs))

# Get a test batch
xs_test_batch, ys_test_batch = next(dataset_test)

# Run model with gate analysis
outputs, gate_signals = model_with_gates.apply(trained_params, rng_key, xs_test_batch)

print(f"Gate signals shape: {gate_signals.shape}")
print(f"Mean gate signal: {gate_signals.mean():.3f}")
print(f"Gate signal std: {gate_signals.std():.3f}")

## 9. Visualize Results

In [None]:
# Plot gate signals over time
plt.figure(figsize=(12, 4))

# Plot 1: Gate signals over trials for a few sessions
plt.subplot(1, 2, 1)
for i in range(min(5, gate_signals.shape[1])):
    plt.plot(gate_signals[:, i, 0], label=f'Session {i+1}', alpha=0.7)
plt.xlabel('Trial')
plt.ylabel('Gate Signal')
plt.title('Gate Signals Over Time')
plt.legend()
plt.ylim([0, 1])

# Plot 2: Gate signal distribution
plt.subplot(1, 2, 2)
plt.hist(gate_signals.flatten(), bins=50, alpha=0.7, density=True)
plt.xlabel('Gate Signal')
plt.ylabel('Density')
plt.title('Gate Signal Distribution')
plt.xlim([0, 1])

plt.tight_layout()
plt.show()

In [None]:
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(train_losses)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 10. Compare with Baseline Models

In [None]:
# Train a baseline BiRNN (without learnable gates) for comparison
def BiRNN_baseline_model(rl_params, network_params):
    """Baseline BiRNN model."""
    model = hybrnn.BiRNN(rl_params, network_params)
    
    def forward(xs):
        batch_size = xs.shape[0]
        initial_state = model.initial_state(batch_size)
        outputs, _ = hk.dynamic_unroll(model, xs, initial_state)
        return outputs
    
    return forward

# Train baseline
baseline_model_fun = hk.transform(lambda xs: BiRNN_baseline_model(rl_params, network_params)(xs))
baseline_params = baseline_model_fun.init(rng_key, xs_batch)

print("Training baseline BiRNN...")
baseline_trained_params, baseline_losses = rnn_utils.fit_model(
    model_fun=baseline_model_fun,
    dataset_train=dataset_train,
    dataset_test=dataset_test,
    loss_fun='categorical',
    optimizer=optax.adam(learning_rate=1e-3),
    n_steps_per_call=100,
    n_steps_max=5000,
    early_stop_step=100,
    if_early_stop=True
)

# Evaluate baseline
baseline_mean_ll, baseline_std_ll = compute_log_likelihood(dataset_test, baseline_model_fun, baseline_trained_params)

print(f"\nComparison:")
print(f"BiControlRNN (learnable gates): {100 * mean_ll:.1f}% ± {100 * std_ll:.1f}%")
print(f"BiRNN (baseline):            {100 * baseline_mean_ll:.1f}% ± {100 * baseline_std_ll:.1f}%")

## Summary

This notebook demonstrates:
1. **Learnable Gate Architecture**: The BiControlRNN model learns to dynamically arbitrate between context and memory streams
2. **Gate Analysis**: We can extract and visualize the gate signals to understand how the model balances fast vs. slow learning
3. **Performance Comparison**: The learnable gated model can be compared against baseline models

The key innovation is that the model learns *when* to rely on immediate context (high gate signal) vs. historical memory (low gate signal), rather than using hardcoded RPE-based weighting.