In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import pandas as pd
from sklearn import metrics

In [None]:
# color_list = {0: '#1f77b4', 1: '#ff7f0e', 2: '#2ca02c', 3: '#d62728', 4: '#9467bd', 5: '#8c564b', 6: '#e377c2', 7: '#7f7f7f', 8: '#bcbd22', 9: '#17becf', -1: '#000000', -2: '#ff0000', -3: '#000000', -4: '#000000', -5: '#000000'}
color_list = {0: 'b', 1: 'g', 2: 'r', 3: 'c', 4: 'm', 5: 'k', 6: 'k', 7: 'k', 8: 'k', }
plt.set_cmap('Set1')
features_dir = '../features/'
datasets = ['cifar10', 'cifar100', 'dtd', 'places365', 'svhn', 'tin', 'mnist', 'fashionmnist', 'notmnist']

In [None]:
# Plot data with given mask.
file = 'resnet18_cifar10_cifar100_dtd_svhn_tin_mnist_fashionmnist_notmnist_places365_tsne.npz'
features = np.load(f'{features_dir}{file}')['features']
logits = np.load(f'{features_dir}{file}')['logits']
labels = np.load(f'{features_dir}{file}')['labels']
mask = (labels ==8) # cifar10

figure(figsize=(15,10), dpi=200)
plt.scatter(features[:,0][mask], features[:,1][mask], c=pd.Series(labels[mask]).map(color_list), s=10, alpha=0.9)

In [None]:

figure(figsize=(15,10), dpi=200)
plt.scatter(logits[:,0][mask], logits[:,1][mask], c=pd.Series(labels[mask]).map(color_list), s=10, alpha=0.9)

In [None]:
files = ['ResNet18_32x32-cifar10-cifar100', 'ResNet18_32x32-cifar10-fashionmnist', 'ResNet18_32x32-cifar10-mnist', 'ResNet18_32x32-cifar10-svhn', 'ResNet18_32x32-cifar10-tin']

data_list = []
for file in files:
    data_list.append(np.load('../features/'+file+'.npz'))

roc, pr_in, pr_out = [], [], []
for data in data_list:
    roc.append(metrics.RocCurveDisplay(fpr=data['fpr'], tpr=data['tpr']))

for data in data_list:
    pr_in.append(metrics.PrecisionRecallDisplay(precision=data['precision_in'], recall=data['recall_in']))

for data in data_list:
    pr_out.append(metrics.PrecisionRecallDisplay(precision=data['precision_out'], recall=data['recall_out']))

In [None]:
fig, ax = plt.subplots(figsize=(15,10), dpi=400)
for i, disp in enumerate(roc):
    label = files[i].split('-')[-2:]
    disp.plot(ax=ax, label=f'roc-{label[0]}-{label[1]}')

ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(15,10), dpi=400)
for i, disp in enumerate(pr_in):
    label = files[i].split('-')[-2:]
    disp.plot(ax=ax, label=f'pr_in-{label[0]}-{label[1]}')

ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(15,10), dpi=400)
for i, disp in enumerate(pr_out):
    label = files[i].split('-')[-2:]
    disp.plot(ax=ax, label=f'pr_out-{label[0]}-{label[1]}')

ax.legend()
plt.show()