In [None]:
## Created by Wentinn Liao

# Kalman Filter Research

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Symlink Setup
import os

def ptpp(PATH: str) -> str: # Converts path to python path
    return PATH.replace('\\', '')

DRIVE_PATH = '/content/gdrive/My\ Drive/KF_RNN'
if not os.path.exists(ptpp(DRIVE_PATH)):
    %mkdir $DRIVE_PATH
SYM_PATH = '/content/KF_RNN'
if not os.path.exists(ptpp(SYM_PATH)):
    !ln -s $DRIVE_PATH $SYM_PATH
%cd $SYM_PATH

In [None]:
!pip install numpy imageio matplotlib scikit-learn torch==2.0.0 tensordict

In [None]:
#@title Configure Jupyter Notebook
import matplotlib
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
#@title Library Setup
import numpy as np
import matplotlib.pyplot as plt
from typing import *
from argparse import Namespace
import copy
import itertools
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.utils as ptu
import tensordict
from tensordict import TensorDict
from matplotlib.collections import PolyCollection

from model.linear_system import LinearSystem
from model.kf import KF
from model.rnn_kf import RnnKF

from infrastructure import utils
from infrastructure.train import *

# seed = 7
# torch.manual_seed(seed)
# random.seed(seed)
torch.set_default_dtype(torch.double)

dev_type = 'cuda'
if dev_type == 'xla':
    !pip install torch-xla cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
    import torch_xla
    import torch_xla.core.xla_model as xm

