Here, I test whether my LSTM net works or not to estimate SGM parameters 

Now I run the real data from Parul (Apr 2, 2023)

Convert to dB Scale

In [1]:
RUN_PYTHON_SCRIPT = True

True

In [2]:
import sys
sys.path.append("../mypkg")
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT

In [3]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
from easydict import EasyDict as edict
from tqdm import trange, tqdm
from joblib import Parallel, delayed
import time

plt.style.use(FIG_ROOT/"base.mplstyle")

In [4]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

In [5]:
from utils.reparam import theta2raw_torch, raw2theta_torch, raw2theta_np
from spectrome import Brain
from sgm.sgm import SGM
from utils.stable import paras_stable_check
from utils.misc import save_pkl, save_pkl_dict2folder, load_pkl, load_pkl_folder2dict, delta_time
from models.lstm import LSTM_SGM
from models.model_utils import weights_init
from models.loss import  weighted_mse_loss, reg_R_loss, lin_R_loss, lin_R_fn, reg_R_fn
from utils.standardize import std_mat, std_vec

In [6]:
# pkgs for pytorch ( Mar 27, 2023) 
import torch
import torch.nn as nn
from torch.functional import F
from torch.optim.lr_scheduler import ExponentialLR

df_dtype = torch.float32
torch.set_default_dtype(df_dtype)
if torch.cuda.is_available():
    torch.set_default_device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    torch.set_default_device("cpu")

In [7]:
seed = 1
import random
random.seed(seed)
np.random.seed(seed);
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True); 

# Data, fn and paras

In [8]:
import netCDF4
fils = list(DATA_ROOT.glob("*s100tp.nc")) # 300/150
file2read = netCDF4.Dataset(fils[0], 'r')
psd_all_full = np.array(file2read.variables["__xarray_dataarray_variable__"][:])
psd_all_full = 10 * np.log10(psd_all_full) # to dB scale, 
# make it num_sub x num_roi x num_freqs x num_ts
psd_all_full = psd_all_full.transpose(3, 0, 1, 2)
time_points = np.array(file2read.variables["timepoints"][:])
freqs = np.array(file2read.variables["frequencies"][:])
ROIs_order = np.array(file2read.variables["regionx"][:])
file2read.close()

In [24]:
freqs

array([ 2.7244898 ,  3.83673469,  4.94897959,  6.06122449,  7.17346939,
        8.28571429,  9.39795918, 10.51020408, 11.62244898, 12.73469388,
       13.84693878, 14.95918367, 16.07142857, 17.18367347, 18.29591837,
       19.40816327, 20.52040816, 21.63265306, 22.74489796, 23.85714286,
       24.96938776, 26.08163265, 27.19387755, 28.30612245, 29.41836735,
       30.53061224, 31.64285714, 32.75510204, 33.86734694, 34.97959184,
       36.09183673, 37.20408163, 38.31632653, 39.42857143, 40.54081633,
       41.65306122, 42.76530612, 43.87755102, 44.98979592])

In [9]:
#I remove the first and last time pts
rm_lim = 5
if rm_lim > 0:
    psd_all_full = psd_all_full[:, :, :, rm_lim:-rm_lim]
    time_points = time_points[rm_lim:-rm_lim];

In [10]:
# Load the Connectome
brain = Brain.Brain()
brain.add_connectome(DATA_ROOT)
brain.reorder_connectome(brain.connectome, brain.distance_matrix)
brain.bi_symmetric_c()
brain.reduce_extreme_dir()

In [11]:
# some constant parameters for this file
paras = edict()

## I reorder them in an alphabetical order and I change tauC to tauG (Mar 27, 2023)
## the orginal order is taue, taui, tauC, speed, alpha, gii, gei
## paras.par_low = np.asarray([0.005,0.005,0.005,5, 0.1,0.001,0.001])
## paras.par_high = np.asarray([0.03, 0.20, 0.03,20,  1,    2,  0.7])
##

# alpha, gei, gii, taue, tauG, taui, speed 
paras.par_low = np.array([0.1, 0.001,0.001, 0.005, 0.005, 0.005, 5])
paras.par_high = np.asarray([1, 0.7, 2, 0.03, 0.03, 0.20, 20])
paras.prior_bds = np.array([paras.par_low, paras.par_high]).T
paras.names = ["alpha", "gei", "gii", "Taue", "TauG", "Taui", "Speed"]

