In [2]:
import numpy as np
import pandas as pd
import os

def statistics(arr, correct=True, order=1):
    datasets = ['Aircraft','Caltech101','CIFAR100','DTD','EuroSAT','Flowers','Food','MNIST','OxfordPet','StanfordCars','SUN397']

    if correct:
        arr = np.delete(arr, 2, axis=1) # delete cifar10 eval result
    if order == 2: # using order II
        order = [9, 6, 7, 8, 5, 10, 0, 1, 3, 4, 2]
        datasets = [datasets[i] for i in order]
        arr = arr[:, order]
#         arr = arr[order, :]
    
    df = pd.DataFrame(arr)
    df.columns = datasets
    display(df)
    
    if len(arr) < 11:
        return 
    
    avg = arr.mean(axis=0)
    last = np.array(arr[-1, :])
    transfer = np.array([np.mean([arr[j, i] for j in range(i)]) for i in range(1, 11)])
    
    g = lambda x: np.around(x.mean(), decimals=1)
    f = lambda x: ' & '.join([str(x) for x in np.around(x, decimals=1).tolist()])
    
    print("transfer: ", g(transfer))
    print("avg: ", g(avg))
    print("last: ", g(last))
    
    print("transfer: ", f(transfer))
    print("avg: ", f(avg))
    print("last: ", f(last))
    
def visualize(name, order=1):
    result = np.load(os.path.join("results", name))
    result = result.reshape(11,-1)
    statistics(result, order=order)

In [5]:
filenames = [
    '5shot_order1.out',
    '5shot_order2.out',
]

for filename in filenames:
    print(filename)

    with open(filename) as file:
        all_lines = [line.rstrip() for line in file]

    split_exp = np.where(['[Dataset]' in x for x in all_lines])[0].tolist() #+ [len(all_lines)]


    datasets = ['Aircraft','Caltech101','CIFAR100','DTD','EuroSAT','Flowers','Food','MNIST','OxfordPet','StanfordCars','SUN397']

    lines = all_lines[split_exp[-1]:]
    if len(np.where(['Computing threshold' in x for x in lines])[0]) == 0:
        lines = all_lines[split_exp[-2]:]
    lines = lines[:np.where(['Computing threshold' in x for x in lines])[0][-1]]

    results = []
    for i in range(len(datasets)):
        if len(np.where([datasets[i] in x for x in lines])[0]) == 0:
            break
        dataset_line = lines[np.where([datasets[i] in x for x in lines])[0][-1]:]     

        accuracy = dataset_line[np.where(['Top-1 accuracy' in x for x in dataset_line])[0][0]]
        accuracy = float(accuracy.split(':')[-1])
        print(datasets[i], accuracy)

        results.append(accuracy)

    if len(results) == 11:
        print(np.mean(results))

    split_exp = split_exp + [len(all_lines)]

    result = []
    for i in range(len(split_exp)-1):
        lines = all_lines[split_exp[i]:split_exp[i+1]]
        if len(np.where(['Computing threshold' in x for x in lines])[0]) == 0:
            break

        incomplete = False
        all_accuracy = []
        for dataset in datasets:
            if len(np.where([dataset in x for x in lines])[0]) == 0:
                incomplete = True
                break
            dataset_line = lines[np.where([dataset in x for x in lines])[0][-1]:]
            if len(np.where(['Top-1 accuracy' in x for x in dataset_line])[0]) == 0:
                incomplete = True
                break

            accuracy = dataset_line[np.where(['Top-1 accuracy' in x for x in dataset_line])[0][0]]
            accuracy = float(accuracy.split(':')[-1])
            all_accuracy.append(accuracy)

        if incomplete:
            break

        result.append(all_accuracy)

    result = np.array(result).reshape(-1,11)
    
    statistics(result, correct=False, order=2 if 'order2' in filename else 1)
    print()


5shot_order1.out
Aircraft 37.74
Caltech101 92.34
CIFAR100 77.38
DTD 63.09
EuroSAT 89.83
Flowers 94.0
Food 87.65
MNIST 94.7
OxfordPet 91.55
StanfordCars 73.59
SUN397 72.4
79.47909090909091


Unnamed: 0,Aircraft,Caltech101,CIFAR100,DTD,EuroSAT,Flowers,Food,MNIST,OxfordPet,StanfordCars,SUN397
0,37.74,88.36,68.26,44.68,55.3,71.04,88.52,59.35,89.1,64.71,65.26
1,37.74,92.34,68.26,44.68,55.3,71.41,88.52,61.4,89.37,64.71,65.26
2,37.74,92.34,77.38,44.68,55.3,71.36,88.52,61.4,89.37,64.71,65.26
3,37.74,92.34,77.38,58.88,55.3,70.22,88.52,61.4,89.37,64.71,65.26
4,37.74,92.34,77.38,58.88,89.69,70.22,88.52,61.4,89.37,64.71,65.26
5,37.74,92.34,77.38,58.88,89.69,89.45,88.52,61.4,89.37,64.71,65.26
6,37.74,92.34,77.38,58.88,89.69,89.45,88.2,61.4,89.37,64.71,65.26
7,37.74,92.34,77.38,58.88,89.69,89.45,88.2,94.7,89.37,64.71,65.26
8,37.74,92.34,77.38,58.88,89.69,89.45,88.2,94.7,91.55,64.71,65.26
9,37.74,92.34,77.38,58.88,89.69,89.45,88.2,94.7,91.55,73.59,65.26


