In [1]:
import pandas as pd
import numpy as np
import os
import seaborn as sns
import torch
import pickle
import json
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt
from collections import Counter
from helpers import *
from explore import *
from sklearn.metrics import roc_auc_score
from itertools import product
pd.set_option('display.max_rows', 1000)

In [2]:
class_names = ["airplane", "automobile","bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [29]:
def softmax_t(logits, temp=1):
    logits = logits/temp
    _max = np.expand_dims(np.max(logits, axis=-1), axis=-1)
    probs = np.exp(logits - _max)
    _sum = np.expand_dims(np.sum(probs, axis=-1), axis=-1)
    return probs/_sum

def cluster_purity(kmeans_targets, in_targets):
    k_classes = np.unique(kmeans_targets).astype(int)
    k_class_idx = [np.nonzero(np.equal(cls_, kmeans_targets)) for cls_ in k_classes]
    in_classes_in_k_clstr = [in_targets[idx] for idx in k_class_idx]
    purity_list = []

    for cluster_k in in_classes_in_k_clstr:
        unique, counts = np.unique(cluster_k, return_counts=True)
        purity_list.append(np.round(np.asarray(counts).max()/len(cluster_k), 5))

    return purity_list

def save_pickle(root, filename, data):
    filename = os.path.join(root, 
                     f'{filename}.pickle')
    
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
        print(f"saved {filename}")

## Load Model

In [68]:
encoder_type = "key"
output_layer = "avg_pool"

## Load IID

In [69]:
# load id
iid = "CIFAR10"
prefix = f"{iid}_{encoder_type}_{output_layer}"
prefix

'CIFAR10_key_avg_pool'

In [71]:
id_train = load_features("../cache", f"{prefix}_id_train.npy")
id_test = load_features("../cache", f"{prefix}_id_test.npy")
#id_train_targ = load_features("../cache", f"{prefix}_id_train_targ.npy")
#id_test_targ = load_features("../cache", f"{prefix}_id_test_targ.npy")
id_train_targ = np.array(CIFAR10("/data/datasets/CIFAR10", train=True).targets)
id_test_targ = np.array(CIFAR10("/data/datasets/CIFAR10", train=False).targets)

## Load OOD

In [72]:
# load ood
oods = ["CIFAR100", "SVHN",  "LSUNResize"]

def get_ood_feats(ood):
    prefix = f"{ood}_{encoder_type}_{output_layer}"
    print(f"Using {prefix}")
    ood_train = load_features("../cache", f"{prefix}_ood_train.npy")
    ood_test = load_features("../cache",  f"{prefix}_ood_test.npy")
    return ood_train, ood_test

## Train Features

In [73]:
#train_linear_model(id_train, id_train_targ, id_test,  id_test_targ, 3)

## Get Prototypes

In [74]:
get_prototypes = True

In [75]:
if get_prototypes:
    num_cluster = 768
    id_im2cluster, id_prototypes, id_density = run_clustering(norm_feats(id_train), num_cluster)

performing kmeans clustering


## Perform OOD detection Prototypes

In [76]:
ood_results = []
gmc = [True]
metrics = ["mahalanobis", "cosine"]
pca_com = 0
cluster_method = "kmeans"

ood_evaluator = OodEvaluator(norm_feats(id_train), norm_feats(id_test), id_train_targ, id_test_targ,
                num_clusters = 10,
                pca_com = None,
                cluster_method = "kmeans",
                means = id_prototypes,
                im2cluster = id_im2cluster,
                clip = 0,
                clip_metric = "cosine")
    
for ood, met in product(oods, metrics):
    global_mal_cov = True
    ood_train, ood_test = get_ood_feats(ood)
    if met == "mahalanobis" and pca_com == 0 and cluster_method=="kmeans" and global_mal_cov:
        ood_evaluator(norm_feats(ood_test), met, global_cov=True, inv_choice="default", recal=True)
    else:
        ood_evaluator(norm_feats(ood_test), met)
        
    ood_evaluator.get_scores()

    ood_evaluator.get_auroc()
    
    print("\n")


    aurocs = ood_evaluator.auroc
    tnrs = ood_evaluator.tnr_at_tpr95

    result = {
        "ood": ood,   
        "metric": met,
            "pca": 0,
            "clusters": "prototypes",
            "n_auroc": aurocs[0],
            "n_tnr": tnrs[0],
            "e_auroc": aurocs[1],
            "e_tnr":  tnrs[1],
            "sklearn_auroc": aurocs[2],
            "global_mal_cov": global_mal_cov
            }
    pprint(result)
    ood_results.append(result)


setting pca and num_cluster to 0 because means are supplied
using supplied means
Using CIFAR100_key_avg_pool
** Using Ground Truths **
** No PCA **
calculating train distances
global conv


100%|██████████| 768/768 [20:47<00:00,  1.62s/it]


calculating test distances
global conv


100%|██████████| 768/768 [03:47<00:00,  3.38it/s]


calculating ood distances
global conv


100%|██████████| 768/768 [03:47<00:00,  3.38it/s]


train ID accuracy: 94.20%
test ID accuracy: 0.11%
n auroc: 65.73706849999999, tnr@tpr95: 10.75
e auroc: 65.7159715, tnr@tpr95: 11.51
sklearn-auroc: 66.415063 %


{'clusters': 'prototypes',
 'e_auroc': 65.7159715,
 'e_tnr': 11.51,
 'global_mal_cov': True,
 'metric': 'mahalanobis',
 'n_auroc': 65.73706849999999,
 'n_tnr': 10.75,
 'ood': 'CIFAR100',
 'pca': 0,
 'sklearn_auroc': 66.415063}
Using CIFAR100_key_avg_pool
** Using Ground Truths **
** No PCA **
calculating train distances
calculating test distances
calculating ood distances
train ID accuracy: 99.94%
test ID accuracy: 0.09%
n auroc: 69.5341345, tnr@tpr95: 13.81
e auroc: 72.22727599999999, tnr@tpr95: 16.21
sklearn-auroc: 72.7243485 %


{'clusters': 'prototypes',
 'e_auroc': 72.22727599999999,
 'e_tnr': 16.21,
 'global_mal_cov': True,
 'metric': 'cosine',
 'n_auroc': 69.5341345,
 'n_tnr': 13.81,
 'ood': 'CIFAR100',
 'pca': 0,
 'sklearn_auroc': 72.7243485}
Using SVHN_key_avg_pool
** Using Ground Truths **
** No PCA **
calculating tr

100%|██████████| 768/768 [20:24<00:00,  1.59s/it]


calculating test distances
global conv


100%|██████████| 768/768 [03:37<00:00,  3.53it/s]


calculating ood distances
global conv


100%|██████████| 768/768 [10:12<00:00,  1.25it/s]


train ID accuracy: 94.20%
test ID accuracy: 0.11%
n auroc: 94.84469153349723, tnr@tpr95: 94.69999999999999
e auroc: 95.14952808082361, tnr@tpr95: 94.78999999999999
sklearn-auroc: 99.2335602335587 %


{'clusters': 'prototypes',
 'e_auroc': 95.14952808082361,
 'e_tnr': 94.78999999999999,
 'global_mal_cov': True,
 'metric': 'mahalanobis',
 'n_auroc': 94.84469153349723,
 'n_tnr': 94.69999999999999,
 'ood': 'SVHN',
 'pca': 0,
 'sklearn_auroc': 99.2335602335587}
Using SVHN_key_avg_pool
** Using Ground Truths **
** No PCA **
calculating train distances
calculating test distances
calculating ood distances
train ID accuracy: 99.94%
test ID accuracy: 0.09%
n auroc: 94.17582782728947, tnr@tpr95: 79.32000000000001
e auroc: 97.99290565457898, tnr@tpr95: 84.44
sklearn-auroc: 98.1736326444376 %


{'clusters': 'prototypes',
 'e_auroc': 97.99290565457898,
 'e_tnr': 84.44,
 'global_mal_cov': True,
 'metric': 'cosine',
 'n_auroc': 94.17582782728947,
 'n_tnr': 79.32000000000001,
 'ood': 'SVHN',
 'pca': 0,

100%|██████████| 768/768 [20:28<00:00,  1.60s/it]


calculating test distances
global conv


100%|██████████| 768/768 [03:41<00:00,  3.46it/s]


calculating ood distances
global conv


100%|██████████| 768/768 [03:40<00:00,  3.48it/s]


train ID accuracy: 94.20%
test ID accuracy: 0.11%
n auroc: 79.80052850000001, tnr@tpr95: 36.95
e auroc: 79.736658, tnr@tpr95: 32.01
sklearn-auroc: 80.02505699999999 %


{'clusters': 'prototypes',
 'e_auroc': 79.736658,
 'e_tnr': 32.01,
 'global_mal_cov': True,
 'metric': 'mahalanobis',
 'n_auroc': 79.80052850000001,
 'n_tnr': 36.95,
 'ood': 'LSUNResize',
 'pca': 0,
 'sklearn_auroc': 80.02505699999999}
Using LSUNResize_key_avg_pool
** Using Ground Truths **
** No PCA **
calculating train distances
calculating test distances
calculating ood distances
train ID accuracy: 99.94%
test ID accuracy: 0.09%
n auroc: 75.88427949999999, tnr@tpr95: 25.979999999999997
e auroc: 80.5792735, tnr@tpr95: 39.989999999999995
sklearn-auroc: 81.41561049999999 %


{'clusters': 'prototypes',
 'e_auroc': 80.5792735,
 'e_tnr': 39.989999999999995,
 'global_mal_cov': True,
 'metric': 'cosine',
 'n_auroc': 75.88427949999999,
 'n_tnr': 25.979999999999997,
 'ood': 'LSUNResize',
 'pca': 0,
 'sklearn_auroc': 81.4156104

In [77]:
pd.DataFrame(ood_results)

Unnamed: 0,ood,metric,pca,clusters,n_auroc,n_tnr,e_auroc,e_tnr,sklearn_auroc,global_mal_cov
0,CIFAR100,mahalanobis,0,prototypes,65.737068,10.75,65.715971,11.51,66.415063,True
1,CIFAR100,cosine,0,prototypes,69.534134,13.81,72.227276,16.21,72.724349,True
2,SVHN,mahalanobis,0,prototypes,94.844692,94.7,95.149528,94.79,99.23356,True
3,SVHN,cosine,0,prototypes,94.175828,79.32,97.992906,84.44,98.173633,True
4,LSUNResize,mahalanobis,0,prototypes,79.800529,36.95,79.736658,32.01,80.025057,True
5,LSUNResize,cosine,0,prototypes,75.884279,25.98,80.579273,39.99,81.41561,True


In [78]:
pd.DataFrame(ood_results)

Unnamed: 0,ood,metric,pca,clusters,n_auroc,n_tnr,e_auroc,e_tnr,sklearn_auroc,global_mal_cov
0,CIFAR100,mahalanobis,0,prototypes,65.737068,10.75,65.715971,11.51,66.415063,True
1,CIFAR100,cosine,0,prototypes,69.534134,13.81,72.227276,16.21,72.724349,True
2,SVHN,mahalanobis,0,prototypes,94.844692,94.7,95.149528,94.79,99.23356,True
3,SVHN,cosine,0,prototypes,94.175828,79.32,97.992906,84.44,98.173633,True
4,LSUNResize,mahalanobis,0,prototypes,79.800529,36.95,79.736658,32.01,80.025057,True
5,LSUNResize,cosine,0,prototypes,75.884279,25.98,80.579273,39.99,81.41561,True


In [79]:
save_pickle("../cache", "CIFAR10_prototype_ood_result", ood_results)

saved ../cache/CIFAR10_prototype_ood_result.pickle


In [38]:
pd.DataFrame(pd.read_pickle("../cache/CIFAR10_prototype_ood_result.pickle"))

Unnamed: 0,ood,metric,pca,clusters,n_auroc,n_tnr,e_auroc,e_tnr,sklearn_auroc,global_mal_cov
0,CIFAR100,mahalanobis,10,prototypes,43.375665,0.0,0.333363,0.0,47.995792,True
1,CIFAR100,mahalanobis,10,prototypes,43.375665,0.0,0.333363,0.0,47.995792,False
2,CIFAR100,cosine,10,prototypes,70.912516,16.37,74.684123,11.14,75.230045,True
3,CIFAR100,cosine,10,prototypes,70.912516,16.37,74.684123,11.14,75.230045,False
4,SVHN,mahalanobis,10,prototypes,36.643526,0.0,0.270425,0.0,31.160693,True
5,SVHN,mahalanobis,10,prototypes,36.643526,0.0,0.270425,0.0,31.160693,False
6,SVHN,cosine,10,prototypes,93.819123,81.95,98.004371,87.27,98.157377,True
7,SVHN,cosine,10,prototypes,93.819123,81.95,98.004371,87.27,98.157377,False
8,LSUNResize,mahalanobis,10,prototypes,50.692852,11.17,0.22383,0.0,59.084254,True
9,LSUNResize,mahalanobis,10,prototypes,50.692852,11.17,0.22383,0.0,59.084254,False


In [31]:
pd.DataFrame(ood_results)

Unnamed: 0,metric,pca,clusters,n_auroc,n_tnr,e_auroc,e_tnr,sklearn_auroc,global_mal_cov
0,mahalanobis,10,prototypes,43.375665,0.0,0.333363,0.0,47.995792,True
1,mahalanobis,10,prototypes,43.375665,0.0,0.333363,0.0,47.995792,False
2,cosine,10,prototypes,70.912516,16.37,74.684123,11.14,75.230045,True
3,cosine,10,prototypes,70.912516,16.37,74.684123,11.14,75.230045,False
4,mahalanobis,10,prototypes,36.643526,0.0,0.270425,0.0,31.160693,True
5,mahalanobis,10,prototypes,36.643526,0.0,0.270425,0.0,31.160693,False
6,cosine,10,prototypes,93.819123,81.95,98.004371,87.27,98.157377,True
7,cosine,10,prototypes,93.819123,81.95,98.004371,87.27,98.157377,False
8,mahalanobis,10,prototypes,50.692852,11.17,0.22383,0.0,59.084254,True
9,mahalanobis,10,prototypes,50.692852,11.17,0.22383,0.0,59.084254,False


In [None]:
new_all_results = []
for res in ood_results:
    gmm_res = res['gmm_results']
    for k,v in gmm_res.items():
        if k == "gmm_results":
            k = "gmm_default"
        res[k] = v
    new_all_results.append(res)

In [None]:
gmm_new_all_results  = pd.DataFrame(new_all_results)
gmm_new_all_results

In [None]:
sec_res = []
_cols = ['auroc', 'sklearn-auroc', 'gmm_default', 'gmm_max_prob', 'gmm_weighted_max_prob']
for _d in new_all_results:
    for i in _cols:
        point = {}
        point["ood"] = _d['ood']
        point["cluster"] = _d['cluster']
        point["pca"] = _d['pca']
        point["metric"] = _d['metric']
        
        if i == 'auroc':
            _t = f'cluster-auroc-{point["metric"]}'
            point["auroc"] = _d[f'{i}']
        elif i == 'sklearn-auroc':
            _t = f'global-auroc-{point["metric"]}'
            point["auroc"] = _d[f'{i}']
        else:
            _t = i
            point["auroc"] = _d[f'{i}'] * 100
        point["auroc_type"] = _t
        sec_res.append(point)

In [None]:
sec_res = pd.DataFrame(sec_res)

In [None]:
sec_res 

In [None]:
in_train_y = id_im2cluster
train_dis = Counter(in_train_y)


In [None]:
plt.bar(np.arange(len(train_dis)), [v for i,v in train_dis.most_common()])

In [None]:
cs = cluster_purity(in_train_y, id_train_targ)

In [None]:
cp_p_clus = [[] for i in range(768)]
for pred, gt in zip(in_train_y, id_train_targ):
    cp_p_clus[pred].append(class_names[gt])

In [None]:
a = [k for i,k in Counter(cp_p_clus[20]).most_common()]
a/np.sum(a)

In [None]:
cs[20]

In [None]:
dom_class = [Counter(i).most_common()[0][0] for i in cp_p_clus]
dom_class

In [None]:
_train_dis = dict(train_dis)
_train_dis = sorted(_train_dis.items(), key=lambda x: x[0])
_, _train_dis_y = zip(*_train_dis)
plt.bar(np.arange(len(_train_dis_y)), _train_dis_y)

In [None]:
fig = plt.figure(figsize=(10, 6))
plt.title(f"Cluster Count Distribution for 768 Clusters on Feature Layer")
plt.ylabel("Count per Cluster")
plt.xlabel("Cluster number")
plt.grid()
plt.bar(np.arange(len(_train_dis_y)), _train_dis_y)
#plt.bar(np.arange(len(train_dis)), [v for i,v in train_dis.most_common()])
#plt.xlabel("Cluster number - sorted by count")
plt.legend()
plt.savefig(f"../cache/proto-feat")

In [None]:
fig = plt.figure(figsize=(50, 6))
plt.title(f"Cluster Count Distribution for 768 Clusters on FC Layer")
plt.bar(np.arange(len(cs)), _train_dis_y)
plt.bar(np.arange(len(cs)), [cs[i]*_train_dis_y[i] for i in range(768)], label="purity level")
plt.ylabel("Cluster Count")
plt.xlabel("Dominant Class in Cluster")
plt.xticks(np.arange(len(cs)), dom_class, rotation=90)
plt.grid()
plt.legend()
plt.savefig(f"../cache/proto-fc-count")

In [None]:
res_2 = []
for i in range(len(cs)):
    _d = {}
    _d["cluster purty"] = cs[i] * 100
    _d["normalized cluster count"] = _train_dis_y[i]/500
    res_2.append(_d)

In [None]:
res_2 = pd.DataFrame(res_2)

In [None]:
fig = plt.figure(figsize=(10, 6))
plt.title(f"GMM - Cosine, OOD CIFAR100, PCA 10")
sns.barplot(x="cluster", y="auroc", \
                    data=sec_res[sec_res.metric=="cosine"], hue="auroc_type")
plt.xlabel("Number of Clusters")
plt.ylabel("AUROC")
plt.grid()
plt.savefig(f"../cache/gmm-glob-cosine")
plt.legend(loc="center right")

In [None]:
res = pd.DataFrame(ood_results)

In [None]:
for ood in oods:
    for pca_com in pcas:
        fig = plt.figure(figsize=(8, 6))
        plt.title(f"Kmeans - OOD: {ood}, PCA: {pca_com}")
        sns.barplot(x="cluster", y="sklearn-auroc", \
                    data=res[(res.ood==ood) & (res.pca == pca_com)], hue="metric")
        plt.grid()
        plt.ylabel("Global AUROC")
        plt.xlabel("Number of Clusters")
        plt.savefig(f"../cache/kmeans-glob-{ood}-{pca_com}")
        plt.clf()

In [None]:
f = res[(res.ood=="CIFAR100") & (res.pca == 10)]
f = f.sort_values(by=['sklearn-auroc'], ascending=False)
plt.figure(figsize=(8, 6))
_d = f.iloc[0]
plt.title(f"Distance Score distribution:\n Cluster AUROC {_d.auroc}, Global AUROC {_d['sklearn-auroc']} \n Kmeans, {_d.metric.capitalize()}, PCA {_d.pca}, Cluster {_d.cluster}")
plt.hist(_d.o_pred, alpha=0.5, label="ood")
plt.hist(_d.i_pred, alpha=0.5, label="id")
plt.ylabel("Count")
plt.xlabel("Distance Score")
plt.legend()
plt.savefig(f"../cache/kmeans-discore-{_d.ood}-{_d.pca}")

In [None]:
f

In [None]:
ood_results_s = pd.DataFrame(ood_results_s)

In [None]:
ood_results.sort_values(by=["sklearn-auroc"], ascending=False)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(ood_results_s["clip"] * 100, ood_results_s.auroc, label="AUROC")
plt.plot(ood_results_s["clip"] * 100, ood_results_s["sklearn-auroc"] * 100,  label="Sklearn AUROC")
plt.plot(ood_results_s["clip"] * 100, ood_results["tnr@tpr95"], label="TNR@TPR95")
plt.xlabel("Clip Percentage")
plt.ylabel("AUROC / TNR@TPR95")
plt.title("OOD: SVHN, Metric: Cosine, Original Protype Count: 768")
plt.legend()
plt.savefig("../cache/sec-SVHN-clip-plot")

In [None]:
#sns.set_theme(style="whitegrid", palette="dark")
#plt.rcdefaults()
bc = oe.assgn_before_clip
plt.figure(figsize=(8, 6))
plt.bar(np.arange(len(bc)), bc, align='center', color=['blue'])
plt.xlabel("Cluster Number")
plt.ylabel("Count per Cluster")
plt.title("Distribution before Clipping")
plt.savefig("../cache/sec-dist-b-clip.jpg")

In [None]:
plt.figure(figsize=(8, 6))
ac = oe.assgn_after_clip
plt.bar(np.arange(len(ac)), ac, align='center', color=['blue'])
plt.xlabel("Cluster Number")
plt.ylabel("Count per Cluster")
plt.title("Distribution after Clipping 50% of Clusters")
plt.savefig("../cache/sec-dist-a-clip.jpg")

In [None]:
plt.hist(oe.out_pred_score)
plt.hist(oe.test_pred_score)

In [None]:
plt.hist(cluster_purity(id_ap_im2cluster, id_train_targets))

In [None]:
plt.hist(cluster_purity(id_fc_im2cluster, id_train_targets))

In [None]:
plt.hist(cluster_purity(run_clustering(id_train_fc, 30)[0], id_train_targets))

In [None]:
plt.hist(cluster_purity(run_clustering(id_train_ap, 30)[0], id_train_targets))

## Resnet18 Model

In [None]:
ckpt_root = "/data/temiloluwa.adeoti/fourth_experiments/CIFAR10_clus_768_neg_256/exp_202/"
sns.set_theme(style="whitegrid", palette="dark")

In [None]:
f = [i for i in os.walk(ckpt_root)  for j in i if "results.pickle" in j]
g = [(v[0] + "/results.pickle", v[0].split("/")[-1].split("_")) for v in f]
g = [(v, w[2], int(w[4]), w[-1]) for v, w in g]
len(g)

In [None]:
all_dat = []
for dat in g:
    temp = pd.read_pickle(dat[0])
    temp = pd.DataFrame(temp)
    temp["encoder"] = dat[1]
    temp["ckpt"] = dat[2]
    temp["output_layer"] = "avg_pool" if dat[3] == "pool" else "fc"
    all_dat.append(temp)
all_dat = pd.concat(all_dat)

In [None]:
#all_dat.to_pickle("../cache/exp_202_resnet18_all_data.pickle")
from itertools import product
ds = "SVHN"

In [None]:
ood_dat = all_dat[all_dat.ood == ds]
ood_100_dat = ood_dat.sort_values(by=["auroc"], ascending=False)

In [None]:
ood_100_dat.clusters.unique()

In [None]:
mets = ["cosine", "mahalanobis"]
pcas = [0, 10]
output_layers = ["fc", "avg_pool"]
encoders = ["query", "key"]
fig = plt.figure(figsize=(8, 6))
for output_layer, encoder, pca in product(output_layers, encoders, pcas):
    for met in mets:
        ood_100_dat_analyse = ood_100_dat[(ood_100_dat.metric==met) & (ood_100_dat["pca components"] == pca)]
        dat = ood_100_dat_analyse[(ood_100_dat_analyse.output_layer==output_layer) & (ood_100_dat_analyse.encoder==encoder)]
        dat = dat.sort_values(by=["ckpt"])
        plt.plot(dat.ckpt, dat.auroc, label=met)
        plt.xlabel("Epochs")
        plt.ylabel("AUROC")
        plt.legend()
        plt.title(f"Encoder: {encoder.capitalize()}, Output-Layer:{output_layer.capitalize()}, PCA: {pca}")
        plt.savefig(f"../cache/{ds}-{encoder}-{output_layer}-{pca}.jpg")
    fig.clf()

## accuracies

In [None]:
h = [i for i in os.walk(ckpt_root)  for j in i for k in j if "accuracies.txt" in k]
h = [(i[0], k) for i in h for j in i for k in j if ".txt" in k]
k = [os.path.join(*v) for v in h]
h = [v[0].split("/")[-2].split("_") for v in h]
h = [(v[2], int(v[4]), v[-1]) for v in h]

In [None]:
k = [open(h, "r").readlines() for h in k]
k = [i[0] for i in k]
k = [i.split(" ") for i in k]
n = []
for j in k:
    temp = []
    for i in j:
        if "%" in i:
            temp.append(float(i.strip("%")))
    n.append(temp)

In [None]:
acc_data = []
for (encoder, ckpt, layer),(train_acc, test_acc) in zip(h, n):
    acc_data.append({"encoder":encoder, 
               "ckpt":ckpt,
                "output_layer": "avg_pool" if layer=="pool" else "fc",
                "train_acc": train_acc,
                "test_acc": test_acc})

In [None]:
acc_data = pd.DataFrame(acc_data).sort_values(by=["encoder", "ckpt"])
acc_data.head()

In [None]:
#acc_data.to_pickle("../cache/exp_202_resnet18_acc_data.pickle")

In [None]:
max_accuracies = []

In [None]:
fig = plt.figure(figsize=(8, 6))
for output_layer, encoder in product(output_layers, encoders):
    acc_analyse = acc_data[(acc_data.encoder==encoder) & (acc_data.output_layer==output_layer)]
    plt.plot(acc_analyse.ckpt, acc_analyse.train_acc, label="train accuracy")
    plt.plot(acc_analyse.ckpt, acc_analyse.test_acc, label="test accuracy")
    plt.legend()
    plt.xlabel("Accuracy")
    plt.ylabel("Epochs")
    plt.legend()
    plt.title(f"Encoder: {encoder.capitalize()}, Output-Layer:{output_layer.capitalize()}")
    plt.savefig(f"../cache/acc-{encoder}-{output_layer}.jpg")
    fig.clf()
    bes = acc_analyse.sort_values(by=["train_acc"], ascending=False).iloc[0, :]
    max_accuracies.append({"encoder": encoder, "output_layer": output_layer, "best_ckpt":bes.ckpt, "best_test_acc":bes.train_acc})

In [None]:
pd.DataFrame(max_accuracies).to_excel("../cache/best_train_acc.xlsx")