plt.rcParams['figure.figsize'] = (7.0, 5.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [None]:
S_D, I_D, O_D, SNR = 2, 2, 1, 2.
B, L = 1, 4

system = LinearSystem.sample_stable_system(Namespace(
    S_D = S_D,
    I_D = I_D,
    O_D = O_D,
    SNR = SNR
))
optimal_kf = KF(system)
learned_kf = RnnKF(S_D, I_D, O_D)
learned_kf.K = nn.Parameter(torch.randn(S_D, O_D))

test_state = torch.randint(-10, 11, (B, S_D), dtype=float)
test_inputs = torch.randint(-10, 11, (B, L, I_D), dtype=float)
test_observations = torch.randint(-10, 11, (B, L, O_D), dtype=float)

# print(system(test_state, test_inputs))
# print(optimal_kf(test_state, test_inputs, test_observations))
result1 = learned_kf(test_state, test_inputs, test_observations)
result2 = learned_kf(test_state, test_inputs, test_observations, mode='form')
result3 = learned_kf(test_state, test_inputs, test_observations, mode='form_sqrt')

# print(torch.norm(result1['state_estimation'] - result2['state_estimation']))
# print(torch.norm(result1['observation_estimation'] - result2['observation_estimation']))
print(result1['state_estimation'])
print(result2['state_estimation'])
print(result3['state_estimation'])

# Sample Complexity

In [None]:
#@title Model Parameters
ModelArgs = Namespace(
    S_D = 6,
    I_D = 6,
    O_D = 4,
    SNR = 2.
)

In [None]:
#@title Training Parameters
total_trace_lengths = sorted(set(torch.ceil(torch.pow(2, torch.arange(7., 12.5, 0.5))).to(int).tolist()))
num_traces = sorted(set(torch.ceil(torch.pow(2, torch.arange(0., 7.5, 0.5))).to(int).tolist()))

BaseTrainArgs = Namespace(
    # Dataset
    train_dataset_size = num_traces,
    valid_dataset_size = 100,
    total_train_sequence_length = total_trace_lengths,
    total_valid_sequence_length = 20000,

    # Batch sampling
    subsequence_length = 10,
    subsequence_initial_mode = "random",    # {"random", "replay_buffer"}
    sample_efficiency = 5,
    replay_buffer = 10,
    batch_size = 128,

    # Optimizer
    beta = 0.1,
    lr = 3e-4,
    momentum = 0.9,
    lr_decay = 0.95,
    optim_type = "Adam",                    # {"GD", "SGD", "SGDMomentum", "Adam"}
    l2_reg = 0.1,

    # Iteration
    iterations_per_epoch = 100,
    epochs = 100
)

In [None]:
#@title Experiment Parameters
BaseExperimentArgs = Namespace(
    n_systems = 16,
    ensemble_size = 1,
    log_frequency = 5,
    print_frequency = 20
)

In [None]:
#@title Experiment Configurations
base_exp_name = 'SC'
output_dir = 'sample_complexity_strong_convergence'
output_fname = 'result'

optim_configs = [
    ('SGDMomentum', {
        'lr': 1.5e-4
    }),
    ('Adam', {
        'lr': 1.5e-2
    }),
    ('GD', {
        'lr': 5.e-4
    })
]
system_configs = [
    ('', {
        'fname': 'systems'
    })
]

result = {}
for (optim_config_name, optim_config), (system_config_name, system_config) in itertools.product(
    optim_configs,
    system_configs
):
    TrainArgs = copy.copy(BaseTrainArgs)
    TrainArgs.__dict__.update(optim_config)
    TrainArgs.optim_type = optim_config_name

    ExperimentArgs = copy.copy(BaseExperimentArgs)
    ExperimentArgs.exp_name = f'Full{system_config_name}{optim_config_name}{base_exp_name}'

    Args = Namespace(
        model = ModelArgs,
        train = TrainArgs,
        experiment = ExperimentArgs
    )

    result[optim_config_name, system_config_name] = run_experiments(
        Args, [
            'total_train_sequence_length',
            'train_dataset_size'
        ], dev_type, {
            'dir': output_dir,
            'fname': output_fname
        }, system_kwargs=system_config
    )

# Sample Complexity

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (20.0, 24.0)

loss_type = 'overfit'

tail = 10
threshold = 5.
for (h, (optim_config_name, _)), (w, (system_config_name, _)) in itertools.product(
    enumerate(optim_configs),
    enumerate(system_configs)
):
    exp_name = system_config_name + optim_config_name + base_exp_name
    r = result[optim_config_name, system_config_name]

    mean_nl = torch.zeros(len(total_trace_lengths), len(num_traces))
    std_nl = torch.zeros(len(total_trace_lengths), len(num_traces))
    for (i, t), (j, n) in itertools.product(enumerate(total_trace_lengths), enumerate(num_traces)):
        str_args = (str(t), str(n))
        il, l = r[str_args]['irreducible_loss'].detach(), r[str_args]['loss'][loss_type].detach()

        nl = torch.mean(l[:, :, -tail:], dim=-1) / il
        nl[nl > threshold] = float('nan')

        converged_nl = utils.remove_nans_and_infs(nl)
        mean_nl[i, j] = torch.mean(converged_nl)
        std_nl[i, j] = torch.std(converged_nl)


    x_range, y_range = torch.log2(torch.tensor(total_trace_lengths)), torch.log2(torch.tensor(num_traces))
    x_mesh, y_mesh = torch.meshgrid(x_range, y_range)


    # Plot 3D
    ax_3d = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 1, projection='3d')

    ax_3d.view_init(elev=15., azim=30.)
    ax_3d.plot_wireframe(x_mesh, y_mesh, mean_nl, color='black', linewidth=0.5)
    ax_3d.scatter(x_mesh, y_mesh, mean_nl, s=12, c=mean_nl.flatten(), cmap=plt.cm.YlOrRd_r, alpha=1)
    ax_3d.contour(x_mesh, y_mesh, mean_nl, zdir='y', offset=-1, cmap='plasma')
    ax_3d.plot_surface(x_mesh, y_mesh, torch.ones_like(x_mesh), color='black', alpha=0.2)
    ax_3d.add_collection3d(PolyCollection(
        torch.stack([
            torch.cat([x_mesh, torch.flip(x_mesh, dims=(0,))], dim=0),
            torch.cat([mean_nl + std_nl, torch.flip(mean_nl - std_nl, dims=(0,))], dim=0)
        ], dim=-1).permute(1, 0, 2),
        facecolors=plt.cm.plasma(torch.linspace(0, 1, len(num_traces))),
        alpha=0.5
    ), zs=y_range, zdir='y')


    ax_3d.set_xlabel('total_trace_length')
    ax_3d.set_ylabel('num_traces')
    ax_3d.set_zlim(bottom=0)
    ax_3d.set_zlabel(f'normalized_{loss_type}_loss')
    ax_3d.set_title(f'{exp_name} 3D')

    # Plot 2D
    ax_2d_n = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 2)

    c = np.log2(num_traces)
    c = plt.cm.plasma((c - np.min(c)) / np.ptp(c))
    for j, n in enumerate(num_traces):
        m, s = mean_nl[:, j], std_nl[:, j]
        ax_2d_n.plot(
            total_trace_lengths, m,
            linewidth=0.5,
            marker='.',
            color=c[j],
            label=f'num_traces{n}'
        )
        ax_2d_n.fill_between(
            total_trace_lengths, m - s, m + s,
            color=c[j],
            alpha=0.05
        )
    ax_2d_n.plot(total_trace_lengths, torch.ones(len(total_trace_lengths)), color='black', linestyle='--')


    ax_2d_n.set_xlabel('total_trace_length')
    ax_2d_n.set_xscale('log')
    ax_2d_n.set_ylabel(f'normalized_{loss_type}_loss')
    ax_2d_n.set_title(f'{exp_name} 2D')
    ax_2d_n.legend()


    # Plot 2D
    ax_2d_t = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 3)

    c = np.log2(total_trace_lengths)
    c = plt.cm.twilight_shifted((c - np.min(c)) / np.ptp(c))
    for i, t in list(enumerate(total_trace_lengths)):
        m, s = mean_nl[i], std_nl[i]
        ax_2d_t.plot(
            num_traces, m,
            linewidth=0.5,
            marker='.',
            color=c[i],
            label=f'total_trace_length{t}'
        )
    ax_2d_t.plot(num_traces, torch.ones(len(num_traces)), color='black', linestyle='--')


    ax_2d_t.set_xlabel('num_traces')
    ax_2d_t.set_xscale('log')
    ax_2d_t.set_ylabel(f'normalized_{loss_type}_loss')
    ax_2d_t.set_title(f'{exp_name} 2D')
    ax_2d_t.legend()

