# Online Adaptive Forecasting

In [1]:
%load_ext autoreload
%autoreload 2

import re
import os
import json
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import trajectron.visualization as visualization
import trajdata.visualization.vis as trajdata_vis

from torch import optim, nn
from torch.utils import data
from trajdata.data_structures.data_index import AgentDataIndex
from tqdm.notebook import tqdm
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.model.model_utils import UpdateMode
from trajectron.model.trajectron import Trajectron
from collections import defaultdict
from pathlib import Path
from typing import Dict, Final, List, Optional
from trajdata import UnifiedDataset, AgentType, AgentBatch

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [2]:
# Change this to suit your computing environment and folder structure!

TRAJDATA_CACHE_DIR: Final[str] = "/home/bivanovic/.unified_data_cache"
LYFT_SAMPLE_RAW_DATA_DIR: Final[str] = "/home/bivanovic/datasets/lyft/scenes/sample.zarr"

In [3]:
base_model = "models/nusc_mm_base_tpp-11_Sep_2022_19_15_45"
k0_model = "models/nusc_mm_k0_tpp-12_Sep_2022_00_40_16"
adaptive_model = "models/nusc_mm_sec4_tpp-13_Sep_2022_11_06_01"
oracle_model = "models/lyft_mm_base_tpp-11_Sep_2022_18_56_49"

base_checkpoint = 20
k0_checkpoint = 20
adaptive_checkpoint = 20
oracle_checkpoint = 1

eval_data = "lyft_sample-mini_val"

history_sec = 2.0
prediction_sec = 6.0

In [4]:
AXHLINE_COLORS = {
    "Base": "#DD9787",
    "K0": "#A6C48A",
    "Oracle": "#BCB6FF"
}

SEABORN_PALETTE = {
    "Finetune": "#AA7C85",
    "K0+Finetune": "#2D93AD",
    "Ours": "#67934D",
    "Ours+Finetune": "#FF8E61",
    "Base": "#DD9787",
    "K0": "#A6C48A",
    "Oracle": "#BCB6FF"
}

In [5]:
def load_model(model_dir: str, device: str, epoch: int = 10, custom_hyperparams: Optional[Dict] = None):
    save_path = Path(model_dir) / f'model_registrar-{epoch}.pt'

    model_registrar = ModelRegistrar(model_dir, device)
    with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
        hyperparams = json.load(config_json)
        
    if custom_hyperparams is not None:
        hyperparams.update(custom_hyperparams)

    trajectron = Trajectron(model_registrar, hyperparams, None, device)
    trajectron.set_environment()
    trajectron.set_annealing_params()

    checkpoint = torch.load(save_path, map_location=device)
    trajectron.load_state_dict(checkpoint["model_state_dict"], strict=False)

    return trajectron, hyperparams

In [6]:
if torch.cuda.is_available():
    device = 'cuda:0'
    torch.cuda.set_device(0)
else:
    device = 'cpu'

In [7]:
def restrict_to_predchal(
    dataset: UnifiedDataset,
    split: str
) -> None:
    with open(f'predchal_{split}_index.pkl','rb') as f:
        within_challenge_split = pickle.load(f)
    
    within_challenge_split = [
        (dataset.cache_path / scene_info_path, num_elems, elems)
        for scene_info_path, num_elems, elems in within_challenge_split
    ]
    
    dataset._scene_index = [orig_path for orig_path, _, _ in within_challenge_split]

    # The data index is effectively a big list of tuples taking the form:
    # (scene_path: str, index_len: int, valid_timesteps: np.ndarray[, agent_name: str])
    dataset._data_index = AgentDataIndex(within_challenge_split, dataset.verbose)
    dataset._data_len: int = len(dataset._data_index)