paras.C = brain.reducedConnectome
paras.D = brain.distance_matrix
paras.freqs = freqs

# Train the model

In [12]:
trained_model = load_pkl_folder2dict(RES_ROOT/"SGM_net", excluding=['opt*'])
sgm_net = trained_model.model;
sgm_net.to(dtype=df_dtype);
sgm_net.eval();

Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/SGM_net/freqs.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/SGM_net/loss.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/SGM_net/loss_test.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/SGM_net/model.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/SGM_net/paras.pkl


In [13]:
# functions to generate training sample (Apr 1, 2023)
def random_choice(n, batchsize=1, len_seg=None):
    """Randomly select the lower and upper bound of the segment
        args:
            n: len of the total time series
    """
    if len_seg is None:
        len_seg = torch.randint(low=10, high=100, size=(1, ))
    up_bd = torch.randint(low=len_seg.item(), high=n, size=(batchsize, ))
    low_bd = up_bd - len_seg
    return low_bd, up_bd


def random_samples_rnn(X, Y=None, batchsize=1, 
                       bds=None, 
                       is_std=True, 
                       theta2raw_fn=None):
    """Randomly select a sample from the whole segment
        args:
            X: PSD, num_seq x 68 x nfreq or 
               PSD, num_sub x num_seq x 68 x nfreq
            Y: params, num x 7, in original sgm scale
        return:
            X_seqs: len_seq x batchsize x num_fs
            Y_seqs: len_seq x batchsize x 7
            
    """
    if X.ndim == 4:
        # if multiple subjects, pick up a subject
        num_sub = X.shape[0]
        sub_idx = np.random.randint(low=0, high=num_sub)
        X = X[sub_idx]
        
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X)
    if is_std:
        #X = X/X.std(axis=(1, 2), keepdims=True)
        # Let std for each ROI and each data
        X = (X-X.mean(axis=2, keepdims=True))/X.std(axis=2, keepdims=True)
    if Y is not None:
        if not isinstance(Y, torch.Tensor):
            Y = torch.tensor(Y)
        if theta2raw_fn: 
            Y = theta2raw_fn(Y)
    if bds is None:
        low_bds, up_bds = random_choice(len(X), batchsize)
    else:
        low_bds, up_bds = bds

    X = X.flatten(1)
    X_seqs = []
    Y_seqs = []
    for low_bd, up_bd in zip(low_bds, up_bds):
        X_seq = X[low_bd:up_bd, :].unsqueeze(1)
        X_seqs.append(X_seq)
        if Y is not None:
            Y_seq = Y[low_bd:up_bd].unsqueeze(1)
            Y_seqs.append(Y_seq)
    if Y is not None:
        return torch.cat(X_seqs, dim=1), torch.cat(Y_seqs, dim=1)
    else:
        return torch.cat(X_seqs, dim=1)
        

In [14]:
def _evaluate(all_data):
    num_sub, len_seq, _, _ = all_data.shape
    all_data_raw = torch.tensor(all_data, dtype=df_dtype).transpose(1, 0)
    all_data_input = (all_data_raw - all_data_raw.mean(axis=-1, keepdims=True))/all_data_raw.std(axis=-1, keepdims=True);
    all_data_input = all_data_input.flatten(2);
    
    with torch.no_grad():
        Y_pred = rnn(all_data_input);
        X_pred = sgm_net(Y_pred.flatten(0, 1));
    corrs = reg_R_fn(all_data_raw.flatten(0, 1), X_pred);
    corrs = corrs.reshape(len_seq, num_sub, -1).transpose(1, 0)
    return corrs.detach().numpy()

In [15]:
paras_rnn = edict()
# batchsize is not in fact used.
paras_rnn.batchsize = 128
paras_rnn.niter = 1000 #!!!! 500
paras_rnn.loss_out = 1
paras_rnn.eval_out = 20
paras_rnn.clip = 1 # from 
paras_rnn.lr_step = 300 #!!!! 10
paras_rnn.gamma = 0.5 #!!!! 0.9
paras_rnn.lr = 2e-4 

