In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import scipy.io as sio
from process_behavior_and_movement_data import SessionDataProcessor
from analyze_behavior_data_functions import BehaviorDataAnalyzer, MovementAnalyzer, BehaviorPlotter, GLMAnalyzer
import matplotlib as mpl
import scipy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

#Remove right and top spines from plots (personal preference)
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

ModuleNotFoundError: No module named 'process_behavior_and_movement_data'

In [None]:
# Initialize the analyzers
behavior_analyzer = BehaviorDataAnalyzer(base_dir='/Volumes/Runyan5/Akhil/behavior/')
movement_analyzer = MovementAnalyzer()

In [None]:
figures_dir = "figures"
data_dir = 'dynamic_choice'

# Load Data

In [None]:
mouse_list = ['IS-2-1L']

date_lists = {
    'IS-2-1L': ['250314', '250320', '250401', '250402', '250404', '250408'],
}

#date_lists = {
#    'IS-2-1L': ['250401', '250402', '250404', '250408'],
#}

In [None]:
# Initialize dictionaries to store results
all_task_dfs = {}
all_trialized_data = {}

# Load data for each mouse and date
for mouse in mouse_list:
    #print(f"\nLoading data for mouse {mouse}")
    mouse_task_dfs = []
    mouse_trialized_data = []

    for date in date_lists[mouse]:
        try:
            #print(f"\nAttempting to load date: {date}")
            task_df, trialized_data = behavior_analyzer.load_session_data(
                mouse_name=mouse,
                date=date,
                verbose=False
            )

            # Add date identifier to task_df
            task_df['date'] = date

            # Filter out trials that are within 5 trials after a context switch
            filtered_task_df = task_df.copy()
            context_changes = filtered_task_df['context'].diff().ne(0)

            # Create a counter that resets at each context change
            trials_since_change = np.zeros(len(filtered_task_df))
            counter = 0
            for i in range(len(filtered_task_df)):
                if context_changes.iloc[i]:
                    counter = 0
                trials_since_change[i] = counter
                counter += 1

            filtered_task_df['trials_since_change'] = trials_since_change

            # Keep only trials that are at least 5 trials after a context switch
            filtered_task_df = filtered_task_df[filtered_task_df['trials_since_change'] >= 0]

            # Store results for this date
            mouse_task_dfs.append(filtered_task_df)
            mouse_trialized_data.append(trialized_data)

            # Print session info
            #print(f"Successfully loaded {len(task_df)} trials")
            #print(f"Performance: {task_df['outcome'].mean():.1%}")

        except Exception as e:
            print(f"Error loading data for date {date}: {str(e)}")
            continue

    if mouse_task_dfs:  # Only store if we successfully loaded any sessions
        # Combine all sessions for this mouse
        all_task_dfs[mouse] = pd.concat(mouse_task_dfs, ignore_index=True)

        # Store trialized data as a list of sessions
        all_trialized_data[mouse] = mouse_trialized_data

        # Print summary for this mouse
        print(f"\nSummary for mouse {mouse}:")
        print(f"Total sessions: {len(mouse_task_dfs)}")
        print(f"Total trials: {len(all_task_dfs[mouse])}")
        print(f"Overall performance: {all_task_dfs[mouse]['outcome'].mean():.1%}")

In [None]:
all_task_dfs['IS-2-1L']

# Task Accuracy Across Contexts