In [8]:
def finetune_update(model: Trajectron, batch: AgentBatch = None, dataloader: data.DataLoader = None, num_epochs: int = None, update_mode: UpdateMode = UpdateMode.NO_UPDATE) -> float:
    if batch is None and dataloader is None:
        raise ValueError("Only one of batch or dataloader can be passed in.")
    
    if dataloader is not None and num_epochs is None:
        raise ValueError("num_epochs must not be None if dataloader is not None.")
    
    lr_scheduler = None
    optimizer = optim.Adam([{'params': model.model_registrar.get_all_but_name_match('map_encoder').parameters()},
                            {'params': model.model_registrar.get_name_match('map_encoder').parameters(),
                             'lr': model.hyperparams['map_enc_learning_rate']/10}],
                           lr=model.hyperparams['learning_rate']/10)
    # Set Learning Rate
    if model.hyperparams['learning_rate_style'] == 'const':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif model.hyperparams['learning_rate_style'] == 'exp':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=model.hyperparams['learning_decay_rate'])
    
    if batch is not None:
        model.step_annealers()
        optimizer.zero_grad(set_to_none=True)

        train_loss = model(batch, update_mode=update_mode)
        train_loss.backward()

        # Clipping gradients.
        if model.hyperparams['grad_clip'] is not None:
            nn.utils.clip_grad_value_(model.model_registrar.parameters(), model.hyperparams['grad_clip'])
        
        optimizer.step()
        
        # Stepping forward the learning rate scheduler and annealers.
        lr_scheduler.step()
    
    elif dataloader is not None:
        batch: AgentBatch
        for batch_idx, batch in enumerate(dataloader):
            model.step_annealers()
            
            optimizer.zero_grad(set_to_none=True)

            train_loss = model(batch)
                        
            train_loss.backward()

            # Clipping gradients.
            if model.hyperparams['grad_clip'] is not None:
                nn.utils.clip_grad_value_(model.model_registrar.parameters(), model.hyperparams['grad_clip'])
            
            optimizer.step()
            
            # Stepping forward the learning rate scheduler and annealers.
            lr_scheduler.step()

In [9]:
def finetune_last_layer_update(model: Trajectron, batch: AgentBatch = None, dataloader: data.DataLoader = None, num_epochs: int = None) -> float:
    if batch is None and dataloader is None:
        raise ValueError("Only one of batch or dataloader can be passed in.")
    
    if dataloader is not None and num_epochs is None:
        raise ValueError("num_epochs must not be None if dataloader is not None.")
    
    lr_scheduler = None
    optimizer = optim.Adam([{'params': model.model_registrar.get_all_but_name_match('last_layer').parameters()},
                            {'params': model.model_registrar.get_name_match('last_layer').parameters(),
                             'lr': model.hyperparams['learning_rate']/10}],
                           lr=0)
    # Set Learning Rate
    if model.hyperparams['learning_rate_style'] == 'const':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
    elif model.hyperparams['learning_rate_style'] == 'exp':
        lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                        gamma=model.hyperparams['learning_decay_rate'])
    
    if batch is not None:
        model.step_annealers()
        optimizer.zero_grad(set_to_none=True)

        train_loss = model(batch)
        train_loss.backward()

        # Clipping gradients.
        if model.hyperparams['grad_clip'] is not None:
            nn.utils.clip_grad_value_(model.model_registrar.parameters(), model.hyperparams['grad_clip'])
        
        optimizer.step()
        
        # Stepping forward the learning rate scheduler and annealers.
        lr_scheduler.step()
            
    elif dataloader is not None:
        batch: AgentBatch
        for batch_idx, batch in enumerate(dataloader):
            model.step_annealers()
            
            optimizer.zero_grad(set_to_none=True)

            train_loss = model(batch)
                        
            train_loss.backward()

            # Clipping gradients.
            if model.hyperparams['grad_clip'] is not None:
                nn.utils.clip_grad_value_(model.model_registrar.parameters(), model.hyperparams['grad_clip'])
            
            optimizer.step()
            
            # Stepping forward the learning rate scheduler and annealers.
            lr_scheduler.step()

In [10]:
adaptive_trajectron, hyperparams = load_model(
    adaptive_model, device, epoch=adaptive_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": True}
)
# For offline test
adaptive_finetune_trajectron, _ = load_model(
    adaptive_model, device, epoch=adaptive_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": True}
)


