In [1]:
import json
import numpy as np
import pandas as pd
import scipy.stats as st

import matplotlib.pyplot as plt

In [2]:
# Here we can add all the methods we have/want
models_list = ['clip_baseline',
    'coop_baseline', 'coop_pseudo_baseline',
    'vpt_baseline', 'vpt_pseudo_baseline', 
    'teacher_student',
]

In [3]:
model = models_list[-1]
print(model)
filename = f"results/results_model_{model}.json"

data = [json.loads(i) for i in open(filename,'r').readlines()]

teacher_student


In [5]:
std_accuracies = []
gen_accuracies = []
seen_accuracies = []
unseen_accuracies = []
harmonic_accuracies = []

fix_num = 0


for m in data:
    # Here change and filter data depending on what results we want/
    # For instance, let's assume we want 
    if (
        (m['config']['DATASET_NAME'] == 'Flowers102')  \
         and (m['config']['VIS_ENCODER'] == 'ViT-B/32') \
         and (m['config']['SPLIT_SEED'] == 500) \
    ):
        std_accuracies.append(m['std_accuracy'])
        gen_accuracies.append(m['gen_accuracy'])
        seen_accuracies.append(m['gen_seen'])
        unseen_accuracies.append(m['gen_unseen'])
        harmonic_accuracies.append(st.hmean([m['gen_seen'], m['gen_unseen']]))
        
std_accuracies = std_accuracies[fix_num:]
gen_accuracies = gen_accuracies[fix_num:]
seen_accuracies = seen_accuracies[fix_num:]
unseen_accuracies = unseen_accuracies[fix_num:]
harmonic_accuracies = harmonic_accuracies[fix_num:]

