In [None]:
## Created by Wentinn Liao

# Kalman Filter Research

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Symlink Setup
import os

def ptpp(PATH: str) -> str: # Converts path to python path
    return PATH.replace('\\', '')

DRIVE_PATH = '/content/gdrive/My\ Drive/KF_RNN'
if not os.path.exists(ptpp(DRIVE_PATH)):
    %mkdir $DRIVE_PATH
SYM_PATH = '/content/KF_RNN'
if not os.path.exists(ptpp(SYM_PATH)):
    !ln -s $DRIVE_PATH $SYM_PATH
%cd $SYM_PATH

In [None]:
!pip install numpy imageio matplotlib ipympl scikit-learn torch==2.0.0 tensordict

In [None]:
#@title Configure Jupyter Notebook
import matplotlib
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
#@title Library Setup
import numpy as np
import matplotlib.pyplot as plt
from typing import *
from argparse import Namespace
import copy
import itertools
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.utils as ptu
import tensordict
from tensordict import TensorDict
from matplotlib.collections import PolyCollection

from model.linear_system import LinearSystem
from model.analytical_kf import AnalyticalKF
from model.rnn_kf import RnnKF
from model.cnn_kf import CnnKF

from infrastructure import utils
from infrastructure.train import *


torch.set_default_dtype(torch.double)

dev_type = 'cpu'
if dev_type == 'xla':
    !pip install torch-xla cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
    import torch_xla
    import torch_xla.core.xla_model as xm

