In [None]:
import sys
import os
import matplotlib.pyplot as plt
from pprint import pprint
import pickle

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../src/")))


from src.pytorch_models import LightningRNNModule

from src.dataset import YangTasks
from torch.utils.data import DataLoader
from src.dataset import collate_fn

# Load models 
Models should be stored as a checkpoint folder with 
``` 
checkpoint_dir 
|- epoch_perf_.ckpt   # best performing model 
|- hp_pl_module.pkl   # model hyperparameters
|- last.ckpt          # final model after full training 
|- task_hp.pkl        # task hyperparameters 
``` 



In [None]:
checkpoint_dir = "saved_models/cernn/floral-spaceship-693_contextdelaydm1_contextdelaydm2_V1_5_FEF_5_3b_5"

with open(
    f"../{checkpoint_dir}/hp_pl_module.pkl",
    "rb",
) as file:
    hp_pl_module = pickle.load(file)

with open(
    f"../{checkpoint_dir}/task_hp.pkl",
    "rb",
) as file:
    task_hp = pickle.load(file)

# load last model by default
pretrained_model = LightningRNNModule.load_from_checkpoint(
    f"../{checkpoint_dir}/last.ckpt"
)

In [None]:
pretrained_model

### Recurrent weight matrix 

In [None]:
recurrent_weights = pretrained_model.model.rnn.rnncell.weight_hh.detach().cpu().numpy()
recurrent_weights.shape

## Regularisers with values

In [None]:
pprint(dict(hp_pl_module.regularisers))

# Cortical embedding object with CE info 
E.g. 
```
cortical_areas
duplicates 
distance_matrix
sensory and motor areas 
dmn_areas
```
and values used for this model



In [None]:
cortical_embeddign = pretrained_model.model.ce
vars(cortical_embeddign)

## Corresponding dataset and dataloader
This is useful for activity analysis but not for connectivity

Input dimension is 2 rings x 2 dims per ring + 26 task IDs + 1 fixation = 31 

Ouput dimension is 2 dims per ring + 1 fixation = 3 

In [None]:
dataset_test = YangTasks(task_hp, mode="test")

dataloader_test = DataLoader(
    dataset_test,
    batch_size=1,  # batch size 1 here because we need all trials in batch to be same task/rule
    collate_fn=collate_fn,
    num_workers=0,
    shuffle=True,
)

In [None]:
trail_batch = next(iter(dataloader_test))

In [None]:
x = trail_batch.x.squeeze(0)
y = trail_batch.y.squeeze(0)

print(x.shape)  # [Timesteps, batchsize, D_in]
print(y.shape)  # [Timesteps, batchsize, D_out]

# Analysis 

In [None]:
from src.analysis_connectivity import (
    fig_3_plot_connectivity_matrix,
    fig_3_weights_over_distance_lambda_fitted,
    fig_3_FLN_matrix,
)

In [None]:
fig1 = fig_3_plot_connectivity_matrix(pretrained_model.model)
fig2 = fig_3_FLN_matrix(pretrained_model.model)
fig3 = fig_3_weights_over_distance_lambda_fitted(pretrained_model.model)