In [None]:
## Created by Wentinn Liao

# Kalman Filter Research

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Symlink Setup
import os

def ptpp(PATH: str) -> str: # Converts path to python path
    return PATH.replace('\\', '')

DRIVE_PATH = '/content/gdrive/My\ Drive/KF_RNN'
if not os.path.exists(ptpp(DRIVE_PATH)):
    %mkdir $DRIVE_PATH
SYM_PATH = '/content/KF_RNN'
if not os.path.exists(ptpp(SYM_PATH)):
    !ln -s $DRIVE_PATH $SYM_PATH
%cd $SYM_PATH

In [None]:
!pip install numpy imageio matplotlib scikit-learn torch tensordict

In [None]:
#@title Configure Jupyter Notebook
import matplotlib
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
#@title Library Setup
import numpy as np
import matplotlib.pyplot as plt
from typing import *
from argparse import Namespace
import copy
import time
import math
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.utils as ptu
import tensordict
from tensordict import TensorDict

from model.linear_system import LinearSystem
from model.kf import KF
from model.rnn_kf import RnnKF

from infrastructure import utils
from infrastructure.train import *

# seed = 7
# torch.manual_seed(seed)
# random.seed(seed)
torch.set_default_dtype(torch.double)

dev_type = 'cuda'
if dev_type == 'xla':
    !pip install torch-xla cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
    import torch_xla
    import torch_xla.core.xla_model as xm

plt.rcParams['figure.figsize'] = (7.0, 5.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [None]:
S_D, I_D, O_D, SNR = 2, 2, 1, 2.
B, L = 1, 4

system = LinearSystem.sample_stable_system(Namespace(
    S_D = S_D,
    I_D = I_D,
    O_D = O_D,
    SNR = SNR
))
optimal_kf = KF(system)
learned_kf = RnnKF(S_D, I_D, O_D)
learned_kf.K = nn.Parameter(torch.randn(S_D, O_D))

test_state = torch.randint(-10, 11, (B, S_D), dtype=float)
test_inputs = torch.randint(-10, 11, (B, L, I_D), dtype=float)
test_observations = torch.randint(-10, 11, (B, L, O_D), dtype=float)

# print(system(test_state, test_inputs))
# print(optimal_kf(test_state, test_inputs, test_observations))
result1 = learned_kf(test_state, test_inputs, test_observations)
result2 = learned_kf(test_state, test_inputs, test_observations, mode='form')
result3 = learned_kf(test_state, test_inputs, test_observations, mode='form_sqrt')

# print(torch.norm(result1['state_estimation'] - result2['state_estimation']))
# print(torch.norm(result1['observation_estimation'] - result2['observation_estimation']))
print(result1['state_estimation'])
print(result2['state_estimation'])
print(result3['state_estimation'])

# Sample Complexity

In [None]:
#@title Model Parameters
ModelArgs = Namespace(
    S_D = 6,
    I_D = 6,
    O_D = 4,
    SNR = 2.
)

In [None]:
#@title Training Parameters
total_trace_lengths = sorted(set(torch.ceil(torch.pow(2, torch.arange(0., 12.5, 0.5))).to(int).tolist()))
num_traces = sorted(set(torch.ceil(torch.pow(2, torch.arange(0., 6.5, 0.5))).to(int).tolist()))

BaseTrainArgs = Namespace(
    # Dataset
    train_dataset_size = num_traces,
    valid_dataset_size = 100,
    total_train_sequence_length = total_trace_lengths,
    total_valid_sequence_length = 20000,

    # Batch sampling
    subsequence_length = 10,
    subsequence_initial_mode = "random",    # {"random", "replay_buffer"}
    sample_efficiency = 5,
    replay_buffer = 10,
    batch_size = 128,

    # Optimizer
    beta = 0.1,
    lr = 3e-4,
    momentum = 0.9,
    lr_decay = 0.99,
    optim_type = "Adam",                    # {"GD", "SGD", "SGDMomentum", "Adam"}
    l2_reg = 0.1,

    # Iteration
    iterations_per_epoch = 100,
    epochs = 20
)

In [None]:
system = LinearSystem.sample_stable_system(ModelArgs)
kf = RnnKF(ModelArgs.S_D, ModelArgs.I_D, ModelArgs.O_D)
optim, scheduler = get_optimizer(kf.parameters(), BaseTrainArgs)

In [None]:
n_iter = 10
B_ = 16

average_times = dict()
for mode, L in itertools.product(['serial', 'form', 'form_sqrt'], total_trace_lengths):
    B = int(math.ceil(B_ * total_trace_lengths[-1] / L))
    start_t = time.time_ns()
    for _ in range(n_iter):
        test_state = torch.randn((B, ModelArgs.S_D))
        test_inputs = torch.randn((B, L, ModelArgs.I_D))
        test_observations = torch.randn((B, L, ModelArgs.O_D))

        result = kf(test_state, test_inputs, test_observations, mode=mode)['observation_estimation']
        loss = torch.norm(result)

        optim.zero_grad()
        loss.backward()
        optim.step()
    end_t = time.time_ns()
    print(f'Length {L} mode {mode}: {(avg_t := 1e-6 * (end_t - start_t) / n_iter)}')
    average_times.setdefault(mode, []).append(avg_t)

In [None]:
for mode in ['serial', 'form', 'form_sqrt']:
    plt.plot(total_trace_lengths, average_times[mode], marker='.', label=mode)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Trace length')
plt.ylabel('Time (ms)')
plt.title('Single iteration time for batch size 16')
plt.legend()
plt.show()