In [None]:
import mlflow
import argparse
import random
import numpy as np
import torch
from tqdm import tqdm
import os


In [None]:
from torch.distributions import Beta
import torch.nn.functional as F

class ModelSelector:
    def __init__(self):
        pass

    def get_next_item_to_label(self):
        """
        Return:
            (index, selection probability)
        """
        pass

    def add_label(self, chosen_idx, true_class, selection_prob):
        pass

    def get_best_model_prediction(self):
        pass

class Ensemble:
    def __init__(self, preds, **kwargs):
        self.preds = preds
        self.device = preds.device
        H, N, C = preds.shape

    def get_preds(self, **kwargs):
        return self.preds.mean(dim=0)
    

In [None]:
def accuracy_loss(preds, labels, **kwargs):
    """Get 1 - accuracy (a loss), nonreduced. Handles whether we are working with scores or integer labels."""
    if len(labels.shape) > 1:
        argmaxed_preds = torch.argmax(preds, dim=-1)
        argmaxed_labels = torch.argmax(labels, dim=-1)
        accs = (argmaxed_preds == argmaxed_labels).float()
    else:
        argmaxed = torch.argmax(preds, dim=-1)
        accs = (argmaxed == labels).float()

    # make it a loss
    return 1 - accs

LOSS_FNS = {
    # 'ce': cross_entropy, # TODO this won't work out of the box; we don't have logits
    'acc': accuracy_loss
}

In [None]:
class Dataset:
    """
    A model selection dataset is a tensor of shape (H,N,C) containing post-softmax prediction scores, 
    where H is the number of models, N is the number of datapoints, and C is the number of classes.

    Optionally, it can also contain an (N,) shaped matrix (assumed to be a file appended with '_labels.pt')
    of ground-truth class labels.
    """
    def __init__(self, filepath, device):
        self.device = device
        self.preds = torch.load(filepath, map_location=device).float() # avoid fp16 precision errors
        print("Loaded preds of shape", self.preds.shape)

        self.labels = None
        label_p = filepath.replace('.pt', '_labels.pt')
        if os.path.exists(label_p):
            self.labels = torch.load(label_p, map_location=device)
            print("Loaded labels of shape", self.labels.shape)
        else:
            print("Did not load labels.")

In [None]:
class Oracle:
    def __init__(self, dataset, loss_fn=None):
        self.dataset = dataset
        self.loss_fn = loss_fn
        self.device = dataset.device
        self.labels = dataset.labels
        assert self.labels is not None, "Oracle needs labels!"

    def true_losses(self, preds):
        """
        Compute the mean loss for each model.
        
        Args:
        - preds: Tensor of shape (H, N, C) representing post-softmax scores from each model for each data point.
        
        Returns:
        - Tensor of shape (H,) representing the mean loss for each model.
        """
        H, N, C = preds.shape
        return self.loss_fn(preds.reshape(-1, C), self.labels.repeat(H), 
                            reduction='none').view(H, N).mean(dim=1)

    def __call__(self, idx):
        return self.labels[idx].item()

In [None]:
def dirichlet_to_beta(alpha_dirichlet: torch.Tensor):
    """
    Get parameters for beta distributions representing the diagonal.
    Args:
        alpha_dirichlet: shape (..., H, C, C)
    Returns:
        alpha_cc, beta_cc: shape (..., H, C)
    """
    C = alpha_dirichlet.shape[-1]
    alpha_cc = alpha_dirichlet[..., torch.arange(C), torch.arange(C)]
    beta_cc  = alpha_dirichlet.sum(dim=-1) - alpha_cc
    return alpha_cc, beta_cc


def create_confusion_matrices(true_labels: torch.Tensor,
                              model_predictions: torch.Tensor,
                              mode='hard') -> torch.Tensor:
    H, N, C = model_predictions.shape
    dev = model_predictions.device
    true_one_hot = F.one_hot(true_labels, C).float().to(dev)

    if mode == 'hard':
        preds = F.one_hot(model_predictions.argmax(-1), C).float()
    elif mode == 'soft':
        preds = model_predictions
    else:
        raise ValueError(mode)

    conf = torch.einsum('nc, hnj -> hcj', true_one_hot, preds)
    return conf / conf.sum(-1, keepdim=True).clamp_min(1e-6)