transfer:  69.6
avg:  73.0
last:  78.7
transfer:  88.4 & 68.3 & 44.7 & 55.3 & 70.8 & 88.5 & 61.1 & 89.3 & 64.7 & 65.3
avg:  37.7 & 92.0 & 75.7 & 55.0 & 77.2 & 81.0 & 88.4 & 73.3 & 89.9 & 66.3 & 65.9
last:  37.7 & 92.3 & 77.4 & 58.9 & 89.7 & 89.4 & 88.2 & 94.7 & 91.6 & 73.6 & 72.4

5shot_order2.out
Aircraft 36.06
Caltech101 92.11
CIFAR100 76.97
DTD 66.01
EuroSAT 84.94
Flowers 94.29
Food 87.63
MNIST 93.2
OxfordPet 91.25
StanfordCars 72.49
SUN397 73.27
78.92909090909092


Unnamed: 0,StanfordCars,Food,MNIST,OxfordPet,Flowers,SUN397,Aircraft,Caltech101,DTD,EuroSAT,CIFAR100
0,72.49,88.52,59.35,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
1,72.49,87.62,59.35,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
2,72.49,87.62,93.2,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
3,72.49,87.62,93.2,91.25,71.04,65.26,24.36,88.36,44.68,55.3,68.26
4,72.49,87.61,93.2,91.25,90.05,65.26,24.36,88.36,44.68,55.3,68.26
5,72.49,87.28,93.2,91.25,90.05,73.27,20.64,88.88,44.68,55.3,68.26
6,72.49,87.28,93.2,91.25,90.05,73.27,36.06,88.88,44.68,55.3,68.26
7,72.49,88.03,93.2,91.25,90.05,73.27,36.06,92.11,44.68,55.3,68.26
8,72.49,88.03,93.2,91.25,90.05,73.27,36.06,92.11,58.94,55.3,68.26
9,72.49,88.03,93.2,91.25,90.05,73.27,36.06,92.11,58.94,84.94,68.26


transfer:  65.4
avg:  71.7
last:  77.9
transfer:  88.5 & 59.4 & 89.1 & 71.0 & 65.3 & 23.7 & 88.5 & 44.7 & 55.3 & 68.3
avg:  72.5 & 87.8 & 87.0 & 90.7 & 83.1 & 69.6 & 29.3 & 89.8 & 48.6 & 60.7 & 69.1
last:  72.5 & 88.0 & 93.2 & 91.2 & 90.0 & 73.3 & 36.1 & 92.1 & 58.9 & 84.9 & 77.0



In [6]:

import numpy as np

filenames = [
    'full_order1.out',
    'full_order2.out',
]

for filename in filenames:
    print(filename)

    with open(filename) as file:
        all_lines = [line.rstrip() for line in file]

    split_exp = np.where(['[Dataset]' in x for x in all_lines])[0].tolist() #+ [len(all_lines)]


    datasets = ['Aircraft','Caltech101','CIFAR100','DTD','EuroSAT','Flowers','Food','MNIST','OxfordPet','StanfordCars','SUN397']

    lines = all_lines[split_exp[-1]:]
    if len(np.where(['Computing threshold' in x for x in lines])[0]) == 0:
        lines = all_lines[split_exp[-2]:]
    lines = lines[:np.where(['Computing threshold' in x for x in lines])[0][-1]]

    results = []
    for i in range(len(datasets)):
        if len(np.where([datasets[i] in x for x in lines])[0]) == 0:
            break
        dataset_line = lines[np.where([datasets[i] in x for x in lines])[0][-1]:]     

        accuracy = dataset_line[np.where(['Top-1 accuracy' in x for x in dataset_line])[0][0]]
        accuracy = float(accuracy.split(':')[-1])
        print(datasets[i], accuracy)

        results.append(accuracy)

    if len(results) == 11:
        print(np.mean(results))

    split_exp = split_exp + [len(all_lines)]

    result = []
    for i in range(len(split_exp)-1):
        lines = all_lines[split_exp[i]:split_exp[i+1]]
        if len(np.where(['Computing threshold' in x for x in lines])[0]) == 0:
            break

        incomplete = False
        all_accuracy = []
        for dataset in datasets:
            if len(np.where([dataset in x for x in lines])[0]) == 0:
                incomplete = True
                break
            dataset_line = lines[np.where([dataset in x for x in lines])[0][-1]:]
            if len(np.where(['Top-1 accuracy' in x for x in dataset_line])[0]) == 0:
                incomplete = True
                break

            accuracy = dataset_line[np.where(['Top-1 accuracy' in x for x in dataset_line])[0][0]]
            accuracy = float(accuracy.split(':')[-1])
            all_accuracy.append(accuracy)

        if incomplete:
            break

        result.append(all_accuracy)

    result = np.array(result).reshape(-1,11)
    
    statistics(result, correct=False, order=2 if 'order2' in filename else 1)
    print()


