In [17]:
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
#
"""
File for running all computer vision experiments.
"""
import argparse
from collections import defaultdict
import math
import os
from re import sub

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, LBFGS, SGD
import tqdm

# Our new algorithm
from online_conformal.magnitude_learner import MagnitudeLearner, MagnitudeLearnerV2
from online_conformal.mag_learner_undiscounted import MagLearnUndiscounted
from online_conformal.ogd_simple import SimpleOGD

# From previous work
from online_conformal.saocp import SAOCP
from online_conformal.faci import FACI, FACI_S
from online_conformal.nex_conformal import NExConformal
from online_conformal.ogd import ScaleFreeOGD
from online_conformal.split_conformal import SplitConformal
from online_conformal.utils import pinball_loss
from cv_utils import create_model, data_loader
from cv_utils import ImageNet, TinyImageNet, CIFAR10, CIFAR100, ImageNetC, TinyImageNetC, CIFAR10C, CIFAR100C

import time

from vision import *

In [18]:
__file__ = "vision.ipynb"

In [19]:

# Train the model, save its logits on all the corrupted test datasets, and do temperature scaling
print("Train the model, save its logits on all the corrupted test datasets, and do temperature scaling")
args = parse_args()
# if args.dataset != "ImageNet":
if not finished(args) and args.dataset != "ImageNet":
    print("Training models")
    train(args)
    print("Finished training")
