In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
print("Using PyTorch Version %s" %torch.__version__)

In [None]:
### Datasets ###

dataset = 5 # each dataset contains the global snapshot matrix as well as the number of snapshots per trajectory and the number of trajectories

# add new datasets as below
if dataset == 0: 
    file_dir = 'toggle_switch_data.p'
    
if dataset == 1: 
    file_dir = 'toggle_switch_data_normed.p'
    
if dataset == 2: 
    file_dir = 'stable_linsys.p'
    
if dataset == 3:
    file_dir = 'slow_manifold_data.p'
    
if dataset == 4: 
    file_dir = 'slow_manifold_data_normed.p'
    
if dataset == 5:
    file_dir = 'malathion_polyculture_pfluorescens_TPMs.p'


In [None]:
def get_snapshot_matrices(X,nT,nTraj): 
    '''This function assumes the global snapshot matrix is constructed with trajectories sequentially placed in the columns'''
    prevInds = [x for x in range(0,nT-1)]
    forInds = [x for x in range(1,nT)]
    for i in range(0,nTraj-1):
        if i == 0:
            more_prevInds = [x + nT for x in prevInds]
            more_forInds = [x + nT for x in forInds]
        else: 
            more_prevInds = [x + nT for x in more_prevInds]
            more_forInds = [x + nT for x in more_forInds]
        prevInds = prevInds + more_prevInds
        forInds = forInds + more_forInds
    Xp = X[:,prevInds]
    Xf = X[:,forInds]
    return Xp,Xf

In [None]:
X,nT,nTraj = pickle.load(open(data_path+file_dir,'rb'))
Xp,Xf = get_snapshot_matrices(X,nT,nTraj)
trainXp = torch.Tensor(Xp.T)
trainXf = torch.Tensor(Xf.T)
testX = torch.Tensor(X.T)

print('Dimension of the state: ' + str(trainXp.shape[1]));
print('Number of trajectories: ' + str(nTraj));
print('Number of total snapshots: ' + str(nT*nTraj));

In [None]:
### Neural network parameters ###

NUM_INPUTS = trainXp.shape[1] # dimension of input
NUM_HL = 8 # number of hidden layers (excludes the input and output layers)
NODES_HL = 8 # number of nodes per hidden layer (number of learned observables)
HL_SIZES = [NODES_HL for i in range(0,NUM_HL+1)] 
NUM_OUTPUTS = NUM_INPUTS + HL_SIZES[-1] + 1 # output layer takes in dimension of input + 1 + dimension of hl's
BATCH_SIZE = 2 #int(nT/10) 

In [None]:
class Net(nn.Module):
    
    def __init__(self, input_dim, output_dim, hl_sizes):
        super(Net, self).__init__()
        current_dim = input_dim
        self.linears = nn.ModuleList()
        for hl_dim in hl_sizes:
            self.linears.append(nn.Linear(current_dim, hl_dim))
            current_dim = hl_dim
        self.linears.append(nn.Linear(output_dim, output_dim,bias=False))

    def forward(self, x):
        input_vecs = x
        for layer in self.linears[:-1]:
            x = F.relu(layer(x))
        y = torch.cat((torch.Tensor(np.ones((x.shape[0],1))),input_vecs,x),dim=1)
        x = self.linears[-1](y)
        return {'KPsiXp':x,'PsiXf':y} 

net = Net(NUM_INPUTS,NUM_OUTPUTS,HL_SIZES)
print(net)

In [None]:
# Defining the loss function and the optimizer

LEARNING_RATE = 0.05
L2_REG = 0.0
MOMENTUM = 0.00

loss_func = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(),lr=LEARNING_RATE,momentum=MOMENTUM,weight_decay=L2_REG)

In [None]:
# Train the network 
print_less_often = 200
eps = 1e-100
train_loss = []
maxEpochs = 100000
prev_loss = 0
curr_loss = 1e10
epoch = 0
numDatapoints = nT*nTraj
net.train()
while (epoch <= maxEpochs): # and (np.abs(curr_loss-prev_loss) > eps):
    prev_loss = curr_loss
    for i in range(0,trainXp.shape[0],BATCH_SIZE):
        
        Kpsixp = net(trainXp[i:i+BATCH_SIZE])['KPsiXp'] 
        psixf = net(trainXf[i:i+BATCH_SIZE])['PsiXf']
        loss = loss_func(psixf, Kpsixp)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    curr_loss = loss.item()
    if epoch % print_less_often == 0:
        print('['+str(epoch)+']'+' loss = '+str(loss.item()))
        train_loss.append(loss.item()) 
    epoch+=1