plt.show()

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (20.0, 24.0)

loss_type = 'validation'

tail = 10
threshold = 5.
for (h, (optim_config_name, _)), (w, (system_config_name, _)) in itertools.product(
    enumerate(optim_configs),
    enumerate(system_configs)
):
    exp_name = system_config_name + optim_config_name + base_exp_name
    r = result[optim_config_name, system_config_name]

    mean_nl = torch.zeros(len(total_trace_lengths), len(num_traces))
    std_nl = torch.zeros(len(total_trace_lengths), len(num_traces))
    for (i, t), (j, n) in itertools.product(enumerate(total_trace_lengths), enumerate(num_traces)):
        str_args = (str(t), str(n))
        il, l = r[str_args]['irreducible_loss'].detach(), r[str_args]['loss'][loss_type].detach()

        nl = torch.mean(l[:, :, -tail:], dim=-1) / il
        nl[nl > threshold] = float('nan')

        converged_nl = utils.remove_nans_and_infs(nl)
        mean_nl[i, j] = torch.mean(converged_nl)
        std_nl[i, j] = torch.std(converged_nl)


    x_range, y_range = torch.log2(torch.tensor(total_trace_lengths)), torch.log2(torch.tensor(num_traces))
    x_mesh, y_mesh = torch.meshgrid(x_range, y_range)


    # Plot 3D
    ax_3d = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 1, projection='3d')

    ax_3d.view_init(elev=15., azim=30.)
    ax_3d.plot_wireframe(x_mesh, y_mesh, mean_nl, color='black', linewidth=0.5)
    ax_3d.scatter(x_mesh, y_mesh, mean_nl, s=12, c=mean_nl.flatten(), cmap=plt.cm.YlOrRd_r, alpha=1)
    ax_3d.contour(x_mesh, y_mesh, mean_nl, zdir='y', offset=-1, cmap='plasma')
    ax_3d.plot_surface(x_mesh, y_mesh, torch.ones_like(x_mesh), color='black', alpha=0.2)
    ax_3d.add_collection3d(PolyCollection(
        torch.stack([
            torch.cat([x_mesh, torch.flip(x_mesh, dims=(0,))], dim=0),
            torch.cat([mean_nl + std_nl, torch.flip(mean_nl - std_nl, dims=(0,))], dim=0)
        ], dim=-1).permute(1, 0, 2),
        facecolors=plt.cm.plasma(torch.linspace(0, 1, len(num_traces))),
        alpha=0.5
    ), zs=y_range, zdir='y')


    ax_3d.set_xlabel('total_trace_length')
    ax_3d.set_ylabel('num_traces')
    ax_3d.set_zlim(bottom=0)
    ax_3d.set_zlabel(f'normalized_{loss_type}_loss')
    ax_3d.set_title(f'{exp_name} 3D')

    # Plot 2D
    ax_2d_n = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 2)

    c = np.log2(num_traces)
    c = plt.cm.plasma((c - np.min(c)) / np.ptp(c))
    for j, n in enumerate(num_traces):
        m, s = mean_nl[:, j], std_nl[:, j]
        ax_2d_n.plot(
            total_trace_lengths, m,
            linewidth=0.5,
            marker='.',
            color=c[j],
            label=f'num_traces{n}'
        )
        ax_2d_n.fill_between(
            total_trace_lengths, m - s, m + s,
            color=c[j],
            alpha=0.05
        )
    ax_2d_n.plot(total_trace_lengths, torch.ones(len(total_trace_lengths)), color='black', linestyle='--')


    ax_2d_n.set_xlabel('total_trace_length')
    ax_2d_n.set_xscale('log')
    ax_2d_n.set_ylabel(f'normalized_{loss_type}_loss')
    ax_2d_n.set_title(f'{exp_name} 2D')
    ax_2d_n.legend()


    # Plot 2D
    ax_2d_t = plt.subplot(H, 3 * W, (3 * W) * h + 3 * w + 3)

    c = np.log2(total_trace_lengths)
    c = plt.cm.twilight_shifted((c - np.min(c)) / np.ptp(c))
    for i, t in list(enumerate(total_trace_lengths)):
        m, s = mean_nl[i], std_nl[i]
        ax_2d_t.plot(
            num_traces, m,
            linewidth=0.5,
            marker='.',
            color=c[i],
            label=f'total_trace_length{t}'
        )
    ax_2d_t.plot(num_traces, torch.ones(len(num_traces)), color='black', linestyle='--')


    ax_2d_t.set_xlabel('num_traces')
    ax_2d_t.set_xscale('log')
    ax_2d_t.set_ylabel(f'normalized_{loss_type}_loss')
    ax_2d_t.set_title(f'{exp_name} 2D')
    ax_2d_t.legend()

