In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
import seaborn as sn
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
print("Using PyTorch Version %s" %torch.__version__)
import os
from deep_KO_learning import Net

In [None]:
### Loading trained network ### 

script_dir = os.path.dirname('deep_KO_learning.py') # getting relative path
trained_models_path = os.path.join(script_dir, 'trained_models') # which relative path do you want to see
data_path = os.path.join(script_dir,'data/')

netsize_dir = trained_models_path + '/malathion_polyculture_netsize.pickle' # contains the shape of network
net_dir = trained_models_path+'/malathion_polyculture_net.pt' # contains params of network

NUM_INPUTS,NUM_OUTPUTS,HL_SIZES = pickle.load(open(netsize_dir,'rb'))

model = Net(NUM_INPUTS,NUM_OUTPUTS,HL_SIZES)
model.load_state_dict(torch.load(net_dir))
model.eval();

In [None]:
### Loading corresponding dataset ###

file_dir = 'malathion_polyculture_pfluorescens_TPMs.p'

def get_snapshot_matrices(X,nT,nTraj): 
    '''This function assumes the global snapshot matrix is constructed with trajectories 
        sequentially placed in the columns'''
    prevInds = [x for x in range(0,nT-1)]
    forInds = [x for x in range(1,nT)]
    for i in range(0,nTraj-1):
        if i == 0:
            more_prevInds = [x + nT for x in prevInds]
            more_forInds = [x + nT for x in forInds]
        else: 
            more_prevInds = [x + nT for x in more_prevInds]
            more_forInds = [x + nT for x in more_forInds]
        prevInds = prevInds + more_prevInds
        forInds = forInds + more_forInds
    Xp = X[:,prevInds]
    Xf = X[:,forInds]
    return Xp,Xf

X,nT,nTraj = pickle.load(open(data_path+file_dir,'rb'))
Xp,Xf = get_snapshot_matrices(X,nT,nTraj)
trainXp = torch.Tensor(Xp.T)
trainXf = torch.Tensor(Xf.T)
testX = torch.Tensor(X.T)

numDatapoints = nT*nTraj # number of total snapshots

print('Dimension of the state: ' + str(trainXp.shape[1]));
print('Number of trajectories: ' + str(nTraj));
print('Number of total snapshots: ' + str(nT*nTraj));

In [None]:
K = model.linears[-1].weight[:].detach().numpy()
PsiX_test = model(testX)['PsiXf']
PsiX_test = PsiX_test.detach().numpy().T

In [None]:
import matplotlib.lines as mlines

numStates = NUM_INPUTS
traj = np.random.randint(0,nTraj) # np.random.randint(0,nTraj/2)
init_index = traj*(nT)

predHorizon = nT
PsiX_pred = np.zeros((K.shape[0],predHorizon))
for i in range(0,predHorizon):
    PsiX_pred[:,i:i+1] = np.dot(np.linalg.matrix_power(K,i),PsiX_test[:,init_index:init_index+1]) 

mse = np.linalg.norm(PsiX_test[:,init_index:init_index+predHorizon] - PsiX_pred,'fro')/np.linalg.norm(PsiX_test[:,init_index:init_index+predHorizon],'fro')
print('Trajectory ' + str(traj) + ', MSE: ' + str(round(mse,5)))


truthLeg = mlines.Line2D([], [], color='black',linestyle='-',marker='',label='Truth')
predLeg = mlines.Line2D([], [], color='black',linestyle='--',label='Predicted')
if numStates > 20: # just for plotting
    numPlots = 20
    plotStates = np.random.randint(1,numStates-1,numPlots)
for i in plotStates:
    plt.figure();
    plt.plot(PsiX_test[i,init_index:init_index+predHorizon],'.-',ms=10,lw=3,color='tab:blue');
    plt.plot(PsiX_pred[i,0:predHorizon],'.--',ms=10,lw=3,color='tab:orange');
#     plt.ylim([-1,1])
    plt.ylabel(r'$\mathbf{x}$'+str(i))
    plt.legend(handles=[truthLeg,predLeg]);

In [None]:
### A better prediction calculation ###

PsiX_pred = np.zeros((K.shape[0],numDatapoints))
trajInds = [x for x in range(0,nT)]
trajInds = [trajInds for x in range(0,nTraj)]
trajInds = [j for i in trajInds for j in i] 
count = 0
initInd = 0
for i in range(0,nTraj):
    psix_test_ic = PsiX_test[:,i*nT:i*nT+1]
    for j in range(0,nT):
        PsiX_pred[:,count:count+1] = np.dot(np.linalg.matrix_power(K,j),psix_test_ic)
        count += 1

In [None]:
### storing the mean squared errors for each gene (row) ###
per_gene_mse = []
for k in range(1,NUM_INPUTS+1):
    dist = np.linalg.norm(PsiX_pred[k,:] - PsiX_test[k,:],ord=2)/np.linalg.norm(PsiX_test[k,:],ord=2)
    if np.isinf(dist):
        dist = 0
    per_gene_mse.append(dist)


In [None]:
total_mu = np.mean(X,axis=1)

fig, ax1 = plt.subplots();
left, bottom, width, height = [0.65, 0.6, 0.2, 0.2]
ax2 = fig.add_axes([left, bottom, width, height]);
ax1.plot(per_gene_mse);
ax1.plot(np.abs(total_mu));
ax2.plot(per_gene_mse);
# ax2.plot(np.abs(total_mu));
# ax2.set_xlim([80,120]);
ax2.set_ylim([0.00005,0.0006]);


In [None]:
corr = np.dot(PsiX_pred,PsiX_test.T)
plt.figure(figsize=(7,5));
sn.heatmap(corr[1:123,1:123],cmap='coolwarm');

In [None]:
theta = np.linspace(0,2*math.pi,100)
plt.figure(figsize=(6,5));
plt.plot(np.real(np.linalg.eigvals(K)),np.imag(np.linalg.eigvals(K)),'o',ms=10);
plt.plot(np.cos(theta),np.sin(theta),color='black');
plt.ylabel('$Imag(\lambda)$');
plt.xlabel('$Real(\lambda)$');
plt.axis('equal');