In [None]:
from disentangle.metrics.calibration import Calibration
# from disentangle.metrics.calibration import get_calibrated_factor_for_stdev
from disentangle.analysis.paper_plots import get_first_index, get_last_index

def get_calibration_stats(calibration_factors, pred, pred_std, tar_normalized):
    scalar = calibration_factors['scalar']
    offset = calibration_factors['offset']
    pred_std = pred_std * scalar + offset
    calib = Calibration(num_bins=30)
    stats = calib.compute_stats(pred, pred_std, tar_normalized)
    return stats

# def get_calibration_factor(pred, pred_std, tar_normalized, epochs = 300, lr = 160.0, eps= 1e-8):
#     calib_dicts = []
#     for col_idx in range(pred.shape[-1]):

#         calib_dict = get_calibrated_factor_for_stdev(pred[...,col_idx], pred_std[...,col_idx], tar_normalized[...,col_idx], 
#                                                           lr=lr, epochs=epochs)
#         calib_dicts.append(calib_dict)
    
#     return calib_dicts


def plot_calib(ax, calib_stats, ch_idx, color='b', q_s = 0.00001,q_e = 0.99999
):
    tmp_stats = calib_stats['calib_stats'][ch_idx]
    tmp_rmv = tmp_stats['rmv']
    tmp_rmse = tmp_stats['rmse']
    count = tmp_stats['bin_count']
    rmse_err = tmp_stats['rmse_err']

    first_idx = get_first_index(count, q_s)
    last_idx = get_last_index(count, q_e)
    # plot the calibration curve with error bars     
    ax.plot(tmp_rmv[first_idx:-last_idx],
            tmp_rmse[first_idx:-last_idx],
            '-+',
            label='C{}'.format(ch_idx),
            color=color,
            )
    rmse_floor = np.array(tmp_rmse[first_idx:-last_idx]) - np.array(rmse_err[first_idx:-last_idx])
    rmse_ceil = np.array(tmp_rmse[first_idx:-last_idx]) + np.array(rmse_err[first_idx:-last_idx])
    ax.fill_between(tmp_rmv[first_idx:-last_idx], rmse_floor, rmse_ceil, alpha=0.3, label='error band')

    # enable the grid and set the background color to gray 
    ax.grid(True)
    ax.set_facecolor('0.75')
    # get x and y limits of the plot
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    minv = min(xlim[0], ylim[0])
    maxv = max(xlim[1], ylim[1])
    ax.plot([minv, maxv], [minv, maxv], '--', color='black')
    print('xlim:', xlim)
    print('ylim:', ylim)     


def l2_fitting(pred, tar, std):
    rmse = torch.sqrt(torch.nn.MSELoss(reduction='none')(pred, tar))
    loss = torch.nn.MSELoss(reduction='mean')(rmse, std)
    return loss

In [None]:
import numpy as np 
pred = np.load('/home/ashesh.ashesh/code/Disentangle/pred_MMSE50.npy')
pred_std = np.load('/home/ashesh.ashesh/code/Disentangle/pred_std_MMSE50.npy')
tar_normalized = np.load('/home/ashesh.ashesh/code/Disentangle/tar_normalized_MMSE50.npy')

In [None]:
pred.shape

In [None]:
calib_factor_dict = get_calibrated_factor_for_stdev(pred, pred_std, tar_normalized)

In [None]:
import matplotlib.pyplot as plt
ch_idx = 2
stats_dict = get_calibration_stats(calib_factor_dict[ch_idx], pred[...,ch_idx:ch_idx+1], pred_std[...,ch_idx:ch_idx+1], tar_normalized[...,ch_idx:ch_idx+1])
_,ax = plt.subplots()
plot_calib(ax, {'calib_stats':stats_dict}, 0, color='b', q_s = 0.00001,q_e =0.99999)

In [None]:
from scipy import stats
q_s = 0.00001
q_e = 0.99999
y = stats_dict[0]['rmse']
x = stats_dict[0]['rmv']
count = stats_dict[0]['bin_count']
# rmse_err = tmp_stats['rmse_err']

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
print('slope:', slope)
print('intercept:', intercept)

In [None]:
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from disentangle.metrics.calibration import nll

def get_batch_mask(bin_masks, batch_size):
    """
    We get a random batch of indices from the bin_masks.
    """
    b_per_bin = batch_size // len(bin_masks)
    indices = []
    for mask in bin_masks:
        indices.append(np.random.choice(np.where(mask)[0], size=b_per_bin, replace=True))
    
    indices = np.concatenate(indices)
    mask = np.zeros_like(bin_masks[0])
    mask[indices] = 1
    return mask

def get_binned_masks(target, q_low=0.000001, q_high=0.999999):
    vlow, vmax = np.quantile(target, [q_low, q_high])
    bincount =  50
    bins = np.linspace(vlow, vmax, bincount)
    bin_masks = []
    for i in range(bincount-1):
        mask = np.logical_and(target >= bins[i], target < bins[i+1])
        if mask.sum() == 0:
            continue
        bin_masks.append(mask)
    return bin_masks


