In [8]:
import torch
from collections import OrderedDict
from basicsr.models import create_model, load_finetuned_model
from basicsr.utils.options import parse
import numpy as np
from tqdm import tqdm
from basicsr.data import create_dataloader, create_dataset
from tqdm.notebook import tqdm as log_progress
from torch.utils.data import TensorDataset, DataLoader

In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def fraction_missed_loss(pset,label):
    misses = (pset[0].squeeze() > label[:, 1:2, :, :].squeeze()).float() + (pset[2].squeeze() < label[:,1:2, :, :].squeeze()).float()
    misses[misses > 1.0] = 1.0
    d = len(misses.shape)
    return misses.mean(dim=tuple(range(1,d)))

def get_rcps_loss_fn(string):
    if string == 'fraction_missed':
        return fraction_missed_loss
    else:
        raise NotImplementedError

rcps_loss_fn = get_rcps_loss_fn('fraction_missed')

In [None]:
def inn_nested_sets_from_output(model, output, lam=None):
    output[:,0,:,:] = torch.minimum(output[:,0,:,:], output[:,1,:,:]-1e-6)
    output[:,2,:,:] = torch.maximum(output[:,2,:,:], output[:,1,:,:]+1e-6)
    
    upper_edge = lam * (output[:,2,:,:] - output[:,1,:,:]) + output[:,1,:,:] 
    lower_edge = output[:,1,:,:] - lam * (output[:,1,:,:] - output[:,0,:,:])
    return lower_edge, output[:,1,:,:], upper_edge 

def nested_sets_from_output(output, lam=None):
    lower_edge, prediction, upper_edge = inn_nested_sets_from_output(model, output, lam)
    return lower_edge, prediction, upper_edge 

def get_rcps_losses_from_outputs(model, out_dataset, rcps_loss_fn, lam, device):
    losses = []
    dataloader = DataLoader(out_dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=False) 
    model = model.to(device)
    for batch in dataloader:
        x, labels = batch
        sets = nested_sets_from_output(x,lam) 
        losses = losses + [rcps_loss_fn(sets, labels.to(device)).cpu(),]
    return torch.cat(losses,dim=0)

In [7]:
def get_lhat(calib_loss_table, lambdas, alpha, B=1):
    n = calib_loss_table.shape[0]
    rhat = calib_loss_table.mean(axis=0)
    rhat = torch.flip(rhat, dims = (0,))
    lambdas = torch.flip(lambdas, dims = (0,))
    lhat_idx = max(np.argmax(((n/(n+1)) * rhat  + B/(n+1)) >= alpha) - 1, 0) # Can't be -1.
    return lambdas[lhat_idx]

In [11]:
lambdas = torch.linspace(0,10,500)

labels = #Ground Truth Data
outputs = #Noisy Data
model = #Define the model of choice

out_dataset = TensorDataset(outputs,labels)
dlambda = lambdas[1]-lambdas[0]

calib_loss_table = torch.zeros((outputs.shape[0],lambdas.shape[0]))
for lam in log_progress(lambdas):
    losses = get_rcps_losses_from_outputs(model, out_dataset, rcps_loss_fn, lam-dlambda, device)
    calib_loss_table[:,np.where(lambdas==lam)[0]] = losses[:,None]   

Output dataset
Computing losses


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [12]:
print("lhat = ", get_lhat(calib_loss_table, lambdas, alpha = 0.1).numpy())

lhat =  3.3066132
