In [1]:
from rnn import *
from fixed_points import *
from viz import *
from utils import *

In [2]:
batch_size = 1
n_epochs = 1000
bits = 1

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create task
task = FlipFlopTask(batch_size=batch_size, T=100, dt=1.0, flip_prob=0.05, bits=bits)
U, _ = task.get()

# Create RNN
rnn = RNN(input_size=task.bits, hidden_size=100, output_size=task.bits, tau=2.0, dt=1.0, use_bias=False)

In [3]:
# Train the RNN
train_R_list, losses, weight_snapshots = train_rnn(rnn, task, n_epochs=n_epochs, learning_rate=1e-2, batch_size=batch_size)
train_R_list = process_R_list(train_R_list)

Training RNN on flip flop task...
Epochs: 1000, Learning rate: 0.01, Batch size: 1
------------------------------------------------------------
Epoch 100/1000, Loss: 0.250805
Epoch 200/1000, Loss: 0.050319
Epoch 300/1000, Loss: 0.017270
Epoch 400/1000, Loss: 0.011128
Epoch 500/1000, Loss: 0.003984
Epoch 600/1000, Loss: 0.001163
Epoch 700/1000, Loss: 0.002687
Epoch 800/1000, Loss: 0.000401
Epoch 900/1000, Loss: 0.000508
Epoch 1000/1000, Loss: 0.000227
------------------------------------------------------------
Training complete!
Shape of R_list: (1000, 1, 101, 100)
Shape of R_list after processing: (1000, 101, 100)


In [4]:
# animate fixed points over training, using PCA fitted on all fixed points found across training
_ = animate_fixed_points(rnn, train_R_list, save_path="plots/FP_training1_fullpca.mp4", stride=10, weight_snapshots=weight_snapshots)

Finding fixed points for 1000 epochs...
  Epoch 100/1000: 1 fixed points
  Epoch 200/1000: 3 fixed points
  Epoch 300/1000: 3 fixed points
  Epoch 400/1000: 3 fixed points
  Epoch 500/1000: 3 fixed points
  Epoch 600/1000: 3 fixed points
  Epoch 700/1000: 3 fixed points
  Epoch 800/1000: 3 fixed points
  Epoch 900/1000: 3 fixed points
  Epoch 1000/1000: 3 fixed points
Generating animation with 100 frames...
Saving animation to plots/FP_training1_fullpca.mp4...
Animation saved successfully!


In [5]:
# # this is incorrect because it uses the current rnn weight to find fixed points, rather than the weights at the time
# unique_fps = find_fixed_points_KE_min(rnn, train_R_list[80])
# pca, lim = animate_R(train_R_list[80], save_path="plots/train_1.mp4", fps=30, stride=1, fixed_points=unique_fps, title="RNN trajectory, first training", pca=pca)

In [6]:
# animate fixed points over inference
# fixed points should not change during inference, but this is a good sanity check that they don't
# R_list, _, weight_snapshots = inference_rnn(rnn, task)
# R_list = process_R_list(R_list)
# _ = animate_fixed_points(rnn, R_list, save_path="plots/FP_inference1_fullpca.mp4", stride=10, weight_snapshots=weight_snapshots)

In [7]:
rnn.reinitialize_weights(weights=['W_in', 'W_out', 'b_out'])

Reinitializing weights: ['W_in', 'W_out', 'b_out']


In [8]:
# Train the RNN
train_R_list, losses, weight_snapshots = train_rnn(rnn, task, n_epochs=300, learning_rate=1e-2, batch_size=batch_size)
train_R_list = process_R_list(train_R_list)

_ = animate_fixed_points(rnn, train_R_list, save_path="plots/FP_training2_fullpca.mp4", stride=1, weight_snapshots=weight_snapshots)

Training RNN on flip flop task...
Epochs: 300, Learning rate: 0.01, Batch size: 1
------------------------------------------------------------
Epoch 100/300, Loss: 0.127564
Epoch 200/300, Loss: 0.065321
Epoch 300/300, Loss: 0.015510
------------------------------------------------------------
Training complete!
Shape of R_list: (300, 1, 101, 100)
Shape of R_list after processing: (300, 101, 100)
Finding fixed points for 300 epochs...
  Epoch 10/300: 3 fixed points
  Epoch 20/300: 3 fixed points
  Epoch 30/300: 3 fixed points
  Epoch 40/300: 3 fixed points
  Epoch 50/300: 1 fixed points
  Epoch 60/300: 1 fixed points
  Epoch 70/300: 1 fixed points
  Epoch 80/300: 3 fixed points
  Epoch 90/300: 3 fixed points
  Epoch 100/300: 3 fixed points
  Epoch 110/300: 3 fixed points
  Epoch 120/300: 3 fixed points
  Epoch 130/300: 1 fixed points
  Epoch 140/300: 3 fixed points
  Epoch 150/300: 3 fixed points
  Epoch 160/300: 3 fixed points
  Epoch 170/300: 3 fixed points
  Epoch 180/300: 3 fixed po