# Using data from 
https://github.com/sebastianpinedaar/finetuning_text_classifiers/blob/main/


### USE FTC KERNEL

In [1]:
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import os
import sys



In [None]:
print(os.getcwd())
sys.path.append(os.getcwd()+ '/finetuning_text_classifiers')

from metadataset.ftc.metadataset import FTCMetadataset

In [6]:
data_dir = "../data"
ftc = FTCMetadataset(data_dir=str(data_dir),
                       metric_name="error",
                       data_version="extended")
mini = FTCMetadataset(data_dir=str(data_dir),
                        metric_name="error",
                        data_version="mini")

splits = ['valid', 'test']
dataset_names = mini.get_dataset_names()

print(f"{'Dataset':<40} {'Split':<8} {'Mini':<8} {'FTC':<8}")
print("-" * 68)

for name in dataset_names:
    for split in splits:
        mini.set_state(name, split)
        mini_samples = len(mini.get_targets())

        ftc.set_state(name, split)
        ftc_samples = len(ftc.get_targets())

        print(f"{name:<40} {split:<8} {mini_samples:<8} {ftc_samples:<8}")

Dataset                                  Split    Mini     FTC     
--------------------------------------------------------------------
imdb                                     valid    5000     5000    
imdb                                     test     25000    25000   
mteb/tweet_sentiment_extraction          valid    5497     5497    
mteb/tweet_sentiment_extraction          test     3534     3534    
ag_news                                  valid    24000    24000   
ag_news                                  test     7600     7600    
dbpedia_14                               valid    112000   112000  
dbpedia_14                               test     70000    70000   
stanfordnlp/sst2                         valid    13470    13470   
stanfordnlp/sst2                         test     10776    10776   
SetFit/mnli                              valid    78541    78541   
SetFit/mnli                              test     62833    62833   


In [12]:
ftc = FTCMetadataset(data_dir=str(data_dir),
                       metric_name="error",
                       data_version="extended")
mini = FTCMetadataset(data_dir=str(data_dir),
                        metric_name="error",
                        data_version="mini")

splits = ['valid', 'test']
dataset_names = mini.get_dataset_names()

# Print header
print(f"{'Dataset':<40} {'Split':<8} {'Type':<8} {'Ensemble Members':<18} {'Classes':<10} {'Samples':<10}")
print("-" * 100)

for name in dataset_names:
    for split in splits:
        # Mini dataset
        mini.set_state(name, split)
        mini_targets = mini.get_targets()
        mini_hp_candidates, mini_indices = mini._get_hp_candidates_and_indices()
        mini_predictions = mini.get_predictions([[0]])
        mini_num_members = len(mini_indices)
        mini_num_classes = mini_predictions.shape[-1]
        mini_num_samples = len(mini_targets)

        # FTC dataset
        ftc.set_state(name, split)
        ftc_targets = ftc.get_targets()
        ftc_hp_candidates, ftc_indices = ftc._get_hp_candidates_and_indices()
        ftc_predictions = ftc.get_predictions([[0]])
        ftc_num_members = len(ftc_indices)
        ftc_num_classes = ftc_predictions.shape[-1]
        ftc_num_samples = len(ftc_targets)

        # Print mini row
        print(f"{name:<40} {split:<8} {'mini':<8} {mini_num_members:<18} {mini_num_classes:<10} {mini_num_samples:<10}")
        # Print FTC row
        print(f"{name:<40} {split:<8} {'ftc':<8} {ftc_num_members:<18} {ftc_num_classes:<10} {ftc_num_samples:<10}")

Dataset                                  Split    Type     Ensemble Members   Classes    Samples   
----------------------------------------------------------------------------------------------------
imdb                                     valid    mini     125                2          5000      
imdb                                     valid    ftc      125                2          5000      
imdb                                     test     mini     125                2          25000     
imdb                                     test     ftc      125                2          25000     
mteb/tweet_sentiment_extraction          valid    mini     100                3          5497      
mteb/tweet_sentiment_extraction          valid    ftc      100                3          5497      
mteb/tweet_sentiment_extraction          test     mini     100                3          3534      
mteb/tweet_sentiment_extraction          test     ftc      100                3          3534      

In [13]:
data_dir = "../data"
ftc = FTCMetadataset(data_dir=str(data_dir),
                       metric_name="error",
                       data_version="extended")
mini = FTCMetadataset(data_dir=str(data_dir),
                        metric_name="error",
                        data_version="mini")

splits = ['valid', 'test']
dataset_names = mini.get_dataset_names()

print(f"{'Dataset':<40} {'Split':<8} {'Mini (members, samples, classes)':<35} {'FTC (members, samples, classes)':<35}")
print("-" * 120)

for name in dataset_names:
    for split in splits:
        # Get all predictions and their shapes
        mini.set_state(name, split)
        _, mini_indices = mini._get_hp_candidates_and_indices()
        mini_shape = mini.get_predictions(mini_indices).shape

        ftc.set_state(name, split)
        _, ftc_indices = ftc._get_hp_candidates_and_indices()
        ftc_shape = ftc.get_predictions(ftc_indices).shape

        print(f"{name:<40} {split:<8} {str(mini_shape):<35} {str(ftc_shape):<35}")

Dataset                                  Split    Mini (members, samples, classes)    FTC (members, samples, classes)    
------------------------------------------------------------------------------------------------------------------------
imdb                                     valid    torch.Size([125, 5000, 2])          torch.Size([125, 5000, 2])         
imdb                                     test     torch.Size([125, 25000, 2])         torch.Size([125, 25000, 2])        
mteb/tweet_sentiment_extraction          valid    torch.Size([100, 5497, 3])          torch.Size([100, 5497, 3])         
mteb/tweet_sentiment_extraction          test     torch.Size([100, 3534, 3])          torch.Size([100, 3534, 3])         
ag_news                                  valid    torch.Size([120, 24000, 4])         torch.Size([99, 24000, 4])         
ag_news                                  test     torch.Size([120, 7600, 4])          torch.Size([99, 7600, 4])          
dbpedia_14               