In [None]:
def plot_context_accuracy_with_sessions(task_df):
    """
    Plot mean accuracy across contexts with individual session averages connected by lines.

    Args:
        task_df (pd.DataFrame): DataFrame containing trial information for one mouse

    Returns:
        tuple: (fig, ax) matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=(3.5, 2.5), dpi=800)

    context_names = ['Congruent', 'Visual', 'Audio']
    context_colors = {0: 'purple', 1: '#EC008C', 2: '#27AAE1'}

    # Calculate overall means and SEMs for each context
    means = []
    sems = []
    for ctx in [0, 1, 2]:
        ctx_trials = task_df[task_df['context'] == ctx]
        mean_acc = ctx_trials['outcome'].mean() * 100
        sem_acc = scipy.stats.sem(ctx_trials['outcome']) * 100
        means.append(mean_acc)
        sems.append(sem_acc)

    # Plot overall means with error bars
    x_positions = np.arange(len(context_names))
    ax.bar(x_positions, means,
           color=[context_colors[i] for i in range(3)],
           alpha=0.7)
    #ax.errorbar(x_positions, means, yerr=sems,
    #           fmt='none', color='black', capsize=5)

    # Plot individual session data and connect them
    for session_date in task_df['date'].unique():
        session_df = task_df[task_df['date'] == session_date]
        session_means = []

        # Calculate mean for each context in this session
        for ctx in [0, 1, 2]:
            ctx_trials = session_df[session_df['context'] == ctx]
            if len(ctx_trials) > 0:
                session_means.append(ctx_trials['outcome'].mean() * 100)
            else:
                session_means.append(np.nan)

        # Plot points and connecting lines
        ax.plot(x_positions, session_means,
                color='gray', alpha=0.3, linewidth=0.5)

        # Plot individual points
        for ctx in [0, 1, 2]:
            if not np.isnan(session_means[ctx]):
                ax.scatter(x_positions[ctx], session_means[ctx],
                          color=context_colors[ctx],
                          s=30, alpha=0.7,
                          edgecolor='white', linewidth=0.5)

    # Customize plot
    ax.set_xticks(x_positions)
    ax.set_xticklabels(context_names)
    ax.set_ylabel('Accuracy (%)')
    ax.set_ylim(40, 100)

    # Add chance line
    ax.axhline(y=50, color='gray', linestyle=':', alpha=0.5)

    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_context_accuracy_with_sessions(all_task_dfs['IS-2-1L'])
plt.savefig(os.path.join(figures_dir, 'imaging_animals_performance.svg'), format='svg', bbox_inches='tight')
plt.show()

# Dynamic Choice Model - Kinematic Data

## Initialize Model

In [None]:
# Prepare data for LSTM model
def prepare_data(task_df, trialized_data):
    # Get position and velocity data (5 channels)
    X = trialized_data['X_pos']  # Shape: (n_trials, timesteps)
    Y = trialized_data['Y_pos']
    View = trialized_data['View']
    dX = trialized_data['X_velocity']
    dY = trialized_data['Y_velocity']

    # Stack all channels
    X_data = np.stack([X, Y, View, dX, dY], axis=2)  # Shape: (n_trials, timesteps, 5)

    # Get choices (already binary: 0 for left, 1 for right)
    y_data = np.array(task_df['choice'])

    # Remove trials with abnormal length (> 2x average)
    trial_lengths = np.array([len(trial) for trial in X])
    avg_length = np.mean(trial_lengths)
    mask = trial_lengths <= 2 * avg_length

    return X_data[mask], y_data[mask]

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Device selection for Apple Silicon (M1/M2) GPU
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")

class MovementDataset(Dataset):
    """Custom Dataset for movement data"""
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def normalize_data(all_sessions_data):
    """Normalize kinematic features across all sessions"""
    # Collect all kinematic data to compute global stats
    all_X = np.concatenate([session['X_data'] for session in all_sessions_data], axis=0)

    # Compute mean and std for each feature
    mean = np.mean(all_X, axis=(0, 1))  # mean across trials and timesteps
    std = np.std(all_X, axis=(0, 1))    # std across trials and timesteps

    # Normalize each session's data
    normalized_sessions = []
    for session in all_sessions_data:
        normalized_session = session.copy()
        normalized_session['X_data'] = (session['X_data'] - mean) / (std + 1e-8)
        normalized_sessions.append(normalized_session)

    return normalized_sessions, mean, std

class Dynamic_Choice_CNN(nn.Module):
    """CNN model for predicting choices at each timestep"""
    def __init__(self, n_features):
        super(Dynamic_Choice_CNN, self).__init__()

        # Convolutional layers with padding='same' to maintain temporal dimension
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv1d(n_features, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            # Second conv block
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            # Third conv block
            nn.Conv1d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
        )

        # Point-wise prediction layer
        self.prediction_layer = nn.Conv1d(32, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x shape: (batch_size, timesteps, n_features)
        # Transpose for CNN: (batch_size, n_features, timesteps)
        x = x.transpose(1, 2)

        # Apply conv layers while maintaining temporal dimension
        x = self.conv_layers(x)

        # Final prediction at each timestep
        x = self.prediction_layer(x)
        x = self.sigmoid(x)

        # Return to (batch_size, timesteps) format
        return x.squeeze(1)

def train_model_leave_one_session_out_CNN(all_sessions_data, device=None):
    """Training loop with normalized data"""
    if device is None:
        device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    # Normalize data
    normalized_sessions, mean, std = normalize_data(all_sessions_data)
    predictions = []

    for test_session_idx in range(len(normalized_sessions)):
        test_session = normalized_sessions[test_session_idx]
        test_X = test_session['X_data']
        test_y = test_session['y_data']
        test_date = test_session['date']

        print(f"\n=== Training model {test_session_idx + 1}/{len(normalized_sessions)} ===")
        print(f"Test session date: {test_date}")

        # Combine other sessions for training
        train_X_list = []
        train_y_list = []
        for train_session_idx in range(len(normalized_sessions)):
            if train_session_idx != test_session_idx:
                train_X_list.append(normalized_sessions[train_session_idx]['X_data'])
                train_y_list.append(normalized_sessions[train_session_idx]['y_data'])

        train_X = np.concatenate(train_X_list, axis=0)
        train_y = np.concatenate(train_y_list)

        # Expand y to match timesteps
        train_y_expanded = np.repeat(train_y[:, np.newaxis], train_X.shape[1], axis=1)
        test_y_expanded = np.repeat(test_y[:, np.newaxis], test_X.shape[1], axis=1)

        # Create datasets and dataloaders
        train_dataset = MovementDataset(train_X, train_y_expanded)
        test_dataset = MovementDataset(test_X, test_y_expanded)

        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

        # Initialize model
        model = Dynamic_Choice_CNN(
            n_features=train_X.shape[2]
        ).to(device)

        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Training loop
        n_epochs = 50
        best_val_loss = float('inf')
        patience = 10
        patience_counter = 0

        for epoch in range(n_epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            for batch_X, batch_y in train_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)

                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            # Validation phase
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for batch_X, batch_y in test_loader:
                    batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                    outputs = model(batch_X)
                    val_loss += criterion(outputs, batch_y).item()

            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(test_loader)

            if (epoch + 1) % 5 == 0:
                print(f'Epoch [{epoch+1}/{n_epochs}], '
                      f'Train Loss: {avg_train_loss:.4f}, '
                      f'Val Loss: {avg_val_loss:.4f}')

            # Early stopping
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Generate predictions
        model.eval()
        session_predictions = []
        with torch.no_grad():
            for batch_X, _ in test_loader:
                batch_X = batch_X.to(device)
                outputs = model(batch_X)
                session_predictions.append(outputs.cpu().numpy())

        session_predictions = np.concatenate(session_predictions, axis=0)
        predictions.append((test_date, session_predictions))

        # Clear memory
        model = model.cpu()

    return predictions

## Train and Make Predictions

In [None]:
all_sessions_data = []

print("Processing sessions for each mouse:")
for mouse in mouse_list:
    print(f"\nMouse: {mouse}")

    # Get all unique dates for this mouse
    mouse_dates = sorted(all_task_dfs[mouse]['date'].unique())
    print(f"Dates in task_df: {mouse_dates}")
    print(f"Number of trialized sessions: {len(all_trialized_data[mouse])}")

    # Process each session
    for session_idx, trialized_data in enumerate(all_trialized_data[mouse]):
        # Get corresponding task_df for this session
        date = mouse_dates[session_idx]  # Use the date in order
        session_df = all_task_dfs[mouse][all_task_dfs[mouse]['date'] == date]

        print(f"\nProcessing session {session_idx + 1} for date {date}")
        print(f"Number of trials in behavioral data: {len(session_df)}")
        print(f"Number of trials in trialized data: {len(trialized_data)}")

        # Prepare data
        try:
            X_data, y_data = prepare_data(session_df, trialized_data)

            # Store session data
            session_info = {
                'X_data': X_data,
                'y_data': y_data,
                'date': date
            }
            all_sessions_data.append(session_info)
            print(f"Successfully added session for date {date}")
            print(f"X_data shape: {X_data.shape}, y_data shape: {y_data.shape}")
        except Exception as e:
            print(f"Error preparing data for date {date}: {e}")

print("\nFinal all_sessions_data summary:")
print(f"Total number of sessions: {len(all_sessions_data)}")
for i, session in enumerate(all_sessions_data):
    print(f"\nSession {i}:")
    print(f"Date: {session['date']}")
    print(f"X_data shape: {session['X_data'].shape}")
    print(f"y_data shape: {session['y_data'].shape}")

In [None]:
all_predictions = train_model_leave_one_session_out_CNN(all_sessions_data, device)

In [None]:
def convert_predictions_to_array(predictions_list):
    """
    Convert list of (session_date, predictions) tuples to a single numpy array.

    Args:
        predictions_list: List of tuples (session_date, predictions) from model

    Returns:
        numpy.ndarray: Concatenated predictions array
    """
    all_predictions = []
    for _, session_pred in predictions_list:
        if isinstance(session_pred, np.ndarray):
            if session_pred.ndim == 3:
                session_pred = session_pred.squeeze()
            all_predictions.append(session_pred)

    if all_predictions:
        return np.concatenate(all_predictions, axis=0)
    else:
        raise ValueError("No valid predictions found in the list")

def calculate_latency_to_choice(predictions_input, all_choices, threshold=0.95):
    """
    Calculate latency to choice for each trial based on when model prediction
    crosses threshold.

    Args:
        predictions_input: Either a numpy array of predictions or list of (date, predictions) tuples
        all_choices: Array of actual choices (0 or 1)
        threshold: Confidence threshold for choice (default 0.8)

    Returns:
        numpy.ndarray: Array of latencies (as percentage of trial duration)
                      -1 indicates threshold never crossed
    """
    # Handle input format
    if isinstance(predictions_input, list):
        # Convert from list of tuples format
        all_predictions = []
        for _, session_pred in predictions_input:
            if session_pred.ndim == 3:
                session_pred = session_pred.squeeze()
            all_predictions.append(session_pred)
        predictions = np.concatenate(all_predictions, axis=0)
    else:
        # Input is already a numpy array
        predictions = predictions_input
        if predictions.ndim == 3:
            predictions = predictions.squeeze()

    n_trials = len(predictions)
    trial_length = predictions.shape[1]  # Number of timepoints per trial
    latencies = np.full(n_trials, -1.)  # Initialize all latencies to -1

    # Calculate latency for each trial
    for i in range(n_trials):
        pred = predictions[i]
        if all_choices[i] == 1:  # Right choice
            # Find first timepoint where prediction > threshold
            thresh_crossed = np.where(pred > threshold)[0]
        else:  # Left choice
            # Find first timepoint where prediction < (1-threshold)
            thresh_crossed = np.where(pred < (1-threshold))[0]

        if len(thresh_crossed) > 0:
            # Convert frame number to percentage of trial duration
            latencies[i] = (thresh_crossed[0] / trial_length) * 100

    return latencies

In [None]:
# Get corresponding choices from the task dataframe
all_choices = []
for mouse in mouse_list:
    all_choices.extend(all_task_dfs[mouse]['choice'].values)
all_choices = np.array(all_choices)

In [None]:
# Now you can use this before any function that expects a numpy array:
predictions_array = convert_predictions_to_array(all_predictions)

# Then use predictions_array in your function calls
if predictions_array.ndim == 3:
    predictions_array = predictions_array.squeeze()

In [None]:
# Calculate latency to choice
latencies = calculate_latency_to_choice(predictions_array, all_choices)

In [None]:
def plot_mean_latency_by_context(task_df, latencies):
    """
    Plot mean latency for each context with error bars and individual session data points.

    Args:
        task_df (pd.DataFrame): DataFrame containing trial information including contexts
        latencies (np.array): Array of latency values for each trial

    Returns:
        tuple: (fig, ax) matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=(3.5, 2.5), dpi=800)

    context_names = ['Congruent', 'Visual', 'Audio']
    context_colors = {0: 'purple', 1: '#EC008C', 2: '#27AAE1'}

    # Calculate overall mean latency for each context
    means = []
    sems = []

    for ctx in [0, 1, 2]:
        # Get latencies for this context
        ctx_mask = task_df['context'] == ctx
        ctx_latencies = latencies[ctx_mask]
        valid_latencies = ctx_latencies[ctx_latencies >= 0]  # Exclude invalid (-1) latencies

        mean_lat = np.mean(valid_latencies) if len(valid_latencies) > 0 else np.nan
        sem_lat = scipy.stats.sem(valid_latencies) if len(valid_latencies) > 0 else np.nan

        means.append(mean_lat)
        sems.append(sem_lat)

    # Plot bars
    x_positions = np.arange(len(context_names))
    ax.bar(x_positions, means,
           color=[context_colors[i] for i in range(3)],
           alpha=0.3)  # Reduced alpha to make individual points more visible

    # Add error bars
    ax.errorbar(x_positions, means, yerr=sems,
               fmt='none', color='black', capsize=5)

    # Plot individual session data points
    for ctx in [0, 1, 2]:
        # Calculate mean latency for each session
        for session_date in task_df['date'].unique():
            session_mask = (task_df['date'] == session_date) & (task_df['context'] == ctx)
            session_latencies = latencies[session_mask]
            valid_latencies = session_latencies[session_latencies >= 0]

            if len(valid_latencies) > 0:
                session_mean = np.mean(valid_latencies)
                # Add jitter to x-position to avoid overlapping points
                jitter = np.random.normal(0, 0.05)
                ax.scatter(x_positions[ctx] + jitter, session_mean,
                          color=context_colors[ctx],
                          s=50,  # Point size
                          alpha=0.7,
                          edgecolor='white',
                          linewidth=0.5)

    # Customize plot
    ax.set_xticks(x_positions)
    ax.set_xticklabels(context_names)
    ax.set_ylabel('Latency to choice\n(% of maze)')
    ax.set_ylim(20, 85)

    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_mean_latency_by_context(all_task_dfs['IS-2-1L'], latencies)