plt.rcParams['figure.figsize'] = (7.0, 5.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [None]:
modelArgs = Namespace(
    S_D = 2,
    I_D = 2,
    O_D = 1,
    SNR = 2.
)
B, L = 1, 4

system = LinearSystem.sample_stable_system(modelArgs)
optimal_kf = KF(system)
learned_kf = RnnKF(modelArgs)

test_state = torch.randint(-10, 11, (B, modelArgs.S_D), dtype=float)
test_inputs = torch.randint(-10, 11, (B, L, modelArgs.I_D), dtype=float)
test_observations = torch.randint(-10, 11, (B, L, modelArgs.O_D), dtype=float)

# print(system(test_state, test_inputs))
# print(optimal_kf(test_state, test_inputs, test_observations))
result1 = learned_kf(test_state, test_inputs, test_observations)
result2 = learned_kf(test_state, test_inputs, test_observations, mode='form')
result3 = learned_kf(test_state, test_inputs, test_observations, mode='form_sqrt')

# print(torch.norm(result1['state_estimation'] - result2['state_estimation']))
# print(torch.norm(result1['observation_estimation'] - result2['observation_estimation']))
print(result1['state_estimation'])
print(result2['state_estimation'])
print(result3['state_estimation'])

# Infinite Impulse Response

In [None]:
#@title Model Parameters
ModelArgs = Namespace(
    S_D = 6,
    I_D = 6,
    O_D = 3,
    SNR = 2.
)

In [None]:
#@title Training Parameters
BaseTrainArgs = Namespace(
    # Dataset
    train_dataset_size = 1,
    valid_dataset_size = 100,
    total_train_sequence_length = 2048,
    total_valid_sequence_length = 20000,

    # Batch sampling
    subsequence_length = 10,
    subsequence_initial_mode = "random",    # {"random", "replay_buffer"}
    sample_efficiency = 5,
    replay_buffer = 10,
    batch_size = 128,

    # Optimizer
    beta = 0.1,
    lr = 3e-4,
    momentum = 0.9,
    lr_decay = 0.95,
    optim_type = "Adam",                    # {"GD", "SGD", "SGDMomentum", "Adam"}
    l2_reg = 0.1,

    # Iteration
    iterations_per_epoch = 100,
    epochs = 100
)

In [None]:
#@title Experiment Parameters
BaseExperimentArgs = Namespace(
    n_systems = 4,
    ensemble_size = 8,
    log_frequency = 5,
    print_frequency = 20
)

In [None]:
#@title Experiment Configurations
base_exp_name = 'IIR'
output_dir = 'infinite_impulse_response_backup'
output_fname = 'result'

optim_configs = [
    ('SGDMomentum', {
        'lr': 1.5e-4
    }),
    ('Adam', {
        'lr': 1.5e-2
    }),
    ('GD', {
        'lr': 5.e-4
    })
]
system_configs = [
    ('', {
        'fname': 'systems'
    })
]

result = {}
for (optim_config_name, optim_config), (system_config_name, system_config) in itertools.product(
    optim_configs,
    system_configs
):
    TrainArgs = copy.copy(BaseTrainArgs)
    TrainArgs.__dict__.update(optim_config)
    TrainArgs.optim_type = optim_config_name

    ExperimentArgs = copy.copy(BaseExperimentArgs)
    ExperimentArgs.exp_name = f'Full{system_config_name}{optim_config_name}{base_exp_name}'

    Args = Namespace(
        model = ModelArgs,
        train = TrainArgs,
        experiment = ExperimentArgs
    )

    result[optim_config_name, system_config_name] = run_experiments(
        Args, [], dev_type, {
            'dir': output_dir,
            'fname': output_fname
        }, system_kwargs=system_config
    )

# Eigenvalue Comparison

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (20.0, 12.0)


loss_type = 'validation'

tail = 10
threshold = 5.
for w, (system_config_name, system_config) in enumerate(system_configs):
    systems = torch.load(f'output/{output_dir}/{system_config["fname"]}.pt', map_location='cpu')
    ensembled_systems = TensorDict(torch.func.stack_module_state(systems)[0], batch_size=(BaseExperimentArgs.n_systems,))

    for h, (optim_config_name, _) in enumerate(optim_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name
        r = result[optim_config_name, system_config_name].to('cpu')

        learned_kfs = r['learned_kf']

        il, l = r['irreducible_loss'].detach(), r['loss'][loss_type].detach()
        nl = torch.mean(l[:, :, -tail:], dim=-1) / il
        nl[nl > threshold] = float('nan')

        n_idx = 0
        # diverged = torch.sum(torch.isnan(learned_kfs['F']) + torch.isinf(learned_kfs['F']), dim=(-2, -1))[n_idx]
        # print(torch.sum(diverged))
        I = torch.eye(ModelArgs.S_D)
        eig_lkf_F = torch.linalg.eig(learned_kfs['F'][n_idx])[0]
        eig_lkf_M = torch.linalg.eig(((I - learned_kfs['K'] @ learned_kfs['H']) @ learned_kfs['F'])[n_idx])[0]

        print(nl[n_idx])
        c = nl[n_idx, :, None].expand(-1, ModelArgs.S_D)
        plt.scatter(
            torch.real(eig_lkf_M).detach(),
            torch.imag(eig_lkf_M).detach(),
            s=128 / (c ** 10),
            c=c,
            cmap='inferno',
            label='RnnKF'
        )

        eig_sys_F = torch.linalg.eig(ensembled_systems['F'])[0][n_idx]
        eig_sys_M = torch.linalg.eig((I - ensembled_systems['K'] @ ensembled_systems['H']) @ ensembled_systems['F'])[0][n_idx]
        plt.scatter(
            torch.real(eig_sys_M),
            torch.imag(eig_sys_M),
            s=256,
            color='black',
            label='System'
        )

        plt.title(f'{exp_name} eigenvalues')
        plt.legend()
        plt.show()

# plt.show()

# Infinite Impulse Response

In [None]:
from mpl_toolkits.mplot3d import Axes3D

from google.colab import output
output.disable_custom_widget_manager()
# %matplotlib widget

H, W = len(optim_configs), 2 * len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 18.0)


loss_type = 'validation'

tail = 10
threshold = 5.
for w, (system_config_name, system_config) in enumerate(system_configs):
    systems = torch.load(f'output/{output_dir}/{system_config["fname"]}.pt', map_location='cpu')
    ensembled_systems = TensorDict(torch.func.stack_module_state(systems)[0], batch_size=(BaseExperimentArgs.n_systems,))

    kfs = [KF(sys) for sys in systems]
    ensembled_kfs = TensorDict(torch.func.stack_module_state(kfs)[0], batch_size=(BaseExperimentArgs.n_systems,))

    for h, (optim_config_name, _) in enumerate(optim_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name
        r = result[optim_config_name, system_config_name].to('cpu')

        il, l = r['irreducible_loss'].detach(), r['loss'][loss_type].detach()
        nl = torch.mean(l[:, :, -tail:], dim=-1) / il
        nl[nl > threshold] = float('nan')

        # Setup impulses
        iir_length = 16
        n_models = BaseExperimentArgs.n_systems * BaseExperimentArgs.ensemble_size

        X_initial_iir = torch.zeros(1, ModelArgs.O_D, ModelArgs.S_D)
        X_input_iir = torch.zeros(1, ModelArgs.O_D, iir_length, ModelArgs.I_D)
        y_iir = torch.cat([
            torch.eye(ModelArgs.O_D)[None, :, None, :],
            torch.zeros(1, ModelArgs.O_D, iir_length - 1, ModelArgs.O_D)
        ], dim=2)

        # Impulse response for true KF
        base_KF = KF(systems[0]).to(dev_type)
        def run_kf(kf_dicts, state, inputs, observations):
            return torch.func.functional_call(base_KF, kf_dicts, (state, inputs, observations), {'steady_state': True})

        with torch.set_grad_enabled(False):
            kf_iir = torch.func.vmap(run_kf)(
                dict(ensembled_kfs),
                X_initial_iir.expand(BaseExperimentArgs.n_systems, -1, -1),
                X_input_iir.expand(BaseExperimentArgs.n_systems, -1, -1, -1),
                y_iir.expand(BaseExperimentArgs.n_systems, -1, -1, -1)
            )['observation_estimation']

        # Impulse response for learned KF
        base_RnnKF = RnnKF(ModelArgs).to(dev_type)
        def run_lkf(kf_dicts, state, inputs, observations):
            return torch.func.functional_call(base_RnnKF, kf_dicts, (state, inputs, observations))

        n_models = BaseExperimentArgs.n_systems * BaseExperimentArgs.ensemble_size
        with torch.set_grad_enabled(False):
            lkf_iir = torch.func.vmap(run_lkf)(
                dict(r['learned_kf'].flatten()),
                X_initial_iir.expand(n_models, -1, -1),
                X_input_iir.expand(n_models, -1, -1, -1),
                y_iir.expand(n_models, -1, -1, -1)
            )['observation_estimation'].view(
                BaseExperimentArgs.n_systems,
                BaseExperimentArgs.ensemble_size,
                ModelArgs.O_D,
                iir_length,
                ModelArgs.O_D
            )

        n_idx = 1
        iir_d = lkf_iir - kf_iir[:, None]   # [N x E x I x L x O_D]

        # Projections of IIR error
        ax = plt.subplot(H, W, W * h + w + 1)
        c = matplotlib.cm.inferno(torch.linspace(0, 1, iir_length) ** 0.5)
        s = torch.pow(2, torch.linspace(8, 4, iir_length))

        reshaped_iir_d = iir_d.flatten(1, -2)
        _, _, Vh = torch.svd(reshaped_iir_d, some=True)
        iir_d_reduced = (reshaped_iir_d @ Vh[:, :, :2]).unflatten(1, (
            BaseExperimentArgs.ensemble_size,
            ModelArgs.O_D,
            iir_length
        ))

        for impulse in range(3):
            for l in range(BaseExperimentArgs.ensemble_size):
                ax.scatter(
                    *iir_d_reduced[n_idx, l, impulse].T,
                    c=c,
                    s=s
                )

        ax.set_xlim(left=-0.15, right=0.15)
        ax.set_ylim(bottom=-0.15, top=0.15)
        ax.legend()
        ax.set_title(f'{exp_name} Error Projection')

        # Frobenius norm of IIR error
        ax = plt.subplot(H, W, W * h + w + 2)

        iir_d_fnorm = torch.mean(torch.norm(iir_d, dim=(2, -1)), dim=1)
        for n in range(BaseExperimentArgs.n_systems):
            ax.plot(iir_d_fnorm[n], marker='.', label=f'sys{n}')
        ax.legend()
        ax.set_title(f'{exp_name} Error Frobenius Norm')

plt.show()

In [None]:
Fn.conv2d(torch.randn(3, 100, 100), torch.randn(1, 3, 5, 5)).shape

In [None]:
M1 = torch.randint(0, 4, (5, 5))
M2 = torch.randint(0, 4, (2, 2))
print(M1)
print(M2)
print(Fn.conv2d(M1[None], M2[None, None]))

In [None]:
torch.zeros(5, 7).index_select(1, torch.arange(3))

In [None]:
ir_length = 256
filter = torch.cat([
    torch.zeros(ModelArgs.O_D, ModelArgs.O_D, 1),
    nn.Parameter(torch.randn(ModelArgs.O_D, ModelArgs.O_D, ir_length - 1))
], dim=-1)

B, L = 5, 2048
inputs = torch.randn(B, L, ModelArgs.I_D)
observations = torch.randn(B, L, ModelArgs.O_D)

print(observations.shape)
print(CnnKF(ModelArgs)(None, inputs, observations)['observation_estimation'].shape)

# Fn.conv2d(filter, observations[:, :, None], padding=(0, L)).squeeze(0).shape

In [None]:
from

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 20.0)
fig, axs = plt.subplots(H, W)

tail = 10
cutoff = 100
for h, (optim_config_name, _) in enumerate(optim_configs):
    for w, (system_config_name, _) in enumerate(system_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name

        r = result[optim_config_name][system_config_name]

        training_loss, overfit_loss, validation_loss, irreducible_loss = (r[k].detach().cpu() for k in (
            'training_loss',
            'overfit_loss',
            'validation_loss',
            'irreducible_loss'
        ))

        x = torch.arange(cutoff, training_loss.shape[-1], dtype=float)

        axs[h, w].plot(x, torch.ones_like(x), linestyle='--', linewidth=1., color='black', label='normalized irreducible_loss')
        normalized_overfit_loss = overfit_loss / irreducible_loss[:, :, None]
        normalized_validation_loss = validation_loss / irreducible_loss[:, :, None]

        N = len(irreducible_loss)
        for n, (overfit_loss, validation_loss) in list(enumerate(zip(normalized_overfit_loss, normalized_validation_loss)))[::2]:
            mean_overfit_loss = torch.mean(overfit_loss, dim=0)[cutoff:]
            mean_validation_loss = torch.mean(validation_loss, dim=0)[cutoff:]

            c = color(n, scale=N)
            axs[h, w].plot(x, mean_overfit_loss, linewidth=1., color=c, label=f'Experiment {n}')
            axs[h, w].plot(x, mean_validation_loss, linewidth=1., color=c)
            axs[h, w].fill_between(x, mean_overfit_loss, mean_validation_loss, alpha=0.1, color=c)

        axs[h, w].set_xlabel('Batch')
        axs[h, w].set_ylabel('Normalized Loss')
        axs[h, w].set_title(f'{exp_name} - Normalized Overfit Loss')
        # axs[h, w].legend()
plt.show()

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 20.0)
fig, axs = plt.subplots(H, W)

tail = 10
for h, (optim_config_name, _) in enumerate(optim_configs):
    for w, (system_config_name, _) in enumerate(system_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name

        r = result[optim_config_name][system_config_name]

        training_loss, overfit_loss, validation_loss, irreducible_loss = (r[k].detach().cpu() for k in (
            'training_loss',
            'overfit_loss',
            'validation_loss',
            'irreducible_loss'
        ))

        normalized_overfit_loss = overfit_loss / irreducible_loss[:, :, None]
        normalized_validation_loss = validation_loss / irreducible_loss[:, :, None]

        irreducible_, indices = torch.sort(irreducible_loss[:, 0])
        tail_ = torch.mean(normalized_overfit_loss[:, :, -tail:], dim=-1)
        mean_ = torch.mean(tail_, dim=1)[indices]
        std_ = torch.std(tail_, dim=1)[indices]
        min_ = tail_[torch.arange(len(mean_)), torch.argmin(tail_, dim=1)][indices]
        max_ = tail_[torch.arange(len(mean_)), torch.argmax(tail_, dim=1)][indices]


        axs_twinx = axs[h, w].twinx()
        axs[h, w].plot(irreducible_, torch.zeros_like(mean_), linewidth=1., linestyle='--', marker='.', color='black', label=f'Mean loss')
        axs_twinx.plot(irreducible_, mean_, linewidth=1., marker='.', color='blue', label='mean overfit loss')

        axs[h, w].plot(irreducible_, min_ - mean_, linewidth=0.5, marker='.', color='turquoise')
        # axs[h, w].plot(mean_, max_ - mean_, linewidth=0.5, marker='.', color='turquoise')

        axs[h, w].fill_between(irreducible_, min_ - mean_, max_ - mean_, color='aquamarine', alpha=0.2, label='min-max')
        axs[h, w].fill_between(irreducible_, -std_, std_, color='aquamarine', alpha=0.5, label='1 std')

        axs[h, w].set_xlabel('Normalized mean loss per system')
        axs[h, w].set_title(f'{exp_name}')
        # axs[h, w].legend()
plt.show()