In [None]:
import numpy as np
import torch
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
from ml import BrainDataset, PolicyEstimator, reinforce

In [None]:
# Input Parameters
num_fns = 6
save_path = None
load_path = None
res = 68
subj = [0]
epochs = 5000
batch = 1
hidden_units = 20
lr = 0.005
save_path = None
load_path = None

In [None]:
# Read brain data (change file locations as necessary)
sc = loadmat(f'data/subjfiles_SC{res}.mat')
fc = loadmat(f'data/subjfiles_FC{res}.mat')
sc = np.array([sc[f's{str(z+1).zfill(3)}'] for z in subj])
fc = np.array([fc[f's{str(z+1).zfill(3)}'] for z in subj])
euc_dist = loadmat('data/euc_dist.mat')[f'eu{res}']
hubs = np.loadtxt(f'data/hubs_{res}.txt', dtype=np.int, delimiter=',')
regions = np.loadtxt(f'data/regions_{res}.txt', dtype=np.int, delimiter=',')

In [None]:
# Init network parameters
pe = PolicyEstimator(res, num_fns)
opt = torch.optim.Adam(pe.network.parameters(), lr=lr)

In [None]:
# Init new/load previous training data
if load_path:
    # Load from checkpoint
    checkpoint = torch.load(load_path)
    plt_data = {k: checkpoint[k] for k in ('rewards','success','mu','sig','train_idx','test_idx')}
    pe.network.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    # New
    plt_data = {
        'rewards': [],
        'success': [],
        'mu': [[] for _ in range(num_fns)],
        'sig': [[] for _ in range(num_fns)]}
    plt_data['train_idx'], plt_data['test_idx'] = train_test_split(subj, train_size=0.7) if len(subj) > 1 else (subj, [])

In [None]:
# Train / test split
train_idx, test_idx = plt_data['train_idx'], plt_data['test_idx']
train_data = BrainDataset(sc[train_idx], fc[train_idx], euc_dist, hubs, regions)
test_data =  BrainDataset(sc[test_idx],  fc[test_idx],  euc_dist, hubs, regions)

In [None]:
# Reinforce and save after each epoch
reinforce(pe, opt, train_data, epochs=epochs, batch=batch, lr=lr, plt_data=plt_data, inc_plt=True, plt_freq=5, plt_off=0, plt_avg=50)

In [None]:
# Most recent reward
plt_data['rewards'][-1]