plt.show()

In [None]:
def plot_example_trajectories_predictions(predictions, all_choices, all_task_dfs, mouse_list, n_examples=1):
    """
    Plot example model prediction trajectories for each context, showing one left and one right choice.

    Args:
        predictions (np.array): Model predictions over time (n_trials, n_timepoints)
        all_choices (np.array): Actual choices made (n_trials)
        all_task_dfs (dict): Dictionary of task DataFrames
        mouse_list (list): List of mouse IDs
        n_examples (int): Number of example trajectories to plot per choice per context

    Returns:
        tuple: (fig, axs) matplotlib figure and axes objects
    """
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=800)

    context_names = ['Congruent', 'Visual', 'Audio']
    context_colors = {0: 'purple', 1: '#EC008C', 2: '#27AAE1'}

    # Get task_df
    task_df = pd.concat([all_task_dfs[mouse] for mouse in mouse_list], ignore_index=True)

    # Create maze position array (0-100%)
    maze_positions = np.linspace(0, 100, predictions.shape[1])

    for ctx_idx, ctx in enumerate([0, 1, 2]):
        ax = axs[ctx_idx]

        # Get trials for this context
        ctx_mask = task_df['context'] == ctx
        ctx_trials = np.where(ctx_mask)[0]

        # Get left and right choice trials
        left_trials = ctx_trials[all_choices[ctx_trials] == 0]
        right_trials = ctx_trials[all_choices[ctx_trials] == 1]

        # Select random correct trials for each choice
        correct_left = left_trials[task_df.iloc[left_trials]['outcome'] == 1]
        correct_right = right_trials[task_df.iloc[right_trials]['outcome'] == 1]

        # Randomly select trials
        np.random.seed(45)  # For reproducibility
        selected_left = np.random.choice(correct_left, size=min(n_examples, len(correct_left)), replace=False)
        selected_right = np.random.choice(correct_right, size=min(n_examples, len(correct_right)), replace=False)

        # Plot trajectories
        for trial_idx in selected_left:
            ax.plot(maze_positions, predictions[trial_idx],
                   color=context_colors[ctx],
                   alpha=0.7,
                   linestyle='-',
                   label='Left choice' if trial_idx == selected_left[0] else '')

        for trial_idx in selected_right:
            ax.plot(maze_positions, predictions[trial_idx],
                   color=context_colors[ctx],
                   alpha=0.7,
                   linestyle='--',
                   label='Right choice' if trial_idx == selected_right[0] else '')

        # Add horizontal line at 0.5
        ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5)

        # Customize plot
        ax.set_title(context_names[ctx_idx])
        ax.set_xlabel('% of maze')
        ax.set_ylabel('P(Right)' if ctx_idx == 0 else '')

        # Set axis limits
        ax.set_ylim(-0.1, 1.1)
        ax.set_xlim(0, 100)

        # Remove spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        # Add legend for first plot only
        if ctx_idx == 0:
            ax.legend(frameon=False)

    plt.tight_layout()
    return fig, axs