paras_rnn.k = 1
paras_rnn.hidden_dim = int(1024/4)
paras_rnn.output_dim = 7
paras_rnn.input_dim = 68*len(paras.freqs)
paras_rnn.is_bidirectional = False#!!!!False
paras_rnn.unstable_pen = 10000 # Whether to filter out the unstable sps or not, if 0 not, if large number, yes
paras_rnn.loss_name = "wmse" # linR, corr, wmse or mse
#paras.names = ["alpha", "gei", "gii", "Taue", "TauG", "Taui", "Speed"]
# 1 dynamic, 0 static
paras_rnn.dy_mask = [1, 1, 1, 1, 1, 1, 0] 
stat_part = "_".join(np.array(paras.names)[np.array(paras_rnn.dy_mask)==0][:-1])
if len(stat_part) > 0:
    folder_name = f"LSTM_simu_net_36meg_{paras_rnn.loss_name}_{stat_part}";
else:
    folder_name = f"LSTM_simu_net_36meg_{paras_rnn.loss_name}";
paras_rnn.save_dir = RES_ROOT/folder_name


psd_all = psd_all_full
#  all_data is the real data, should be num_sub x len_seq x nrois x nfreqs
#  or len_seq x nrois x nfreqs
all_data = psd_all.transpose(0, 3, 1, 2)

all_data_raw = torch.tensor(all_data, dtype=df_dtype).transpose(1, 0)
all_data_input = (all_data_raw - all_data_raw.mean(axis=-1, keepdims=True))/all_data_raw.std(axis=-1, keepdims=True);
all_data_input = all_data_input.flatten(2);

In [16]:
rnn = LSTM_SGM(input_dim=paras_rnn.input_dim, 
               hidden_dim=paras_rnn.hidden_dim, 
               output_dim=paras_rnn.output_dim, 
               is_bidirectional=paras_rnn.is_bidirectional, 
               prior_bds=torch.tensor(paras.prior_bds, dtype=df_dtype), 
               k = paras_rnn.k, 
               dy_mask = paras_rnn.dy_mask
)
rnn.apply(weights_init)
rnn.to(dtype=df_dtype);
if paras_rnn.loss_name.startswith("corr"):
    loss_fn = reg_R_loss
elif paras_rnn.loss_name.startswith("linR"):
    loss_fn = lin_R_loss
elif paras_rnn.loss_name.startswith("wmse"):
    loss_fn = weighted_mse_loss
elif paras_rnn.loss_name.startswith("mse"):
    loss_fn = nn.MSELoss()
else:
    raise KeyError("No such loss")

optimizer = torch.optim.AdamW(rnn.parameters(), lr=paras_rnn.lr, weight_decay=0)
scheduler = ExponentialLR(optimizer, gamma=paras_rnn.gamma, verbose=True)

Adjusting learning rate of group 0 to 2.0000e-04.


<torch.optim.lr_scheduler.ExponentialLR at 0x7f1169652a30>

In [17]:
# training
loss_cur = 0
losses = []
losses_test = []

t0 = time.time()
sgm_net.eval()
loss_add = 0
for ix in range(paras_rnn.niter):
    rnn.train()
    # Here because the whole dataset is not large, 
    # I use them as one batch
    # Of course, you can use random_samples_rnn to draw 
    # X_seq = random_samples_rnn(all_data, 
    #                           batchsize=paras_rnn.batchsize)
    X_seq = all_data_input
    # Zero the gradients
    optimizer.zero_grad()
    
    theta_pred = rnn(X_seq)
    X_pred = sgm_net(theta_pred.flatten(0, 1))
    loss_main = loss_fn(X_seq.flatten(0, 1).reshape(-1, 68, len(paras.freqs)),
                   X_pred)
    if paras_rnn.unstable_pen > 0:
        unstable_inds = paras_stable_check(theta_pred.flatten(0, 1).detach().numpy());
        unstable_inds = torch.tensor(unstable_inds).reshape(*theta_pred.shape[:2])
        loss_add = (paras_rnn.unstable_pen * unstable_inds.unsqueeze(-1) * theta_pred).mean();
    loss = loss_main + loss_add
    
    # Perform backward pass
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(rnn.parameters(), paras_rnn.clip)
    # Perform optimization
    optimizer.step()
    
    if ix % paras_rnn.lr_step == (paras_rnn.lr_step-1):
        scheduler.step()
    
    loss_cur = loss_cur + loss_main.item()
    if ix % paras_rnn.loss_out == (paras_rnn.loss_out-1):
        losses.append(loss_cur/paras_rnn.loss_out)
        print(f"At iter {ix+1}/{paras_rnn.niter}, "
              f"the losses are {loss_cur/paras_rnn.loss_out:.5f} (train). "
              f"The time used is {delta_time(t0):.3f}s. "
             )
        loss_cur = 0
        t0 = time.time()
        
    if ix % paras_rnn.eval_out == (paras_rnn.eval_out-1):
        rnn.eval()
        loss_test = _evaluate(all_data).mean()
        losses_test.append(loss_test)
        print(f"="*100)
        print(f"At iter {ix+1}/{paras_rnn.niter}, "
              f"the losses on all data are {loss_test:.5f}. "
              f"The time used is {delta_time(t0):.3f}s. "
             )
        print(f"="*100)
        t0 = time.time()
    


