In [1]:
import json
import numpy as np
import pandas as pd
import scipy.stats as st

import matplotlib.pyplot as plt

In [2]:
# Here we can add all the methods we have/want
models_list = [
    'coop_baseline', 'coop_pseudo_baseline',
    'vpt_baseline', 'vpt_pseudo_baseline', 
    'teacher_student',
]

In [16]:
model = models_list[-1]
print(model)
filename = f"results_model_{model}.json"

data = [json.loads(i) for i in open(filename,'r').readlines()]

teacher_student


In [17]:
std_accuracies = []
gen_accuracies = []
seen_accuracies = []
unseen_accuracies = []
harmonic_accuracies = []

for m in data:
    # Here change and filter data depending on what results we want/
    # For instance, let's assume we want 
    if (
        (m['config']['DATASET_NAME'] == 'Flowers102') # \
#         and (m['config']['VIS_ENCODER'] == 'RN101') \
#         and (m['config']['SPLIT_SEED'] == 0) 
    ):
        std_accuracies.append(m['std_accuracy'])
        gen_accuracies.append(m['gen_accuracy'])
        seen_accuracies.append(m['gen_seen'])
        unseen_accuracies.append(m['gen_unseen'])
        harmonic_accuracies.append(st.hmean([m['gen_seen'], m['gen_unseen']]))

print(f"Mean STD accuracy: {round((np.sum(std_accuracies)/len(std_accuracies))*100, 2)}")
print(f"Std STD accuracy: {np.std(std_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(std_accuracies)-1, loc=np.mean(std_accuracies), scale=st.sem(std_accuracies)) 
print(f"95% STD confidence interval {round((np.sum(std_accuracies)/len(std_accuracies) - interval[0])*100, 2)}")
print('\n')

print(f"Mean SEEN accuracy: {round((np.sum(seen_accuracies)/len(seen_accuracies))*100, 2)}")
print(f"Std SEEN accuracy: {np.std(seen_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(seen_accuracies)-1, loc=np.mean(seen_accuracies), scale=st.sem(seen_accuracies)) 
print(f"95% SEEN confidence interval {round((np.sum(seen_accuracies)/len(seen_accuracies) - interval[0])*100, 2)}")
print('\n')

print(f"Mean UNSEEN accuracy: {round((np.sum(unseen_accuracies)/len(unseen_accuracies))*100,2)}")
print(f"Std UNSEEN accuracy: {np.std(unseen_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(unseen_accuracies)-1, loc=np.mean(unseen_accuracies), scale=st.sem(unseen_accuracies)) 
print(f"95% UNSEEN confidence interval {round((np.sum(unseen_accuracies)/len(unseen_accuracies) - interval[0])*100,2)}")
print('\n')

print(f"Mean HARMONIC accuracy: {round((np.sum(harmonic_accuracies)/len(harmonic_accuracies))*100, 2)}")
print(f"Std HARMONIC accuracy: {np.std(harmonic_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(harmonic_accuracies)-1, loc=np.mean(harmonic_accuracies), scale=st.sem(harmonic_accuracies)) 
print(f"95% HARMONIC confidence interval {round((np.sum(harmonic_accuracies)/len(harmonic_accuracies) - interval[0])*100, 2)}")

Mean STD accuracy: 79.39
Std STD accuracy: 0.007259043305443267
95% STD confidence interval 1.01


Mean SEEN accuracy: 76.91
Std SEEN accuracy: 0.005305268273933617
95% SEEN confidence interval 0.74


Mean UNSEEN accuracy: 75.01
Std UNSEEN accuracy: 0.00909790458363294
95% UNSEEN confidence interval 1.26


Mean HARMONIC accuracy: 75.94
Std HARMONIC accuracy: 0.0030370373270242235
95% HARMONIC confidence interval 0.42


In the google sheets, I report the mean and the confidence interval.

In [18]:
import pickle

In [20]:
with open('Flowers102_coop_baseline_RN101_opt_100_spl_0.pickle', 'rb') as f:
    data = pickle.load(f)

In [23]:
data[0]

array([[[ 0.0141661 ,  0.01154805, -0.00731809, ...,  0.0088033 ,
          0.00533164,  0.00627134],
        [ 0.00435409,  0.0012099 , -0.00862264, ..., -0.00905044,
          0.01261343,  0.00636003],
        [ 0.00259562,  0.00175074, -0.00354808, ..., -0.01755268,
          0.00199584, -0.01359151],
        ...,
        [ 0.0136095 , -0.01302855, -0.00152316, ..., -0.00584143,
          0.01086486,  0.01312908],
        [-0.0004344 ,  0.0026134 , -0.00459469, ...,  0.0004229 ,
          0.01705721,  0.02296569],
        [ 0.0063952 ,  0.002413  , -0.00985158, ..., -0.00904753,
         -0.00120628, -0.00495799]]], dtype=float32)