In [1]:
import os
import pathlib
import sys
import argparse
import pandas as pd
import torch
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import seaborn as sns
import numpy as np

from datasets import  DATASETS, get_dataset
from architectures import get_architecture

import matplotlib.pyplot as plt

from scipy.stats import norm

from utils import ModelManager
from utils import standard_l2_norm
from tqdm import tqdm
from certify_utils import *
import wandb

import cp.transformations as cp_t
import cp.graph_transformations as cp_gt
from cp.graph_cp import GraphCP
import pickle
from easydict import EasyDict
import re

Torch Graph Models are running on cuda
v16


In [2]:
api = wandb.Api()
runs = api.runs("run_name")
all_run = []
for run in runs:
    if run.state != 'finished':
        continue
    config = run.config
    run.name
    row = pd.Series(config)
    row.name = run.name
    row.acc =  run.history().to_dict()['acc'][0]
    row.trained_noirse= float(re.search(r'noise_(\d+\.\d+)', row.checkpoint).group(1))
    all_run.append(row)

In [8]:
configs_df = pd.concat(all_run, axis=1).T.reset_index()

In [None]:
i=0 # 1, 2
args = EasyDict(configs_df.iloc[i])

In [12]:
def load_pkl(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


In [None]:
method = "APS" # TPS

In [None]:
path_to_logits = ""

In [None]:
y_pred, logits, y_true = load_pkl(f'{path_to_logits}/y_pred_logits_y_true.pkl')

In [None]:
acc = ((y_pred == y_true).sum() / y_true.shape[0]).item()
print(f"acc = {acc}")
y_true_mask = F.one_hot(y_true).bool()

#APS
if method == "APS":
    cp = GraphCP(transformation_sequence=[cp_t.APSTransformation(softmax=True)], coverage_guarantee=0.9)
    sc_scores = torch.stack([cp.get_scores_from_logits(logits[:, i, :]) for i in range(logits.shape[1])]).permute(1, 2, 0) + 1
elif method == "TPS":
    cp = GraphCP(transformation_sequence=[cp_t.TPSTransformation(softmax=True)], coverage_guarantee=0.9)
    sc_scores = torch.stack([cp.get_scores_from_logits(logits[:, i, :]) for i in range(logits.shape[1])]).permute(1, 2, 0) 

esc_scores = sc_scores.mean(axis=2)


In [17]:
result = []
coverages = np.array([0.7,0.8, 0.85, 0.9, 0.95]) 
radi_range = np.array([0, 0.25, 0.5, 0.75, 1, 1.5]) * args.smoothing_sigma
print(f"radi_range = {radi_range}")

for radi in tqdm(radi_range):
    np_upper = np_upperbound_tensor(sc_scores, SIGMA=args.smoothing_sigma, radi=radi, n_classes=num_classes, alpha=args.alpha)
    dkw_upper = dkw_upperbound_tensor(sc_scores, SIGMA=args.smoothing_sigma, radi=radi, n_classes=num_classes, alpha=args.alpha)

    rscp_upper = esc_scores + radi / args.smoothing_sigma

    for args.coverage_guarantee in coverages:
        cp.coverage_guarantee = args.coverage_guarantee
        for iter_i in range(args.n_iters):

            cal_mask = get_cal_mask(esc_scores, fraction=args.fraction)
            eval_mask = ~cal_mask

            cp.calibrate_from_scores(esc_scores[cal_mask], y_true_mask[cal_mask])
            np_pred_set = cp.predict_from_scores(np_upper[eval_mask])
            dkw_pred_set = cp.predict_from_scores(dkw_upper[eval_mask])

            result.append({
                "SIGMA": args.smoothing_sigma,
                "radi": radi,
                "iter": iter_i,
                "method": "NP",
                "$1-\\alpha$": args.coverage_guarantee,
                "coverage": cp.coverage(np_pred_set, y_true_mask[eval_mask]),
                "set_size": cp.average_set_size(np_pred_set),
                "singleton_hits": singleton_hit(np_pred_set, y_true_mask[eval_mask]),
            })

            result.append({
                "SIGMA": args.smoothing_sigma,
                "radi": radi,
                "iter": iter_i,
                "method": "DKW",
                "$1-\\alpha$": args.coverage_guarantee,
                "coverage": cp.coverage(dkw_pred_set, y_true_mask[eval_mask]),
                "set_size": cp.average_set_size(dkw_pred_set),
                "singleton_hits": singleton_hit(dkw_pred_set, y_true_mask[eval_mask]),
            })
result = pd.DataFrame(result)
result.to_csv(f'/result.csv', index=False)

radi_range = [0.    0.125 0.25  0.375 0.5   0.75 ]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [02:53<00:00, 28.89s/it]
