In [None]:
import numpy as np
from scipy.io import loadmat
from brain import Brain
from preferred_path import PreferredPath
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.nn import Sequential, Linear, ReLU
from torch.optim import Adam
from torch.distributions import Normal
from IPython import display
from scipy.stats import norm

# Global vars

In [None]:
res = 68 # Brain resolution
fns = 4  # Number of criteria in path algorithm

# Train/test node pair splits

In [None]:
def train_test_ind(res, train_pct=0.7):
    rows = res * (res - 1)
    rand = np.zeros((rows,2), dtype=np.int)
    k = 0
    for i in range(res-1):
        for j in range(i+1, res):
            rand[k] = (i,j)
            rand[k+1] = (j,i)
            k += 2
    np.random.shuffle(rand)
    train_rows = int(train_pct * rows)
    train_ind, test_ind = rand[:train_rows,:], rand[train_rows:,:]
    return (train_ind[:,0], train_ind[:,1]), (test_ind[:,0], test_ind[:,1])

In [None]:
train_ind, test_ind = train_test_ind(res)

# Brain data

In [None]:
class BrainDataset(Dataset):
    def __init__(self, sc, fc, euc_dist, hubs):
        """
        Parameters
        ----------
        sc, fc : numpy.ndarray
            Connectivity matrices for each subject
            3D with shape: (number of subjects, resolution, resolution)
        euc_dist : numpy.ndarray
            Euclidean distance matrix
            2D with shape: (resolution, resolution)
        hubs : numpy.ndarray
            Array with the indexes of the hub nodes
        """

        n = len(sc)
        res = len(euc_dist)

        # Init vars
        triu = int(res * (res - 1) / 2)
        triu_i = np.triu_indices(res, 1)
        self.adj = np.zeros((n, triu))
        self.sp = np.zeros((n, res, res))
        self.pp = [None] * n

        # Fill vars
        for i in range(n):
            brain = Brain(sc[i], fc[i], euc_dist, hubs=hubs)
            streamlines = brain.streamlines()
            node_str = brain.node_strength(weighted=False)
            is_target = brain.is_target
            is_hub = brain.hubs(binary=True)
            fns = [
                lambda loc, nxt, prev, target: streamlines[loc,nxt],
                lambda loc, nxt, prev, target: node_str[nxt],
                lambda loc, nxt, prev, target: is_target(nxt, target),
                lambda loc, nxt, prev, target: is_hub[nxt]]
            weights = list(np.random.random(size=len(fns)))
            self.adj[i] = brain.sc_bin[triu_i]
            self.sp[i] = brain.shortest_paths()
            self.pp[i] = PreferredPath(adj=brain.sc_bin, fn_vector=fns, fn_weights=weights)

    def __len__(self):
        return len(self.adj)

    def __getitem__(self, idx):
        return (self.adj[idx], self.sp[idx], self.pp[idx])

In [None]:
def load_con(con, res, subj=None):
    if subj is None:
        subj = np.arange(1, 485)
    mat_data = loadmat(f'data/subjfiles_{con}{res}.mat')
    return np.array([mat_data[f's{str(z).zfill(3)}'] for z in subj])

def load_data(res, subj=None):
    sc = load_con('sc', res, subj)
    fc = load_con('fc', res, subj)
    euc_dist = loadmat('data/euc_dist.mat')[f'eu{res}']
    hubs = np.loadtxt(f'data/hubs_{res}.txt', dtype=np.int, delimiter=',')
    return BrainDataset(sc, fc, euc_dist, hubs)

In [None]:
data = load_data(res)

# Policy estimator network

In [None]:
class PolicyEstimator():
    def __init__(self, res, fn_len, hidden_units=20, load_path=None):
        self.n_inputs = int(res * (res - 1) / 2)
        self.n_outputs = fn_len * 2 # includes both mean and ln(sigma)
        self.network = Sequential(
            Linear(self.n_inputs, hidden_units),
            ReLU(),
            Linear(hidden_units, self.n_outputs))
        if load_path:
            self.network.load_state_dict(torch.load(load_path))

    def predict(self, state):
        return self.network(torch.FloatTensor(state))

