# Example of regions on the taxi dataset

In [41]:
%load_ext autoreload
%autoreload 2

In [30]:
import torch
import numpy as np
from tqdm import tqdm

from moc.conformal import *
from moc.models.mqf2.lightning_module import MQF2LightningModule
from moc.models.trainers.lightning_trainer import get_lightning_trainer
from moc.configs.config import get_config
from moc.utils.run_config import RunConfig
from moc.analysis.dataframes import load_datamodule
from moc.metrics.metrics_computer import compute_cum_region_size
from moc.metrics.cache import Cache, EmptyCache

import taxi_example.taxi_utils as utils

In [31]:
config = get_config()
# Uncomment this line for to train the model only on the first batches
config.fast = False
# config.device = 'cpu'
config.default_batch_size = 250

In [32]:
dataset = 'taxi'
rc = RunConfig(config, 'wang', dataset, 0, hparams={})
datamodule = load_datamodule(rc)

In [33]:
scaled_test_X, scaled_test_Y = datamodule.data_test[:]

test_X = datamodule.scaler_x.inverse_transform(scaled_test_X)
test_Y = datamodule.scaler_y.inverse_transform(scaled_test_Y)

In [34]:
p, q = datamodule.input_dim, datamodule.output_dim
mqf2_model = MQF2LightningModule(p, q)
trainer = get_lightning_trainer(rc)
trainer.fit(mqf2_model, datamodule)
mqf2_model.to(config.device)

In [35]:
conformalizers = {
    'M-CP': M_CP,
    'CopulaCPTS': CopulaCPTS,
    'DR-CP': DR_CP,
    'C-HDR': C_HDR,
    'PCP': PCP,
    'HD-PCP': HD_PCP,
    'STDQR': STDQR,
    'C-PCP': C_PCP,
    'L-CP': L_CP,
}

alpha = 0.2
n_shift = 2

x, y = datamodule.data_calib[:]
xlim = x[:, 0].min() * n_shift, x[:, 0].max() * n_shift
ylim = y[:, 0].min() * n_shift, y[:, 0].max() * n_shift
zlim = y[:, 1].min() * n_shift, y[:, 1].max() * n_shift

In [36]:
n_samples = 100
batch_size_test = 1000

cache_calib = Cache(mqf2_model, datamodule.calib_dataloader(), n_samples, add_second_sample=True)
cache_test = Cache(mqf2_model, datamodule.get_dataloader(datamodule.data_test, batch_size=batch_size_test), n_samples, add_second_sample=True)

In [37]:
def get_cache_test_for_sample(idx: int, cache):
    batch = idx // 1000
    pos_in_batch = idx % 1000

    current_cache_test_for_sample = cache[batch].copy()
    current_cache_test_for_sample['samples'] = current_cache_test_for_sample['samples'][:, pos_in_batch:pos_in_batch+1, :]
    current_cache_test_for_sample['log_probs'] = current_cache_test_for_sample['log_probs'][:, pos_in_batch:pos_in_batch+1]
    current_cache_test_for_sample['samples2'] = current_cache_test_for_sample['samples2'][:, pos_in_batch:pos_in_batch+1, :]
    current_cache_test_for_sample['log_probs2'] = current_cache_test_for_sample['log_probs2'][:, pos_in_batch:pos_in_batch+1]

    return current_cache_test_for_sample

In [38]:
# Get uncertainty for X
conformal_method_unc = C_HDR(datamodule.calib_dataloader(), mqf2_model, cache_calib=cache_calib)
log_region_sizes = []

for current_idx in tqdm(range(len(test_X[:100, 0]))):
    current_x = scaled_test_X[current_idx].unsqueeze(0)
    associated_y = scaled_test_Y[current_idx].unsqueeze(0)

    current_cache_test = get_cache_test_for_sample(current_idx, cache_test)

    if conformal_method_unc.is_in_region(current_x, associated_y, alpha=0.2, cache=current_cache_test):
        region_size = compute_cum_region_size(conformal_method_unc, mqf2_model, alpha, current_x, n_samples=100, cache_test=current_cache_test)[-1].item()
        if 0 < region_size < 1E308:
            log_region_sizes.append(region_size)
        else:
            log_region_sizes.append(np.mean(log_region_sizes))
    else:
        log_region_sizes.append(np.mean(log_region_sizes))

low_uncertainty_idx = np.argmin(log_region_sizes)
high_uncertainty_idx = np.argmax(log_region_sizes)
low_uncertainty_idx, high_uncertainty_idx

In [49]:
point_indices = [96, 88]
alpha = 0.2

for idx in point_indices:
    scaled_x_test, scaled_y_test = scaled_test_X[idx].unsqueeze(0).to(config.device), scaled_test_Y[idx].unsqueeze(0).to(config.device)
    x_test, y_test = test_X[idx], test_Y[idx]

    current_cache_test = get_cache_test_for_sample(idx, cache_test)

    visual_group_X, visual_group_Y = test_X[idx:idx+5], test_Y[idx:idx+5]
    utils.visualize_data_on_map(visual_group_X.numpy(), visual_group_Y.numpy(), idx)

    for name, conformalizer in tqdm(conformalizers.items()):
        method = conformalizer(datamodule.calib_dataloader(), mqf2_model, cache_calib=cache_calib)
        scaled_contour = utils.get_contour(scaled_x_test, method, alpha, xlim, zlim, grid_side=3000, cache=current_cache_test)
        region_size = compute_cum_region_size(method, mqf2_model, alpha, scaled_x_test, n_samples=1000, cache_test=current_cache_test)[-1].item()
        
        utils.show_contours_on_map(name, x_test, y_test, scaled_contour, idx, region_size, datamodule)