In [None]:
fig, axs = plot_example_trajectories_predictions(predictions_array, all_choices, all_task_dfs, mouse_list, n_examples=3)
plt.show()

# Dynamic Choice - Neural Data

In [None]:
class NeuralDynamicCNN(nn.Module):
    """CNN model for predicting choices from neural data at each timestep"""
    def __init__(self, n_neurons, dropout_rate=0.2):
        super(NeuralDynamicCNN, self).__init__()

        # First convolutional block with larger kernel
        self.conv1 = nn.Sequential(
            nn.Conv1d(n_neurons, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Residual blocks
        self.res_block1 = ResidualBlock(64, 128)
        self.res_block2 = ResidualBlock(128, 128)

        # Attention mechanism
        self.attention = TemporalAttention(128)

        # Final prediction layer
        self.prediction_layer = nn.Conv1d(128, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x shape: (batch_size, frames, n_neurons)
        # Transpose for CNN: (batch_size, n_neurons, frames)
        x = x.transpose(1, 2)

        # Apply convolutional blocks
        x = self.conv1(x)
        x = self.res_block1(x)
        x = self.res_block2(x)

        # Apply attention
        x = self.attention(x)

        # Final prediction at each timestep
        x = self.prediction_layer(x)
        x = self.sigmoid(x)

        # Return to (batch_size, frames) format
        return x.squeeze(1)


class ResidualBlock(nn.Module):
    """Residual block for the CNN"""
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        # Main path
        self.conv_path = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels)
        )

        # Skip connection
        self.skip = nn.Sequential()
        if in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm1d(out_channels)
            )

        self.relu = nn.ReLU()

    def forward(self, x):
        # Apply main path
        out = self.conv_path(x)
        # Apply skip connection
        out = out + self.skip(x)
        # Apply ReLU
        out = self.relu(out)
        return out