def initialize_dirichlets(soft_confusion: torch.Tensor,
                          prior_strength: float,
                          disable_diag_prior=False) -> torch.Tensor:
    H, C, _ = soft_confusion.shape

    if disable_diag_prior:
        # uniform - 2 pseudo counts per row to match diag method
        base = torch.full((C, C), 2 / C,
                          dtype=soft_confusion.dtype,
                          device=soft_confusion.device)
    else:
        base = torch.full((C, C), 1.0 / (C - 1),
                            dtype=soft_confusion.dtype,
                            device=soft_confusion.device)
        base.fill_diagonal_(1.0)

    base = base.unsqueeze(0).expand(H, C, C)
    return base + prior_strength * soft_confusion


def batch_update_dirichlet_for_item(dirichlet_alphas: torch.Tensor,
                                    classifier_preds: torch.Tensor,
                                    update_weight: float = 1.0) -> torch.Tensor:
    N, H, C = classifier_preds.shape
    updated = dirichlet_alphas[None, None].expand(N, C, H, C, C).clone()
    updates = classifier_preds[:, None].expand(-1, C, -1, -1) * update_weight
    for c in range(C):
        updated[:, c, :, c, :] += updates[:, c, :, :]
    return updated


def compute_pbest_beta_batched(alpha_batch: torch.Tensor,  # (B_, C_, C, H)
                                beta_batch:  torch.Tensor, # (B_, C_, C, H)
                                num_points: int = 256,
                                eps: float = 1e-30,
                                chunk_size: int = None) -> torch.Tensor:
    device = alpha_batch.device
    N = alpha_batch.shape[0]
    C, H = alpha_batch.shape[-2:]
    chunk_size = chunk_size or N
    x = torch.linspace(1e-6, 1 - 1e-6, num_points, device=device).unsqueeze(-1) # P×1

    prob_out = torch.zeros_like(alpha_batch)
    for start in range(0, N, chunk_size):
        end = min(start + chunk_size, N)
        a_flat = alpha_batch[start:end].reshape(-1, H)          # B_*C_*C × H
        b_flat = beta_batch[start:end].reshape(-1, H)
        
        logpdf = Beta(a_flat.reshape(-1), b_flat.reshape(-1)).log_prob(x)
        pdf = logpdf.exp().T.reshape(-1, H, num_points)         # B_*C_*C × H × P

        cdf = torch.zeros_like(pdf)
        for j in range(1, num_points):
            dx = x[j] - x[j-1]
            cdf[:, :, j] = cdf[:, :, j-1] + 0.5*(pdf[:, :, j] + pdf[:, :, j-1])*dx

        log_cdf = torch.log(cdf.clamp_min(eps))
        # clamp to min/max float32 +-(log(3.4 * 1e38) = ~88) to avoid inf; 
        # rare that this happens (only observed with uniform prior)
        prod_excl = torch.exp( (log_cdf.sum(1, keepdim=True) - log_cdf).clamp(-80,80) ) 
        integrand = pdf * prod_excl

        prob = torch.trapz(integrand, x.squeeze(), dim=2)
        prob = prob / prob.sum(-1, keepdim=True).clamp_min(eps)
        prob_out[start:end] = prob.reshape(alpha_batch[start:end].shape)

    return prob_out # (B_, C_, C, H)