In [None]:
data_dir = "../data"
#data_version = "mini" #10% of total with 20% val split
data_version = "extended"                              # NOTE not yet downloaded
metadataset = FTCMetadataset(data_dir=str(data_dir), 
                             metric_name="error",
                             data_version=data_version)
dataset_names = metadataset.get_dataset_names()
dataset_names

In [None]:
for dataset_name in dataset_names:

    metadataset.set_state(dataset_name=dataset_name,
                        split="valid")
    hp_candidates, indices = metadataset._get_hp_candidates_and_indices()
    predictions = metadataset.get_predictions([[0]])
    targets = metadataset.get_targets()
    num_configs = len(hp_candidates)
    num_classes = max(targets)
    num_val_samples = len(targets)

    metadataset.set_state(dataset_name=dataset_name,
                        split="test")
    targets = metadataset.get_targets()
    num_test_samples = len(targets)

    print("Dataset:", dataset_name,
        "num_configs:", num_configs,
          "num_classes:", num_classes,
           "num_val_samples:", num_val_samples,
           "num_test_samples:", num_test_samples)

Dataset: imdb num_configs: 125 num_classes: tensor(1) num_val_samples: 5000 num_test_samples: 25000
Dataset: mteb/tweet_sentiment_extraction num_configs: 100 num_classes: tensor(2) num_val_samples: 5497 num_test_samples: 3534
Dataset: ag_news num_configs: 120 num_classes: tensor(3) num_val_samples: 24000 num_test_samples: 7600
Dataset: dbpedia_14 num_configs: 65 num_classes: tensor(13) num_val_samples: 112000 num_test_samples: 70000
Dataset: stanfordnlp/sst2 num_configs: 125 num_classes: tensor(1) num_val_samples: 13470 num_test_samples: 10776
Dataset: SetFit/mnli num_configs: 100 num_classes: tensor(2) num_val_samples: 78541 num_test_samples: 62833


### Retreive the IMDB dataset

- Hyper parameters consist of LORA rank, learning rate and model (GPT2, Bert-Large, Albert-Large, Bart-Large, T5-Large)

- Greedy 50 always provide the highest score (NLL)

- Target tensor has 0, 1 => binary classification

In [None]:
imdb_name = dataset_names[0]
metadataset.set_state(dataset_name=imdb_name,
                    split="valid")
hp_candidates, indices = metadataset._get_hp_candidates_and_indices()

hp_candidates = pd.DataFrame(hp_candidates)
indices = pd.DataFrame(indices)
indices

#understand the data get unique values
cols = hp_candidates.columns
for col in cols:
    print(col, hp_candidates[col].unique())

print('----------------------')

print('indices', indices[0].unique())
print('Shape of hp_candidates:', hp_candidates.shape)
print('Shape of indices:', indices.shape)

print('----------------------')
predictions = metadataset.get_predictions([[0]])
print('Shape of predictions:', predictions.shape)

targets = metadataset.get_targets()
print('Shape of targets:', targets.shape)
print('Unique values in targets:', targets.unique())

#

0 [  8.  16.  32.  64. 128.]
1 [5.e-03 1.e-03 5.e-04 1.e-04 1.e-05]
2 [0. 1.]
3 [0. 1.]
4 [0. 1.]
5 [0. 1.]
6 [1. 0.]
----------------------
indices [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124]
Shape of hp_candidates: (125, 7)
Shape of indices: (125, 1)
----------------------
Shape of predictions: torch.Size([1, 1, 5000, 2])
Shape of targets: torch.Size([5000])
Unique values in targets: tensor([0, 1])


In [None]:
all_indices = list(range(hp_candidates.shape[0]))  # hp_candidates.shape[0] should be 125
predictions_all = metadataset.get_predictions(all_indices)
print("Shape of predictions for all candidates:", predictions_all.shape)


Shape of predictions for all candidates: torch.Size([125, 5000, 2])


In [None]:
metadataset.set_state(dataset_name=imdb_name,
                    split="test")
hp_candidates_test, indices_test = metadataset._get_hp_candidates_and_indices()
metadataset.set_state(dataset_name=imdb_name,
                    split="valid")
hp_candidates_valid, indices_valid = metadataset._get_hp_candidates_and_indices()
(hp_candidates_test !=  hp_candidates_valid).sum()

tensor(0)

### Preprocess to get all ensemble probabilites for all models across validation and test sets

In [None]:
def retrieve_data(dataset_name, splits):
    """
       hp_candidates, indices, predictions, targets 
    """
    results = {}
    for split in splits:
        metadataset.set_state(dataset_name=dataset_name,
                        split=split)
        hp_candidates, indices = metadataset._get_hp_candidates_and_indices()
        predictions = metadataset.get_predictions(indices)
        targets = metadataset.get_targets()
        results[split] = (hp_candidates, indices, predictions, targets)
    return results

splits = ["valid", "test"]
imdb_name = dataset_names[0]
results = retrieve_data(imdb_name, splits)


# lets have a look at the data
print('Preds: ',results['valid'][2].shape, 'Targets: ', results['valid'][3].shape)
# Sanity check
print((results['valid'][1] != results['test'][1]).sum())

Preds:  torch.Size([125, 5000, 2]) Targets:  torch.Size([5000])
tensor(0)


# Comparision with the paper

### Starting with looking at a greedy implementation

1. Best
2. Greedy 5
3. Greedy 50