At iter 1/1000, the losses are 2.83012 (train). The time used is 1.875s. 
At iter 2/1000, the losses are 2.71379 (train). The time used is 1.805s. 
At iter 3/1000, the losses are 2.67381 (train). The time used is 1.893s. 
At iter 4/1000, the losses are 2.63547 (train). The time used is 1.811s. 
At iter 5/1000, the losses are 2.57832 (train). The time used is 1.879s. 
At iter 6/1000, the losses are 2.50872 (train). The time used is 1.786s. 
At iter 7/1000, the losses are 2.44638 (train). The time used is 1.786s. 
At iter 8/1000, the losses are 2.38293 (train). The time used is 1.821s. 
At iter 9/1000, the losses are 2.32352 (train). The time used is 1.780s. 
At iter 10/1000, the losses are 2.25683 (train). The time used is 1.794s. 
At iter 11/1000, the losses are 2.19898 (train). The time used is 1.812s. 
At iter 12/1000, the losses are 2.14877 (train). The time used is 1.792s. 
At iter 13/1000, the losses are 2.10108 (train). The time used is 1.812s. 
At iter 14/1000, the losses are 2.

# Save

In [18]:
if (paras_rnn.save_dir).exists():
    trained_model = load_pkl_folder2dict(paras_rnn.save_dir)
else:
    trained_model = edict()
    trained_model.model = rnn
    trained_model.loss_fn = loss_fn
    trained_model.optimizer = optimizer
    trained_model.paras = paras_rnn
    trained_model.loss = losses
    save_pkl_dict2folder(paras_rnn.save_dir, trained_model, is_force=True)

/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse
Create a folder /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/model.pkl
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss_fn.pkl
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/optimizer.pkl
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/paras.pkl
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss.pkl


# PSD 

In [19]:
trained_model = load_pkl_folder2dict(paras_rnn.save_dir);

Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss_fn.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/model.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/optimizer.pkl
Load file /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/paras.pkl


In [20]:
trained_model.model.eval()
with torch.no_grad():
    Y_pred = trained_model.model(all_data_input)
sgm_paramss_est = Y_pred.cpu().numpy().transpose(1, 0, 2)
trained_model.sgm_paramss_est = sgm_paramss_est
save_pkl_dict2folder(paras_rnn.save_dir, trained_model, is_force=False)

/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss_fn.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/model.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/optimizer.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/paras.pkl exists! Use is_force=True to save it anyway
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/sgm_paramss_est.pkl


In [23]:
# calculate rec PSD and save, only need once
sgmmodel = SGM(paras.C, paras.D, paras.freqs)
def _run_fn(sgm_param):
    cur_PSD = sgmmodel.run_local_coupling_forward(sgm_param)
    return cur_PSD[:68]
X_recs = []
for sgm_params_est in tqdm(trained_model.sgm_paramss_est):
    if np.sum(paras_rnn.dy_mask) == 0:
        # only for all static model
        X_rec = _run_fn(sgm_params_est[0])
        X_rec = np.tile(X_rec, (len(sgm_params_est), 1, 1))
    else:
        with Parallel(n_jobs=20) as parallel:
            X_rec = parallel(delayed(_run_fn)(param) for param in sgm_params_est)
    X_recs.append(X_rec)
    
# save
trained_model.Rec_PSD = np.array(X_recs)
save_pkl_dict2folder(paras_rnn.save_dir, trained_model, is_force=False)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [08:15<00:00, 13.75s/it]


/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/loss_fn.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/model.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/optimizer.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/paras.pkl exists! Use is_force=True to save it anyway
/data/rajlab1/user_data/jin/MyResearch/TV-SGM/notebooks/../mypkg/../results/LSTM_simu_net_36meg_wmse/sgm_paramss_est.pkl exists! Use is_force=True to save it anyway
Save to /data/rajlab1/user_data/jin/MyResearch/TV-SGM/n