class TemporalAttention(nn.Module):
    """Temporal attention mechanism"""
    def __init__(self, channels):
        super(TemporalAttention, self).__init__()

        self.query = nn.Conv1d(channels, channels // 8, kernel_size=1)
        self.key = nn.Conv1d(channels, channels // 8, kernel_size=1)
        self.value = nn.Conv1d(channels, channels, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # x shape: (batch_size, channels, frames)
        batch_size, channels, frames = x.size()

        # Query, key, value projections
        proj_query = self.query(x).view(batch_size, -1, frames).permute(0, 2, 1)  # (B, F, C')
        proj_key = self.key(x).view(batch_size, -1, frames)  # (B, C', F)
        proj_value = self.value(x)  # (B, C, F)

        # Attention map
        energy = torch.bmm(proj_query, proj_key)  # (B, F, F)
        attention = self.softmax(energy)  # (B, F, F)

        # Apply attention
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, F)

        # Residual connection
        out = self.gamma * out + x

        return out

In [None]:
# Initialize model and criterion and optimizer
model = NeuralDynamicCNN(n_neurons=train_X.shape[2]).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Set epochs and training mode
num_epochs = 5
model.train()

# Train model
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(train_X)
    loss = criterion(outputs, train_y)
    loss.backward()
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
# Evaluate model
model.eval()
with torch.no_grad():
    test_outputs = model(test_X)
    test_loss = criterion(test_outputs, test_y)
    print(f'Test Loss: {test_loss.item():.4f}')