print(f"Mean STD accuracy: {round((np.sum(std_accuracies)/len(std_accuracies))*100, 2)}")
print(f"Std STD accuracy: {np.std(std_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(std_accuracies)-1, loc=np.mean(std_accuracies), scale=st.sem(std_accuracies)) 
print(f"95% STD confidence interval {round((np.sum(std_accuracies)/len(std_accuracies) - interval[0])*100, 2)}")
print('\n')

print(f"Mean SEEN accuracy: {round((np.sum(seen_accuracies)/len(seen_accuracies))*100, 2)}")
print(f"Std SEEN accuracy: {np.std(seen_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(seen_accuracies)-1, loc=np.mean(seen_accuracies), scale=st.sem(seen_accuracies)) 
print(f"95% SEEN confidence interval {round((np.sum(seen_accuracies)/len(seen_accuracies) - interval[0])*100, 2)}")
print('\n')

print(f"Mean UNSEEN accuracy: {round((np.sum(unseen_accuracies)/len(unseen_accuracies))*100,2)}")
print(f"Std UNSEEN accuracy: {np.std(unseen_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(unseen_accuracies)-1, loc=np.mean(unseen_accuracies), scale=st.sem(unseen_accuracies)) 
print(f"95% UNSEEN confidence interval {round((np.sum(unseen_accuracies)/len(unseen_accuracies) - interval[0])*100,2)}")
print('\n')

print(f"Mean HARMONIC accuracy: {round((np.sum(harmonic_accuracies)/len(harmonic_accuracies))*100, 2)}")
print(f"Std HARMONIC accuracy: {np.std(harmonic_accuracies)}")
interval = st.t.interval(alpha=0.95, df=len(harmonic_accuracies)-1, loc=np.mean(harmonic_accuracies), scale=st.sem(harmonic_accuracies)) 
print(f"95% HARMONIC confidence interval {round((np.sum(harmonic_accuracies)/len(harmonic_accuracies) - interval[0])*100, 2)}")

Mean STD accuracy: 82.76
Std STD accuracy: 0.006392898964132173
95% STD confidence interval 0.89


Mean SEEN accuracy: 75.78
Std SEEN accuracy: 0.005193545968350561
95% SEEN confidence interval 0.72


Mean UNSEEN accuracy: 78.92
Std UNSEEN accuracy: 0.008059541466209097
95% UNSEEN confidence interval 1.12


Mean HARMONIC accuracy: 77.32
Std HARMONIC accuracy: 0.004679945524518792
95% HARMONIC confidence interval 0.65


In [6]:
seen_accuracies

[0.7559820538384845,
 0.7619641076769691,
 0.7485044865403788,
 0.7597208374875374,
 0.7627118644067796]

In [7]:
unseen_accuracies

[0.7810014038371549,
 0.789424426766495,
 0.7903603182030885,
 0.8034627983153955,
 0.7819372952737482]

In [8]:
harmonic_accuracies

[0.7682880936356221,
 0.7754512368306479,
 0.7688631806185371,
 0.7809798107472069,
 0.7722049354570382]

In [9]:
std_accuracies

[0.8207767898923725,
 0.8235844642021526,
 0.8296677585400094,
 0.8390266729059429,
 0.8249883013570426]

In the google sheets, I report the mean and the confidence interval.

In [10]:
a = [0.906, 0.772, 0.567, 0.8936]

In [11]:
b = [0.8175, 0.8405, 0.8705, 0.7975]

In [12]:
np.mean(st.hmean([a, b]))

0.7984508171723251

In [13]:
import pickle

In [14]:
# classes = []
# with open(f"pseudolabels/train.json", "r") as f:
#     data = json.load(f)
#     for d in data["categories"]:
#         classes.append(d["name"].replace("_", " "))

classes = []
with open(f"pseudolabels/flo_classes.txt", 'r') as file:
    for l in file:
        classes.append(l.strip())

class_dict = {i:c for i,c in enumerate(classes)}
class_dict

{0: 'pink primrose',
 1: 'hard-leaved pocket orchid',
 2: 'canterbury bells',
 3: 'sweet pea',
 4: 'english marigold',
 5: 'tiger lily',
 6: 'moon orchid',
 7: 'bird of paradise',
 8: 'monkshood',
 9: 'globe thistle',
 10: 'snapdragon',
 11: "colt's foot",
 12: 'king protea',
 13: 'spear thistle',
 14: 'yellow iris',
 15: 'globe flower',
 16: 'purple coneflower',
 17: 'peruvian lily',
 18: 'balloon flower',
 19: 'giant white arum lily',
 20: 'fire lily',
 21: 'pincushion flower',
 22: 'fritillary',
 23: 'red ginger',
 24: 'grape hyacinth',
 25: 'corn poppy',
 26: 'prince of wales feathers',
 27: 'stemless gentian',
 28: 'artichoke',
 29: 'sweet william',
 30: 'carnation',
 31: 'garden phlox',
 32: 'love in the mist',
 33: 'mexican aster',
 34: 'alpine sea holly',
 35: 'ruby-lipped cattleya',
 36: 'cape flower',
 37: 'great masterwort',
 38: 'siam tulip',
 39: 'lenten rose',
 40: 'barbeton daisy',
 41: 'daffodil',
 42: 'sword lily',
 43: 'poinsettia',
 44: 'bolero deep blue',
 45: 'wall

In [294]:
# with open('pseudolabels/RESICS45_ViT-B32_14_pseudolabels.pickle', 'rb') as f:
#     clip = pickle.load(f)

In [29]:
with open('pseudolabels/Flowers102_teacher_student_ViT-B32_teacher_iter_2_opt_1_pseudolabels_spl_500_s_epochs_50_before_student.pickle', 'rb') as f:
#with open('pseudolabels/DTD_teacher_student_ViT-B32_teacher_iter_3_pseudolabels_spl_500.pickle', 'rb') as f:
    teacher_bef = pickle.load(f)

In [34]:
with open('pseudolabels/Flowers102_teacher_student_ViT-B32_teacher_iter_2_opt_1_pseudolabels_spl_500_s_epochs_50.pickle', 'rb') as f:
    teacher = pickle.load(f)

In [36]:
with open('pseudolabels/Flowers102_teacher_student_ViT-B32_student_iter_3_opt_1_pseudolabels_spl_500_s_epochs_50.pickle', 'rb') as f:
    student = pickle.load(f)

In [562]:
# with open('pseudolabels/Flowers102_ablation_teacher_student_ViT-B32_teacher_iter_2_pseudolabels_spl_500.pickle', 'rb') as f:
#     student = pickle.load(f)

In [30]:
len(list(zip(teacher_bef['filepaths'], teacher_bef['labels'])))

156

In [35]:
len(list(zip(teacher['filepaths'], teacher['labels'])))

234

In [37]:
len(list(zip(student['filepaths'], student['labels'])))

234

In [49]:
len((set(teacher['filepaths'])).intersection(set(student['filepaths'])))

234

In [38]:
correct = 0
for f, l in list(zip(teacher_bef['filepaths'], teacher_bef['labels'])):
    #print(f)
    #cls = ' '.join(f.split('/')[-1].split('_')[:-1]).strip()
    cls = class_dict[int(f.split('/')[-2])]
    #cls = (f.split('/')[-2])
    #print(cls)
    true_cls = class_dict[int(l)]
    #print(cls, true_cls)
    if cls == true_cls:
        correct += 1
    else:
        continue
        print(cls, true_cls)
        
print(len(teacher_bef['filepaths']))
print(correct)
print(correct/len(teacher_bef['filepaths']))

156
146
0.9358974358974359


In [39]:
correct = 0
for f, l in list(zip(teacher['filepaths'], teacher['labels'])):
    #print(f)
    #cls = ' '.join(f.split('/')[-1].split('_')[:-1]).strip()
    cls = class_dict[int(f.split('/')[-2])]
    # cls = (f.split('/')[-2])
    #print(cls)
    true_cls = class_dict[int(l)]
    #print(cls, true_cls)
    if cls == true_cls:
        correct += 1
    else:
        continue
        print(cls, true_cls)
        
print(len(teacher['filepaths']))
print(correct)
print(correct/len(teacher['filepaths']))

234
211
0.9017094017094017


In [40]:
correct = 0
for f, l in list(zip(student['filepaths'], student['labels'])):
    # cls = ' '.join(f.split('/')[-1].split('_')[:-1]).strip()
    cls = class_dict[int(f.split('/')[-2])]
    # cls = (f.split('/')[-2])
    true_cls = class_dict[int(l)]
    if cls == true_cls:
        correct += 1
    else:
        continue
        print(cls, true_cls)

print(len(student['filepaths']))
print(correct)
print(correct/len(student['filepaths']))

234
211
0.9017094017094017


In [403]:
correct

212

In [580]:
correct = 0
for f, l in list(zip(clip['filepaths'], clip['labels'])):
    cls = ' '.join(f.split('/')[-1].split('_')[:-1]).strip()
    true_cls = class_dict[int(l)]
    if cls == true_cls:
        correct += 1
    else:
        continue
        print(cls, true_cls)
        
print(len(clip['filepaths']))
print(correct)
print(correct/len(clip['filepaths']))

252
0
0.0


In [75]:
correct

156

In [540]:
classes = []
with open(f"../development/RESICS45/train.json", "r") as f:
    datas = json.load(f)
    for d in datas["categories"]:
        classes.append(d["name"].replace("_", " "))

In [141]:
for idx, i in enumerate(data['labels']):
    print(classes[i], data['filepaths'][idx].split('/')[-1])