def pbest_row_mixture_batched(updated_dirichlet: torch.Tensor,
                                pi_hat: torch.Tensor,
                                num_points: int = 256) -> torch.Tensor:
    """
    Args:
        updated_dirichlet: (B_, C_, H, C, C)
        pi: (C,)
    where:
        B_ and C_ are additional dimensions that can be used for hypothetical item and hypothetical class updates, respectively
            (i.e. all operations are broadcast over B_ and C_)
        pi is the marginal class distribution P(class=C) over the entire dataset
    
    Returns:
        prob_best: (B_, C_, H),  P(h is best | C_, B_)
    """
    C = updated_dirichlet.shape[-1]

    # P(h is best | row c)
    alpha_cc, beta_cc = dirichlet_to_beta(updated_dirichlet)
    prob_best_b_c_ch = compute_pbest_beta_batched(alpha_cc.transpose(-1, -2), beta_cc.transpose(-1, -2), num_points=num_points) # (B_,C_,C,H)

    # convert conditional to marginal probabilities using pi_hat
    # expected P(best | item b) = Σ_c expected P(best | item b, class=c) * P(class=c)
    marginal_probs = (prob_best_b_c_ch * pi_hat.view(1, C, 1)).sum(-2)  # (B_, C_, H)

    return marginal_probs


def batch_update_beta(selector, # selector.dirichlets: (H,C,C)
                      preds,    # (B, H)
                      update_weight=1.0
                      ): 
    B, H = preds.shape
    C = selector.dirichlets.shape[-1]
    alpha_cc_before, beta_cc_before = dirichlet_to_beta(selector.dirichlets) # (H, C)

    pred_classes = preds.unsqueeze(1).expand(B,C,H)
    class_range  = torch.arange(C, device=alpha_cc_before.device).unsqueeze(1).expand(B,C,H)
    eq_mask = (pred_classes == class_range) # B,C,H
    eq_mask = eq_mask.permute(0,2,1) # B,H,C

    alpha_batch = alpha_cc_before.expand(B, H, C).clone()
    beta_batch = beta_cc_before.expand(B, H, C).clone()
    alpha_batch[eq_mask] += 1.0 * update_weight
    beta_batch[~eq_mask]  += 1.0 * update_weight

    return alpha_batch, beta_batch # (B, H, C), (B, H, C)
       