k0_trajectron, _ = load_model(k0_model, device, epoch=k0_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})
k0_finetune_trajectron, _ = load_model(k0_model, device, epoch=k0_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})

base_trajectron, _ = load_model(base_model, device, epoch=base_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})
finetune_trajectron, _ = load_model(base_model, device, epoch=base_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})

oracle_trajectron, _ = load_model(oracle_model, device, epoch=oracle_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})

# Load training and evaluation environments and scenes
attention_radius = defaultdict(lambda: 20.0) # Default range is 20m unless otherwise specified.
attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0
attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0
attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0
attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0

map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)}

online_eval_dataset = UnifiedDataset(
    desired_data=[eval_data],
    history_sec=(0.1, history_sec),
    future_sec=(prediction_sec, prediction_sec),
    agent_interaction_distances=attention_radius,
    incl_robot_future=hyperparams['incl_robot_node'],
    incl_raster_map=hyperparams['map_encoding'],
    raster_map_params=map_params,
    only_predict=[AgentType.VEHICLE],
    no_types=[AgentType.UNKNOWN],
    num_workers=0,
    cache_location=TRAJDATA_CACHE_DIR,
    data_dirs={
        "lyft_sample": LYFT_SAMPLE_RAW_DATA_DIR,
    },
    verbose=True
)

batch_eval_dataset = UnifiedDataset(
    desired_data=[eval_data],
    history_sec=(history_sec, history_sec),
    future_sec=(prediction_sec, prediction_sec),
    agent_interaction_distances=attention_radius,
    incl_robot_future=hyperparams['incl_robot_node'],
    incl_raster_map=hyperparams['map_encoding'],
    raster_map_params=map_params,
    only_predict=[AgentType.VEHICLE],
    no_types=[AgentType.UNKNOWN],
    num_workers=0,
    cache_location=TRAJDATA_CACHE_DIR,
    data_dirs={
        "lyft_sample": "",
    },
    verbose=True
)

Loading data for matched scene tags: ['mini_val-palo_alto-lyft_sample']


Calculating Agent Data (Serially): 100%|██████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 49578.06it/s]


20 scenes in the scene index.


Creating Agent Data Index (Serially): 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 149.97it/s]
Structuring Agent Data Index: 100%|███████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 17870.92it/s]

Loading data for matched scene tags: ['mini_val-palo_alto-lyft_sample']



Calculating Agent Data (Serially): 100%|██████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 49402.87it/s]


20 scenes in the scene index.


Creating Agent Data Index (Serially): 100%|█████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 526.36it/s]
Structuring Agent Data Index: 100%|███████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 24008.61it/s]


In [11]:
prog = re.compile("(.*)/(?P<scene_name>.*)/(.*)$")

def plot_outputs(
    eval_dataset: UnifiedDataset,
    dataset_idx: int,
    model: Trajectron,
    model_name: str,
    agent_ts: int,
    save=True,
    extra_str=None,
    subfolder="",
    filetype="png"
):
    batch: AgentBatch = eval_dataset.get_collate_fn(pad_format="right")([eval_dataset[dataset_idx]])
    with torch.no_grad():
        # predictions = model.predict(batch,
        #                             z_mode=True,
        #                             gmm_mode=True,
        #                             full_dist=False,
        #                             output_dists=False)
        # prediction = next(iter(predictions.values()))
        
        pred_dists, _ = model.predict(batch,
                                      z_mode=False,
                                      gmm_mode=False,
                                      full_dist=True,
                                      output_dists=True)
        # pred_dist = next(iter(pred_dists.values()))
        
    batch.to("cpu")
    
    fig, ax = plt.subplots()
    trajdata_vis.plot_agent_batch(batch, batch_idx=0, ax=ax, show=False, close=False)
    visualization.visualize_distribution(ax, pred_dists, batch_idx=0)
    
    # batch_eval: Dict[str, torch.Tensor] = evaluation.compute_batch_statistics_pt(
    #     batch.agent_fut[..., :2],
    #     prediction_output_dict=torch.from_numpy(prediction),
    #     y_dists=pred_dist
    # )
    
    scene_info_path, _, scene_ts = eval_dataset._data_index[dataset_idx]
    scene_name = prog.match(scene_info_path).group("scene_name")
    
    agent_name = batch.agent_name[0]
    agent_type_name = f"{str(AgentType(batch.agent_type[0].item()))}/{agent_name}"
    
    ax.set_title(f"{scene_name}/t={scene_ts} {agent_type_name}")
    # print(model_name, extra_str, batch_eval)
    
    if save:
        fname = f"plots/{subfolder}{model_name}_{scene_name}_{agent_name}_t{agent_ts}"
        if extra_str:
            fname += "_" + extra_str
        fig.savefig(fname + f".{filetype}")
        
        plt.close(fig)

