## import

In [None]:
from perceptual_discrimination import PerceptualDiscrimination
import torch, torch.nn as nn, torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import math
import os
from EI_RNN import Net, compute_loss, accuracy
from torch.optim.lr_scheduler import StepLR

## Generate Training data 

In [None]:
# -- 任务参数 --
dt= 10 # time step in ms
tau = 100 # time constant in ms
T = 2000 # total time in ms
N_batch = 64 # batch size
coherence = [0, 0.05,0.1, 0.15, 0.2, 0.3, 0.5, 0.7]

task = PerceptualDiscrimination(dt, tau, T, N_batch, coherence=coherence)
print("task loaded")

dt= 10 # time step in ms
tau = 100 # time constant in ms
T = 2000 # total time in ms
N_batch1 = 200 # batch size
task_val = PerceptualDiscrimination(dt, tau, T, N_batch1,coherence=coherence)
print("val_task loaded")

device = 'cpu'
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

In [None]:
x, y, mask, trial_params = task.get_trial_batch()
network_params = task.get_task_params()
print(network_params)
print(f"x shape:{x.shape}, y shape: {y.shape}, mask shape: {mask.shape}")

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(8,3), sharey=False)
sample_index = 4

time_input = range(0, len(x[sample_index,:,:])*dt, dt)
time_output = range(0, len(y[sample_index,:,:])*dt, dt)
mask_sample = mask[sample_index,:,:].mean(axis=-1)  # shape: [time]
masked_indices = np.where(mask_sample > 0)[0]
if len(masked_indices) > 0:
    xmin = masked_indices[0] * dt
    xmax = masked_indices[-1] * dt

axes[0].plot(time_input, x[sample_index,:,:])
axes[0].set_ylabel("Input Magnitude")
axes[0].set_xlabel("Time (ms)")
axes[0].set_title(f"Input(coherece={trial_params[sample_index]['coherence']})")
axes[0].axvspan(xmin, xmax, facecolor='silver', alpha=0.5)

axes[1].plot(time_output, y[sample_index,:,:])
axes[1].set_ylabel("Output")
axes[1].set_xlabel("Time (ms)")
axes[1].set_title("Output Data")
axes[1].axvspan(xmin, xmax, facecolor='silver', alpha=0.5)
plt.tight_layout()

## Train a model

In [None]:
# define save path
save_root = "./savemodels/"
os.makedirs(save_root, exist_ok=True) # 如果不存在则创建，不会报错

# -- 模型参数 --
epoch_num = 2500  # 调试用可先设小一点
hidden_size = 50 # 隐藏层大小
lr = 5e-3 # 学习率
l1_lambda = 1e-4  # L1正则化系数

mode = 'none'  # 'dense', 'block', o`r 'none' 
model = Net(input_size=2, 
            hidden_size=hidden_size, 
            output_size=2, dt=dt, 
            sigma_rec=0.15,
            mode=mode,
            noneg=True, 
            with_Tanh=False).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                       factor=0.5, patience=5)
best_loss = float('inf')
best_state_dict = None

In [None]:
for epoch in range(1, epoch_num + 1):
    x, y, mask, _ = task.get_trial_batch()

    model.train()
    x = torch.tensor(x, dtype=torch.float32).permute(1, 0, 2).to(device)
    y = torch.tensor(y, dtype=torch.float32).permute(1, 0, 2).to(device)
    mask = torch.tensor(mask, dtype=torch.float32).permute(1, 0, 2).to(device)

    y_pred, _ = model(x)

    loss = compute_loss(y_pred, y, mask)
    
    # ---- add L1 ----
    # l1_norm = sum(param.abs().sum() for name, param in model.named_parameters() if 'h2h.weight' in name)
    # loss = loss + l1_lambda * l1_norm
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    if epoch % 50 == 0 or epoch==1:
        model.eval()
        x, y, mask, _ = task_val.get_trial_batch()
        x = torch.tensor(x, dtype=torch.float32).permute(1, 0, 2).to(device)
        y = torch.tensor(y, dtype=torch.float32).permute(1, 0, 2).to(device)
        mask = torch.tensor(mask, dtype=torch.float32).permute(1, 0, 2).to(device)

        with torch.no_grad():
            y_pred, _ = model(x)
            val_loss = compute_loss(y_pred, y, mask)
            acc = accuracy(y_pred, y, mask)

            lr = optimizer.param_groups[0]['lr']
            print(f"[{lr=} | {hidden_size=} | Epoch {epoch}] Train-val_Loss: {loss.item():.4f}- {val_loss.item():.4f}, Acc: {acc:.3f}")
            scheduler.step(val_loss)

            if val_loss.item() < best_loss:
                best_loss = val_loss.item()
                best_state_dict = model.state_dict()
                
                model_path = f"{save_root}{mode}model_lr{lr:.0e}_hidden{hidden_size}_loss{best_loss:.4f}.pt"
                torch.save(best_state_dict, model_path)
                print(f"✅ Best model for hidden_size={hidden_size} saved to: {model_path}")

## Evaluate

In [None]:
dt= 10 # time step in ms
tau = 100 # time constant in ms
T = 2000 # total time in ms
N_batch = 500 # batch size

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'
print("loading task")
task_test = PerceptualDiscrimination(dt, tau, T, N_batch)

In [None]:
hidden_size = 50
mode = 'none'  # 'dense', 'block', o`r 'none'
net = Net(input_size=2, hidden_size=hidden_size, 
            output_size=2, dt=dt, sigma_rec=0.15,mode=mode,noneg=False,with_Tanh=False).to(device)
model_path = r"muti_models4_noneg\nonemodel_lr1e-04_hidden50_loss0.0093.pt"
net.load_state_dict(torch.load(model_path,map_location=device))

In [None]:
net.eval()

# 使用测试数据
x_test, y_test, mask_test, trial_params = task_test.get_trial_batch()
test_inputs = torch.tensor(x_test, dtype=torch.float32).permute(1, 0, 2).to(device)
test_targets = torch.tensor(y_test, dtype=torch.float32).permute(1, 0, 2).to(device)
test_masks = torch.tensor(mask_test, dtype=torch.float32).permute(1, 0, 2).to(device)

with torch.no_grad():
    test_outputs,rnn_activity = net(test_inputs)
    test_acc = accuracy(test_outputs, test_targets, test_masks)

print(f"[Best Model] Test Accuracy: {test_acc:.3f}")


In [None]:
import matplotlib.pyplot as plt
idx = 9

plt.figure(figsize=(4,2))
plt.plot(range(0, len(test_inputs.cpu().numpy()[:,idx,:])*dt,dt), test_inputs.cpu().numpy()[:,idx,:])
plt.ylabel("Input Magnitude")
plt.xlabel("Time (ms)")
plt.title("Input Data")
plt.legend(["Input Channel 1", "Input Channel 2"])

plt.figure(figsize=(4,2))
plt.plot(range(0, len(test_outputs.cpu().numpy()[:,idx,:])*dt,dt),test_outputs.cpu().numpy()[:,idx,:])
plt.ylabel("Activity of Output Unit")
plt.xlabel("Time (ms)")
plt.title("Output on New Sample")
plt.legend(["Output Channel 1", "Output Channel 2"])