In [None]:
class CODA(ModelSelector):
    def __init__(self, 
                 dataset,
                 prefilter_n=0,
                 alpha=0.9,
                 learning_rate=0.01,
                 multiplier=2.0,
                 disable_diag_prior=False,  # for ablation 1
                 q='eig',                   # for ablation 2
                 ):
        self.dataset = dataset
        self.device = dataset.preds.device
        self.H, self.N, self.C = dataset.preds.shape
        self.prefilter_n = prefilter_n
        self.disable_diag_prior = disable_diag_prior
        self.q = q

        # hyperparams
        self.prior_strength = (1 - alpha)
        self.update_strength = learning_rate

        # initialize dirichlets
        ens_pred = Ensemble(dataset.preds).get_preds()
        ens_pred_hard = ens_pred.argmax(-1)  # pseudo labels
        soft_conf = create_confusion_matrices(ens_pred_hard, dataset.preds, mode='soft')
        self.dirichlets = multiplier * initialize_dirichlets(soft_conf, self.prior_strength, self.disable_diag_prior)
        self.update_pi_hat()

        self.labeled_idxs, self.labels = [], []
        self.unlabeled_idxs = list(range(self.N))
        self.q_vals = []
        self.stochastic = False
        self.step = 0

    @classmethod
    def from_args(cls, dataset, args):
        return cls(dataset,
                   prefilter_n=args.prefilter_n,
                   alpha=args.alpha,
                   learning_rate=args.learning_rate,
                   multiplier=args.multiplier,
                   disable_diag_prior=args.no_diag_prior,
                   q=args.q)

    def _prefilter(self, idxs):
        # filter any data points where every model disagrees - waste of compute
        maj, _ = torch.mode(self.dataset.preds.argmax(-1), dim=0)
        mask = (self.dataset.preds.argmax(-1) != maj).sum(0) > 0
        idxs = [i for i in idxs if mask[i]]
        # can also randomly subsample (disabled by default)
        if self.prefilter_n and len(idxs) > self.prefilter_n:
            idxs = random.sample(idxs, self.prefilter_n)
            self.stochastic = True
        return idxs

    def update_pi_hat(self):
        adjusted = torch.einsum('hcs, hns -> hnc', self.dirichlets, self.dataset.preds)
        # per item
        self.pi_hat_xi = adjusted.sum(0)
        self.pi_hat_xi = self.pi_hat_xi / self.pi_hat_xi.sum(dim=-1, keepdim=True).clamp_(min=1e-12)
        # marginal (entire dataset)
        self.pi_hat = self.pi_hat_xi.sum(0)
        self.pi_hat = self.pi_hat / self.pi_hat.sum()

    def eig_batched(self, chunk_size: int = 100, update_weight: float = 1.0, num_points: int = 256):
        """
            TODO: Document shapes etc.
        """
        candidate_ids = self._prefilter(self.unlabeled_idxs) or self.unlabeled_idxs
        classifier_preds = self.dataset.preds.permute(1, 0, 2)
        candidates = torch.tensor(candidate_ids, device=classifier_preds.device)
        N, H, C = classifier_preds.shape

        # compute current pbest per row
        dirichlets_before = self.dirichlets.unsqueeze(0).unsqueeze(0).expand(1, 1, H, C, C)
        
        # get diagonal betas
        alpha_cc_before, beta_cc_before = dirichlet_to_beta(dirichlets_before) # (1, 1, H, C)
        alpha_cc_before = alpha_cc_before.permute(0,3,1,2)  # (1, C, 1, H)
        beta_cc_before  = beta_cc_before.permute(0,3,1,2)   # (1, C, 1, H)
        pbest_rows_before = compute_pbest_beta_batched(alpha_cc_before, beta_cc_before).squeeze(-2) # (1, C, H)

        mixture0 = (self.pi_hat[:, None] * pbest_rows_before).sum(1)   # (1,H)
        H_before = -(mixture0.clamp_min(1e-12).mul(mixture0.clamp_min(1e-12).log2())).sum(-1)

        # broadcast helpers
        mixture0_bc = mixture0.view(1, 1, H)      # (1,1,H)
        pi_hat_row  = self.pi_hat.view(1, C, 1)   # (1,C,1)

        eig_chunks = []
        for s in tqdm(range(0, len(candidates), chunk_size)):
            ids   = candidates[s:s + chunk_size] # (B,)
            preds = classifier_preds[ids].argmax(-1) # (B, H)
            pi_hat_xi = self.pi_hat_xi[ids]

            # do all hypothetical updates at once
            alpha_hypothetical, beta_hypothetical = batch_update_beta(self, preds, update_weight) # (B,H,C_)
            alpha_hypothetical = alpha_hypothetical.permute(0,2,1).unsqueeze(-2)  # (B, C_, 1, H)
            beta_hypothetical = beta_hypothetical.permute(0,2,1).unsqueeze(-2)    # (B, C_, 1, H)

            pbest_hypothetical_rows = compute_pbest_beta_batched(alpha_hypothetical, 
                                                                    beta_hypothetical, 
                                                                    num_points=num_points).squeeze(-2) # (B, C_, H)
            deltas = pi_hat_row * (pbest_hypothetical_rows - pbest_rows_before) # (B,C,H)
            mix_new = mixture0_bc + deltas # (B,C,H)
            H_after = -(mix_new.clamp_min(1e-12).mul(mix_new.clamp_min(1e-12).log2())).sum(-1) # (B,C)
            
            eig = H_before - (pi_hat_xi * H_after).sum(-1) # (B,)
            eig_chunks.append(eig)

        return torch.cat(eig_chunks), candidate_ids

    def get_next_item_to_label(self):
        if self.q == 'eig':
            # default; expected information gain
            q_vals, cand = self.eig_batched()
        else:
            raise NotImplementedError(self.q)

        # greedy sampling with random selection between ties
        best = q_vals.max()
        ties = torch.isclose(q_vals, best, rtol=1e-8)
        idx_local = random.choice(torch.nonzero(ties, as_tuple=True)[0].tolist()) \
                    if ties.sum() > 1 else torch.argmax(q_vals).item()
        if ties.sum() > 1:
            self.stochastic = True

        return cand[idx_local], q_vals[idx_local].item()

    def add_label(self, idx, true_class, selection_prob):
        preds = F.one_hot(self.dataset.preds[:, idx].argmax(-1), self.C).float()
        self.dirichlets[:, true_class] += self.update_strength * preds

        self.update_pi_hat()
        self.labeled_idxs.append(idx)
        self.labels.append(int(true_class))
        self.q_vals.append(selection_prob)
        self.unlabeled_idxs.remove(idx)

    def get_pbest(self):
        H, C, _ = self.dirichlets.shape
        expanded = self.dirichlets.unsqueeze(0).unsqueeze(0).expand(1, 1, H, C, C)
        pbest = pbest_row_mixture_batched(expanded, self.pi_hat).squeeze(0) # (H,)

        return pbest

    def get_best_model_prediction(self):
        pbest = self.get_pbest()
    
        # track how many times we've done this
        self.step += 1 

        return torch.argmax(pbest)