full_order1.out
Aircraft 54.31
Caltech101 95.28
CIFAR100 86.16
DTD 78.62
EuroSAT 98.63
Flowers 96.1
Food 92.29
MNIST 99.34
OxfordPet 93.79
StanfordCars 85.81
SUN397 80.89
87.38363636363636


Unnamed: 0,Aircraft,Caltech101,CIFAR100,DTD,EuroSAT,Flowers,Food,MNIST,OxfordPet,StanfordCars,SUN397
0,54.31,88.36,68.26,44.68,55.3,71.04,88.52,59.35,89.1,64.71,65.26
1,54.31,94.07,68.26,44.68,55.3,71.04,88.52,69.71,86.73,64.71,65.26
2,54.31,94.07,86.16,44.68,55.3,71.25,88.52,69.35,86.73,64.71,65.26
3,54.31,94.07,86.16,75.48,55.3,68.65,88.52,69.35,86.73,64.71,65.26
4,54.31,94.07,86.16,75.48,98.63,68.65,88.52,69.35,86.73,64.71,65.26
5,54.31,94.07,86.16,75.48,98.63,91.43,88.52,69.35,86.73,64.71,65.26
6,54.31,94.07,86.16,75.48,98.63,91.43,92.19,69.35,86.73,64.71,65.26
7,54.31,94.07,86.16,75.48,98.63,91.43,92.19,99.34,86.73,64.71,65.26
8,54.31,94.07,86.16,75.48,98.63,91.43,92.19,99.34,93.79,64.71,65.26
9,54.31,94.07,86.16,75.48,98.63,91.43,92.19,99.34,93.79,85.81,65.26


transfer:  70.0
avg:  77.8
last:  86.6
transfer:  88.4 & 68.3 & 44.7 & 55.3 & 70.1 & 88.5 & 68.0 & 87.0 & 64.7 & 65.3
avg:  54.3 & 93.6 & 82.9 & 67.1 & 82.9 & 81.7 & 90.2 & 79.4 & 88.9 & 68.5 & 66.7
last:  54.3 & 94.1 & 86.2 & 75.5 & 98.6 & 91.4 & 92.2 & 99.3 & 93.8 & 85.8 & 80.9

full_order2.out
Aircraft 54.76
Caltech101 95.22
CIFAR100 86.14
DTD 79.36
EuroSAT 98.57
Flowers 95.92
Food 92.18
MNIST 99.43
OxfordPet 93.32
StanfordCars 85.56
SUN397 80.77
87.38454545454546


Unnamed: 0,StanfordCars,Food,MNIST,OxfordPet,Flowers,SUN397,Aircraft,Caltech101,DTD,EuroSAT,CIFAR100
0,85.56,88.52,59.35,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
1,85.56,92.11,59.35,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
2,85.56,92.11,99.43,89.1,71.04,65.26,24.36,88.36,44.68,55.3,68.26
3,85.56,92.11,99.43,93.32,71.04,65.26,24.36,88.36,44.68,55.3,68.26
4,85.56,92.11,99.43,93.32,91.36,65.26,24.36,88.36,44.68,55.3,68.26
5,85.56,92.11,99.43,93.32,91.36,80.77,19.47,86.92,44.68,55.3,68.26
6,85.56,92.11,99.43,93.32,91.36,80.77,54.76,86.92,44.68,55.3,68.26
7,85.56,92.11,99.43,93.32,91.36,80.77,54.76,95.22,44.68,55.3,68.26
8,85.56,92.11,99.43,93.32,91.36,80.77,54.76,95.22,76.12,55.3,68.26
9,85.56,92.11,99.43,93.32,91.36,80.77,54.76,95.22,76.12,98.57,68.26


transfer:  65.3
avg:  75.8
last:  86.7
transfer:  88.5 & 59.4 & 89.1 & 71.0 & 65.3 & 23.5 & 87.9 & 44.7 & 55.3 & 68.3
avg:  85.6 & 91.8 & 92.1 & 92.2 & 84.0 & 73.7 & 37.7 & 90.6 & 53.3 & 63.2 & 69.9
last:  85.6 & 92.1 & 99.4 & 93.3 & 91.4 & 80.8 & 54.8 & 95.2 & 76.1 & 98.6 & 86.1