In [None]:
pe = PolicyEstimator(res, fns)

# Reward function

In [None]:
def reward(pred, sp):
    inf_mask = np.isinf(pred)
    r = 1 / (1 + pred - sp)
    r[inf_mask] = -1
    return r.mean()

# Plotting helpers

In [None]:
def cust_plot(ax, x, y, title=None, xlab=None, ylab=None, labels=None, off=0, avg=None):
    inst = lambda M: any(isinstance(M, j) for j in [list, np.ndarray])
    get = lambda M: (lambda i: M[i]) if inst(M[0])  else (lambda i: M)
    len2d = lambda M: len(M) if inst(M[0]) else 1
    getx = get(x)
    gety = get(y)
    ax.set_title(title)
    ax.set_xlabel(xlab)
    ax.set_ylabel(ylab)
    for i in range(len2d(y)):
        yi = gety(i)[off:]
        xi = getx(i)[off:]
        label = labels[i] if labels else None
        ax.plot(xi, yi, label=label)
        if avg:
             x_avg, y_avg = move_avg(yi, avg)
             ax.plot(x_avg + xi[0], y_avg, label=f'{label} {avg} point avg')
    if labels:
        ax.legend()
    return ax

def cust_plot_pdf(ax, mu, sig, title=None, xlab=None, ylab=None, labels=None):
    xmin = (mu - 3 * sig).min()
    xmax = (mu + 3 * sig).max()
    x = np.arange(xmin, xmax, 0.001)
    y = [norm.pdf(x, mu[i], sig[i]) for i in range(len(mu))]
    return cust_plot(ax=ax, x=x, y=y, title=title, xlab=xlab, ylab=ylab, labels=labels)

def move_avg(y, p):
    if len(y) < p:
        return np.array([]), np.array([])
    c = np.cumsum(y)
    y_avg = (c[p:] - c[:-p]) / p
    x_avg = np.arange(len(y_avg)) + p
    return x_avg, y_avg

def plot_all(batch, plt_avg=None, plt_off=0, figsize=(20,30)):
    _, ax = plt.subplots(nrows=4, ncols=2, figsize=figsize, facecolor='w')
    len_plt_rewards = len(plt_train_rewards)
    len_data = len(data)
    x = np.arange(len_plt_rewards) + 1
    fn_labs = [f'fn{j+1}' for j in range(fns)]
    def_plot = lambda ax, y, ylab, labels, title, avg=None: cust_plot(ax, x, y, xlab='Batches', ylab=ylab, labels=labels, off=plt_off, title=f'{title}\n(n={len_data}, res={res}, batch size={batch})', avg=avg)
    def_plot(ax[0,0], plt_train_rewards,   ylab='Mean rewards (Train)',  labels=['Train'], title='Mean rewards (Train) vs. batches',                           avg=plt_avg)
    def_plot(ax[0,1], plt_test_rewards,    ylab='Mean rewards (Test)',   labels=['Test'],  title='Mean rewards vs. batches',                           avg=plt_avg)
    def_plot(ax[1,0], plt_train_success,   ylab='Success ratio (Train)', labels=['Train'], title='Success ratio (Train) vs. batches',                          avg=plt_avg)
    def_plot(ax[1,1], plt_test_success,    ylab='Success ratio (Test)',  labels=['Test'],  title='Success ratio vs. batches',                          avg=plt_avg)
    def_plot(ax[2,0], plt_mu,              ylab='Mu',                    labels=fn_labs,   title='Mean criteria weight vs. batches')
    def_plot(ax[2,1], plt_sig,             ylab='Sigma',                 labels=fn_labs,   title='Standard deviation for criteria weight vs. batches')
    cust_plot_pdf(ax[3,0], np.array(plt_mu)[:,-1], np.array(plt_sig)[:,-1], xlab='Weight', ylab='Density', labels=fn_labs,
        title=f"Probility density function for criteria weights (n={len_data}, res={res}, batch={batch})")