if args.local_rank in [-1, 0]:
    print("Getting temp file...")
    temp_file = get_temp_file(args)
    print("Done")
    if not finished(args):
        print("get_logits(args)")
        get_logits(args)
        print("...Done")
        temp = temperature_scaling(args)
        with open(temp_file, "w") as f:
            f.write(str(temp))

    # Load the saved logits
    print("Load the saved logits")
    with open(temp_file) as f:
        temp = float(f.readline())
    n_data = None
    sev2results = defaultdict(list)
    for corruption in corruptions:
        severities = [0] if corruption is None else [1, 2, 3, 4, 5]
        for severity in severities:
            try:
                logits, labels = torch.load(get_results_file(args, corruption, severity))
            except:
                continue
            sev2results[severity].append(list(zip(F.softmax(logits / temp, dim=-1).numpy(), labels.numpy())))
            n_data = len(labels) if n_data is None else min(n_data, len(labels))

    # Initialize conformal prediction methods, along with accumulators for results
    print("Initialize conformal prediction methods, along with accumulators for results")
    lmbda, k_reg, n_class = raps_params(args.dataset)
    D_old = 1 + lmbda * np.sqrt(n_class - k_reg)
    D = 1*D_old
    methods_iterate = [SimpleOGD, MagnitudeLearnerV2, SplitConformal, NExConformal, FACI, ScaleFreeOGD, FACI_S, SAOCP, MagnitudeLearner, MagLearnUndiscounted]
    #methods = [SplitConformal]
    label2err = defaultdict(list)
    h = 5 + 0.5 * (len(methods_iterate) > 5)
    np.random.seed(0)
    num_trials = 10
    elasped_time = np.zeros(num_trials,)
    time_simpleogd = np.zeros_like(elasped_time)
    mean_time_norm = np.zeros(len(methods_iterate))
    std_time_norm = np.zeros(len(methods_iterate))
    for method_ in methods_iterate:
        methods = [method_]
        i_method = methods_iterate.index(method_)
        print("Current method: ", method_.__name__)
        for jj in range(num_trials):
            for i_shift, shift in enumerate(["sudden"]):
                sevs, s_opts, w_opts = [], [], []
                # warmup, window, run_length = 1000, 100, 500 # original code
                warmup, window, run_length = 1000, 100, 1000 # our code
                state = np.random.RandomState(0)
                order = state.permutation(n_data)[: 6 * run_length + window // 2 + warmup]
                coverages, s_hats, widths = [{m.__name__: [] for m in methods} for _ in range(3)]
                predictors = [m(None, None, max_scale = D, lifetime = 32, coverage = args.target_cov) for m in methods]
                t_vec = np.zeros(len(order))
                start_time = time.time()
                for t, i in tqdm.tqdm(enumerate(order, start=-warmup), total=len(order)):
                    # Get saved results for the desired severity
                    sev = t_to_sev(t, window=window, schedule = shift)
                    probs, label = sev2results[sev][state.randint(0, len(sev2results[sev]))][i]

                    # Convert probability to APS score
                    i_sort = np.flip(np.argsort(probs))
                    p_sort_cumsum = np.cumsum(probs[i_sort]) - state.rand() * probs[i_sort]
                    s_sort_cumsum = p_sort_cumsum + lmbda * np.sqrt(np.cumsum([i > k_reg for i in range(n_class)]))
                    w_opt = np.argsort(i_sort)[label] + 1
                    s_opt = s_sort_cumsum[w_opt - 1]
                    if t >= 0:
                        sevs.append(sev)
                        s_opts.append(s_opt)
                        w_opts.append(w_opt)
                        t_vec[t] = t

                    # Update all the conformal predictors
                    for predictor in predictors:
                        name = type(predictor).__name__
                        if t >= 0:
                            _, s_hat = predictor.predict(horizon=1)
                            w = np.sum(s_sort_cumsum <= s_hat)
                            s_hats[name].append(s_hat)
                            widths[name].append(w)
                            coverages[name].append(w >= w_opt)
                        predictor.update(ground_truth=pd.Series([s_opt]), forecast=pd.Series([0]), horizon=1)
                #print("Total timesteps:", t)
                elasped_time[jj] = time.time() - start_time
                if method_.__name__ == "SimpleOGD":
                    time_simpleogd[jj] = elasped_time[jj]
                # Perform evaluation & produce a pretty graph
                plot_loss = False

                s_opts = np.asarray(s_opts)
                int_q = pd.Series(s_opts).rolling(window).quantile(args.target_cov).dropna()
                #print(f"Distribution shift: {shift}")
                if jj == num_trials-1:
                    for i, m in enumerate(methods):
                    # Compute various summary statistics
                        name = m.__name__ # name of the method (OGD, SAOCP, etc.)
                        label = sub("Split", "S", sub("Conformal", "CP", sub("ScaleFree", "SF-", sub("_", "-", name))))
                        s_hat = np.asarray(s_hats[name])
                        int_cov = gaussian_filter1d(pd.Series(coverages[name]).rolling(window).mean().dropna(), sigma=3)
                        int_w = pd.Series(s_hats[name] if plot_loss else widths[name]).rolling(window).mean().dropna()
                        int_losses = pd.Series(pinball_loss(s_opts, s_hat, args.target_cov)).rolling(window).mean().dropna()
                        opts = [pinball_loss(s_opts[i : i + window], q, args.target_cov).mean() for i, q in enumerate(int_q)]
                        int_regret = int_losses.values - np.asarray(opts)
                        int_miscov = np.abs(args.target_cov - int_cov)

                        # Do the plotting
                        label2err[label].append(f"{np.max(int_miscov):.2f}")
                        mean_time_simpleogd = np.mean(time_simpleogd)
                        mean_time_norm[i_method] = (np.mean(elasped_time))/mean_time_simpleogd
                        std_time_norm[i_method] = (np.std(elasped_time))/mean_time_simpleogd
                        print(
                            f"{name:15}: "
                            f"Cov: {np.mean(coverages[name]):.3f}, "
                            f"Avg Width: {np.mean(widths[name]):.1f}, "
                            f"Avg Miscov: {np.mean(int_miscov):.3f}, "
                            f"Avg Regret: {np.mean(int_regret):.4f}, "
                            f"Avg. Runtime (n=10): {mean_time_norm[i_method]:.6f}, "
                            f"Std. Runtime (n=10): {std_time_norm[i_method]:.6f}, "
                            f"LCEk: {np.max(int_miscov):.2f},"
                        )

Train the model, save its logits on all the corrupted test datasets, and do temperature scaling
Getting temp file...
Done
Load the saved logits


Initialize conformal prediction methods, along with accumulators for results
Current method:  SimpleOGD


100%|██████████| 7050/7050 [00:01<00:00, 4403.13it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4064.49it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4088.26it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4238.50it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4399.64it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4456.39it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4249.87it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4020.52it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4121.32it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4305.31it/s]


SimpleOGD      : Cov: 0.899, Avg Width: 125.0, Avg Miscov: 0.018, Avg Regret: 0.0013, Avg. Runtime (n=10): 1.000000, Std. Runtime (n=10): 0.034854, LCEk: 0.11,
Current method:  MagnitudeLearnerV2


100%|██████████| 7050/7050 [00:01<00:00, 4292.18it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4323.20it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3964.38it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3820.18it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4229.92it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4262.42it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4124.75it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4241.30it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4260.10it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4135.90it/s]


MagnitudeLearnerV2: Cov: 0.888, Avg Width: 122.1, Avg Miscov: 0.018, Avg Regret: 0.0034, Avg. Runtime (n=10): 1.016792, Std. Runtime (n=10): 0.038785, LCEk: 0.07,
Current method:  SplitConformal


100%|██████████| 7050/7050 [00:02<00:00, 2493.43it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2688.52it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2701.89it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2721.31it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2745.08it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2756.88it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2725.19it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2733.57it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2749.21it/s]
100%|██████████| 7050/7050 [00:02<00:00, 2699.94it/s]


SplitConformal : Cov: 0.843, Avg Width: 129.5, Avg Miscov: 0.114, Avg Regret: 0.0059, Avg. Runtime (n=10): 1.566431, Std. Runtime (n=10): 0.044899, LCEk: 0.47,
Current method:  NExConformal


100%|██████████| 7050/7050 [00:03<00:00, 1884.50it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1934.71it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1923.87it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1895.78it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1909.47it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1918.12it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1892.42it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1899.85it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1891.06it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1935.38it/s]


NExConformal   : Cov: 0.892, Avg Width: 123.0, Avg Miscov: 0.026, Avg Regret: 0.0012, Avg. Runtime (n=10): 2.215250, Std. Runtime (n=10): 0.020513, LCEk: 0.14,
Current method:  FACI


100%|██████████| 7050/7050 [00:04<00:00, 1451.46it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1423.03it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1442.24it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1413.70it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1418.65it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1448.80it/s]
100%|██████████| 7050/7050 [00:05<00:00, 1401.68it/s]
100%|██████████| 7050/7050 [00:04<00:00, 1442.92it/s]
100%|██████████| 7050/7050 [00:05<00:00, 1339.83it/s]
100%|██████████| 7050/7050 [00:05<00:00, 1404.66it/s]


FACI           : Cov: 0.889, Avg Width: 123.4, Avg Miscov: 0.025, Avg Regret: 0.0011, Avg. Runtime (n=10): 2.980709, Std. Runtime (n=10): 0.068055, LCEk: 0.12,
Current method:  ScaleFreeOGD
[]


100%|██████████| 7050/7050 [00:01<00:00, 4034.13it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 3888.80it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 4096.34it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 3918.86it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 3918.83it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 4237.53it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 4153.99it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 3869.07it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 4201.63it/s]


[]


100%|██████████| 7050/7050 [00:01<00:00, 4101.97it/s]


ScaleFreeOGD   : Cov: 0.899, Avg Width: 124.5, Avg Miscov: 0.017, Avg Regret: 0.0013, Avg. Runtime (n=10): 1.047247, Std. Runtime (n=10): 0.033382, LCEk: 0.09,
Current method:  FACI_S


100%|██████████| 7050/7050 [00:03<00:00, 1836.68it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1929.84it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1998.93it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1881.86it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1906.59it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1992.38it/s]
100%|██████████| 7050/7050 [00:03<00:00, 2006.47it/s]
100%|██████████| 7050/7050 [00:03<00:00, 1933.57it/s]
100%|██████████| 7050/7050 [00:03<00:00, 2040.66it/s]
100%|██████████| 7050/7050 [00:03<00:00, 2010.04it/s]


FACI_S         : Cov: 0.892, Avg Width: 124.7, Avg Miscov: 0.024, Avg Regret: 0.0011, Avg. Runtime (n=10): 2.165840, Std. Runtime (n=10): 0.070238, LCEk: 0.11,
Current method:  SAOCP


100%|██████████| 7050/7050 [00:18<00:00, 390.24it/s]
100%|██████████| 7050/7050 [00:18<00:00, 385.66it/s]
100%|██████████| 7050/7050 [00:18<00:00, 376.46it/s]
100%|██████████| 7050/7050 [00:18<00:00, 371.38it/s]
100%|██████████| 7050/7050 [00:18<00:00, 391.08it/s]
100%|██████████| 7050/7050 [00:18<00:00, 383.89it/s]
100%|██████████| 7050/7050 [00:18<00:00, 382.91it/s]
100%|██████████| 7050/7050 [00:18<00:00, 377.64it/s]
100%|██████████| 7050/7050 [00:18<00:00, 374.02it/s]
100%|██████████| 7050/7050 [00:18<00:00, 386.35it/s]


SAOCP          : Cov: 0.883, Avg Width: 119.2, Avg Miscov: 0.025, Avg Regret: 0.0016, Avg. Runtime (n=10): 11.066553, Std. Runtime (n=10): 0.186862, LCEk: 0.10,
Current method:  MagnitudeLearner


100%|██████████| 7050/7050 [00:01<00:00, 4075.33it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4010.26it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4082.09it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4136.77it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3983.52it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3839.13it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3949.12it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3773.17it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3818.39it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4138.46it/s]


MagnitudeLearner: Cov: 0.884, Avg Width: 118.5, Avg Miscov: 0.023, Avg Regret: 0.0021, Avg. Runtime (n=10): 1.063503, Std. Runtime (n=10): 0.034109, LCEk: 0.08,
Current method:  MagLearnUndiscounted


100%|██████████| 7050/7050 [00:01<00:00, 4241.31it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4003.13it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4044.93it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4053.69it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3970.65it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4006.08it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4034.47it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3915.85it/s]
100%|██████████| 7050/7050 [00:01<00:00, 4138.86it/s]
100%|██████████| 7050/7050 [00:01<00:00, 3956.03it/s]

MagLearnUndiscounted: Cov: 0.894, Avg Width: 122.1, Avg Miscov: 0.020, Avg Regret: 0.0014, Avg. Runtime (n=10): 1.048302, Std. Runtime (n=10): 0.022795, LCEk: 0.09,





In [31]:
# Save the results
np.savez_compressed('runtime_results.npz',methods_iterate = methods_iterate, mean_time_norm = mean_time_norm, std_time_norm = std_time_norm, label2err = label2err)

In [33]:


for i in range(len(methods_iterate)):
    print(f"{methods_iterate[i].__name__}: {np.round(mean_time_norm[i],2)} +/- {np.round(std_time_norm[i],2)}")

SimpleOGD: 1.0 +/- 0.03
MagnitudeLearnerV2: 1.02 +/- 0.04
SplitConformal: 1.57 +/- 0.04
NExConformal: 2.22 +/- 0.02
FACI: 2.98 +/- 0.07
ScaleFreeOGD: 1.05 +/- 0.03
FACI_S: 2.17 +/- 0.07
SAOCP: 11.07 +/- 0.19
MagnitudeLearner: 1.06 +/- 0.03
MagLearnUndiscounted: 1.05 +/- 0.02
