In [1]:
import calibrator as cal
import numpy as np

In [2]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

NUM_CLASSES = 3129 ##ViLT has these many classes


In [3]:
logits_and_labels_file = "/teamspace/studios/this_studio/Selective_Prediction_VQA/predictions/logits_and_labels/"

## **Process data(Logits and Labels) from saved pt files**

In [4]:
from collections import Counter

def max_occurence(labels):
    count = Counter(labels)
    max_num = max(count, key=count.get)
    return max_num
    
def load_data(batch_no):
    data_path = "/teamspace/studios/this_studio/Selective_Prediction_VQA/predictions/logits_and_labels/"
    file_name = "Logits_and_labels" + str(batch_no) + ".pt"
    data = torch.load(data_path + file_name)
    final_labels = []
    final_logits = []
    for log, labels in zip(data['logits'], data['labels']):
        if(len(labels)) == 0:
            continue
        final_logits.append(torch.from_numpy(log))
        final_labels.append(labels)
    
    labels = np.array([ max_occurence(answer_labels) for answer_labels in final_labels])
    logits = torch.cat(final_logits).numpy()
    return logits, labels


## **Save Calibrator weights**

In [6]:
def save_calibrator(calibrator, calibrator_type= "vector_calibrator"):
    parent_path ="/teamspace/studios/this_studio/Selective_Prediction_VQA/calibration_methods/"
    file_name = "scaling/"+ calibrator_type + ".pt"
    
    vector_model_dict = {
        "biasFlag" : calibrator.biasFlag,
        "temperature" : calibrator.temperature,
        "bias" : calibrator.bias,
        "weights" : calibrator.weights,
        "num_label": calibrator.num_label
    }
    # self, num_label, bias=False, weights=None, device=None, print_verbose=False
    scaling_model_path = parent_path + file_name
    torch.save(vector_model_dict, scaling_model_path)

## Load Calibrator from a saved path

In [7]:
def load_calibrator(path, calibrator_type = "vector_calibrator"):
    dict = torch.load(path)
    if calibrator_type == "vector_calibrator":
        cali = cal.VectorScaling(bias=dict['biasFlag'], 
                                 weights= dict['weights'],
                                 num_label = dict['num_label'],
                                 device=device,
                                 print_verbose= False)
        cali.temperature = dict['temperature']
        cali.bias = dict['bias']
    else:
        #todo
        return None
    return cali

## **Training Loop**

In [9]:
def train(calibator):
    EPOCHS = 1
    NUM_OF_BATCHES = 2139 ## no. of files under data_path
    loss_accumulate = []
    ece_accumulate = []
    for epoch in range(EPOCHS):
        
        for batch_no in range(NUM_OF_BATCHES):
            print("Processing batch no. : ", batch_no)
            logits, labels = load_data(batch_no)
            loss, ece =  calibrator.fit(logits, labels)
            loss_accumulate = loss_accumulate + loss
            ece_accumulate = ece_accumulate + ece
    return loss_accumulate, ece_accumulate
            

## **Define Calibrator**

In [10]:
calibrator = cal.VectorScaling(bias=True, num_label = NUM_CLASSES, device=device, print_verbose= False)

In [11]:
loss, ece = train(calibrator)

Processing batch no. :  0
Processing batch no. :  1
Processing batch no. :  2


Processing batch no. :  3
Processing batch no. :  4
Processing batch no. :  5
Processing batch no. :  6
Processing batch no. :  7
Processing batch no. :  8
Processing batch no. :  9
Processing batch no. :  10
Processing batch no. :  11
Processing batch no. :  12
Processing batch no. :  13
Processing batch no. :  14
Processing batch no. :  15
Processing batch no. :  16
Processing batch no. :  17
Processing batch no. :  18
Processing batch no. :  19
Processing batch no. :  20
Processing batch no. :  21
Processing batch no. :  22
Processing batch no. :  23
Processing batch no. :  24
Processing batch no. :  25
Processing batch no. :  26
Processing batch no. :  27
Processing batch no. :  28
Processing batch no. :  29
Processing batch no. :  30
Processing batch no. :  31
Processing batch no. :  32
Processing batch no. :  33
Processing batch no. :  34
Processing batch no. :  35
Processing batch no. :  36
Processing batch no. :  37
Processing batch no. :  38
Processing batch no. :  39
Processi

  ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)