In [None]:
USE_DB = False
if USE_DB:
    mlflow.set_tracking_uri('sqlite:///coda.sqlite')

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def parse_args(optiions):
    parser = argparse.ArgumentParser()
    # dataset settings
    parser.add_argument("--task", help="{ 'sketch_painting', ... }", default=None)
    parser.add_argument("--data-dir", default='data')

    # benchmarking settings
    parser.add_argument("--iters", type=int, default=100)
    parser.add_argument("--seeds", type=int, default=5) # how many seeds to use - one experiment per seed
    parser.add_argument("--force-rerun", action="store_true", help="Overwrite existing runs.")
    parser.add_argument("--experiment-name", default=None) # overrides default of using task as experiment name
    parser.add_argument("--no-mlflow", action="store_true", help="Disable MLflow logging.")

    # general method settings
    parser.add_argument("--loss", help="{ 'ce', 'acc', ... }", default="acc",)
    parser.add_argument("--method", help="{ 'iid', 'beta', 'activetesting', 'vma' }", default='iid')
    
    # CODA settings
    parser.add_argument("--alpha", default=0.9, type=float)      # TODO: change to 1-alpha
    parser.add_argument("--learning-rate", default=0.01, type=float)
    parser.add_argument("--multiplier", default=2.0, type=float) # TODO: change to temperature
    parser.add_argument("--prefilter-n", type=int, default=0, help="Subsample n test data points each iteration. Useful for speeding up EIG calculations on large datsets. Disabled by default.")
    parser.add_argument("--no-diag-prior", action="store_true", help="Disable diagonal prior (Eq 7); used for ablation 1.")
    parser.add_argument("--q", default="eig", help="Acquisition function {eig, iid, uncertainty}. Default EIG (eq 17). Used for ablation 2.")

    return parser.parse_args(optiions)

