In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sn
import pickle
import torch
import os
from deep_KO_learning import Net

In [None]:
### Loading trained network ### 

script_dir = os.path.dirname('deep_KO_learning.py') # getting relative path
trained_models_path = os.path.join(script_dir, 'trained_models') # which relative path do you want to see
data_path = os.path.join(script_dir,'data/')

netsize_dir = trained_models_path + '/malathion_fluorescens_netsize.pickle' # contains the shape of network
net_dir = trained_models_path+'/malathion_fluorescens_net.pt' # contains params of network

NUM_INPUTS,NUM_OUTPUTS,HL_SIZES = pickle.load(open(netsize_dir,'rb'))

model = Net(NUM_INPUTS,NUM_OUTPUTS,HL_SIZES)
model.load_state_dict(torch.load(net_dir))
model.eval();

In [None]:
### Loading corresponding dataset ###

file_dir = 'malathion_fluorescens_tpm.p' # dataset

def get_snapshot_matrices(X,nT,nTraj): 
    '''This function assumes the global snapshot matrix is constructed with trajectories 
        sequentially placed in the columns'''
    prevInds = [x for x in range(0,nT-1)]
    forInds = [x for x in range(1,nT)]
    for i in range(0,nTraj-1):
        if i == 0:
            more_prevInds = [x + nT for x in prevInds]
            more_forInds = [x + nT for x in forInds]
        else: 
            more_prevInds = [x + nT for x in more_prevInds]
            more_forInds = [x + nT for x in more_forInds]
        prevInds = prevInds + more_prevInds
        forInds = forInds + more_forInds
    Xp = X[:,prevInds]
    Xf = X[:,forInds]
    return Xp,Xf

X,nT,nTraj = pickle.load(open(data_path+file_dir,'rb'))
Xp,Xf = get_snapshot_matrices(X,nT,nTraj)
trainXp = torch.Tensor(Xp.T)
trainXf = torch.Tensor(Xf.T)
testX = torch.Tensor(X.T)

numDatapoints = nT*nTraj # number of total snapshots

print('Dimension of the state: ' + str(trainXp.shape[1]));
print('Number of trajectories: ' + str(nTraj));
print('Number of total snapshots: ' + str(nT*nTraj));

In [None]:
K = model.linears[-1].weight[:].detach().numpy()
PsiX_test = (model(testX)['PsiXf']).detach().numpy().T

In [None]:
import math
theta = np.linspace(0,2*math.pi,100)
x = np.cos(theta)
y = np.sin(theta)
L = np.linalg.eigvals(K)
fig1 = plt.figure();
plt.title('eigenvalues of linear operator')
plt.plot(np.real(L),np.imag(L),'o')
plt.plot(x,y)
plt.axis('equal')
plt.grid('on')   

In [None]:
### N-step prediction ###

PsiX_pred = np.zeros((K.shape[0],numDatapoints))
count = 0
for i in range(0,nTraj):
    psix_test_ic = PsiX_test[:,i*nT:i*nT+1]
    for j in range(0,nT):
        PsiX_pred[:,count:count+1] = np.dot(np.linalg.matrix_power(K,j),psix_test_ic) 
        count += 1

In [None]:
### plotting predictions ### 

nrows = 5
ncols = 3
plotidx = np.random.randint(1,PsiX_test.shape[0],nrows*ncols)
plt.rcParams.update({'font.size': 12})

fig, ax = plt.subplots(nrows, ncols, figsize=(15, 10))
plt.suptitle('M-step predictions')
idx = 0
for row in range(nrows):
    for col in range(ncols):
            ax[row,col].plot(PsiX_pred[plotidx[idx],:],'o--',ms=16,mec='black',lw=6,color='tab:blue');
            ax[row,col].plot(PsiX_test[plotidx[idx],:],'s--',ms=8,mec='black',lw=4,color='tab:orange');
            ax[row,col].grid()
            ax[row,col].spines['right'].set_visible(False)
            ax[row,col].spines['top'].set_visible(False)
            idx += 1
# in the above n-step pred plots, all trajectory predictions are plotted in the same figure. 