In [None]:
def greedy_select_ensemble(member_probs, labels, m, eps=1e-12):
    """
    Args:
        member_probs shape [num_members, num_samples, num_classes] 
        labels shape [num_samples] 
    Returns:
        selected (list of int): The indices (with respect to member_probs) of the selected ensemble members.
        ensemble_nlls (list of float): The ensemble NLL after each member is added.
    """
    # Ensure labels is a tensor.
    if not torch.is_tensor(member_probs) and torch.is_tensor(labels):
        raise ValueError('Invalid data type as input')

    
    num_members = member_probs.shape[0]
    
    # Function to compute ensemble NLL given a list of candidate indices.
    def compute_ensemble_nll(indices, member_probs, labels, epsilon):
        ensemble_probs = member_probs[indices].mean(dim=0)
        nll = -torch.gather(torch.log(ensemble_probs + epsilon), dim = 1, index = labels.unsqueeze(1)).squeeze()
        nll_mean = nll.mean()
        return nll_mean.item()
    
    selected = []       # List to hold the indices of selected members.
    ensemble_nlls = []  # List to hold the ensemble NLL after each addition.
    
    # Create a set of all candidate indices.
    remaining = set(range(num_members))
    
    # Greedy forward selection:
    for i in range(m):
        best_candidate = None
        best_nll = float('inf')
        
        # Evaluate each candidate in the remaining pool.
        for candidate in remaining:
            candidate_set = selected + [candidate]
            candidate_nll = compute_ensemble_nll(candidate_set, member_probs, labels, eps)
            if candidate_nll < best_nll:
                best_nll = candidate_nll
                best_candidate = candidate
        
        # Add the best candidate for this iteration.
        selected.append(best_candidate)
        remaining.remove(best_candidate)
        ensemble_nlls.append(best_nll)
        if i == (m-1):
            print(f"Ensemble members {selected}: NLL = {best_nll:.4f}")
    
    return selected, ensemble_nlls