def do_model_selection_experiment(dataset, oracle, args, loss_fn, seed=0):
    seed_all(seed)
    true_losses = oracle.true_losses(dataset.preds)
    best_loss = min(oracle.true_losses(dataset.preds))
    print("Best possible loss is", best_loss)

    # initialize method
    if args.method == 'iid':
        selector = IID(dataset, loss_fn)
    elif args.method == 'uncertainty':
        selector = Uncertainty(dataset, loss_fn)
    elif args.method.startswith('coda'):
        selector = CODA.from_args(dataset, args)
    elif args.method == 'activetesting':
        selector = ActiveTesting(dataset, loss_fn)
    elif args.method == 'vma':
        selector = VMA(dataset, loss_fn)
    elif args.method == 'model_picker':
        from coda.baselines.modelpicker import TASK_EPS
        if args.task in TASK_EPS.keys():
            selector = ModelPicker(dataset, epsilon=TASK_EPS[args.task])
        else:
            print(args.task, "not in TASK_EPS; using default")
            selector = ModelPicker(dataset)
    else:
        raise ValueError(args.method + " is not a supported method.")

    # Get prior regret
    best_model_idx_pred = selector.get_best_model_prediction()
    regret_loss = true_losses[best_model_idx_pred] - best_loss
    print("Regret at 0:", regret_loss)

    ## Active model selection loop
    cumulative_regret_loss = 0
    for m in tqdm(range(args.iters)):
        # select item, label, select model
        chosen_idx, selection_prob = selector.get_next_item_to_label()
        true_class = oracle(chosen_idx)
        selector.add_label(chosen_idx, true_class, selection_prob)
        best_model_idx_pred = selector.get_best_model_prediction()

        # compute and log metrics
        regret_loss = true_losses[best_model_idx_pred] - best_loss
        cumulative_regret_loss += regret_loss
        print("Regret at", m+1, ":", regret_loss)
        print("Cuml Regret at", m+1, ":", cumulative_regret_loss)
        if not args.no_mlflow:
            mlflow.log_metric("regret", regret_loss.item(), step=m+1)
            mlflow.log_metric("cumulative regret", cumulative_regret_loss.item(), step=m+1)

    return selector.stochastic

def main(options):
    args = parse_args(options)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device is", device)

    # Load prediction results of all hypotheses
    dataset = Dataset(os.path.join(args.data_dir, args.task + ".pt"), device=device)

    # Create oracle
    loss_fn = LOSS_FNS[args.loss]
    oracle = Oracle(dataset, loss_fn=loss_fn)
    
    ## Model selection loop
    if args.no_mlflow:
        # simple run loop without MLflow logging
        for seed in range(args.seeds):
            print("Running active model selection with seed", seed)
            print("DEBUG ARGS", args.__dict__)
            seed_stochastic = do_model_selection_experiment(dataset, oracle, args, loss_fn, seed=seed)

            if not seed_stochastic:
                print("Method is not stochastic for this task. Skipping further seeds.")
                break
    else:
        # create mlflow 'experiment' (= dataset/task)
        experiment_name = args.experiment_name or args.task
        mlflow.set_experiment(experiment_name)

        def get_mlflow_run_id(run_name):
            run_id = None
            matching_runs = mlflow.search_runs(experiment_names=[experiment_name], filter_string=f"tags.mlflow.runName = '{run_name}'", max_results=1)
            finished = False
            stochastic = None
            if len(matching_runs):
                run_id = matching_runs.run_id.values[0]
                finished = matching_runs.status.values[0] == 'FINISHED'
                stochastic = 'params.stochastic' in matching_runs.columns and matching_runs['params.stochastic'].values[0] == 'True'
            return run_id, finished, stochastic

        # create mlflow 'run' (= algorithm)
        run_name = "-".join([experiment_name, args.method])
        run_id, _, _ = get_mlflow_run_id(run_name)
        with mlflow.start_run(run_id=run_id, run_name=run_name):
            mlflow.log_params(args.__dict__)
            for seed in range(args.seeds):
                # create nested ml flow 'run' (= seed)
                seed_run_name = "-".join([experiment_name, args.method, str(seed)])
                seed_run_id, seed_finished, seed_stochastic = get_mlflow_run_id(seed_run_name)
                if seed_finished and not args.force_rerun:
                    print("Seed", seed, "finished. Skipping.")
                else:
                    with mlflow.start_run(nested=True, run_id=seed_run_id, run_name=seed_run_name):
                        mlflow.log_param("seed", seed)
                        print("Running active model selection with seed", seed)
                        print("DEBUG ARGS", args.__dict__)
                        seed_stochastic = do_model_selection_experiment(dataset, oracle, args, loss_fn, seed=seed)
                        mlflow.log_param("stochastic", seed_stochastic)

                if not seed_stochastic:
                    print("Method is not stochastic for this task. Skipping further seeds.")
                    break


In [None]:
main(['--task', 'cifar10_5592', '--method', 'coda', '--data-dir', 'dataset'])