plt.show()

# Kalman Filter Eigenvalue Comparison

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (20.0, 12.0)


loss_type = 'validation'

tail = 10
threshold = 5.
for w, (system_config_name, system_config) in enumerate(system_configs):
    systems = tensordict.utils.expand_right(
        TensorDict(torch.func.stack_module_state(
            torch.load(f'output/{output_dir}/{system_config["fname"]}.pt', map_location='cpu')
        )[0], batch_size=(BaseExperimentArgs.n_systems,)),
        (BaseExperimentArgs.n_systems, BaseExperimentArgs.ensemble_size)
    )

    for h, (optim_config_name, _) in enumerate(optim_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name
        r = result[optim_config_name, system_config_name]

        learned_kfs = tensordict.utils.expand_right(TensorDict(
            RnnKF(ModelArgs.S_D, ModelArgs.I_D, ModelArgs.O_D).state_dict(),
        batch_size=())[None, None, None, None], (
            len(total_trace_lengths),
            len(num_traces),
            BaseExperimentArgs.n_systems,
            BaseExperimentArgs.ensemble_size
        )).clone()

        mean_nl = torch.zeros(len(total_trace_lengths), len(num_traces))
        I = torch.eye(ModelArgs.S_D)
        for (i, t), (j, n) in itertools.product(enumerate(total_trace_lengths), enumerate(num_traces)):
            learned_kfs[i, j] = r[str(t), str(n), 'learned_kf']

            il, l = r[str(t), str(n), 'irreducible_loss'].detach(), r[str(t), str(n), 'loss'][loss_type].detach()
            nl = torch.mean(l[:, :, -tail:], dim=-1) / il
            nl[nl > threshold] = float('nan')
            mean_nl[i, j] = torch.mean(utils.remove_nans_and_infs(nl))

        n_idx, e_idx = 0, 0

        eig_lkf_F = torch.linalg.eig(learned_kfs['F'])[0][:, :, n_idx, e_idx]
        eig_lkf_M = torch.linalg.eig((I - learned_kfs['K'] @ learned_kfs['H']) @ learned_kfs['F'])[0][:, :, n_idx, e_idx]

        c = mean_nl[:, :, None].expand(-1, -1, ModelArgs.S_D)
        plt.scatter(
            torch.real(eig_lkf_F),
            torch.imag(eig_lkf_F),
            s=64 / (c ** 10),
            c=mean_nl[:, :, None].expand(-1, -1, ModelArgs.S_D),
            cmap='plasma',
            label='RnnKF'
        )

        eig_sys_F = torch.linalg.eig(systems['F'])[0][n_idx, e_idx]
        eig_sys_M = torch.linalg.eig((I - systems['K'] @ systems['H']) @ systems['F'])[0][n_idx, e_idx]
        plt.scatter(
            torch.real(eig_sys_F),
            torch.imag(eig_sys_F),
            s=256,
            color='black',
            label='System'
        )

        plt.legend()
        plt.show()

# plt.show()

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 20.0)
fig, axs = plt.subplots(H, W)