In [12]:
def get_dataloader(
    eval_dataset: UnifiedDataset,
    batch_size: int = 128,
    num_workers: int = 0,
    shuffle: bool = False
):
    return data.DataLoader(
        eval_dataset,
        collate_fn=eval_dataset.get_collate_fn(pad_format="right"),
        pin_memory=False if device == 'cpu' else True,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers
    )

In [13]:
metrics_list = ["ml_ade", "ml_fde", "nll_mean", "min_ade_5", "min_ade_10"]

### Online (Per-Agent) Evaluation

In [14]:
def add_results_to_summary(
    model_name: str,
    overall_summary_dict: Dict[str, List[float]],
    model_perf_dict: Dict[AgentType, Dict[str, np.ndarray]]
):
    overall_summary_dict["model"].append(model_name)
    for metric in metrics_list:
        overall_summary_dict[metric].append(np.mean(np.concatenate(model_perf_dict[AgentType.VEHICLE][metric])).item())

In [15]:
eval_dict = defaultdict(list)

with torch.no_grad():
    # Base Model
    batch_eval_dataloader = get_dataloader(batch_eval_dataset, num_workers=16)
    model_perf = defaultdict(lambda: defaultdict(list))
    for eval_batch in tqdm(batch_eval_dataloader, desc=f'Eval Base'):
        eval_results: Dict[AgentType, Dict[str, torch.Tensor]] = base_trajectron.predict_and_evaluate_batch(eval_batch)
        for agent_type, metric_dict in eval_results.items():
            for metric, values in metric_dict.items():
                model_perf[agent_type][metric].append(values.cpu().numpy())
    
    add_results_to_summary("Base", eval_dict, model_perf)
    
    # K0 Model
    batch_eval_dataloader = get_dataloader(batch_eval_dataset, batch_size=64, num_workers=8)
    model_perf = defaultdict(lambda: defaultdict(list))
    for eval_batch in tqdm(batch_eval_dataloader, desc=f'Eval K0'):
        eval_results: Dict[AgentType, Dict[str, torch.Tensor]] = k0_trajectron.predict_and_evaluate_batch(eval_batch)
        for agent_type, metric_dict in eval_results.items():
            for metric, values in metric_dict.items():
                model_perf[agent_type][metric].append(values.cpu().numpy())
                
    add_results_to_summary("K0", eval_dict, model_perf)
    
    # Oracle Model
    batch_eval_dataloader = get_dataloader(batch_eval_dataset, num_workers=8)
    model_perf = defaultdict(lambda: defaultdict(list))
    for eval_batch in tqdm(batch_eval_dataloader, desc=f'Eval Oracle'):
        eval_results: Dict[AgentType, Dict[str, torch.Tensor]] = oracle_trajectron.predict_and_evaluate_batch(eval_batch)
        for agent_type, metric_dict in eval_results.items():
            for metric, values in metric_dict.items():
                model_perf[agent_type][metric].append(values.cpu().numpy())
                
    add_results_to_summary("Oracle", eval_dict, model_perf)
    
    del eval_batch
    del eval_results
    del model_perf

Eval Base:   0%|          | 0/218 [00:00<?, ?it/s]

Eval K0:   0%|          | 0/436 [00:00<?, ?it/s]

Eval Oracle:   0%|          | 0/218 [00:00<?, ?it/s]

In [16]:
eval_dict