print('['+str(epoch)+']'+' loss = '+ str(loss.item()))

In [None]:
K = net.linears[-1].weight[:].detach().numpy()
net.eval()


In [None]:
# quick test

Kpsixp_test = net(testXp[nT-2:nT-1])['KPsiXp']
psixf_test = net(testXf[nT-2:nT-1])['PsiXf']
print(Kpsixp_test)
print(psixf_test)


In [None]:
PsiX_test = net(testX)['PsiXf']
PsiX_test = PsiX_test.detach().numpy().T

In [None]:
import matplotlib.lines as mlines

numStates = data.shape[0]
traj = 4 #np.random.randint(0,nTraj) # np.random.randint(0,nTraj/2)
init_index = traj*(nT)

predHorizon = nT
PsiX_pred = np.zeros((K.shape[0],predHorizon))
for i in range(0,predHorizon):
    PsiX_pred[:,i:i+1] = np.dot(np.linalg.matrix_power(K,i),PsiX_test[:,init_index:init_index+1]) 

mse = np.linalg.norm(PsiX_test[:,init_index:init_index+predHorizon] - PsiX_pred,'fro')/np.linalg.norm(PsiX_test[:,init_index:init_index+predHorizon],'fro')
print('Trajectory ' + str(traj) + ', MSE: ' + str(round(mse,5)))

if numStates > 20: # just for plotting
    numPlots = 20
    plotStates = np.random.randint(1,numStates-1,numPlots)
for i in plotStates:
    plt.figure();
    plt.plot(PsiX_test[i,init_index:init_index+predHorizon],'.-',ms=10,lw=3,color='tab:blue');
    plt.plot(PsiX_pred[i,0:predHorizon],'.--',ms=10,lw=3,color='tab:orange');
#     plt.ylim([-1,1])
    plt.ylabel(r'$\mathbf{x}$'+str(i))
    plt.legend(handles=[truthLeg,predLeg]);
truthLeg = mlines.Line2D([], [], color='black',linestyle='-',marker='',label='Truth')
predLeg = mlines.Line2D([], [], color='black',linestyle='--',label='Predicted')
# plt.savefig('repr_preds_traj'+str(traj)+'.pdf')

In [None]:
### A better prediction calculation ###

PsiX_pred = np.zeros((K.shape[0],numDatapoints))
trajInds = [x for x in range(0,nT)]
trajInds = [trajInds for x in range(0,nTraj)]
trajInds = [j for i in trajInds for j in i] 
count = 0
initInd = 0
for i in range(0,nTraj):
    psix_test_ic = PsiX_test[:,i*nT:i*nT+1]
    for j in range(0,nT):
        PsiX_pred[:,count:count+1] = np.dot(np.linalg.matrix_power(K,j),psix_test_ic)
        count += 1

In [None]:
### storing the mean squared errors for each gene (row) ###
per_gene_mse = []
for k in range(1,trainData.shape[1]+1):
    dist = np.linalg.norm(PsiX_pred[k,:] -PsiX_test[k,:],ord=2)/np.linalg.norm(PsiX_test[k,:],ord=2)
    if np.isinf(dist):
        dist = 0
    per_gene_mse.append(dist)


In [None]:
total_mu = np.mean(X,axis=1)
# plt.figure();
# plt.plot(per_gene_mse);
# plt.plot(np.abs(total_mu));

fig, ax1 = plt.subplots();
left, bottom, width, height = [0.65, 0.6, 0.2, 0.2]
ax2 = fig.add_axes([left, bottom, width, height]);
ax1.plot(per_gene_mse);
ax1.plot(np.abs(total_mu));
ax2.plot(per_gene_mse);
# ax2.plot(np.abs(total_mu));
# ax2.set_xlim([80,120]);
ax2.set_ylim([0.00005,0.0006]);



In [None]:
corr = np.dot(PsiX_pred,PsiX_test.T)
plt.figure(figsize=(7,5));
sn.heatmap(corr[1:123,1:123],cmap='coolwarm');

In [None]:
import math
theta = np.linspace(0,2*math.pi,100)
plt.figure(figsize=(6,5));
plt.plot(np.real(np.linalg.eigvals(K)),np.imag(np.linalg.eigvals(K)),'o',ms=10);
plt.plot(np.cos(theta),np.sin(theta),color='black');
plt.ylabel('$Imag(\lambda)$');
plt.xlabel('$Real(\lambda)$');
plt.axis('equal');
# plt.savefig('toggleswitch_eigvals.pdf')