# Reinforce

In [None]:
# Plotting data

plt_train_rewards = []
plt_test_rewards = []
plt_train_success = []
plt_test_success = []
plt_mu = [[] for _ in range(fns)]
plt_sig = [[] for _ in range(fns)]

In [None]:
def reinforce(pe, data, epochs, batch=22, lr=0.01, plot=True, print_freq=22, plt_off=0, plt_avg=None): # batch: 2,4,11,22,44,121,242 divide 484
    # Setup
    opt = Adam(pe.network.parameters(), lr=lr)
    len_data = len(data)
    len_fn = data.pp[0].fn_length

    # Run
    for e in range(epochs):
        print(f'(Epoch {e+1}):', end=' ')
        offset = 0

        # Epoch
        while offset + batch <= len_data:
            train_rewards = torch.zeros(batch,1)
            test_rewards = torch.zeros(batch, 1)
            train_success = np.zeros(batch)
            test_success = np.zeros(batch)
            adj, sp, pp = data[offset:offset+batch]
            probs = pe.predict(adj)
            mu, sig = probs[:,:len_fn], torch.exp(probs[:,len_fn:])
            m = Normal(mu, sig)
            actions = m.sample()

            # Batch
            for i in range(batch):
                print(i+1+offset, end=' ')
                pp[i].fn_weights = actions[i].tolist()
                pred = pp[i].retrieve_all_paths()

                # Train
                train_pred = pred[train_ind]
                train_sp = sp[i][train_ind]
                train_mask = np.where((train_sp > 0) & (~np.isinf(train_sp)))
                train_rewards[i] = reward(train_pred[train_mask], train_sp[train_mask])
                train_success[i] = 1 - np.isinf(train_pred).sum() / len(train_pred)

                # Test
                test_pred = pred[test_ind]
                test_sp = sp[i][test_ind]
                test_mask = np.where((test_sp > 0) & (~np.isinf(test_sp)))
                test_rewards[i] = reward(test_pred[test_mask], test_sp[test_mask])
                test_success[i] = 1 - np.isinf(test_pred).sum() / len(test_pred)

            # Step
            opt.zero_grad()
            loss = -m.log_prob(actions) * train_rewards
            loss = loss.mean()
            loss.backward()
            opt.step()

            # Plotting
            if plot:
                # Add data to arrays
                plt_train_rewards.append(train_rewards.mean().item())
                plt_test_rewards.append(test_rewards.mean().item())
                plt_train_success.append(train_success.mean())
                plt_test_success.append(test_success.mean())
                for j in range(len_fn):
                    plt_mu[j].append(mu[:,j].mean().item())
                    plt_sig[j].append(sig[:,j].mean().item())
                len_plt_rewards = len(plt_train_rewards)

                # Plot data
                if (len_plt_rewards + 1) % print_freq == 0:
                    plot_all(batch, plt_avg, plt_off)
                    display.clear_output(wait=True)
                    display.display(plt.gcf())

            # Run next batch
            offset += batch

In [None]:
reinforce(pe, data, epochs=100, batch=1, lr=0.005, print_freq=484, plt_off=0, plt_avg=200)

In [None]:
plot_all(batch=1, plt_avg=None, plt_off=0, figsize=(30,30))

# Save data

In [None]:
def save(fname_prefix):
    np_save = lambda ext, A: np.savetxt(f'{fname_prefix}_{ext}.txt', A, fmt='%.18f', delimiter=',')
    np_save('train_rewards', plt_train_rewards)
    np_save('test_rewards',  plt_test_rewards)
    np_save('train_success', plt_train_success)
    np_save('test_success',  plt_test_success)
    np_save('mu',  plt_mu)
    np_save('sig',  plt_sig)
    torch.save(pe.network.state_dict(), f'{fname_prefix}_torch_state.pt')

In [None]:
#save('results/s001')