In [None]:
import numpy as np
from task_mask import dataset
from task_mask import plot_behav_per
from task_mask import plot_activity
from task_mask import perf_trials
from model import Net
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import scipy.io as sio
import random
import matplotlib.pyplot as plt

def set_seed(seed: int):
    """Fix Python, NumPy and PyTorch RNGs for reproducibility."""
    random.seed(seed) # Python
    np.random.seed(seed) # NumPy
    torch.manual_seed(seed) # Torch
    torch.cuda.manual_seed_all(seed)

def separation_loss(activity, mask, e_size=0.8):
    tau         = 0   # target separation: 0 = less correlation, 1 = more correlation
    eps         = 1e-6       # small constant
    n_exc = int(e_size * activity.shape[2])  # number of excitatory neurons
    h_LEVER = activity[mask[:,:,0]==2,:,0:n_exc].mean(dim=0) - activity[mask[:,:,0]==1,:,0:n_exc].mean(dim=0) # activity (Lever window -baseline)
    h_GNG = activity[mask[:,:,1]==3,:,0:n_exc].mean(dim=0) - activity[mask[:,:,1]==4,:,0:n_exc].mean(dim=0) # activity (Go - No-Go)
    cs = F.cosine_similarity(h_LEVER, h_GNG, dim=0, eps=eps) # cosine similarity
    penalty = (tau - cs.abs()).abs() # penalty for deviation from target task seperation
    return penalty, h_LEVER, h_GNG

seed_list = [0, 42, 1337, 271828, 314159, 1618033, 1414213, 1732051, 2236067, 57721566,
             1813382119, 827308000, 1627694679, 1911784258, 903170603, 86939547, 556019486, 2073320062, 1097954098, 1043521779]
seed_fail = [0, 314159, 1618033, 1732051, 2236067, 827308000, 1911784258, 86939547, 2073320062, 1043521779] # seed that failed to learn dual-task

In [None]:
dt = 50
input, target, _, _, _, _ = dataset(dt)
input_size = input.shape[-1]
output_size = target.shape[-1]
hidden_size = 100
num_trial = 500
num_DTT = 100