defaultdict(list,
            {'model': ['Base', 'K0', 'Oracle'],
             'ml_ade': [85.19214630126953,
              48.29320526123047,
              2.1609904766082764],
             'ml_fde': [179.68934631347656,
              89.61846160888672,
              5.328505516052246],
             'nll_mean': [7.990538120269775,
              10.895771026611328,
              0.7372894287109375],
             'min_ade_5': [48.203311920166016,
              35.55550003051758,
              1.2764157056808472],
             'min_ade_10': [38.24354553222656,
              25.114696502685547,
              1.0219292640686035]})

In [18]:
with open(f"results/base_performance.pkl", 'wb') as f:
    pickle.dump(eval_dict, f)

In [19]:
with open("results/base_performance.pkl", 'rb') as f:
    eval_dict = pickle.load(f)

In [20]:
eval_dict

defaultdict(list,
            {'model': ['Base', 'K0', 'Oracle'],
             'ml_ade': [85.19214630126953,
              48.29320526123047,
              2.1609904766082764],
             'ml_fde': [179.68934631347656,
              89.61846160888672,
              5.328505516052246],
             'nll_mean': [7.990538120269775,
              10.895771026611328,
              0.7372894287109375],
             'min_ade_5': [48.203311920166016,
              35.55550003051758,
              1.2764157056808472],
             'min_ade_10': [38.24354553222656,
              25.114696502685547,
              1.0219292640686035]})

## Calibration

In [14]:
def compute_calibration_values(model: Trajectron, dataloader: data.DataLoader):
    est_probs = list()
    
    for batch in tqdm(dataloader, leave=False):
        batch.to(model.device)
        
        node_type: AgentType
        for node_type in batch.agent_types():
            mgcvae = model.node_models_dict[node_type.name]

            agent_type_batch = batch.for_agent_type(node_type)
            ph = agent_type_batch.agent_fut.shape[1]

            # Run forward pass
            y_dists, _ = mgcvae.predict(agent_type_batch,
                                        prediction_horizon=ph,
                                        num_samples=1,
                                        z_mode=False,
                                        gmm_mode=False,
                                        full_dist=True,
                                        output_dists=True)
            
            pred_log_pdf = torch.clamp(y_dists.log_prob(batch.agent_fut[..., :2]), min=-20.)
            estimated_pdf = torch.exp(pred_log_pdf)
            
            est_probs.extend(estimated_pdf.flatten().tolist())
            
    return np.asarray(est_probs)

In [15]:
calib_values = list()

base_calib = compute_calibration_values(base_trajectron, get_dataloader(batch_eval_dataset, num_workers=16))

batch_dataloader = get_dataloader(batch_eval_dataset, num_workers=8)
k0_calib = compute_calibration_values(k0_trajectron, batch_dataloader)
oracle_calib = compute_calibration_values(oracle_trajectron, batch_dataloader)

calib_values.append((0, base_calib, k0_calib, oracle_calib))

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

In [16]:
def calibration_plot(est_gt_probs: np.ndarray, method_name: str, ax):
    thresholds = np.linspace(0, 1, 101)
    
    num_gt_points = est_gt_probs.shape[0]
    # Commenting this out, although it could be used to scale to the max PDF height.
    # No need though, without it we can see overconfidence easier.
    # max_pdf = est_gt_probs.max()
    
    fraction_gt_included = [(est_gt_probs >= threshold).sum() / num_gt_points for threshold in thresholds]
    
    ax.plot(thresholds, fraction_gt_included, label=method_name, lw=3, c=SEABORN_PALETTE[method_name])

