In [1]:
import json

import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import numpy as np

2022-08-26 15:21:49.400797: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
def _load_json(json_file):
    with open(json_file) as f:
        data = json.load(f)
    return data

def load_results_to_df(json_results_patern):
    if isinstance(json_results_patern, list):
        json_files = []
        for patern in json_results_patern:
            json_files += tf.io.gfile.glob(patern)
    else:
        json_files = tf.io.gfile.glob(json_results_patern)
    
    results = [_load_json(jfile) for jfile in json_files]

    return pd.DataFrame(results)

def format_results(df):
    df['f_1^m'] = df['f1_score_cls'].apply(np.mean).round(3)
    df = df.drop(['f1_score_cls'], axis=1)
    
    df = pd.concat([df.drop(['acc_bins'], axis=1),
                      df['acc_bins'].apply(pd.Series).add_prefix('acc_')],
                     axis=1)
    df = df[['dataset', 'model', 'training_setup', 'acc_1', 'acc_2', 'acc_3', 'acc_4', 'acc_all', 'f_1^m']]
    df[['acc_1', 'acc_2', 'acc_3', 'acc_4', 'acc_all']] = (df[['acc_1', 'acc_2', 'acc_3', 'acc_4', 'acc_all']]*100).round(2)

    return df

def filter_results(df, dataset, column_list):
    df = df[(df.dataset==dataset)]
    df = df[['model', 'training_setup'] + column_list]
    df = df.set_index(['model','training_setup'])
    columns = [(dataset, col) for col in list(df.columns)]
    df.columns=pd.MultiIndex.from_tuples(columns)
    df = df.copy()
    
    return df

In [3]:
results_json = ['/data/fagner/training/bags_paper/results/exp1/*_results.json',
                '/data/fagner/training/bags_paper/checkpoints/*imagenet*/*_results.json',
                '/data/fagner/training/bags_paper/checkpoints/*_ssb_*/*_results.json']
results = load_results_to_df(results_json)
results = format_results(results)

training_setup_sort = {'repre': 0,
                       'imagenet': 1,
                       'crt': 2,
                       'imagenetcrt': 3,
                       'cbfocal': 4,
                       'imagenetcbfocal': 5,
                       'bags': 6,
                       'imagenetbags': 7,
                       'ssb': 8,
                       'resnet50': 0,
                       'mbnetv3': 1,
                       'effv2b2': 2,
                       'swin-s': 3}

results = results.sort_values(by=['model', 'training_setup'], key=lambda x: x.map(training_setup_sort))

In [4]:
results = results[results.training_setup.isin(['repre', 'crt', 'cbfocal', 'bags', 'ssb'])]

In [5]:
wcs_results = filter_results(results, 'wcs', ['acc_1', 'acc_2', 'acc_3', 'acc_4', 'acc_all', 'f_1^m'])
wcs_results

Unnamed: 0_level_0,Unnamed: 1_level_0,wcs,wcs,wcs,wcs,wcs,wcs
Unnamed: 0_level_1,Unnamed: 1_level_1,acc_1,acc_2,acc_3,acc_4,acc_all,f_1^m
model,training_setup,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
resnet50,repre,0.0,13.05,32.06,85.02,81.53,0.289
resnet50,crt,16.25,28.94,48.98,75.38,73.47,0.322
resnet50,cbfocal,5.0,14.42,40.61,83.41,80.43,0.316
resnet50,bags,16.25,14.19,34.84,84.43,81.14,0.283
resnet50,ssb,15.0,21.79,40.97,84.85,81.91,0.343
mbnetv3,repre,0.0,11.92,36.52,84.24,80.98,0.287
mbnetv3,crt,7.5,28.26,46.85,76.71,74.61,0.292
mbnetv3,cbfocal,0.0,9.19,36.98,79.41,76.42,0.276
mbnetv3,bags,10.0,10.44,28.68,84.09,80.48,0.248
mbnetv3,ssb,1.25,20.32,40.43,83.99,81.04,0.319


In [6]:
serengeti_results = filter_results(results, 'serengeti', ['acc_2', 'acc_3', 'acc_4', 'acc_all', 'f_1^m'])
serengeti_results

Unnamed: 0_level_0,Unnamed: 1_level_0,serengeti,serengeti,serengeti,serengeti,serengeti
Unnamed: 0_level_1,Unnamed: 1_level_1,acc_2,acc_3,acc_4,acc_all,f_1^m
model,training_setup,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
resnet50,repre,2.9,44.82,92.96,92.79,0.522
resnet50,crt,6.52,51.58,87.29,87.16,0.396
resnet50,cbfocal,0.72,38.66,92.88,92.69,0.492
resnet50,bags,9.42,41.27,93.12,92.94,0.499
resnet50,ssb,5.07,48.16,92.91,92.75,0.518
mbnetv3,repre,2.17,29.78,92.73,92.52,0.457
mbnetv3,crt,4.35,43.95,87.62,87.47,0.396
mbnetv3,cbfocal,0.72,29.78,92.44,92.23,0.46
mbnetv3,bags,5.07,28.78,92.19,91.97,0.432
mbnetv3,ssb,3.62,36.97,92.72,92.53,0.478