for h, (optim_config_name, _) in enumerate(optim_configs):
    for w, (system_config_name, _) in enumerate(system_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name

        r = result[optim_config_name][system_config_name]

        training_loss, overfit_loss, validation_loss, irreducible_loss = (r[k].detach().cpu() for k in (
            'training_loss',
            'overfit_loss',
            'validation_loss',
            'irreducible_loss'
        ))

        x = torch.arange(training_loss.shape[-1], dtype=float)
        irreducible_loss = irreducible_loss[:, :1]

        axs[h, w].plot(x, torch.ones_like(x), linestyle='--', linewidth=0.5, color='black', label='normalized irreducible_loss')
        for lname in ('training_loss', 'overfit_loss', 'validation_loss'):
            loss = torch.mean(eval(lname), dim=1)
            normalized_loss = loss / irreducible_loss

            mean_normalized_loss = torch.mean(normalized_loss, dim=0)
            min_normalized_loss = torch.min(normalized_loss, dim=0)
            max_normalized_loss = torch.max(normalized_loss, dim=0)

            axs[h, w].plot(x, min_normalized_loss.values, linewidth=0.5, label=f'mean normalized {lname}')
        #     plt.fill_between(x, min_normalized_loss, max_normalized_loss, alpha=0.2)

        axs[h, w].set_xlabel('Batch')
        axs[h, w].set_ylabel('Normalized Loss')
        axs[h, w].set_title(exp_name)
        # axs[h, w].legend()
plt.show()

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 20.0)
fig, axs = plt.subplots(H, W)