# Example usage:
def compute_ensemble_nll(member_probs, labels, eps=1e-12):
    ensemble_probs = member_probs.mean(dim=0)  # [num_samples, num_classes]
    nll = -torch.gather(torch.log(ensemble_probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
    return nll.mean().item()

member_probs = results['valid'][2]
labels = results['valid'][3]

for m in [1, 5, 50]:
    indices, losses = greedy_select_ensemble(member_probs, labels, m)

    temp_test_probs = results['test'][2][indices]
    temp_test_labels = results['test'][3]
    print(temp_test_probs.shape, temp_test_labels.shape)
    print("Non claibrated nll :", compute_ensemble_nll(temp_test_probs, temp_test_labels))



Ensemble members [124]: NLL = 0.2141
torch.Size([1, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.2085971087217331
Ensemble members [124, 98, 118, 84, 87]: NLL = 0.1305
torch.Size([5, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.12724649906158447
Ensemble members [124, 98, 118, 84, 87, 43, 113, 112, 82, 93, 117, 109, 108, 94, 54, 83, 97, 103, 99, 114, 7, 123, 77, 38, 79, 102, 88, 107, 73, 78, 11, 92, 89, 68, 122, 119, 28, 53, 69, 63, 48, 58, 104, 8, 64, 22, 59, 33, 12, 2]: NLL = 0.1325
torch.Size([50, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.13053670525550842


### Simple best n

In [None]:
def best_n_ensemble(member_probs, labels, m, eps=1e-12):
    num_members = member_probs.shape[0]

    individual_nlls = []
    for i in range(num_members):
        probs = member_probs[i]  # [num_samples, num_classes]
        nll = -torch.gather(torch.log(probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
        individual_nlls.append(nll.mean().item())

    sorted_indices = sorted(range(num_members), key=lambda i: individual_nlls[i])
    indices = sorted_indices[:m]
    print(indices)
    return indices

for i in [1, 5, 50]:
    ensemble_indices = best_n_ensemble(member_probs, labels, i)
    temp_test_probs = results['test'][2][ensemble_indices]
    temp_test_labels = results['test'][3]
    print(temp_test_probs.shape, temp_test_labels.shape)
    print("Non claibrated nll :", compute_ensemble_nll(temp_test_probs, temp_test_labels))

[124]
torch.Size([1, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.2085971087217331
[124, 109, 99, 94, 114]
torch.Size([5, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.15763503313064575
[124, 109, 99, 94, 114, 84, 89, 119, 93, 104, 98, 123, 88, 83, 79, 118, 78, 113, 108, 103, 69, 64, 74, 44, 87, 112, 117, 77, 97, 43, 59, 34, 82, 49, 73, 39, 92, 48, 107, 122, 38, 29, 53, 63, 68, 102, 58, 23, 28, 33]
torch.Size([50, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.13196343183517456


### Implementation based of email from author

In [None]:
def greedy_select_ensemble_with_metric(member_probs, labels, m, metric_name="nll", no_resample=True, eps=1e-12 ):
    """
    Args: 
        member_probs shape [num_members, num_samples, num_classes]
        labels shape [num_samples]
        m (int): Number of ensemble members to select.
        metric_name (str): The metric to use for evaluating the ensemble.
        no_resample (bool): If True, do not resample the selected members.
        eps (float): Small value to avoid log(0).
        
    """
    if not torch.is_tensor(member_probs) and torch.is_tensor(labels):
        raise ValueError('Invalid data type as input')
    
    num_members = member_probs.shape[0]
    selected = []       # List to hold the indices of selected models.
    remaining = set(range(num_members))
    ensemble_metrics = []  # To record the metric after each addition.
    
    def compute_metric(indices):
        # Average the probabilities of the models in 'indices'
        ensemble_probs = member_probs[list(indices)].mean(dim=0)  # Shape: [num_samples, num_classes]
        if metric_name == "nll":
            # Negative log-likelihood.
            nll = -torch.gather(torch.log(ensemble_probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
            return nll.mean().item()
        elif metric_name == "error":
            # Classification error rate.
            preds = ensemble_probs.argmax(dim=1)
            error_rate = (preds != labels).float().mean().item()
            return error_rate
        elif metric_name == "relative_absolute_error":
            # Relative absolute error: |pred - true|/clamp(|true|,min=1)
            preds = ensemble_probs.argmax(dim=1).float()
            rel_abs_error = (torch.abs(preds - labels) / torch.clamp(torch.abs(labels), min=1.0)).mean().item()
            return rel_abs_error
        else:
            raise ValueError(f"Unknown metric_name: {metric_name}")
    
    # Greedy forward selection: add one candidate at each iteration.
    for i in range(m):
        best_candidate = None
        best_metric_value = float('inf')
        # Evaluate each candidate in the remaining set.
        for candidate in remaining:
            candidate_set = selected + [candidate]
            candidate_metric = compute_metric(candidate_set)
            if candidate_metric < best_metric_value:
                best_metric_value = candidate_metric
                best_candidate = candidate
        # Add the best candidate
        selected.append(best_candidate)
        # If no resampling, remove it from the pool.
        if no_resample:
            remaining.remove(best_candidate)
        ensemble_metrics.append(best_metric_value)
        #print(f"Step {i+1}: Selected ensemble members {selected} => {metric_name} = {best_metric_value:.4f}")
    
    return selected, ensemble_metrics


member_probs = results['valid'][2]
labels = results['valid'][3]

for i in [1, 5, 50]:
    ensemble_indices, _ = greedy_select_ensemble_with_metric( member_probs, labels, m=i, metric_name="nll", 
                                                                     no_resample=False)
    temp_test_probs = results['test'][2][ensemble_indices]
    temp_test_labels = results['test'][3]
    print(temp_test_probs.shape, temp_test_labels.shape)
    print("Non claibrated nll :", compute_ensemble_nll(temp_test_probs, temp_test_labels))


torch.Size([1, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.2085971087217331
torch.Size([5, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.12724649906158447
torch.Size([50, 25000, 2]) torch.Size([25000])
Non claibrated nll : 0.12414136528968811


### Greedy as by Caruana, R., Niculescu-Mizil, A., Crew, G., and Ksikes, A. 

Below is the interpretation based on the limited information from their paper

In [None]:
# Suppose we have 5 models and their individual NLL values:
individual_nlls = [0.45, 0.23, 0.50, 0.31, 0.29]
num_members = len(individual_nlls)  # 5

# Create a list of indices [0, 1, 2, 3, 4]
all_indices = range(num_members)

# Sort these indices by looking up the NLL in individual_nlls
# The default sort order is ascending, so the model with the smallest NLL is first.
sorted_indices = sorted(range(num_members), key=lambda i: individual_nlls[i], reverse = False)

print("Individual NLLs:", individual_nlls)
print("Sorted indices:", sorted_indices) 


Individual NLLs: [0.45, 0.23, 0.5, 0.31, 0.29]
Sorted indices: [1, 4, 3, 0, 2]


In [None]:
def greedy_select_ensemble_with_initial(member_probs, labels, m, init_N, no_resample = True, tolerance = 3, eps=1e-12):
    """
    Greedily selects an ensemble of m models using an initial candidate pool of the best init_N models.
    Only adds a candidate if it reduces the ensemble's NLL."

    Args:
        m: Attempted ensemble size
        init_N: Number of initial ensemble

    """

    # ------ Helper fn for the ensemble NLL ------
    def compute_ensemble_nll(indices):
        ensemble_probs = member_probs[indices].mean(dim=0)  # [num_samples, num_classes]
        nll = -torch.gather(torch.log(ensemble_probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
        return nll.mean().item()
    
    #ensure m >= init_N
    if init_N > m:
        raise ValueError('Initial ensemble must be smaller than m')

    # Ensure data are tensor.
    if not (torch.is_tensor(member_probs) and torch.is_tensor(labels)):
        raise ValueError('member_probs and labels must both be tensors')

    
    num_members = member_probs.shape[0]

    # 1. Get the individual nll
    individual_nlls = []
    for i in range(num_members):
        probs = member_probs[i]  # [num_samples, num_classes]
        nll = -torch.gather(torch.log(probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
        individual_nlls.append(nll.mean().item())
    
    # 2. Sort models by individual NLL (lower is better).
    sorted_indices = sorted(range(num_members), key=lambda i: individual_nlls[i])
    
    # 3. Form the initial candidate pool (the top init_N models).
    init_N = min(init_N, num_members)
    candidate_pool = sorted_indices[:init_N] # best init_N members
    print(f"Initial candidate pool (top {init_N} models): {candidate_pool}")
    
    # 4. Initialize the ensemble with the full candidate pool.
    selected = candidate_pool.copy()
    # Based of indices
    remaining = set(range(num_members)) - set(candidate_pool)
    current_nll = compute_ensemble_nll(selected)
    ensemble_nlls = [current_nll]
    
    # 5. Greedy forward selection: add candidates only if they improve (i.e. lower) the NLL.
    counter = 0
    while len(selected) < m:
        best_candidate = None
        best_nll = float('inf')
        for candidate in remaining:
            candidate_set = selected + [candidate]
            candidate_nll = compute_ensemble_nll(candidate_set)
            if candidate_nll < best_nll:
                best_nll = candidate_nll
                best_candidate = candidate
        if best_nll < current_nll:
            selected.append(best_candidate)
            if no_resample:
                remaining.remove(best_candidate)
            ensemble_nlls.append(best_nll)
            current_nll = best_nll
            #print(f"Added candidate {best_candidate}, ensemble {selected}: NLL = {best_nll:.4f}")
            #reset counter
            counter = 0
        else:
            counter += 1
            if counter >= tolerance:
                print("Stopping early.")
                break
            else:
                print("No improvement of NLL - continouing")

    return selected, ensemble_nlls


# Example usage:
member_probs = results['valid'][2]
labels = results['valid'][3]

m = 50
for i in np.linspace(1,10, 10):
    init_N = int(i)    
    selected_indices, ensemble_nlls = greedy_select_ensemble_with_initial(member_probs, labels, m, init_N,
                                                                          no_resample=False)
    print(f"Initial ensemble size {i}")
    print("Final selected ensemble indices:", selected_indices)

    temp_test_probs = results['test'][2][selected_indices]
    temp_test_labels = results['test'][3]
    print("Non claibrated nll :", compute_ensemble_nll(temp_test_probs, temp_test_labels))

Initial candidate pool (top 1 models): [124]
No improvement of NLL - continouing
No improvement of NLL - continouing
Stopping early.
Initial ensemble size 1.0
Final selected ensemble indices: [124, 98, 118, 84, 87, 43, 113, 112, 84, 82, 118]
Non claibrated nll : 0.12684710323810577
Initial candidate pool (top 2 models): [124, 109]
No improvement of NLL - continouing
No improvement of NLL - continouing
Stopping early.
Initial ensemble size 2.0
Final selected ensemble indices: [124, 109, 83, 118, 87, 84, 43, 113, 112, 82, 84, 118, 93, 117, 54, 87, 118, 98]
Non claibrated nll : 0.1262870877981186
Initial candidate pool (top 3 models): [124, 109, 99]
No improvement of NLL - continouing
No improvement of NLL - continouing
Stopping early.
Initial ensemble size 3.0
Final selected ensemble indices: [124, 109, 99, 118, 87, 84, 112, 43, 113, 98, 82, 118, 84, 117, 93, 54, 87, 108]
Non claibrated nll : 0.1260664016008377
Initial candidate pool (top 4 models): [124, 109, 99, 94]
No improvement of N

### Create a new greedy 50 function that aims to perform well with calibration 

**Version A**
1. c_1-temperature-calibrate each ensemble candidate
2. sort the c_1-temperature-calibrated models by their NLL performance and pick the top 5(?) as an initial ensemble
3. for in in range(50):
4. calibrate c_1 and c_2 for the ensemble
5. add the model that improves the calibrated ensemble NNL the most (without changing c_1 and c_2 to safe computational time)
6. calibrate c_1 and c_2 for the ensemble
7. return ensemble

**Version B**

1. c_1-temperature-calibrate each ensemble candidate
2. sort the c_1-temperature-calibrated models by their NLL performance and pick the top 5(?) as an initial ensemble
3. for in in range(50):
4. add the model that improves the calibrated ensemble NNL the most where we recalibrate c_1 and c_2 for each of the 125 candidate ensembles (maybe only use a small grid locally around the winning c_1 and c_2 of the ensemble winning the previous iteration to safe computational time)
5. calibrate c_1 and c_2 for the ensemble on a large very fine grid
6. return ensemble

### Comment:

Starting with an initial ensemble and then allowing resampling - beats results from the paper

# Sandbox stops

## Adjusted calibrator

In [None]:
sys.path.append('../../src')
sys.path.append(os.getcwd()+ '/finetuning_text_classifiers')

from calibrator import PrecomputedCalibrator
from metadataset.ftc.metadataset import FTCMetadataset

data_dir = "../data"
data_version = "mini" #10% of total with 20% val split
#data_version = "extended"                              # NOTE not yet downloaded
metadataset = FTCMetadataset(data_dir=str(data_dir), 
                             metric_name="error",
                             data_version=data_version)
dataset_names = metadataset.get_dataset_names()
dataset_names

['imdb',
 'mteb/tweet_sentiment_extraction',
 'ag_news',
 'dbpedia_14',
 'stanfordnlp/sst2',
 'SetFit/mnli']

### Testing the calibrator

In [None]:
def create_new_split(val_probs, val_labels, test_probs, test_labels, seed):
    """
    Creates a new validation/test split by randomly shuffling the union of the default splits.
    
    Args:
        shapes: [num_models, num_samples, num_classes]
    Outputs:
        shapes: [num_models, num_samples, num_classes]
    """
    # Combine along the sample dimension (assumed to be 0)
    combined_probs = torch.cat([val_probs, test_probs], dim=1)
    combined_labels = torch.cat([val_labels, test_labels], dim=0)
    total = combined_labels.shape[0]
    val_count = val_labels.shape[0]
    
    # Create a permutation using numpy's random generator with the given seed.
    rng = np.random.default_rng(seed)
    permuted_indices = rng.permutation(total)
    
    # Compute the new split sizes
    new_val_indices = permuted_indices[:val_count]
    new_test_indices = permuted_indices[val_count:]
    
    new_val_member_probs = combined_probs[:,new_val_indices,:]
    new_val_labels = combined_labels[new_val_indices]
    new_test_member_probs = combined_probs[:,new_test_indices,:]
    new_test_labels = combined_labels[new_test_indices]
    
    return new_val_member_probs, new_val_labels, new_test_member_probs, new_test_labels


# Parameters for experiment
num_datasets = 2
datasets = dataset_names[:1]

def retrieve_data(dataset_name, splits):
    """
       hp_candidates, indices, predictions, targets 
    """
    results = {}
    for split in splits:
        metadataset.set_state(dataset_name=dataset_name,
                        split=split)
        hp_candidates, indices = metadataset._get_hp_candidates_and_indices()
        predictions = metadataset.get_predictions(indices)
        targets = metadataset.get_targets()
        results[split] = (hp_candidates, indices, predictions, targets)
    return results

splits = ["valid", "test"]

# used to evaluate non calibrated ensembles
def compute_ensemble_nll(member_probs, labels, eps=1e-12):
    ensemble_probs = member_probs.mean(dim=0)  # [num_samples, num_classes]
    nll = -torch.gather(torch.log(ensemble_probs + eps), 1, labels.unsqueeze(1)).squeeze(1)
    return nll.mean().item()


# ------------- START EXPERIMENTS -----------------

experimental_results = []

#Loop over the datasets
for name in datasets:
    print("Processing dataset:", name)
    results = retrieve_data(name, splits)
    default_val_member_probs = results['valid'][2]
    default_val_labels = results['valid'][3]
    default_test_member_probs = results['test'][2]
    default_test_labels = results['test'][3]

    for split in range(num_datasets):
        if split == 0:
            # Use the original split
            val_member_probs = default_val_member_probs
            val_labels = default_val_labels
            test_member_probs = default_test_member_probs
            test_labels = default_test_labels
        else:
            val_member_probs, val_labels, test_member_probs, test_labels = create_new_split(val_member_probs, val_labels, 
                                                                                            test_member_probs, test_labels,
                                                                                             seed=split)
        for method in ["convex_comb", "pure_logits"]:
            calibrator = PrecomputedCalibrator(adjusting_alpha_method=method, 
                                           clamping_alphas=False, 
                                           logits_based_adjustments=True)
            # --- Ensemble Selection Methods ---
            greedy_indices, _ = calibrator.greedy_ensemble(member_probs=val_member_probs, 
                                                       labels=val_labels, m=50, no_resample=False)
            #NOTE no longer init_N, also just one retured value
            greedy_init_indices = calibrator.greedy_ensemble_with_initial(member_probs=val_member_probs, 
                                                                         labels=val_labels, m=50, no_resample=False, 
                                                                         tolerance=3, eps=1e-12)
            
            if method == 'convex_comb':
                c2_vals = np.linspace(0, 3, 100)
            elif method == 'pure_logits':
                c2_vals = np.linspace(0, 10, 100)
            temps = np.linspace(0.5, 2, 50)
            epi_scalar_vals = np.array([1])

            greedy_init_temp_indices, _ = calibrator.greedy_ensemble_with_initial_and_temp(member_probs=val_member_probs, 
                                                    labels=val_labels, m=50, init_N=5, no_resample=False, tolerance=3, eps=1e-12,
                                                    c1_vals=temps, c2_vals=c2_vals, epi_scalar_vals=epi_scalar_vals)

            # Create a dictionary for the three ensemble selection methods.
            ensemble_methods = {"greedy": greedy_indices,
                                "greedy_init": greedy_init_indices,
                                "greedy_init_temp": greedy_init_temp_indices}

            for ens_method, indices in ensemble_methods.items():
                # Update validation and test ensemble probabilities based on the selected indices
                val_probs_ens = val_member_probs[indices]
                test_probs_ens = test_member_probs[indices]

                # Grid search calibration using the validation ensemble probabilities.
                _, best_params = calibrator.grid_search_c1_c2_precomputed(val_probs_ens, val_labels, temps, c2_vals, 
                                                                          epi_scalar_vals)
                c1_prim = best_params['c1']
                c2_prim = best_params['c2']
                epi_scalar_prim = best_params['epi_scalar']

                # Apply calibration on the test ensemble
                calibrator_results = calibrator.predict(test_probs_ens, c1_prim, c2_prim, epi_scalar_prim, test_labels)
                # Extract calibrated NLL (ensure a scalar by taking the mean over samples)
                calibrated_nll = calibrator_results['nll'].mean()
                # Evaluate baseline (non-calibrated) ensemble NLL on the test set
                non_calibrated_nll = compute_ensemble_nll(test_probs_ens, test_labels)

                # Store experimental results for this ensemble type
                experimental_results.append({'dataset': name,
                                             'Split': split,
                                             'method': method,
                                             'ensemble_type': ens_method,
                                             'ensemble_size': len(indices),
                                             'non_calibrated_nll': non_calibrated_nll,
                                             'calibrated_nll': calibrated_nll,
                                             'c1': c1_prim,
                                             'c2': c2_prim,
                                             'epi_scalar': epi_scalar_prim
                                             })

df_results = pd.DataFrame(experimental_results)
df_results


Processing dataset: imdb
Greedy ensemble completed; selected [124, 98, 118, 84, 87, 43, 113, 112, 84, 82, 118, 109, 93, 117, 54, 87, 108, 84, 124, 118, 97, 113, 84, 7, 87, 118, 38, 93, 108, 124, 84, 118, 87, 84, 102, 113, 11, 118, 93, 117, 84, 87, 112, 114, 83, 118, 43, 82, 84, 113] with NLL = 0.12283
Greedy enemble with initial 6; selected [94, 99, 84, 93, 98, 89, 124, 118, 43, 87, 112, 113, 124, 118, 82, 117, 87, 108, 54, 118, 97, 113, 109, 7, 87, 118, 38, 108, 83, 114, 118, 77, 113, 102, 87] with NLL = 0.12412
Greedy enemble with temp scaling; selected [94, 99, 84, 93, 98, 109, 118, 87, 43, 113, 112, 118, 82, 117] with NLL = 0.12526
Greedy ensemble completed; selected [124, 98, 118, 84, 87, 43, 113, 112, 84, 82, 118, 109, 93, 117, 54, 87, 108, 84, 124, 118, 97, 113, 84, 7, 87, 118, 38, 93, 108, 124, 84, 118, 87, 84, 102, 113, 11, 118, 93, 117, 84, 87, 112, 114, 83, 118, 43, 82, 84, 113] with NLL = 0.12283
Greedy enemble with initial 6; selected [94, 99, 84, 93, 98, 89, 124, 118, 43,

Unnamed: 0,dataset,method,ensemble_type,ensemble_size,non_calibrated_nll,calibrated_nll,c1,c2,epi_scalar
0,imdb,convex_comb,greedy,50,0.124141,0.12423,0.969388,0.030303,1
1,imdb,convex_comb,greedy_init,35,0.124,0.124087,0.969388,0.060606,1
2,imdb,convex_comb,greedy_init_temp,14,0.135776,0.125154,1.5,0.151515,1
3,imdb,pure_logits,greedy,50,0.124141,0.124155,0.969388,1.010101,1
4,imdb,pure_logits,greedy_init,35,0.124,0.124017,0.969388,1.010101,1
5,imdb,pure_logits,greedy_init_temp,50,0.153543,0.12242,1.173469,2.929293,1
6,imdb,convex_comb,greedy,50,0.12419,0.124204,0.989796,0.030303,1
7,imdb,convex_comb,greedy_init,27,0.123634,0.123645,0.989796,0.060606,1
8,imdb,convex_comb,greedy_init_temp,17,0.136334,0.125467,1.5,0.090909,1
9,imdb,pure_logits,greedy,50,0.12419,0.124157,0.989796,1.010101,1


### Testing pipeline

In [None]:
datasets = dataset_names[:3]

experimental_results = []

for name in datasets:
    print("Processing dataset:", name)
    results = retrieve_data(name, splits)
    val_member_probs = results['valid'][2]
    val_labels = results['valid'][3]
    test_member_probs = results['test'][2]
    test_labels = results['test'][3]

    for method in ["convex_comb", "pure_logits"]:
        calibrator = PrecomputedCalibrator(adjusting_alpha_method=method, 
                                           clamping_alphas=False, 
                                           logits_based_adjustments=True)
        # --- Ensemble Selection Methods ---
        # 1. Greedy ensemble (with resampling)
        greedy_indices, _ = calibrator.greedy_ensemble(member_probs=val_member_probs, 
                                                       labels=val_labels, m=50, no_resample=False)
        # 2. Greedy ensemble with an initial candidate pool
        greedy_init_indices = calibrator.greedy_ensemble_with_initial(member_probs=val_member_probs, 
                                                                         labels=val_labels, m=50,
                                                                         no_resample=False, 
                                                                         tolerance=3, eps=1e-12)
        # 3. Greedy ensemble with initial + temperature scaling calibration
        if method == 'convex_comb':
            c2_vals = np.linspace(0, 3, 100)
        elif method == 'pure_logits':
            c2_vals = np.linspace(0, 10, 100)
        temps = np.linspace(0.5, 1.5, 50)
        epi_scalar_vals = np.array([1])
        greedy_init_temp_indices, _ = calibrator.greedy_ensemble_with_initial_and_temp(member_probs=val_member_probs, 
                                                    labels=val_labels, m=50, init_N=5, no_resample=False, tolerance=3, eps=1e-12,
                                                    c1_vals=temps, c2_vals=c2_vals, epi_scalar_vals=epi_scalar_vals)

        # Create a dictionary for the three ensemble selection methods.
        ensemble_methods = {
            "greedy": greedy_indices,
            "greedy_init": greedy_init_indices,
            "greedy_init_temp": greedy_init_temp_indices
        }

        # Loop over each ensemble type
        for ens_method, indices in ensemble_methods.items():
            # Update validation and test ensemble probabilities based on the selected indices
            val_probs_ens = val_member_probs[indices]
            test_probs_ens = test_member_probs[indices]

            # Grid search calibration using the validation ensemble probabilities.
            _, best_params = calibrator.grid_search_c1_c2_precomputed(
                val_probs_ens, val_labels, temps, c2_vals, epi_scalar_vals)
            c1_prim = best_params['c1']
            c2_prim = best_params['c2']
            epi_scalar_prim = best_params['epi_scalar']

            # Apply calibration on the test ensemble
            calibrator_results = calibrator.predict(test_probs_ens, c1_prim, c2_prim, epi_scalar_prim, test_labels)
            # Extract calibrated NLL (ensure a scalar by taking the mean over samples)
            calibrated_nll = calibrator_results['nll'].mean()
            # Evaluate baseline (non-calibrated) ensemble NLL on the test set
            non_calibrated_nll = compute_ensemble_nll(test_probs_ens, test_labels)

            # Store experimental results for this ensemble type
            experimental_results.append({
                'dataset': name,
                'method': method,
                'ensemble_type': ens_method,
                'ensemble_size': len(indices),
                'non_calibrated_nll': non_calibrated_nll,
                'calibrated_nll': calibrated_nll,
                'c1': c1_prim,
                'c2': c2_prim,
                'epi_scalar': epi_scalar_prim
            })

df_results = pd.DataFrame(experimental_results)
df_results


Processing dataset: imdb
Ensemble [124, 109, 99, 94, 114, 118, 87, 113, 43, 83, 112, 84, 118, 82, 117, 93, 87, 84, 108, 54, 118, 7, 98, 113, 84, 87, 38, 118, 93, 108]: NLL = 0.12387
Initial candidate pool (top 5 models): [94, 99, 84, 93, 98]
Completed calibration for initial candidate pool ([94, 99, 84, 93, 98]), with an NLL of:  0.14777784049510956
Stopping early.
Ensemble [94, 99, 84, 93, 98, 109, 118, 87, 43, 113, 112, 118, 82, 117]: NLL = 0.12526
Ensemble [124, 109, 99, 94, 114, 118, 87, 113, 43, 83, 112, 84, 118, 82, 117, 93, 87, 84, 108, 54, 118, 7, 98, 113, 84, 87, 38, 118, 93, 108]: NLL = 0.12387
Initial candidate pool (top 5 models): [94, 99, 84, 93, 98]
Completed calibration for initial candidate pool ([94, 99, 84, 93, 98]), with an NLL of:  0.14334142208099365
Ensemble [94, 99, 84, 93, 98, 118, 108, 123, 48, 113, 108, 88, 118, 87, 103, 108, 97, 123, 88, 73, 118, 79, 113, 108, 87, 112, 123, 88, 102, 79, 108, 97, 118, 43, 113, 108, 123, 79, 82, 88, 103, 87, 112, 108, 78, 123, 

Unnamed: 0,dataset,method,ensemble_type,ensemble_size,non_calibrated_nll,calibrated_nll,c1,c2,epi_scalar
0,imdb,convex_comb,greedy,50,0.124141,0.12423,0.969388,0.030303,1
1,imdb,convex_comb,greedy_init,50,0.12488,0.124498,1.071429,0.060606,1
2,imdb,convex_comb,greedy_init_temp,50,0.135776,0.125154,1.5,0.151515,1
3,imdb,pure_logits,greedy,50,0.124141,0.124155,0.969388,1.010101,1
4,imdb,pure_logits,greedy_init,50,0.12488,0.124374,1.091837,1.010101,1
5,imdb,pure_logits,greedy_init_temp,50,0.153543,0.12242,1.173469,2.929293,1
6,mteb/tweet_sentiment_extraction,convex_comb,greedy,50,0.507702,0.50726,1.05102,0.0,1
7,mteb/tweet_sentiment_extraction,convex_comb,greedy_init,50,0.508739,0.50794,1.091837,0.0,1
8,mteb/tweet_sentiment_extraction,convex_comb,greedy_init_temp,50,0.523816,0.508146,1.5,0.0,1
9,mteb/tweet_sentiment_extraction,pure_logits,greedy,50,0.507702,0.506611,1.193878,0.909091,1


In [None]:
def create_new_split(val_probs, val_labels, test_probs, test_labels, seed):
    """
    Creates a new validation/test split by randomly shuffling the union of the default splits.
    
    Args:
        shapes: [num_models, num_samples, num_classes]
    Outputs:
        shapes: [num_models, num_samples, num_classes]
    """
    # Combine along the sample dimension (assumed to be 0)
    combined_probs = torch.cat([val_probs, test_probs], dim=1)
    combined_labels = torch.cat([val_labels, test_labels], dim=0)
    total = combined_labels.shape[0]
    val_count = val_labels.shape[0]
    
    # Create a permutation using numpy's random generator with the given seed.
    rng = np.random.default_rng(seed)
    permuted_indices = rng.permutation(total)
    
    # Compute the new split sizes
    new_val_indices = permuted_indices[:val_count]
    new_test_indices = permuted_indices[val_count:]
    
    new_val_member_probs = combined_probs[:,new_val_indices,:]
    new_val_labels = combined_labels[new_val_indices]
    new_test_member_probs = combined_probs[:,new_test_indices,:]
    new_test_labels = combined_labels[new_test_indices]
    
    return new_val_member_probs, new_val_labels, new_test_member_probs, new_test_labels


#test delete later - shape [num_models, num_samples, num_classes]
test_tensor = torch.randn(5, 100, 10)
test_labels = torch.randint(0, 10, (100,))
val_tensor = torch.randn(5, 20, 10)
val_labels = torch.randint(0, 10, (20,))

#example usage
new_val_probs, new_val_labels, new_test_probs, new_test_labels = create_new_split(val_tensor, val_labels, test_tensor, test_labels, seed=42)
new_val_probs.shape, new_val_labels.shape, new_test_probs.shape, new_test_labels.shape

(torch.Size([5, 20, 10]),
 torch.Size([20]),
 torch.Size([5, 100, 10]),
 torch.Size([100]))

### Simple pipeline (old) on mini

In [None]:
# No actual need for seeds

data_dir = "../data"
data_version = "mini"                                   #10% of total with 20% val split
#data_version = "extended"                              # NOTE not yet downloaded
metadataset = FTCMetadataset(data_dir=str(data_dir), 
                             metric_name="error",
                             data_version=data_version)
dataset_names = metadataset.get_dataset_names()

splits = ["valid", "test"]      # based of github

calibration_results = []

for name in dataset_names:
    print('\n ----------------------')
    print(f'Dataset: {name}')
    results = retrieve_data(name, splits)
    val_member_probs = results['valid'][2]
    val_labels = results['valid'][3]
    test_member_probs = results['test'][2]
    test_labels = results['test'][3]

    for method  in ["convex_comb", "pure_logits", "convex_comb_no_exp", "convex_comb_global"]:
        #print(f'Method: {method}')
        calibrator = PrecomputedCalibrator(adjusting_alpha_method=method, clamping_alphas=False, logits_based_adjustments=True)
        ensemble_indices, _ = calibrator.greedy_select_ensemble(member_probs = val_member_probs, labels= val_labels, m = 50, 
                                                        no_resample=False) 
        # only take the probs of the selected ensemble
        val_member_probs = val_member_probs[ensemble_indices]
        test_member_probs = test_member_probs[ensemble_indices]

        if method == 'convex_comb':
            c2_vals = np.linspace(0, 3, 100)
        elif method == 'pure_logits':
            c2_vals = np.linspace(0, 10, 100)
        elif method == 'convex_comb_no_exp':
            c2_vals = np.linspace(0, 10, 100)
        elif method == 'convex_comb_global':
            c2_vals = np.linspace(0, 1, 100)

        c1_vals = np.linspace(0.8, 1.2, 10)
        epi_scalar_vals = np.array([1])

        _, best_params = calibrator.grid_search_c1_c2_precomputed(val_member_probs, val_labels,
                                                       c1_vals, c2_vals, epi_scalar_vals)

        c1_prim = best_params['c1']
        c2_prim = best_params['c2']
        epi_scalar_prim = best_params['epi_scalar']

        calibrator_results = calibrator.predict(test_member_probs, c1_prim, c2_prim, epi_scalar_prim, test_labels)
        calibration_results.append({'dataset': name, 
                                    'Calibration method': method,
                                    'C1': c1_prim,
                                    'C2': c2_prim,
                                    'epi_scalar': epi_scalar_prim,
                                    'nll': calibrator_results['nll'].mean()})

        #print("Calibrated test nll: ", calibrator_results['nll'].mean())
        #print("Non claibrated test nll :", compute_ensemble_nll(test_member_probs, test_labels))

calibration_results = pd.DataFrame(calibration_results)
calibration_results


 ----------------------
Dataset: imdb

 ----------------------
Dataset: mteb/tweet_sentiment_extraction

 ----------------------
Dataset: ag_news

 ----------------------
Dataset: dbpedia_14

 ----------------------
Dataset: stanfordnlp/sst2

 ----------------------
Dataset: SetFit/mnli


Unnamed: 0,dataset,Calibration method,C1,C2,epi_scalar,nll
0,imdb,convex_comb,0.977778,0.030303,1,0.124203
1,imdb,pure_logits,0.977778,1.010101,1,0.124134
2,imdb,convex_comb_no_exp,0.977778,0.0,1,0.124203
3,imdb,convex_comb_global,0.977778,0.10101,1,0.124203
4,mteb/tweet_sentiment_extraction,convex_comb,1.066667,0.0,1,0.50716
5,mteb/tweet_sentiment_extraction,pure_logits,1.2,0.909091,1,0.506576
6,mteb/tweet_sentiment_extraction,convex_comb_no_exp,1.066667,0.0,1,0.50716
7,mteb/tweet_sentiment_extraction,convex_comb_global,1.066667,0.0,1,0.50716
8,ag_news,convex_comb,1.2,0.0,1,0.195051
9,ag_news,pure_logits,1.2,0.909091,1,0.19521