In [7]:
caltech_results = filter_results(results, 'caltech', [ 'acc_2', 'acc_4', 'acc_all', 'f_1^m'])
caltech_results

Unnamed: 0_level_0,Unnamed: 1_level_0,caltech,caltech,caltech,caltech
Unnamed: 0_level_1,Unnamed: 1_level_1,acc_2,acc_4,acc_all,f_1^m
model,training_setup,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
resnet50,repre,0.0,44.95,42.19,0.301
resnet50,crt,0.0,39.59,37.16,0.297
resnet50,cbfocal,0.0,44.47,41.74,0.314
resnet50,bags,1.56,48.04,45.19,0.326
resnet50,ssb,0.03,46.09,43.26,0.31
mbnetv3,repre,0.0,46.14,43.31,0.319
mbnetv3,crt,0.16,37.76,35.45,0.284
mbnetv3,cbfocal,0.0,41.27,38.74,0.302
mbnetv3,bags,0.0,43.96,41.26,0.304
mbnetv3,ssb,0.11,45.93,43.12,0.323


In [8]:
wellington_results = filter_results(results, 'wellington', ['acc_2', 'acc_3', 'acc_4', 'acc_all', 'f_1^m'])
wellington_results

Unnamed: 0_level_0,Unnamed: 1_level_0,wellington,wellington,wellington,wellington,wellington
Unnamed: 0_level_1,Unnamed: 1_level_1,acc_2,acc_3,acc_4,acc_all,f_1^m
model,training_setup,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
resnet50,repre,0.0,4.58,75.66,72.23,0.309
resnet50,crt,0.0,10.65,70.91,68.01,0.315
resnet50,cbfocal,0.0,4.84,75.12,71.73,0.312
resnet50,bags,0.0,25.92,75.0,72.64,0.319
resnet50,ssb,0.0,5.51,75.62,72.24,0.312
mbnetv3,repre,0.0,5.7,75.65,72.28,0.326
mbnetv3,crt,0.0,21.68,70.45,68.1,0.313
mbnetv3,cbfocal,0.0,4.58,74.94,71.55,0.313
mbnetv3,bags,0.0,23.41,74.02,71.58,0.292
mbnetv3,ssb,0.0,13.54,75.34,72.36,0.326


In [9]:
result = pd.concat([wcs_results, serengeti_results, caltech_results, wellington_results], axis=1)
result

Unnamed: 0_level_0,Unnamed: 1_level_0,wcs,wcs,wcs,wcs,wcs,wcs,serengeti,serengeti,serengeti,serengeti,serengeti,caltech,caltech,caltech,caltech,wellington,wellington,wellington,wellington,wellington
Unnamed: 0_level_1,Unnamed: 1_level_1,acc_1,acc_2,acc_3,acc_4,acc_all,f_1^m,acc_2,acc_3,acc_4,acc_all,f_1^m,acc_2,acc_4,acc_all,f_1^m,acc_2,acc_3,acc_4,acc_all,f_1^m
model,training_setup,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
resnet50,repre,0.0,13.05,32.06,85.02,81.53,0.289,2.9,44.82,92.96,92.79,0.522,0.0,44.95,42.19,0.301,0.0,4.58,75.66,72.23,0.309
resnet50,crt,16.25,28.94,48.98,75.38,73.47,0.322,6.52,51.58,87.29,87.16,0.396,0.0,39.59,37.16,0.297,0.0,10.65,70.91,68.01,0.315
resnet50,cbfocal,5.0,14.42,40.61,83.41,80.43,0.316,0.72,38.66,92.88,92.69,0.492,0.0,44.47,41.74,0.314,0.0,4.84,75.12,71.73,0.312
resnet50,bags,16.25,14.19,34.84,84.43,81.14,0.283,9.42,41.27,93.12,92.94,0.499,1.56,48.04,45.19,0.326,0.0,25.92,75.0,72.64,0.319
resnet50,ssb,15.0,21.79,40.97,84.85,81.91,0.343,5.07,48.16,92.91,92.75,0.518,0.03,46.09,43.26,0.31,0.0,5.51,75.62,72.24,0.312
mbnetv3,repre,0.0,11.92,36.52,84.24,80.98,0.287,2.17,29.78,92.73,92.52,0.457,0.0,46.14,43.31,0.319,0.0,5.7,75.65,72.28,0.326
mbnetv3,crt,7.5,28.26,46.85,76.71,74.61,0.292,4.35,43.95,87.62,87.47,0.396,0.16,37.76,35.45,0.284,0.0,21.68,70.45,68.1,0.313
mbnetv3,cbfocal,0.0,9.19,36.98,79.41,76.42,0.276,0.72,29.78,92.44,92.23,0.46,0.0,41.27,38.74,0.302,0.0,4.58,74.94,71.55,0.313
mbnetv3,bags,10.0,10.44,28.68,84.09,80.48,0.248,5.07,28.78,92.19,91.97,0.432,0.0,43.96,41.26,0.304,0.0,23.41,74.02,71.58,0.292
mbnetv3,ssb,1.25,20.32,40.43,83.99,81.04,0.319,3.62,36.97,92.72,92.53,0.478,0.11,45.93,43.12,0.323,0.0,13.54,75.34,72.36,0.326