for f, seed in enumerate(seed_list):
    set_seed(seed)

    if seed in seed_fail:
        print(f'Skipping seed {seed} due to known failure.')
        continue

    net = Net(input_size=input_size, hidden_size=hidden_size, output_size=output_size, dt=dt)
    print(net)

    # Use Adam optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=1e-4)
    criterion = nn.MSELoss(reduction='sum')
    lambda_l1 = 1e-4
    lambda_sel  = 1
    running_loss = 0

    for i in range(1000):
        inputs, labels, _, _, _, _ = dataset(dt) 
        inputs = torch.from_numpy(inputs).float()
        labels = torch.from_numpy(labels.flatten()).float()

        outputs, activity = net(inputs)
        outputs = outputs.view(-1)

        loss = criterion(outputs, labels)/(labels != 0).sum() + lambda_l1 * torch.mean(torch.abs(activity))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
     
    activity_dict, performance_dict, trial_infos, task_onsets = {}, {}, {}, {}
    for i in range(num_trial):
        inputs, labels, _, _, task_onset, trial_type = dataset(dt, batch_size=1, DT=True)
        inputs = torch.from_numpy(inputs).float()
        labels = torch.from_numpy(labels.flatten()).float()

        action_pred, rnn_activity = net(inputs)

        performance_dict[i]  = action_pred[:, 0, :].detach().numpy()
        activity_dict[i] = rnn_activity[:, 0, :].detach().numpy()
        trial_infos[i] = trial_type
        task_onsets[i] = task_onset
    aligned_performance = plot_behav_per([-1,6],dt,num_trial,trial_infos,performance_dict,task_onsets)
    aligned_activity = plot_activity([-1,6],dt,num_trial,trial_infos,activity_dict,task_onsets, sort_task=None)
    sio.savemat('runs/20250808_decorrelate_only/run' + str(f) + '_pre.mat', {'aligned_performance': aligned_performance, 'aligned_activity': aligned_activity,\
                                                    'trial_infos': list(trial_infos.values()), 'runing_loss': running_loss})

    # dual-task learning
    # Freeze the input and output layer parameters
    for param in net.rnn.input2h.parameters():
        param.requires_grad = False
    for param in net.rnn.h2h.parameters():
        param.requires_grad = True
    for param in net.fc.parameters():
        param.requires_grad = False
    # Create an optimiser that only updates parameters with requires_grad=True
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01, weight_decay=1e-4)

    # index for updating using loss 1 (to improve task performance); the rest using loss 2 (to improve task performance and to separate)
    # extract index for dual-task trials
    idx_GO    = [k for k, v in trial_infos.items() if v[0] == 4]  # index for dual-task Go trials
    idx_NG    = [k for k, v in trial_infos.items() if v[0] == 5]  # index for dual-task No-Go trials
    # extract motor task and cognitive task activity
    aligned_activity = np.array(aligned_activity)  # shape: (num_trial, time_steps, n_neurons)
    activity_base    = aligned_activity[idx_NG, 0:20, :].mean(axis=0).mean(axis=0) # only use DT-NG trial activity for calculating CD_motor; baseline window: 0-20
    activity_LEVER = aligned_activity[idx_NG, 20:120, :].mean(axis=0).mean(axis=0) # only use DT-NG trial activity for calculating CD_motor; lever window: 20-120
    activity_GO    =  aligned_activity[idx_GO, 60:84, :].mean(axis=0).mean(axis=0) # Go cue + response window: 60-84
    activity_NG    =  aligned_activity[idx_NG, 60:84, :].mean(axis=0).mean(axis=0) # No-Go cue + response window: 60-84
    # find neurons that are negatively modulated by the motor task and the cognitive task: task coordination mediated by sacrificing one task for the other
    freeze_idx = np.nonzero(((activity_LEVER - activity_base) * (activity_GO - activity_NG)) < 0)[0]
    freeze_idx_tensor = torch.tensor(freeze_idx, dtype=torch.long)
    mask1 = torch.zeros(hidden_size)
    mask1[freeze_idx_tensor] = 1.0
    mask2 = 1.0 - mask1
    mask1 = mask1.to(net.rnn.h2h.weight.device).view(-1, 1)
    mask2 = mask2.to(net.rnn.h2h.weight.device).view(-1, 1)

    def mask_grad_hook(mask):
        def hook(grad):
            return grad * mask
        return hook

    running_loss = 0
    st_motorSuc_per, st_hit_per, st_cr_per = np.zeros(num_DTT), np.zeros(num_DTT), np.zeros(num_DTT)
    dt_motorSuc_per, dt_hit_per, dt_cr_per = np.zeros(num_DTT), np.zeros(num_DTT), np.zeros(num_DTT)
    for i in range(num_DTT):
        while True:
            inputs, labels, CD_mask, dual_motor_mask, task_onset, trial_type = dataset(dt, DT=True)
            inputs = torch.from_numpy(inputs).float()
            labels = torch.from_numpy(labels.flatten()).float()
            dual_motor_mask = torch.from_numpy(dual_motor_mask.flatten()).float()
            task_onset = torch.from_numpy(np.array(task_onset)).float()
            trial_type = torch.from_numpy(np.array(trial_type)).float()
            if trial_type.unique().numel() >= 5:
                break

        outputs, activity = net(inputs)
        outputs = outputs.view(-1)
        
        # calculate task performance for each step
        st_motorSuc_per[i], st_hit_per[i], st_cr_per[i], dt_motorSuc_per[i], dt_hit_per[i], dt_cr_per[i] = \
            perf_trials(outputs, task_onset, trial_type, dt)
        
        # Backward with gradient masking
        # motor performance outside dual-task phase and all cognitive performance
        loss_motor_outDT_cog = criterion(outputs[dual_motor_mask==0], labels[dual_motor_mask==0])  
        # within dual-task phase, only calculate the loss if the output is less than 0.3
        loss_motor_inDT = criterion(outputs[dual_motor_mask==1][outputs[dual_motor_mask==1]<0.3], labels[dual_motor_mask==1][outputs[dual_motor_mask==1]<0.3])
        loss1 = (loss_motor_outDT_cog + loss_motor_inDT)/(labels != 0).sum() + lambda_l1 * activity.abs().mean()
        loss2 = loss1 + lambda_sel*separation_loss(activity, CD_mask)[0]
        
        optimizer.zero_grad()
        # loss 1: focus on task performance
        handle = net.rnn.h2h.weight.register_hook(mask_grad_hook(mask1))
        loss2.backward(retain_graph=True)
        handle.remove()
        # loss 2: focus on task performance and task separation
        handle = net.rnn.h2h.weight.register_hook(mask_grad_hook(mask2))
        loss2.backward()
        handle.remove()
        optimizer.step()

        running_loss += loss1.item()
        print(f'Iteration {i+1}, Loss: {loss1:.4f}')
        print(f'separation_loss: {lambda_sel * separation_loss(activity, CD_mask)[0]:.4f}')


    activity_dict, performance_dict, trial_infos, task_onsets = {}, {}, {}, {}
    for i in range(num_trial):
        inputs, labels, _, _, task_onset, trial_type = dataset(dt, batch_size=1, DT=True)
        inputs = torch.from_numpy(inputs).float()
        labels = torch.from_numpy(labels.flatten()).float()

        action_pred, rnn_activity = net(inputs)
        performance_dict[i] = action_pred[:, 0, :].detach().numpy()
        activity_dict[i] = rnn_activity[:, 0, :].detach().numpy()

        trial_infos[i] = trial_type
        task_onsets[i] = task_onset
    aligned_performance = plot_behav_per([-1,6],dt,num_trial,trial_infos,performance_dict,task_onsets)
    aligned_activity = plot_activity([-1,6],dt,num_trial,trial_infos,activity_dict,task_onsets, sort_task=None)

    sio.savemat('runs/20250808_decorrelate_only/run' + str(f) + '_post.mat', {'aligned_performance': aligned_performance, 'aligned_activity': aligned_activity, \
                                                    'trial_infos': list(trial_infos.values()), 'runing_loss': running_loss, \
                                                    'st_motorSuc_per': st_motorSuc_per, 'st_hit_per': st_hit_per, 'st_cr_per': st_cr_per, \
                                                    'dt_motorSuc_per': dt_motorSuc_per, 'dt_hit_per': dt_hit_per, 'dt_cr_per': dt_cr_per})

    # check task separation
    _, h_LEVER, h_GNG =  separation_loss(activity, CD_mask) 
    plt.figure(figsize=(5, 5))
    plt.scatter(h_LEVER.detach().numpy(), h_GNG.detach().numpy(), c='blue')
    plt.xlabel('activity in motor task')
    plt.ylabel('activity in cognitive task')
    plt.show()
    print(F.cosine_similarity(h_LEVER, h_GNG, dim=0, eps=1e-6))

    # check performance
    fig, ax = plt.subplots(1, 2, figsize=(6, 3), sharex=True, sharey=True)
    ax[0].plot(st_motorSuc_per, label='st_motorSuc')
    ax[0].plot(dt_motorSuc_per, label='dt_motorSuc')
    ax[0].set_ylabel('motor task performance')
    ax[1].plot(st_hit_per+st_cr_per, label='st_GNG')
    ax[1].plot(dt_hit_per+dt_cr_per, label='dt_GNG')
    ax[1].set_ylabel('cognitive task performance')
    plt.xlim([0,num_DTT])
    plt.tight_layout()
    plt.show()