In [7]:
import numpy as np
import numpy.random as npr
import pylab
import curbd
import math
import matplotlib.pyplot as plt
import seaborn as sns


import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import torch.optim as optim
import pickle 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [8]:
sim = torch.load("REAL_SIM.pt")
dtData=sim['params']['dtData'] # time step of training data
# dtData=None
dtFactor=5 # number of interpolation steps for RNN 
# tauRNN=2*sim['params']['tau']/2 # decay constant of RNN 
tauRNN=50 # decay constant of RNN 
dtRNN = 0.002
g=1.5 # instability constant 
tauWN=0.1 #decay constant on filtered white noise inputs
ampInWN=0.01 #input amplitude of filtered white noise 
resetPoints=None
plotStatus=True
regions=None


# Model Definition

In [9]:
import sys 

class CTRNN(nn.Module):
    """Continuous-time RNN."""
    def __init__(self,
                 g,
                 tauRNN,
                 dtRNN,
                 input_units,
                 output_units,
                 dt=None,
                 **kwargs):
      
        super(CTRNN,self).__init__()

        self.tauRNN = tauRNN
        self.dtRNN=dtRNN
        self.g=g
        self.input_units=input_units
        self.output_units=output_units
        self.J = nn.Parameter(self.g * torch.Tensor(npr.randn(input_units, output_units) / math.sqrt(output_units)))

    def forward(self, hidden, noise=None):
        if(noise is not None):
           noise = torch.zeros_like(noise)
        
        """Propogate input through the network."""
        
        # When calculating fixed points: We want to find _hidden_ vectors such that -hidden + JR equals the 0 vector.
        activation = torch.tanh(torch.Tensor(hidden))
        JR = torch.matmul(self.J, activation).reshape((self.input_units,1))
        

        new_hidden = hidden + self.dtRNN*(-hidden + JR)/self.tauRNN
        return activation, new_hidden



In [10]:
class pseudomodel(nn.Module):
    def __init__(self, hidden_size, input_units):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_units = input_units
        self.candidate_hidden = nn.Parameter(10*(torch.rand(hidden_size)-0.5))
        

    def forward(self, J_matrix):
        activation = torch.tanh(torch.Tensor(self.candidate_hidden))
        JR = torch.matmul(J_matrix, activation).reshape((self.input_units, 1))
        
        difference = -self.candidate_hidden + JR
        return difference

    

In [11]:
number_units = 300
rnn_model = CTRNN(g=g,
                  tauRNN=tauRNN,
                  dtRNN=dtRNN,
                  input_units=number_units,
                  output_units=number_units)

rnn_model.to(device)

weights_initial = rnn_model.J.detach().cpu().numpy().copy()


rnn_model.load_state_dict(torch.load("REAL_WEIGHTS_NONOISE.pt"))

<All keys matched successfully>

In [24]:
def find_stable_fps(model, num_to_find=10, max_epochs=200000, epsilon=1e-7):
    fixed_points = []

    for index in range(num_to_find):
        # Initialize random starting point:

        starting_point = 100*(torch.rand(300,1)-0.5)

        print(f"Starting to look for fixed point {index}.")
        for i in range(max_epochs):
            
            new_point = model(starting_point)[1]
            print(torch.sum( (starting_point - new_point)**2 ) )
            if(torch.sum( (starting_point - new_point)**2 ) < epsilon):
                print(f"Found a stable fixed point after {i} epochs")
                fixed_points.append(new_point.detach().clone())
                break
            starting_point = new_point
        
       
    return fixed_points

stable_fixed_points = find_stable_fps(rnn_model)


Starting to look for fixed point 0.
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumBackward0>)
tensor(0.0015, grad_fn=<SumB

KeyboardInterrupt: 

In [23]:
stable_fixed_points[4]

tensor([[ 0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [ 0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [-0.0380],
        [-0.0380],
        [ 0.0380],
        [-0.0380],
        [ 0.

In [31]:

def find_fixed_points(model, num_to_find=10, epochs=8000):
    fixed_points = []

    training_epochs = epochs

    for index in range(num_to_find):
        loss_function = nn.MSELoss()
        sm = pseudomodel(hidden_size=(300,1), input_units=300)
        optimizer = torch.optim.Adam(sm.parameters(), lr=1e-2)
        
        print(f"Starting to look for fixed point {index}.")
        for i in range(training_epochs):
            diffs = sm(rnn_model.J)
            loss = loss_function(diffs, torch.zeros_like(diffs))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        #   print(i, loss)
        print(f"Finished approximating. The error is {torch.sum((rnn_model(sm.candidate_hidden)[1] - sm.candidate_hidden) ** 2)}. The final loss is {loss}")
        fixed_points.append(sm.candidate_hidden.detach().clone())

    return fixed_points

fixed_points = find_fixed_points(rnn_model)
torch.save(fixed_points, "fixed_points.pt")
            

Starting to look for fixed point 0.
Finished approximating. The error is 8.84524098410111e-08. The final loss is 0.17280635237693787
Starting to look for fixed point 1.
Finished approximating. The error is 3.867171116667123e-08. The final loss is 0.08109264075756073
Starting to look for fixed point 2.
Finished approximating. The error is 4.256128249835456e-08. The final loss is 0.08944143354892731
Starting to look for fixed point 3.
Finished approximating. The error is 5.1106766818520555e-08. The final loss is 0.10792297124862671
Starting to look for fixed point 4.
Finished approximating. The error is 2.8708901922414043e-08. The final loss is 0.059825796633958817
Starting to look for fixed point 5.
Finished approximating. The error is 4.555435495490201e-08. The final loss is 0.0964672863483429
Starting to look for fixed point 6.
Finished approximating. The error is 5.1635012709994044e-08. The final loss is 0.1076478585600853
Starting to look for fixed point 7.
Finished approximating. T

In [4]:
fixed_points = torch.load("REAL_FIXED_POINTS.pt")

In [None]:
def remove_duplicate_fixed_points(fixed_points_array):
    

In [32]:
torch.save(fixed_points, "REAL_FIXED_POINTS.pt")