In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.abspath('..'))

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
plt.style.use('../plotting/style_sheet.mplstyle')
from synthetic_data.dataset import Dataset
from plotting import plotting_utils
import sklearn
from toy_rnn.dataset import Dataset
from toy_rnn.rnn import MultiSetRNN

In [None]:
# Configuration
BATCH_SIZE = 32
NUM_SETS = 5
NUM_NEURONS = 100
HIDDEN_SIZE = 64
LATENT_DIM = 2
LATENT_TIMESCALE = 1.0
DT = 0.1
TRIAL_DURATION = 10.0

In [None]:
# Dataset, responsible for generating "synthetic" data 
dataset = LinearGPDataset(
    
    num_sets=2,
    num_neurons=100,
    latent_dim=2,
    latent_timescale=0.1,
    dt=0.01,
    max_firing_rate=100.0,
    trial_duration=2.0,
)

In [9]:
# Get a trial of data 
data_iter = iter(dataset)
trial = next(data_iter)
time, latents, spike_counts, dataset_idx = trial['time'], trial['latents'], trial['spike_counts'], trial['dataset_idx']

In [None]:


def fit_pcs(data, n_pcs=None, ground_truth=None,):
    from sklearn.decomposition import PCA
    pca = PCA(n_components=n_pcs)
    data_centered = data - data.mean(dim=0, keepdim=True)
    pcs = pca.fit_transform(data_centered)
    if ground_truth is not None:
        from sklearn.linear_model import LinearRegression
        reg = LinearRegression(fit_intercept=False).fit(pcs, ground_truth)
        pcs_aligned = reg.predict(pcs)
        return pcs_aligned
    return pcs

# Calculate variance explained
r_sq = sklearn.metrics.r2_score(latents.numpy()[:, :N_PCS], pcs)
print(f'Variance explained by first {N_PCS} PCs: {r_sq:.3f}')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")

# 1. Init Data
dataset = Dataset(
    num_sets=NUM_SETS, 
    num_neurons=NUM_NEURONS, 
    latent_dim=LATENT_DIM, 
    latent_timescale=LATENT_TIMESCALE, 
    dt=DT, 
    trial_duration=TRIAL_DURATION)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)

# 2. Init Model
model = MultiSetRNN(
    num_sets=NUM_SETS,
    input_size=NUM_NEURONS,    # Input is spike counts
    hidden_size=HIDDEN_SIZE,   # Size of LSTM hidden state
    latent_dim=LATENT_DIM      # Output target dimension
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# 3. Training Loop
# Note: Since it's an IterableDataset with 'while True', you need a break condition
max_steps = 1000

for step, batch in enumerate(dataloader):
    if step >= max_steps:
        break
    
    # Move data to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Unpack data
    x = batch["spike_counts"]  # (Batch, Time, Neurons)
    y = batch["latents"]       # (Batch, Time, Latents)
    indices = batch["dataset_idx"] # (Batch)
    
    # Forward pass
    # We pass both the data AND the indices so the model knows how to route
    predictions = model(x, indices)
    
    # Loss calculation
    loss = criterion(predictions, y)
    
    # Backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")