Processing batch no. :  157
Processing batch no. :  158
Processing batch no. :  159
Processing batch no. :  160
Processing batch no. :  161
Processing batch no. :  162
Processing batch no. :  163
Processing batch no. :  164
Processing batch no. :  165
Processing batch no. :  166
Processing batch no. :  167
Processing batch no. :  168
Processing batch no. :  169
Processing batch no. :  170
Processing batch no. :  171
Processing batch no. :  172
Processing batch no. :  173
Processing batch no. :  174
Processing batch no. :  175
Processing batch no. :  176
Processing batch no. :  177
Processing batch no. :  178
Processing batch no. :  179
Processing batch no. :  180
Processing batch no. :  181
Processing batch no. :  182
Processing batch no. :  183
Processing batch no. :  184
Processing batch no. :  185
Processing batch no. :  186
Processing batch no. :  187
Processing batch no. :  188
Processing batch no. :  189
Processing batch no. :  190
Processing batch no. :  191
Processing batch no.

## **Plot Loss & ECE vs #Batches**

In [None]:
import matplotlib.pyplot as plt

def plot_loss_ece_vs_batches(loss_values, ece_values):
    """
    Plots loss and ECE versus batches graph.

    Parameters:
    - loss_values: List of loss values.
    - ece_values: List of ECE values.
    """
    # Check that the lists are the same length
    if len(loss_values) != len(ece_values):
        raise ValueError("The length of loss_values and ece_values must be the same.")
    
    # Generate batch numbers based on the length of the lists
    batch_numbers = list(range(1, len(loss_values) + 1))

    fig, ax1 = plt.subplots()

    color = 'tab:red'
    ax1.set_xlabel('Batch Number')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(batch_numbers, loss_values, color=color, label='Loss')
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:blue'
    ax2.set_ylabel('ECE', color=color)  # we already handled the x-label with ax1
    ax2.plot(batch_numbers, ece_values, color=color, label='ECE')
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    plt.title('Loss and ECE vs Batches')
    fig.legend(loc="upper left", bbox_to_anchor=(0.1,0.9))
    plt.show()

plot_loss_ece_vs_batches(loss, ece)


## **Save our Calibrator**

In [12]:
save_calibrator(calibrator)

## **Calibrating ViLT Logit using Vector Scaling**

In [25]:
def calibrate_logits(calibrator):
    NUM_BATCHES = 2139
    data_path = "/teamspace/studios/this_studio/Selective_Prediction_VQA/predictions/logits_and_labels/"
    result_path = "/teamspace/studios/this_studio/Selective_Prediction_VQA/calibration_methods/calibrated_logits_and_labels/"
    for batch_no in range(NUM_BATCHES):
        print("Calibrating batch no. : ", batch_no)
        file_name = "Logits_and_labels" + str(batch_no) + ".pt"
        res_file_name = "vec_Logits_and_labels" + str(batch_no) + ".pt"
        data = torch.load(data_path + file_name)
        
        logits = np.stack(data['logits'])
        cal_logits = calibrator.calibrate(logits)
        res = {}
        res['labels'] = data['labels']
        res['logits'] = cal_logits
        torch.save(res, result_path+res_file_name)
        
        

In [26]:
calibrate_logits(calibrator)

Calibrating batch no. :  0
Calibrating batch no. :  1


Calibrating batch no. :  2
Calibrating batch no. :  3
Calibrating batch no. :  4
Calibrating batch no. :  5
Calibrating batch no. :  6
Calibrating batch no. :  7
Calibrating batch no. :  8
Calibrating batch no. :  9
Calibrating batch no. :  10
Calibrating batch no. :  11
Calibrating batch no. :  12
Calibrating batch no. :  13
Calibrating batch no. :  14
Calibrating batch no. :  15
Calibrating batch no. :  16
Calibrating batch no. :  17
Calibrating batch no. :  18
Calibrating batch no. :  19
Calibrating batch no. :  20
Calibrating batch no. :  21
Calibrating batch no. :  22
Calibrating batch no. :  23
Calibrating batch no. :  24
Calibrating batch no. :  25
Calibrating batch no. :  26
Calibrating batch no. :  27
Calibrating batch no. :  28
Calibrating batch no. :  29
Calibrating batch no. :  30
Calibrating batch no. :  31
Calibrating batch no. :  32
Calibrating batch no. :  33
Calibrating batch no. :  34
Calibrating batch no. :  35
Calibrating batch no. :  36
Calibrating batch no. :  37