def get_calibrated_factor_for_stdev(pred, pred_std, target, batch_size=32*(512**2), epochs=500, lr=0.01, q_low=0.000001, q_high=0.999999):
    """
    Here, we calibrate with multiplying the predicted std (computed from logvar) with a scalar.
    We return the calibrated scalar. This needs to be multiplied with the std.
    Why is the input logvar and not std? because the model typically predicts logvar and not std.
    """
    import torch
    from tqdm import tqdm
    
    pred = pred.reshape(-1)
    pred_std = pred_std.reshape( -1)
    target = target.reshape(-1)

    # vlow, vmax = np.quantile(target, [q_low, q_high])
    # bincount =  50
    # bins = np.linspace(vlow, vmax, bincount)
    # bin_masks = []
    # for i in range(bincount-1):
    #     mask = np.logical_and(target >= bins[i], target < bins[i+1])
    #     if mask.sum() == 0:
    #         continue
    #     bin_masks.append(mask)
    bin_masks = get_binned_masks(target, q_low, q_high)
    bincount = len(bin_masks)
    # mask = np.logical_and(target > vlow, target < vmax)
    
    # create a learnable scalar
    std_scalar = torch.nn.Parameter(torch.tensor(2.0))
    std_offset = torch.nn.Parameter(torch.tensor(0.0))
    optimizer = torch.optim.Adam([std_scalar,std_offset], lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=50)
    loss_arr = []
    # tqdm with text description as loss
    bar = tqdm(range(epochs))
    for _ in bar:
        optimizer.zero_grad()
        mask = get_batch_mask(bin_masks, batch_size)
        pred_batch = torch.Tensor(pred[mask]).cuda()
        pred_std_batch = torch.Tensor(pred_std[mask]).cuda() * std_scalar + std_offset
        pred_logvar_batch = 2 * torch.log(pred_std_batch)
        target_batch = torch.Tensor(target[mask]).cuda()

        # loss = torch.mean(nll(target_batch, pred_batch, pred_logvar_batch))
        loss = l2_fitting(pred_batch, target_batch, pred_std_batch)
        loss.backward()
        loss_arr.append(loss.item())
        optimizer.step()
        scheduler.step(loss)
        # if learning rate is below 1e-5, break
        if optimizer.param_groups[0]['lr'] < 1e-4:
            break
        bar.set_description(f'nll: {np.mean(loss_arr[-10:])} scalar: {std_scalar.item()} offset: {std_offset.item()}')

    output = {'scalar':std_scalar.item(),
                'offset':std_offset.item(), 
              'loss': loss_arr, 
              'vlow': vlow, 
              'vmax': vmax
              }
    
    return output

In [None]:
masks = get_binned_masks(tar_normalized[...,0], q_low=0.000001, q_high=0.999999)

In [None]:
from tqdm import tqdm
for i in tqdm(range(len(masks))):
    for j in range(i+1, len(masks)):
        assert np.logical_and(masks[i], masks[j]).sum() ==0


In [None]:
plt.plot([tar_normalized[mask].mean() for mask in masks])

In [None]:
output = get_calibrated_factor_for_stdev(pred[...,0], pred_std[...,0], tar_normalized[...,0], lr=0.1, epochs=200)

In [None]:
def balanced_nll(tar, pred, predlogvar, bins=50):
    """
    """
    vmin, vmax = torch.min(tar), torch.max(tar)
    bins = torch.linspace(vmin, vmax, bins)
    nll_vals = []
    for i in range(bins.shape[0]-1):
        mask = torch.logical_and(tar >= bins[i], tar < bins[i+1])
        if mask.sum() == 0:
            continue
        nll_vals.append(nll(tar[mask], pred[mask], predlogvar[mask]).mean().item())
    return np.mean(nll_vals)


def balanced_l2fitting(tar, pred, pred_std, bins=50):
    """
    """
    vmin, vmax = torch.min(tar), torch.max(tar)
    bins = torch.linspace(vmin, vmax, bins)
    loss = []
    for i in range(bins.shape[0]-1):
        mask = torch.logical_and(tar >= bins[i], tar < bins[i+1])
        if mask.sum() == 0:
            continue
        loss.append(l2_fitting(tar[mask], pred[mask], pred_std[mask]).mean().item())
    return np.mean(loss)




In [None]:
factor = 11
offset = 0.8
# pred_logvar = 2 * torch.log(torch.Tensor(pred_std[...,0]*factor + offset))
# loss = balanced_nll(torch.Tensor(tar_normalized[...,0]), torch.Tensor(pred[...,0]), pred_logvar)
loss = balanced_l2fitting(torch.Tensor(tar_normalized[...,0]), torch.Tensor(pred[...,0]), torch.Tensor(pred_std[...,0]*factor + offset))
loss

In [None]:
output.keys()

In [None]:
import matplotlib.pyplot as plt
stats = get_calibration_stats(output, pred[...,ch_idx:ch_idx+1], pred_std[...,ch_idx:ch_idx+1], tar_normalized[...,ch_idx:ch_idx+1])
_,ax = plt.subplots()
plot_calib(ax, {'calib_stats':stats}, 0, color='b')