In [None]:
numObs = K.shape[0]
L, V = np.linalg.eig(K)
sortLinds = (np.argsort(np.absolute(L)))[::-1]
V = V[:,sortLinds]
W = np.linalg.inv(V)
p = 20 # numObs
Wh = np.dot(np.concatenate((np.identity(p),np.zeros((p,numObs-p))),axis=1),W)
Yo = Wh @ PsiX_test

In [None]:
# plot observable modes after smoothing 
from scipy.interpolate import make_interp_spline

tSpan = np.linspace(0,nT-1,nT)
tNew = np.linspace(tSpan.min(), tSpan.max(), 200)

plt.rcParams.update({'font.size': 16})

plt.figure(figsize=(12,7));
Ysmooth = np.zeros([len(Yo),len(tNew)])
for i in range(0,Yo.shape[0]):
    spl = make_interp_spline(tSpan, np.real(Yo[i,0:nT]), k=3)
    Ysmooth[i,:] = spl(tNew)
    plt.plot(tNew, Ysmooth[i,:],'o-',lw=4);
plt.title('Smoothed observable modes')
plt.xlabel('time')
plt.ylabel(r'$Real(y_{obs})$')
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)

In [None]:
### Sensitivity analysis ###

def calc_SensitivityMat(Wh,net,X_global,nGridpts,nOutputs):
    ''' Output sensitivity matrix is calculated by perturbing a single element of the state at a 
        time, computing the resulting outputs, subtracting the mean output from the resulting outputs,
        and finally averaging over the mean subtracted resultant outputs. 
    '''
    X_mean = np.mean(X_global,axis=1).reshape(X_global.shape[0],1) # the reference values
    PsiX_mean = (net(torch.Tensor(X_mean.T))['PsiXf']).detach().numpy().T 
    X_std = np.std(X_global,axis=1).reshape(X_global.shape[0],1)
    
    y_mean = np.dot(Wh,PsiX_mean)
    
    X_range = np.zeros((len(X_mean),nGridpts))
    for i in range(0,len(X_mean)):
        X_range[i,:] = np.linspace(X_mean[i]-X_std[i],X_mean[i]+X_std[i],nGridpts).T
        
    from copy import deepcopy
    S = np.zeros((nOutputs,X_global.shape[0]),dtype=complex) # sensitivity matrix 
    for s in range(0,S.shape[1]):
        X_sens = deepcopy(X_mean)
        Y = np.zeros((nOutputs,nGridpts),dtype=complex)
        for i in range(0,nGridpts): # looping through the various perturbations of state s
            X_sens[s,:] = X_range[s,i]
            PsiX_sens = (net(torch.Tensor(X_sens.T))['PsiXf']).detach().numpy().T    
            Y_sens = np.dot(Wh,PsiX_sens)
            Y[:,i:i+1] = Y_sens - y_mean # Take away y(x_mean) from every column of Y_sens
        S[:,s] = np.mean(Y,axis=1)

    # normalizing S to be between 0 and 1. 
    S = S/np.max(S)
    for i in range(S.shape[0]):
        for j in range(S.shape[1]):
            S[i,j] = np.linalg.norm(S[i,j])
            
    return S.real

nGridpts = 100
S = calc_SensitivityMat(Wh,model,X,nGridpts,p)


In [None]:
plt.figure(figsize=(20,10));
sn.heatmap(S,cmap='viridis');

colNorms = []
for i in range(0,S.shape[1]):
    colNorms.append(np.linalg.norm(S[:,i],ord=2))

In [None]:
colNorms = np.array(colNorms)
num2sel = 50
inds_maxcolNorms = np.flip(colNorms.argsort()[-num2sel:])
maxcolNorm = colNorms[colNorms.argsort()[-num2sel:]]

# plot means of background subtracted data
nrows = 10
ncols = 5
plt.rcParams.update({'font.size': 12})
f1, ax1 = plt.subplots(nrows, ncols, figsize=(20, 40))
idx = 0
for row in range(0,nrows):
    for col in range(0,ncols):
            ax1[row,col].plot(X[inds_maxcolNorms[idx],:],'o--',color='tab:blue')
            ax1[row,col].spines['right'].set_visible(False)
            ax1[row,col].spines['top'].set_visible(False)
            idx += 1

In [None]:
inds_maxcolNorms