In [17]:
def make_calibration_plots(base_calib, k0_calib, oracle_calib, ours_calib, ours_finetune_calib, finetune_calib, k0_finetune_calib, num_updates):
    fig, ax = plt.subplots()

    ax.plot([0, 1], [1, 0], label="Ideal", ls="--", c="k")

    calibration_plot(base_calib, "Base", ax)
    calibration_plot(k0_calib, "K0", ax)
    calibration_plot(oracle_calib, "Oracle", ax)
    calibration_plot(ours_finetune_calib, "Ours+Finetune", ax)
    calibration_plot(ours_calib, "Ours", ax)
    calibration_plot(finetune_calib, "Finetune", ax)
    calibration_plot(k0_finetune_calib, "K0+Finetune", ax)

    ax.legend(bbox_to_anchor=(1.04, 0.5), loc="center left", borderaxespad=0)
    ax.set_xlabel("Probability Threshold")
    ax.set_ylabel("Fraction of GT Points")

    ax.grid(False)
    
    fig.savefig(f"results/calib_after{num_updates}.pdf", bbox_inches="tight")

In [18]:
online_eval_dataloader = get_dataloader(online_eval_dataset, batch_size=1, shuffle=True)
batch_eval_dataloader = get_dataloader(batch_eval_dataset, num_workers=8)

adaptive_trajectron.reset_adaptive_info()
adaptive_finetune_trajectron.reset_adaptive_info()

# Resetting the finetune baselines to their base.
finetune_trajectron, _ = load_model(base_model, device, epoch=base_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})
k0_finetune_trajectron, _ = load_model(k0_model, device, epoch=k0_checkpoint,
    custom_hyperparams={"trajdata_cache_dir": TRAJDATA_CACHE_DIR,
                        "single_mode_multi_sample": False})

N_SAMPLES = 1001

outer_pbar = tqdm(
    online_eval_dataloader,
    total=min(N_SAMPLES, len(online_eval_dataloader)),
    desc=f'Adaptive Eval PH={prediction_sec}',
    position=0,
)

plot_per_eval = False

online_batch: AgentBatch
for data_sample, online_batch in enumerate(outer_pbar):
    if data_sample >= N_SAMPLES:
        outer_pbar.close()
        break
    
    if data_sample == 0:
        ours_calib = compute_calibration_values(adaptive_trajectron, batch_eval_dataloader)
        finetune_calib = compute_calibration_values(finetune_trajectron, batch_eval_dataloader)
        k0_finetune_calib = compute_calibration_values(k0_finetune_trajectron, batch_eval_dataloader)
        
        calib_values.append((data_sample, ours_calib, ours_calib, finetune_calib, k0_finetune_calib))
        
        # make_calibration_plots(base_calib, k0_calib, oracle_calib, ours_calib, finetune_calib, k0_finetune_calib, data_sample)
    
    with torch.no_grad():
        # This is the inference call that internally updates L_n and K_n.
        adaptive_trajectron.adaptive_predict(
            online_batch,
            update_mode=UpdateMode.ITERATIVE
        )
    
    if data_sample < 101:
        with torch.no_grad():
            adaptive_finetune_trajectron.adaptive_predict(
                online_batch,
                update_mode=UpdateMode.ITERATIVE
            )
    else:
        finetune_update(adaptive_finetune_trajectron, online_batch, update_mode=UpdateMode.NO_UPDATE)

    finetune_update(finetune_trajectron, online_batch)
    finetune_last_layer_update(k0_finetune_trajectron, online_batch)
    
    if (data_sample + 1) % 100 == 0:    
        ours_calib = compute_calibration_values(adaptive_trajectron, batch_eval_dataloader)
        ours_finetune_calib = compute_calibration_values(adaptive_finetune_trajectron, batch_eval_dataloader)
        finetune_calib = compute_calibration_values(finetune_trajectron, batch_eval_dataloader)
        k0_finetune_calib = compute_calibration_values(k0_finetune_trajectron, batch_eval_dataloader)
        
        calib_values.append((data_sample, ours_calib, ours_finetune_calib, finetune_calib, k0_finetune_calib))
        
        # make_calibration_plots(base_calib, k0_calib, oracle_calib, ours_calib, ours_finetune_calib, finetune_calib, k0_finetune_calib, data_sample + 1)

Adaptive Eval PH=6.0:   0%|          | 0/1001 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

  0%|          | 0/218 [00:00<?, ?it/s]

In [19]:
with open(f"results/model_calibrations.pkl", 'wb') as f:
    pickle.dump(calib_values, f)