tail = 10
cutoff = 100
for h, (optim_config_name, _) in enumerate(optim_configs):
    for w, (system_config_name, _) in enumerate(system_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name

        r = result[optim_config_name][system_config_name]

        training_loss, overfit_loss, validation_loss, irreducible_loss = (r[k].detach().cpu() for k in (
            'training_loss',
            'overfit_loss',
            'validation_loss',
            'irreducible_loss'
        ))

        x = torch.arange(cutoff, training_loss.shape[-1], dtype=float)

        axs[h, w].plot(x, torch.ones_like(x), linestyle='--', linewidth=1., color='black', label='normalized irreducible_loss')
        normalized_overfit_loss = overfit_loss / irreducible_loss[:, :, None]
        normalized_validation_loss = validation_loss / irreducible_loss[:, :, None]

        N = len(irreducible_loss)
        for n, (overfit_loss, validation_loss) in list(enumerate(zip(normalized_overfit_loss, normalized_validation_loss)))[::2]:
            mean_overfit_loss = torch.mean(overfit_loss, dim=0)[cutoff:]
            mean_validation_loss = torch.mean(validation_loss, dim=0)[cutoff:]

            c = color(n, scale=N)
            axs[h, w].plot(x, mean_overfit_loss, linewidth=1., color=c, label=f'Experiment {n}')
            axs[h, w].plot(x, mean_validation_loss, linewidth=1., color=c)
            axs[h, w].fill_between(x, mean_overfit_loss, mean_validation_loss, alpha=0.1, color=c)

        axs[h, w].set_xlabel('Batch')
        axs[h, w].set_ylabel('Normalized Loss')
        axs[h, w].set_title(f'{exp_name} - Normalized Overfit Loss')
        # axs[h, w].legend()
plt.show()

In [None]:
H, W = len(optim_configs), len(system_configs)
plt.rcParams['figure.figsize'] = (12.0, 20.0)
fig, axs = plt.subplots(H, W)

tail = 10
for h, (optim_config_name, _) in enumerate(optim_configs):
    for w, (system_config_name, _) in enumerate(system_configs):
        exp_name = system_config_name + optim_config_name + base_exp_name

        r = result[optim_config_name][system_config_name]

        training_loss, overfit_loss, validation_loss, irreducible_loss = (r[k].detach().cpu() for k in (
            'training_loss',
            'overfit_loss',
            'validation_loss',
            'irreducible_loss'
        ))

        normalized_overfit_loss = overfit_loss / irreducible_loss[:, :, None]
        normalized_validation_loss = validation_loss / irreducible_loss[:, :, None]

        irreducible_, indices = torch.sort(irreducible_loss[:, 0])
        tail_ = torch.mean(normalized_overfit_loss[:, :, -tail:], dim=-1)
        mean_ = torch.mean(tail_, dim=1)[indices]
        std_ = torch.std(tail_, dim=1)[indices]
        min_ = tail_[torch.arange(len(mean_)), torch.argmin(tail_, dim=1)][indices]
        max_ = tail_[torch.arange(len(mean_)), torch.argmax(tail_, dim=1)][indices]


        axs_twinx = axs[h, w].twinx()
        axs[h, w].plot(irreducible_, torch.zeros_like(mean_), linewidth=1., linestyle='--', marker='.', color='black', label=f'Mean loss')
        axs_twinx.plot(irreducible_, mean_, linewidth=1., marker='.', color='blue', label='mean overfit loss')

        axs[h, w].plot(irreducible_, min_ - mean_, linewidth=0.5, marker='.', color='turquoise')
        # axs[h, w].plot(mean_, max_ - mean_, linewidth=0.5, marker='.', color='turquoise')

        axs[h, w].fill_between(irreducible_, min_ - mean_, max_ - mean_, color='aquamarine', alpha=0.2, label='min-max')
        axs[h, w].fill_between(irreducible_, -std_, std_, color='aquamarine', alpha=0.5, label='1 std')

        axs[h, w].set_xlabel('Normalized mean loss per system')
        axs[h, w].set_title(f'{exp_name}')
        # axs[h, w].legend()
plt.show()