In [1]:
## Created by Wentinn Liao

# Kalman Filter Research

In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Symlink Setup
import os
def ptpp(PATH: str) -> str: # Converts path to python path
    return PATH.replace('\\', '')

DRIVE_PATH = '/content/gdrive/My\ Drive/KF_RNN'
if not os.path.exists(ptpp(DRIVE_PATH)):
    %mkdir $DRIVE_PATH
SYM_PATH = '/content/KF_RNN'
if not os.path.exists(ptpp(SYM_PATH)):
    !ln -s $DRIVE_PATH $SYM_PATH
%cd $SYM_PATH

In [2]:
!pip install numpy imageio matplotlib ipympl scikit-learn torch==2.0.0 pytorch-ignite tensordict

[0m[31mERROR: Could not find a version that satisfies the requirement torch==2.0.0 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for torch==2.0.0[0m[31m
[0m

In [2]:
#@title Configure Jupyter Notebook
import matplotlib
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
#@title Library Setup
import collections
import numpy as np
import matplotlib.pyplot as plt
from typing import *
from argparse import Namespace
import copy
import itertools
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.utils as ptu
import tensordict
from tensordict import TensorDict
from matplotlib.collections import PolyCollection

from model.linear_system import LinearSystem
from model.kf import KF
from model.analytical_kf import AnalyticalKF
from model.rnn_kf import *
from model.cnn_kf import *

from infrastructure import utils, loader, validate
from infrastructure.settings import device, dtype
from infrastructure.train import *
from infrastructure.experiment import *

torch.set_default_dtype(dtype)

if device == 'xla':
    !pip install torch-xla cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
    import torch_xla
    import torch_xla.core.xla_model as xm

plt.rcParams['figure.figsize'] = (7.0, 5.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

TypeError: unsupported operand type(s) for |: 'torch._C._TensorMeta' and 'type'

In [None]:
modelArgs = Namespace(
    S_D = 3,
    I_D = 3,
    O_D = 2,
    SNR = 2.,
    ir_length = 2,
    input_enabled = False
)
B, L = 1, 16

system = LinearSystem.sample_stable_system(modelArgs).to(device)
optimal_kf = AnalyticalKF(system).to(device)
rnn_kf = RnnKF(modelArgs).to(device).train()
cnn_kf = CnnKF(modelArgs).to(device).train()
cnn_kf_lstsq = CnnKFLeastSquares(modelArgs).to(device).train()

test_trace = TensorDict({
    'state': torch.randn((B, modelArgs.S_D)),
    'input': torch.randn((B, L, modelArgs.I_D)),
    'observation': torch.randn((B, L, modelArgs.O_D))
}, batch_size=(B,), device=device)

lstsq, err = cnn_kf_lstsq._least_squares_initialization(test_trace)
cnn_kf_lstsq.input_IR.data = lstsq['input_IR']
cnn_kf_lstsq.observation_IR.data = lstsq['observation_IR']

# print(system(test_state, test_inputs))
# print(optimal_kf(test_state, test_inputs, test_observations))
with torch.set_grad_enabled(False):
    result1 = rnn_kf(test_trace, mode='sequential')
    result2 = rnn_kf(test_trace, mode='form')
    result3 = rnn_kf(test_trace, mode='form_sqrt')
torch.set_grad_enabled(True)

# print(torch.norm(result1['state_estimation'] - result2['state_estimation']))
# print(torch.norm(result1['observation_estimation'] - result2['observation_estimation']))
print(result1['observation_estimation'])
print(result2['observation_estimation'])
print(result3['observation_estimation'])

result4 = cnn_kf(test_trace)
print(result4['observation_estimation'])

result5 = cnn_kf_lstsq(test_trace)
print(result5['observation_estimation'])

In [None]:
#@title Plotting Code
n_idx = 0

def plot_experiment(
        args: Namespace,
        configurations: OrderedDict,
        systems: List[LinearSystem],
        result: Dict[str, torch.Tensor | np.ndarray[TensorDict]],
        base_exp_name: str,
        output_dir: str,
        clip: float=1e-6,
        histogram: float=False
):
    plt.rcParams['figure.figsize'] = (8.0, 6.0)
    exp_name = f'{output_dir}/{base_exp_name}'

    seq_lengths = torch.tensor(configurations[-1][1]['train.total_train_sequence_length'], device=device)
    learned_kf_arr = result['learned_kf']
    M = result['metric']
    # snvl_ = (M.l - M.eil).squeeze(-1)
    snvl_ = (M.al - M.il).squeeze(-1)

    # Plotting function
    def make_plot_for_learned_kfs(snvl: torch.Tensor, name: object, color: np.ndarray):
        name = str(name)
        if len(name) > 50:
            name = name[:50]

        quantiles = torch.tensor([0.25, 0.75], device=device)
        # snvl_median = snvl.median(-1).values.median(-1).values
        # snvl_train_quantiles = torch.quantile(snvl.median(-1).values, quantiles, dim=-1)
        # snvl_valid_quantiles = torch.quantile(snvl, quantiles, dim=-1).median(-1).values
        snvl_median = snvl.median(-1).values
        snvl_train_quantiles = torch.quantile(snvl, quantiles, dim=-1)

        # Compute the best fit line
        log_seq_lengths, log_snvl_median = torch.log(seq_lengths), torch.log(snvl_median)
        augmented_log_seq_lengths = torch.stack([log_seq_lengths, torch.ones_like(log_seq_lengths)], dim=-1)
        line = (torch.linalg.pinv(augmented_log_seq_lengths) @ log_snvl_median.unsqueeze(-1))
        snvl_median_fit = torch.exp(augmented_log_seq_lengths @ line).squeeze(-1)

        # Generate the plots
        plt.plot(seq_lengths.cpu(), snvl_median[n_idx].cpu(), color=color, marker='.', markersize=16, label=f'{name}_median')
        # plt.fill_between(
        #     seq_lengths.cpu(),
        #     snvl_valid_quantiles[0, n_idx, :].clamp_min(clip).cpu(),
        #     snvl_valid_quantiles[1, n_idx, :].clamp_min(clip).cpu(),
        #     color=color,
        #     alpha=0.1,
        #     label=f'{name}_validation_quartiles'
        # )
        plt.fill_between(
            seq_lengths.cpu(),
            snvl_train_quantiles[0, n_idx, :].clamp_min(clip).cpu(),
            snvl_train_quantiles[1, n_idx, :].clamp_min(clip).cpu(),
            color=color,
            alpha=0.3,
            label=f'{name}_training_quartiles'
        )
        plt.plot(
            seq_lengths.cpu(),
            snvl_median_fit[n_idx].cpu(),
            color='black',
            linestyle='dashed',
            label=f'$y = {line[n_idx, 1].exp().item()}x^\u007B{line[n_idx, 0].item()}\u007D$'
        )

    color_list = np.array([
        [76, 186, 182],
        [237, 125, 102],
        [127, 113, 240],
        [247, 214, 124]
    ], dtype=float) / 255

    hp_name, _ = configurations[0]
    hp_list = _.get('name', list(_.values())[0])
    for i, hp_value in enumerate(hp_list):
        make_plot_for_learned_kfs(snvl_[i].transpose(0, 1), hp_value, color_list[i])

    plt.xscale('log')
    plt.xlabel('total_trace_length')
    plt.yscale('log')
    plt.ylabel(r'normalized_validation_loss: $\frac{1}{L}|| F_\theta(\tau) - \tau ||^2 - || KF(\tau) - \tau ||^2$')
    plt.title(exp_name)
    plt.legend(fontsize=6)
    plt.show()

    if histogram:
        def make_histogram_plot(snvl: torch.Tensor, name: str, color: np.ndarray):
            snvl_ttl = snvl[n_idx, seq_lengths.tolist().index(2000)].flatten().clamp_min(clip)
            bins = torch.exp(torch.linspace(
                torch.log(torch.min(snvl_ttl)),
                torch.log(torch.max(snvl_ttl)),
                n_bins + 1
            ))
            plt.hist(snvl_ttl, bins, color=color, alpha=0.4, label=name)

        n_bins = 64
        for i, seq_length in enumerate(seq_lengths.tolist()):
            snvl_ttl_difference = (snvl_[1] - snvl_[0]).transpose(0, 1)[n_idx, i].flatten()
            diff_threshold = 0
            print(f'{torch.sum(snvl_ttl_difference > diff_threshold).item()} / {snvl_ttl_difference.numel()} greater than {diff_threshold}')

            bins = torch.linspace(torch.min(snvl_ttl_difference), torch.max(snvl_ttl_difference), n_bins + 1)
            counts, bins, _ = plt.hist(snvl_ttl_difference, bins, color=color_list[-1], alpha=0.5, label=f'{hp_list[1]} - {hp_list[0]}')
            plt.plot([diff_threshold, diff_threshold], [0, max(counts)], color='black', linestyle='dashed')
            plt.title(f'train_sequence_length{seq_length} difference histogram')

            plt.legend()
            plt.show()


def plot_comparison(
        args: Namespace,
        configurations: OrderedDict,
        systems: List[LinearSystem],
        learned_kf_arr: np.ndarray[TensorDict],
        base_exp_name: str,
        output_dir: str,
        log_xscale: bool
):
    plt.rcParams['figure.figsize'] = (8.0, 6.0)
    exp_name = f'{output_dir}/{base_exp_name}'

    outer_hp_name, _ = configurations[0]
    outer_hp_values = _.get('name', list(_.values())[0])

    inner_hp_name, _ = configurations[1]
    inner_hp_values = _.get('name', list(_.values())[0])

    learned_kf_arr = result['learned_kf']
    M = result['metric']
    # snvl_ = (M.l - M.eil).cpu()
    # snvl_median = snvl_.median(-1).values.median(-1).values.permute(-1, *range(snvl_.ndim - 3))[n_idx]

    snvl_ = (M.al - M.il).squeeze(-1).cpu()
    snvl_median = snvl_.median(-1).values.permute(-1, *range(snvl_.ndim - 2))[n_idx]

    c = plt.cm.pink(np.linspace(0, 0.8, len(outer_hp_values)))
    for i, outer_hp_value in enumerate(outer_hp_values):
        plt.plot(inner_hp_values, snvl_median[i], c=c[i], marker='.', markersize=16, label=f'{outer_hp_name}{outer_hp_value}')
        argmin = torch.argmin(snvl_median[i])
        plt.scatter([inner_hp_values[argmin]], [snvl_median[i, argmin]], c=c[i] * 0.5, s=256, marker='*')
    # Use snvl_median[:, 0, i] for multiple RNN initializations

    plt.xlabel(inner_hp_name)
    if log_xscale:
        plt.xscale('log')
    # plt.xticks(hp_values)
    # plt.gca().xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.ylabel(r'normalized_validation_loss: $\frac{1}{L}|| F_\theta(\tau) - \tau ||^2 - || KF(\tau) - \tau ||^2$')
    plt.yscale('log')
    plt.title(exp_name)
    plt.legend(fontsize=6)

    plt.show()

# Adversarial Systems

In [None]:
#@title Adversarial Systems

base_exp_name = 'AdversarialSystemsBasicDebug2'
output_dir = 'system2_CNN'
output_fname = 'result'

system2, args = loader.load_system_and_args('data/2dim_scalar_system_matrices')
systems = [system2]

args.model.model = CnnKFAnalyticalLeastSquares
args.model.ir_length = 8
args.experiment.ensemble_size = 1
args.experiment.metrics = {'analytical_validation'}
args.experiment.exp_name = base_exp_name

configurations = []

result = run_experiments(
    args, configurations, {
        'dir': output_dir,
        'fname': output_fname
    }, systems
)

In [None]:
for k, v in result['system_ptr'].items():
    v.grad = None
observation_IR = result['learned_kf'][()][1][0, 0]['observation_IR']
observation_IR.sum().backward()
for k, v in result['system_ptr'].items():
    print(f"{k}: {v.grad}")