In [1]:
from rnn import *
from fixed_points import *
from viz import *

In [2]:
batch_size = 1
n_epochs = 1000
bits = 1

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create task
task = FlipFlopTask(batch_size=batch_size, T=100, dt=1.0, flip_prob=0.05, bits=bits)
U, _ = task.get()

# Create RNN
rnn = RNN(input_size=task.bits, hidden_size=100, output_size=task.bits, tau=2.0, dt=1.0)

In [3]:
# Train the RNN
train_R_list, losses, weight_snapshots = train_rnn(rnn, task, n_epochs=n_epochs, learning_rate=1e-2, batch_size=batch_size)

print(np.array(train_R_list).shape)
train_R_list = torch.cat(train_R_list, dim=0)
train_R_list = list(torch.unbind(train_R_list, dim=0))

print(np.array(train_R_list).shape)

Training RNN on flip flop task...
Epochs: 1000, Learning rate: 0.01, Batch size: 1
------------------------------------------------------------
Epoch 100/1000, Loss: 0.250805
Epoch 200/1000, Loss: 0.050319
Epoch 300/1000, Loss: 0.017270
Epoch 400/1000, Loss: 0.011128
Epoch 500/1000, Loss: 0.003984
Epoch 600/1000, Loss: 0.001163
Epoch 700/1000, Loss: 0.002687
Epoch 800/1000, Loss: 0.000401
Epoch 900/1000, Loss: 0.000508
Epoch 1000/1000, Loss: 0.000227
------------------------------------------------------------
Training complete!
(1000, 1, 101, 100)
(1000, 101, 100)


In [4]:
pca = fit_pca(train_R_list[0])

animate_fixed_points(rnn, train_R_list, pca, save_path="plots/FP_training1.mp4", stride=10, weight_snapshots=weight_snapshots)

Finding fixed points for 1000 epochs...
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 2 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 0 unique fixed points
101 initializations, found 2 unique fixed points
101 initializations, found 0 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 0 unique fixed points
101 initializations, found 0 unique fixed points
101 initializations, found 2 unique fixed points
101 initializations, found 1 unique fixed points
101 initializations, found 3 unique fixed points
101 initializations, found 2 unique fixed points
101 initializations, found 2 

(PCA(n_components=2),
 (np.float64(-1.5202435553068523),
  np.float64(2.0495178258797067),
  np.float64(-1.8790789823619312),
  np.float64(0.9413575311210696)))

In [12]:
#pca = fit_pca(train_R_list[500])


unique_fps = find_fixed_points_KE_min(rnn, train_R_list[80])
pca, lim = animate_R(train_R_list[80], save_path="plots/train_1.mp4", fps=30, stride=1, fixed_points=unique_fps, title="RNN trajectory, first training", pca=pca)

101 initializations, found 3 unique fixed points
r_batch shape:  (101, 100)
Generating animation with 101 frames...
Saving animation to plots/train_1.mp4...
Animation saved successfully!


In [8]:
unique_fps[0].shape

torch.Size([1, 100])

In [4]:
R_list, _ = inference_rnn(rnn, task)

hidden_states = rnn.get_hidden_states(U)
unique_fps = find_fixed_points_KE_min(rnn, hidden_states)

pca, lim = animate_R(R_list, save_path="plots/inference_1.mp4", fps=30, stride=100, fixed_points=unique_fps, title="RNN trajectory, inference after first training")

Running inference...
------------------------------------------------------------
Epoch 100/1000, Loss: 0.000674
Epoch 200/1000, Loss: 0.001830
Epoch 300/1000, Loss: 0.000162
Epoch 400/1000, Loss: 0.000967
Epoch 500/1000, Loss: 0.000199
Epoch 600/1000, Loss: 0.000209
Epoch 700/1000, Loss: 0.000728
Epoch 800/1000, Loss: 0.001304
Epoch 900/1000, Loss: 0.001206
Epoch 1000/1000, Loss: 0.000272
------------------------------------------------------------
Inference complete! Mean loss: 0.000465
101 initializations, found 101 fixed points
Generating animation with 1010 frames...
Saving animation to plots/inference_1.mp4...
Animation saved successfully!


In [5]:
rnn.reinitialize_weights(weights=['W_in', 'W_out', 'b_out'])

Reinitializing weights: ['W_in', 'W_out', 'b_out']


In [None]:
train_R_list, losses = train_rnn(rnn, task, n_epochs=n_epochs, learning_rate=1e-2, batch_size=batch_size)

hidden_states = rnn.get_hidden_states(U)
unique_fps, stabilities = find_fixed_points(rnn, hidden_states)
pca, lim = animate_R(train_R_list, save_path="plots/train_2.mp4", fps=30, stride=100, fixed_points=unique_fps, title="RNN trajectory, training after reinitialization")

Training RNN on flip flop task...
Epochs: 1000, Learning rate: 0.01, Batch size: 1
------------------------------------------------------------
Epoch 100/1000, Loss: 0.852188
Epoch 200/1000, Loss: 0.856369
Epoch 300/1000, Loss: 0.776592
Epoch 400/1000, Loss: 0.642530
Epoch 500/1000, Loss: 0.185191
Epoch 600/1000, Loss: 0.038812
Epoch 700/1000, Loss: 0.024373
Epoch 800/1000, Loss: 0.021704
Epoch 900/1000, Loss: 0.011636
Epoch 1000/1000, Loss: 0.014084
------------------------------------------------------------
Training complete!
Generating animation with 1010 frames...
Saving animation to train_2.mp4...
Animation saved successfully!


In [None]:
R_list, _ = inference_rnn(rnn, task)

hidden_states = rnn.get_hidden_states(U)
unique_fps, stabilities = find_fixed_points(rnn, hidden_states)

pca, lim = animate_R(R_list, save_path="plots/inference_2.mp4", fps=30, stride=100, fixed_points=unique_fps, title="RNN trajectory, inference after second training")

Running inference...
------------------------------------------------------------
Epoch 100/1000, Loss: 0.005014
Epoch 200/1000, Loss: 0.006800
Epoch 300/1000, Loss: 0.004893
Epoch 400/1000, Loss: 0.010733
Epoch 500/1000, Loss: 0.007764
Epoch 600/1000, Loss: 0.006275
Epoch 700/1000, Loss: 0.007647
Epoch 800/1000, Loss: 0.007318
Epoch 900/1000, Loss: 0.009302
Epoch 1000/1000, Loss: 0.006620
------------------------------------------------------------
Inference complete! Mean loss: 0.008087
Generating animation with 1010 frames...
Saving animation to inference_2.mp4...
Animation saved successfully!
