In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os
import csv
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
import sys
from torchmetrics import ConfusionMatrix
from torchmetrics.classification import F1Score, Accuracy, Precision, Recall
sys.path.append('../')
from datamodule.datamodule import select_data
from models.models import Classifier

In [5]:
print(plt.style.available)
plt.style.use('seaborn-v0_8-dark-palette')

['Solarize_Light2', '_classic_test_patch', '_mpl-gallery', '_mpl-gallery-nogrid', 'bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-v0_8', 'seaborn-v0_8-bright', 'seaborn-v0_8-colorblind', 'seaborn-v0_8-dark', 'seaborn-v0_8-dark-palette', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8-deep', 'seaborn-v0_8-muted', 'seaborn-v0_8-notebook', 'seaborn-v0_8-paper', 'seaborn-v0_8-pastel', 'seaborn-v0_8-poster', 'seaborn-v0_8-talk', 'seaborn-v0_8-ticks', 'seaborn-v0_8-white', 'seaborn-v0_8-whitegrid', 'tableau-colorblind10']


In [6]:
path_root = '../../'
path_results = os.path.join(path_root, 'results/baseline_classifier_analysis')
config = yaml.load(open(os.path.join(path_root, 'config.yaml')), Loader=yaml.FullLoader)

os.makedirs(path_results, exist_ok=True)

In [7]:
result_files = os.listdir(os.path.join(path_root, 'results'))
result_files = [os.path.join(path_root, 'results', f) for f in result_files if 'classifier_baseline' in f]
result_files.sort()

In [8]:
bench_image_folder = result_files[0]
resampled_sample_folder = result_files[1]
sim_output_folder = result_files[2]

IndexError: list index out of range

In [None]:
bench_image_versions = os.listdir(bench_image_folder)
bench_image_versions = [i for i in bench_image_versions if '.ipynb' not in i]

resampled_sample_versions = os.listdir(resampled_sample_folder)
resampled_sample_versions = [i for i in resampled_sample_versions if '.ipynb' not in i]

sim_output_versions = os.listdir(sim_output_folder)
sim_output_versions = [i for i in sim_output_versions if '.ipynb' not in i]

In [None]:
print(sim_output_versions)

# Loss

## Bench image version

In [None]:
# Load in the metrics
version_metrics = {}
for version in bench_image_versions:
    metrics = {}
    path_metrics = os.path.join(bench_image_folder, version, 'logs', 'metrics.csv')
    try:
        with open(path_metrics) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            for i,row in enumerate(reader):
                if i == 0:
                    for header in row:
                        metrics[header] = []
                    key_list = list(metrics.keys())
                else:
                    for j,value in enumerate(row):
                        metrics[key_list[j]].append(value)
        version_metrics[version] = metrics
    except:
        pass

In [None]:
# Clean up the metrics
for k,metrics in version_metrics.items():
    metrics['epoch'] = np.unique(np.asarray(metrics['epoch'], dtype=int))
    metrics['loss_train'] = np.asarray([float(i) for i in metrics['loss_train'] if i != ''])
    metrics['loss_val'] = np.asarray([float(i) for i in metrics['loss_val'] if i != ''])
    version_metrics[k] = metrics

In [None]:
# Get a tag for either pretrained or not from the version config
bench_image_version_tags = {}
for version in bench_image_versions:
    config = yaml.load(open(os.path.join(bench_image_folder, version, 'config.yaml'), 'r'), Loader = yaml.FullLoader)
    if config['classifier']['transfer_learn']:
        bench_image_version_tags[version] = 'pretrained'
    else:
        bench_image_version_tags[version] = 'non-pretrained'

In [None]:
# Plot the loss
fig , ax = plt.subplots(2,1, figsize=(5,5))

for version in bench_image_versions:
    if bench_image_version_tags[version] == 'pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label = 'Pretrained')
    elif bench_image_version_tags[version] == 'non-pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Non pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label='Non pretrained')

ax[0].set_title("Train loss")
ax[1].set_title("Validation loss")
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

for a in ax.flatten():
    a.set_xlabel("Epoch")
    a.set_ylabel("Cross Entropy Loss")
    a.set_xticks([i for i in range(0,21,2)], [i for i in range(0,21,2)])
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'bench_images_loss.pdf'))

## Resampled sample version

In [None]:
# Load in the metrics
version_metrics = {}
for version in resampled_sample_versions:
    metrics = {}
    path_metrics = os.path.join(resampled_sample_folder, version, 'logs', 'metrics.csv')
    try:
        with open(path_metrics) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            for i,row in enumerate(reader):
                if i == 0:
                    for header in row:
                        metrics[header] = []
                    key_list = list(metrics.keys())
                else:
                    for j,value in enumerate(row):
                        metrics[key_list[j]].append(value)
        version_metrics[version] = metrics
    except:
        pass

In [None]:
# Clean up the metrics
for k,metrics in version_metrics.items():
    metrics['epoch'] = np.unique(np.asarray(metrics['epoch'], dtype=int))
    metrics['loss_train'] = np.asarray([float(i) for i in metrics['loss_train'] if i != ''])
    metrics['loss_val'] = np.asarray([float(i) for i in metrics['loss_val'] if i != ''])
    version_metrics[k] = metrics

In [None]:
# Get a tag for either pretrained or not from the version config
resampled_sample_version_tags = {}
for version in resampled_sample_versions:
    config = yaml.load(open(os.path.join(resampled_sample_folder, version, 'config.yaml'), 'r'), Loader = yaml.FullLoader)
    if config['classifier']['transfer_learn']:
        resampled_sample_version_tags[version] = 'pretrained'
    else:
        resampled_sample_version_tags[version] = 'non-pretrained'

In [None]:
# Plot the loss
fig , ax = plt.subplots(2,1, figsize=(5,5))

for version in resampled_sample_versions:
    if resampled_sample_version_tags[version] == 'pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label = 'Pretrained')
    elif resampled_sample_version_tags[version] == 'non-pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Non pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label='Non pretrained')

ax[0].set_title("Train loss")
ax[1].set_title("Validation loss")
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

for a in ax.flatten():
    a.set_xlabel("Epoch")
    a.set_ylabel("Cross Entropy Loss")
    a.set_xticks([i for i in range(0,21,2)], [i for i in range(0,21,2)])
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'resampled_sample_loss.pdf'))

## Sim output version

In [None]:
# Load in the metrics
version_metrics = {}
for version in sim_output_versions:
    metrics = {}
    path_metrics = os.path.join(sim_output_folder, version, 'logs', 'metrics.csv')
    try:
        with open(path_metrics) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            for i,row in enumerate(reader):
                if i == 0:
                    for header in row:
                        metrics[header] = []
                    key_list = list(metrics.keys())
                else:
                    for j,value in enumerate(row):
                        metrics[key_list[j]].append(value)
        version_metrics[version] = metrics
    except:
        pass

In [None]:
# Clean up the metrics
for k,metrics in version_metrics.items():
    metrics['epoch'] = np.unique(np.asarray(metrics['epoch'], dtype=int))
    metrics['loss_train'] = np.asarray([float(i) for i in metrics['loss_train'] if i != ''])
    metrics['loss_val'] = np.asarray([float(i) for i in metrics['loss_val'] if i != ''])
    version_metrics[k] = metrics

In [None]:
# Get a tag for either pretrained or not from the version config
sim_output_version_tags = {}
for version in sim_output_versions:
    config = yaml.load(open(os.path.join(sim_output_folder, version, 'config.yaml'), 'r'), Loader = yaml.FullLoader)
    if config['classifier']['transfer_learn']:
        sim_output_version_tags[version] = 'pretrained'
    else:
        sim_output_version_tags[version] = 'non-pretrained'

In [None]:
# Plot the loss
fig , ax = plt.subplots(2,1, figsize=(5,5))

for version in sim_output_versions:
    if sim_output_version_tags[version] == 'pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label = 'Pretrained')
    elif sim_output_version_tags[version] == 'non-pretrained':
        ax[0].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_train'], label='Non pretrained')
        ax[1].plot(version_metrics[version]['epoch'], version_metrics[version]['loss_val'], label='Non pretrained')

ax[0].set_title("Train loss")
ax[1].set_title("Validation loss")
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

for a in ax.flatten():
    a.set_xlabel("Epoch")
    a.set_ylabel("Cross Entropy Loss")
    a.set_xticks([i for i in range(0,21,2)], [i for i in range(0,21,2)])
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sim_output_loss.pdf'))

# Confusion matrices

In [None]:
bench_image_preds = {}
sim_output_preds = {}
resampled_sample_preds = {}

## Bench images

In [None]:
bench_image_preds['bench_image'] = {}
sim_output_preds['bench_image'] = {}
resampled_sample_preds['bench_image'] = {}

In [None]:
config['which_data'] = 'bench_image'
dm = select_data(config)
dm.setup()
train_dataloader = dm.train_dataloader()
valid_dataloader = dm.val_dataloader()

In [None]:
bench_image_preds['bench_image'] = {}
sim_output_preds['bench_image'] = {}
resampled_sample_preds['bench_image'] = {}

In [None]:

for version in bench_image_versions:
    checkpoint_path = os.path.join(bench_image_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    bench_image_preds['bench_image'][bench_image_version_tags[version]] = {}
    
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['train'] = {}
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid'] = {}

    # Training dataset
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['pred'] = []
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['pred'].append(pred)
        bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['pred'] = []
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['pred'].append(pred)
        bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in bench_image_versions:
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['pred'] = torch.tensor(bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['pred']).squeeze()
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['truth'] = torch.tensor(bench_image_preds['bench_image'][bench_image_version_tags[version]]['train']['truth']).squeeze()
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['truth'] = torch.tensor(bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['truth']).squeeze()
    bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['pred'] = torch.tensor(bench_image_preds['bench_image'][bench_image_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
confmat = ConfusionMatrix(task="multiclass", num_classes=10)

In [None]:
print(bench_image_version_tags)

In [None]:
pt_bench_bench_train_cfm = confmat(bench_image_preds['bench_image'][bench_image_version_tags['version_0']]['train']['pred'], bench_image_preds['bench_image'][bench_image_version_tags['version_0']]['train']['truth'])
pt_bench_bench_val_cfm = confmat(bench_image_preds['bench_image'][bench_image_version_tags['version_0']]['valid']['pred'], bench_image_preds['bench_image'][bench_image_version_tags['version_0']]['valid']['truth'])
npt_bench_bench_train_cfm = confmat(bench_image_preds['bench_image'][bench_image_version_tags['version_1']]['train']['pred'], bench_image_preds['bench_image'][bench_image_version_tags['version_1']]['train']['truth'])
pnt_bench_bench_val_cfm = confmat(bench_image_preds['bench_image'][bench_image_version_tags['version_1']]['valid']['pred'], bench_image_preds['bench_image'][bench_image_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_bench_bench_train_cfm_df = pd.DataFrame(npt_bench_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_bench_bench_val_cfm_df = pd.DataFrame(npt_bench_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_bench_train_cfm_df = pd.DataFrame(pt_bench_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_bench_val_cfm_df = pd.DataFrame(pt_bench_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_bench_bench_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_bench_bench_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_bench_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_bench_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'bench_bench_cfm.pdf'))

In [None]:
for version in sim_output_versions:
    checkpoint_path = os.path.join(sim_output_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    sim_output_preds['bench_image'][sim_output_version_tags[version]] = {}
    
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['train'] = {}
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid'] = {}

    # Training dataset
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['pred'] = []
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['pred'].append(pred)
        sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['pred'] = []
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['pred'].append(pred)
        sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in sim_output_versions:
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['pred'] = torch.tensor(sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['pred']).squeeze()
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['truth'] = torch.tensor(sim_output_preds['bench_image'][sim_output_version_tags[version]]['train']['truth']).squeeze()
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['truth'] = torch.tensor(sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['truth']).squeeze()
    sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['pred'] = torch.tensor(sim_output_preds['bench_image'][sim_output_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(sim_output_version_tags)

In [None]:
pt_sim_bench_train_cfm = confmat(sim_output_preds['bench_image'][sim_output_version_tags['version_0']]['train']['pred'], sim_output_preds['bench_image'][sim_output_version_tags['version_0']]['train']['truth'])
pt_sim_bench_val_cfm = confmat(sim_output_preds['bench_image'][sim_output_version_tags['version_0']]['valid']['pred'], sim_output_preds['bench_image'][sim_output_version_tags['version_0']]['valid']['truth'])
npt_sim_bench_train_cfm = confmat(sim_output_preds['bench_image'][sim_output_version_tags['version_1']]['train']['pred'], sim_output_preds['bench_image'][sim_output_version_tags['version_1']]['train']['truth'])
npt_sim_bench_val_cfm = confmat(sim_output_preds['bench_image'][sim_output_version_tags['version_1']]['valid']['pred'], sim_output_preds['bench_image'][sim_output_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sim_bench_train_cfm_df = pd.DataFrame(npt_sim_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sim_bench_val_cfm_df = pd.DataFrame(npt_sim_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_bench_train_cfm_df = pd.DataFrame(pt_sim_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_bench_val_cfm_df = pd.DataFrame(pt_sim_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sim_bench_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sim_bench_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_bench_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_bench_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sim_bench_cfm.pdf'))

In [None]:
for version in resampled_sample_versions:
    checkpoint_path = os.path.join(resampled_sample_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]] = {}
    
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train'] = {}
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid'] = {}

    # Training dataset
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['pred'] = []
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['pred'].append(pred)
        resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['pred'] = []
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['pred'].append(pred)
        resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in resampled_sample_versions:
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['pred'] = torch.tensor(resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['pred']).squeeze()
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['truth'] = torch.tensor(resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['train']['truth']).squeeze()
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['truth'] = torch.tensor(resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['truth']).squeeze()
    resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['pred'] = torch.tensor(resampled_sample_preds['bench_image'][resampled_sample_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(resampled_sample_version_tags)

In [None]:
pt_sample_bench_train_cfm = confmat(resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_0']]['train']['pred'], resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_0']]['train']['truth'])
pt_sample_bench_val_cfm = confmat(resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_0']]['valid']['pred'], resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_0']]['valid']['truth'])
npt_sample_bench_train_cfm = confmat(resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_1']]['train']['pred'], resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_1']]['train']['truth'])
npt_sample_bench_val_cfm = confmat(resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_1']]['valid']['pred'], resampled_sample_preds['bench_image'][resampled_sample_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sample_bench_train_cfm_df = pd.DataFrame(npt_sample_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sample_bench_val_cfm_df = pd.DataFrame(npt_sample_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_bench_train_cfm_df = pd.DataFrame(pt_sample_bench_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_bench_val_cfm_df = pd.DataFrame(pt_sample_bench_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sample_bench_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sample_bench_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_bench_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_bench_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sample_bench_cfm.pdf'))

## Simulation images

In [None]:
bench_image_preds['sim_output'] = {}
sim_output_preds['sim_output'] = {}
resampled_sample_preds['sim_output'] = {}

In [None]:
config['which_data'] = 'sim_output'
dm = select_data(config)
dm.setup()
train_dataloader = dm.train_dataloader()
valid_dataloader = dm.val_dataloader()

In [None]:

for version in bench_image_versions:
    checkpoint_path = os.path.join(bench_image_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    bench_image_preds['sim_output'][bench_image_version_tags[version]] = {}
    
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['train'] = {}
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid'] = {}

    # Training dataset
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['pred'] = []
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['pred'].append(pred)
        bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['pred'] = []
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['pred'].append(pred)
        bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in bench_image_versions:
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['pred'] = torch.tensor(bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['pred']).squeeze()
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['truth'] = torch.tensor(bench_image_preds['sim_output'][bench_image_version_tags[version]]['train']['truth']).squeeze()
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['truth'] = torch.tensor(bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['truth']).squeeze()
    bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['pred'] = torch.tensor(bench_image_preds['sim_output'][bench_image_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
confmat = ConfusionMatrix(task="multiclass", num_classes=10)

In [None]:
print(bench_image_version_tags)

In [None]:
pt_bench_sim_train_cfm = confmat(bench_image_preds['sim_output'][bench_image_version_tags['version_0']]['train']['pred'], bench_image_preds['sim_output'][bench_image_version_tags['version_0']]['train']['truth'])
pt_bench_sim_val_cfm = confmat(bench_image_preds['sim_output'][bench_image_version_tags['version_0']]['valid']['pred'], bench_image_preds['sim_output'][bench_image_version_tags['version_0']]['valid']['truth'])
npt_bench_sim_train_cfm = confmat(bench_image_preds['sim_output'][bench_image_version_tags['version_1']]['train']['pred'], bench_image_preds['sim_output'][bench_image_version_tags['version_1']]['train']['truth'])
npt_bench_sim_val_cfm = confmat(bench_image_preds['sim_output'][bench_image_version_tags['version_1']]['valid']['pred'], bench_image_preds['sim_output'][bench_image_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_bench_sim_train_cfm_df = pd.DataFrame(npt_bench_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_bench_sim_val_cfm_df = pd.DataFrame(npt_bench_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_sim_train_cfm_df = pd.DataFrame(pt_bench_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_sim_val_cfm_df = pd.DataFrame(pt_bench_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_bench_sim_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_bench_sim_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_sim_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_sim_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'bench_sim_cfm.pdf'))

In [None]:
for version in sim_output_versions:
    checkpoint_path = os.path.join(sim_output_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    sim_output_preds['sim_output'][sim_output_version_tags[version]] = {}
    
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['train'] = {}
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid'] = {}

    # Training dataset
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['pred'] = []
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['pred'].append(pred)
        sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['pred'] = []
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['pred'].append(pred)
        sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in sim_output_versions:
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['pred'] = torch.tensor(sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['pred']).squeeze()
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['truth'] = torch.tensor(sim_output_preds['sim_output'][sim_output_version_tags[version]]['train']['truth']).squeeze()
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['truth'] = torch.tensor(sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['truth']).squeeze()
    sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['pred'] = torch.tensor(sim_output_preds['sim_output'][sim_output_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(sim_output_version_tags)

In [None]:
pt_sim_sim_train_cfm = confmat(sim_output_preds['sim_output'][sim_output_version_tags['version_0']]['train']['pred'], sim_output_preds['sim_output'][sim_output_version_tags['version_0']]['train']['truth'])
pt_sim_sim_val_cfm = confmat(sim_output_preds['sim_output'][sim_output_version_tags['version_0']]['valid']['pred'], sim_output_preds['sim_output'][sim_output_version_tags['version_0']]['valid']['truth'])
npt_sim_sim_train_cfm = confmat(sim_output_preds['sim_output'][sim_output_version_tags['version_1']]['train']['pred'], sim_output_preds['sim_output'][sim_output_version_tags['version_1']]['train']['truth'])
npt_sim_sim_val_cfm = confmat(sim_output_preds['sim_output'][sim_output_version_tags['version_1']]['valid']['pred'], sim_output_preds['sim_output'][sim_output_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sim_sim_train_cfm_df = pd.DataFrame(npt_sim_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sim_sim_val_cfm_df = pd.DataFrame(npt_sim_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_sim_train_cfm_df = pd.DataFrame(pt_sim_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_sim_val_cfm_df = pd.DataFrame(pt_sim_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sim_sim_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sim_sim_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_sim_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_sim_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sim_sim_cfm.pdf'))

In [None]:
for version in resampled_sample_versions:
    checkpoint_path = os.path.join(resampled_sample_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]] = {}
    
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train'] = {}
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid'] = {}

    # Training dataset
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['pred'] = []
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['pred'].append(pred)
        resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['pred'] = []
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['pred'].append(pred)
        resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in resampled_sample_versions:
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['pred'] = torch.tensor(resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['pred']).squeeze()
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['truth'] = torch.tensor(resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['train']['truth']).squeeze()
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['truth'] = torch.tensor(resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['truth']).squeeze()
    resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['pred'] = torch.tensor(resampled_sample_preds['sim_output'][resampled_sample_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(resampled_sample_version_tags)

In [None]:
pt_sample_sim_train_cfm = confmat(resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_0']]['train']['pred'], resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_0']]['train']['truth'])
pt_sample_sim_val_cfm = confmat(resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_0']]['valid']['pred'], resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_0']]['valid']['truth'])
npt_sample_sim_train_cfm = confmat(resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_1']]['train']['pred'], resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_1']]['train']['truth'])
npt_sample_sim_val_cfm = confmat(resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_1']]['valid']['pred'], resampled_sample_preds['sim_output'][resampled_sample_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sample_sim_train_cfm_df = pd.DataFrame(npt_sample_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sample_sim_val_cfm_df = pd.DataFrame(npt_sample_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_sim_train_cfm_df = pd.DataFrame(pt_sample_sim_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_sim_val_cfm_df = pd.DataFrame(pt_sample_sim_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sample_sim_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sample_sim_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_sim_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_sim_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sample_sim_cfm.pdf'))

## Ideal images

In [None]:
bench_image_preds['resampled_sample'] = {}
sim_output_preds['resampled_sample'] = {}
resampled_sample_preds['resampled_sample'] = {}

In [None]:
config['which_data'] = 'resampled_sample'
dm = select_data(config)
dm.setup()
train_dataloader = dm.train_dataloader()
valid_dataloader = dm.val_dataloader()

In [None]:

for version in bench_image_versions:
    checkpoint_path = os.path.join(bench_image_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]] = {}
    
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train'] = {}
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid'] = {}

    # Training dataset
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['pred'] = []
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['pred'].append(pred)
        bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['pred'] = []
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['pred'].append(pred)
        bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in bench_image_versions:
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['pred'] = torch.tensor(bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['pred']).squeeze()
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['truth'] = torch.tensor(bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['train']['truth']).squeeze()
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['truth'] = torch.tensor(bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['truth']).squeeze()
    bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['pred'] = torch.tensor(bench_image_preds['resampled_sample'][bench_image_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
confmat = ConfusionMatrix(task="multiclass", num_classes=10)

In [None]:
pt_bench_sample_train_cfm = confmat(bench_image_preds['resampled_sample'][bench_image_version_tags['version_0']]['train']['pred'], bench_image_preds['resampled_sample'][bench_image_version_tags['version_0']]['train']['truth'])
pt_bench_sample_val_cfm = confmat(bench_image_preds['resampled_sample'][bench_image_version_tags['version_0']]['valid']['pred'], bench_image_preds['resampled_sample'][bench_image_version_tags['version_0']]['valid']['truth'])
npt_bench_sample_train_cfm = confmat(bench_image_preds['resampled_sample'][bench_image_version_tags['version_1']]['train']['pred'], bench_image_preds['resampled_sample'][bench_image_version_tags['version_1']]['train']['truth'])
npt_bench_sample_val_cfm = confmat(bench_image_preds['resampled_sample'][bench_image_version_tags['version_1']]['valid']['pred'], bench_image_preds['resampled_sample'][bench_image_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_bench_sample_train_cfm_df = pd.DataFrame(npt_bench_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_bench_sample_val_cfm_df = pd.DataFrame(npt_bench_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_sample_train_cfm_df = pd.DataFrame(pt_bench_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_bench_sample_val_cfm_df = pd.DataFrame(pt_bench_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_bench_sample_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_bench_sample_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_sample_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_bench_sample_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'bench_sample_cfm.pdf'))

In [None]:
for version in sim_output_versions:
    checkpoint_path = os.path.join(sim_output_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]] = {}
    
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train'] = {}
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid'] = {}

    # Training dataset
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['pred'] = []
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['pred'].append(pred)
        sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['pred'] = []
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['pred'].append(pred)
        sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in sim_output_versions:
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['pred'] = torch.tensor(sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['pred']).squeeze()
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['truth'] = torch.tensor(sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['train']['truth']).squeeze()
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['truth'] = torch.tensor(sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['truth']).squeeze()
    sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['pred'] = torch.tensor(sim_output_preds['resampled_sample'][sim_output_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(sim_output_version_tags)

In [None]:
pt_sim_sample_train_cfm = confmat(sim_output_preds['resampled_sample'][sim_output_version_tags['version_0']]['train']['pred'], sim_output_preds['resampled_sample'][sim_output_version_tags['version_0']]['train']['truth'])
pt_sim_sample_val_cfm = confmat(sim_output_preds['resampled_sample'][sim_output_version_tags['version_0']]['valid']['pred'], sim_output_preds['resampled_sample'][sim_output_version_tags['version_0']]['valid']['truth'])
npt_sim_sample_train_cfm = confmat(sim_output_preds['resampled_sample'][sim_output_version_tags['version_1']]['train']['pred'], sim_output_preds['resampled_sample'][sim_output_version_tags['version_1']]['train']['truth'])
npt_sim_sample_val_cfm = confmat(sim_output_preds['resampled_sample'][sim_output_version_tags['version_1']]['valid']['pred'], sim_output_preds['resampled_sample'][sim_output_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sim_sample_train_cfm_df = pd.DataFrame(npt_sim_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sim_sample_val_cfm_df = pd.DataFrame(npt_sim_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_sample_train_cfm_df = pd.DataFrame(pt_sim_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sim_sample_val_cfm_df = pd.DataFrame(pt_sim_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sim_sample_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sim_sample_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_sample_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sim_sample_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sim_sample_cfm.pdf'))

In [None]:
for version in resampled_sample_versions:
    checkpoint_path = os.path.join(resampled_sample_folder, version, 'checkpoints', 'last.ckpt')
    classifier = Classifier.load_from_checkpoint(checkpoint_path).cuda()
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]] = {}
    
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train'] = {}
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid'] = {}

    # Training dataset
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['pred'] = []
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['truth'] = []
    for batch in tqdm(train_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['pred'].append(pred)
        resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['truth'].append(target)
    # Validation dataset
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['pred'] = []
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['truth'] = []
    for batch in tqdm(valid_dataloader):
        sample, target = batch
        
        sample = torch.cat((sample, sample, sample), dim=1).cuda()
        pred = classifier(sample)
        target = torch.argmax(target, dim=-1).detach().cpu()
        pred = torch.argmax(pred, dim=-1).cpu()

        resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['pred'].append(pred)
        resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['truth'].append(target)

In [None]:
for version in resampled_sample_versions:
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['pred'] = torch.tensor(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['pred']).squeeze()
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['truth'] = torch.tensor(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['train']['truth']).squeeze()
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['truth'] = torch.tensor(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['truth']).squeeze()
    resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['pred'] = torch.tensor(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags[version]]['valid']['pred']).squeeze()

In [None]:
print(resampled_sample_version_tags)

In [None]:
pt_sample_sample_train_cfm = confmat(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_0']]['train']['pred'], resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_0']]['train']['truth'])
pt_sample_sample_val_cfm = confmat(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_0']]['valid']['pred'], resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_0']]['valid']['truth'])
npt_sample_sample_train_cfm = confmat(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_1']]['train']['pred'], resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_1']]['train']['truth'])
npt_sample_sample_val_cfm = confmat(resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_1']]['valid']['pred'], resampled_sample_preds['resampled_sample'][resampled_sample_version_tags['version_1']]['valid']['truth'])

In [None]:
npt_sample_sample_train_cfm_df = pd.DataFrame(npt_sample_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
npt_sample_sample_val_cfm_df = pd.DataFrame(npt_sample_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_sample_train_cfm_df = pd.DataFrame(pt_sample_sample_train_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
pt_sample_sample_val_cfm_df = pd.DataFrame(pt_sample_sample_val_cfm, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,2, figsize=(10,10))
sns.heatmap(npt_sample_sample_train_cfm_df, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(npt_sample_sample_val_cfm_df, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_sample_train_cfm_df, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(pt_sample_sample_val_cfm_df, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig(os.path.join(path_results, 'sample_sample_cfm.pdf'))

## Classification scores

In [None]:
f1 = F1Score(task='multiclass', num_classes=10)
acc = Accuracy(task='multiclass', num_classes=10)

In [None]:
# Bench model
# Bench data
bench_bench_npt_train_f1 = f1(bench_image_preds['bench_image']['non-pretrained']['train']['pred'], bench_image_preds['bench_image']['non-pretrained']['train']['truth'])
bench_bench_npt_train_acc = acc(bench_image_preds['bench_image']['non-pretrained']['train']['pred'], bench_image_preds['bench_image']['non-pretrained']['train']['truth'])

bench_bench_pt_train_f1 = f1(bench_image_preds['bench_image']['pretrained']['train']['pred'], bench_image_preds['bench_image']['pretrained']['train']['truth'])
bench_bench_pt_train_acc = acc(bench_image_preds['bench_image']['pretrained']['train']['pred'], bench_image_preds['bench_image']['pretrained']['train']['truth'])

bench_bench_npt_valid_f1 = f1(bench_image_preds['bench_image']['non-pretrained']['valid']['pred'], bench_image_preds['bench_image']['non-pretrained']['valid']['truth'])
bench_bench_npt_valid_acc = acc(bench_image_preds['bench_image']['non-pretrained']['valid']['pred'], bench_image_preds['bench_image']['non-pretrained']['valid']['truth'])

bench_bench_pt_valid_f1 = f1(bench_image_preds['bench_image']['pretrained']['valid']['pred'], bench_image_preds['bench_image']['pretrained']['valid']['truth'])
bench_bench_pt_valid_acc = acc(bench_image_preds['bench_image']['pretrained']['valid']['pred'], bench_image_preds['bench_image']['pretrained']['valid']['truth'])

# Sim data
bench_sim_npt_train_f1 = f1(bench_image_preds['sim_output']['non-pretrained']['train']['pred'], bench_image_preds['sim_output']['non-pretrained']['train']['truth'])
bench_sim_npt_train_acc = acc(bench_image_preds['sim_output']['non-pretrained']['train']['pred'], bench_image_preds['sim_output']['non-pretrained']['train']['truth'])

bench_sim_pt_train_f1 = f1(bench_image_preds['sim_output']['pretrained']['train']['pred'], bench_image_preds['sim_output']['pretrained']['train']['truth'])
bench_sim_pt_train_acc = acc(bench_image_preds['sim_output']['pretrained']['train']['pred'], bench_image_preds['sim_output']['pretrained']['train']['truth'])

bench_sim_npt_valid_f1 = f1(bench_image_preds['sim_output']['non-pretrained']['valid']['pred'], bench_image_preds['sim_output']['non-pretrained']['valid']['truth'])
bench_sim_npt_valid_acc = acc(bench_image_preds['sim_output']['non-pretrained']['valid']['pred'], bench_image_preds['sim_output']['non-pretrained']['valid']['truth'])

bench_sim_pt_valid_f1 = f1(bench_image_preds['sim_output']['pretrained']['valid']['pred'], bench_image_preds['sim_output']['pretrained']['valid']['truth'])
bench_sim_pt_valid_acc = acc(bench_image_preds['sim_output']['pretrained']['valid']['pred'], bench_image_preds['sim_output']['pretrained']['valid']['truth'])

# Sample data
bench_sample_npt_train_f1 = f1(bench_image_preds['resampled_sample']['non-pretrained']['train']['pred'], bench_image_preds['resampled_sample']['non-pretrained']['train']['truth'])
bench_sample_npt_train_acc = acc(bench_image_preds['resampled_sample']['non-pretrained']['train']['pred'], bench_image_preds['resampled_sample']['non-pretrained']['train']['truth'])

bench_sample_pt_train_f1 = f1(bench_image_preds['resampled_sample']['pretrained']['train']['pred'], bench_image_preds['resampled_sample']['pretrained']['train']['truth'])
bench_sample_pt_train_acc = acc(bench_image_preds['resampled_sample']['pretrained']['train']['pred'], bench_image_preds['resampled_sample']['pretrained']['train']['truth'])

bench_sample_npt_valid_f1 = f1(bench_image_preds['resampled_sample']['non-pretrained']['valid']['pred'], bench_image_preds['resampled_sample']['non-pretrained']['valid']['truth'])
bench_sample_npt_valid_acc = acc(bench_image_preds['resampled_sample']['non-pretrained']['valid']['pred'], bench_image_preds['resampled_sample']['non-pretrained']['valid']['truth'])

bench_sample_pt_valid_f1 = f1(bench_image_preds['resampled_sample']['pretrained']['valid']['pred'], bench_image_preds['resampled_sample']['pretrained']['valid']['truth'])
bench_sample_pt_valid_acc = acc(bench_image_preds['resampled_sample']['pretrained']['valid']['pred'], bench_image_preds['resampled_sample']['pretrained']['valid']['truth'])

In [None]:
# Sample model
# Bench data
sample_bench_npt_train_f1 = f1(resampled_sample_preds['bench_image']['non-pretrained']['train']['pred'], resampled_sample_preds['bench_image']['non-pretrained']['train']['truth'])
sample_bench_npt_train_acc = acc(resampled_sample_preds['bench_image']['non-pretrained']['train']['pred'], resampled_sample_preds['bench_image']['non-pretrained']['train']['truth'])

sample_bench_pt_train_f1 = f1(resampled_sample_preds['bench_image']['pretrained']['train']['pred'], resampled_sample_preds['bench_image']['pretrained']['train']['truth'])
sample_bench_pt_train_acc = acc(resampled_sample_preds['bench_image']['pretrained']['train']['pred'], resampled_sample_preds['bench_image']['pretrained']['train']['truth'])

sample_bench_npt_valid_f1 = f1(resampled_sample_preds['bench_image']['non-pretrained']['valid']['pred'], resampled_sample_preds['bench_image']['non-pretrained']['valid']['truth'])
sample_bench_npt_valid_acc = acc(resampled_sample_preds['bench_image']['non-pretrained']['valid']['pred'], resampled_sample_preds['bench_image']['non-pretrained']['valid']['truth'])

sample_bench_pt_valid_f1 = f1(resampled_sample_preds['bench_image']['pretrained']['valid']['pred'], resampled_sample_preds['bench_image']['pretrained']['valid']['truth'])
sample_bench_pt_valid_acc = acc(resampled_sample_preds['bench_image']['pretrained']['valid']['pred'], resampled_sample_preds['bench_image']['pretrained']['valid']['truth'])

# Sim data
sample_sim_npt_train_f1 = f1(resampled_sample_preds['sim_output']['non-pretrained']['train']['pred'], resampled_sample_preds['sim_output']['non-pretrained']['train']['truth'])
sample_sim_npt_train_acc = acc(resampled_sample_preds['sim_output']['non-pretrained']['train']['pred'], resampled_sample_preds['sim_output']['non-pretrained']['train']['truth'])

sample_sim_pt_train_f1 = f1(resampled_sample_preds['sim_output']['pretrained']['train']['pred'], resampled_sample_preds['sim_output']['pretrained']['train']['truth'])
sample_sim_pt_train_acc = acc(resampled_sample_preds['sim_output']['pretrained']['train']['pred'], resampled_sample_preds['sim_output']['pretrained']['train']['truth'])

sample_sim_npt_valid_f1 = f1(resampled_sample_preds['sim_output']['non-pretrained']['valid']['pred'], resampled_sample_preds['sim_output']['non-pretrained']['valid']['truth'])
sample_sim_npt_valid_acc = acc(resampled_sample_preds['sim_output']['non-pretrained']['valid']['pred'], resampled_sample_preds['sim_output']['non-pretrained']['valid']['truth'])

sample_sim_pt_valid_f1 = f1(resampled_sample_preds['sim_output']['pretrained']['valid']['pred'], resampled_sample_preds['sim_output']['pretrained']['valid']['truth'])
sample_sim_pt_valid_acc = acc(resampled_sample_preds['sim_output']['pretrained']['valid']['pred'], resampled_sample_preds['sim_output']['pretrained']['valid']['truth'])

# Sample data
sample_sample_npt_train_f1 = f1(resampled_sample_preds['resampled_sample']['non-pretrained']['train']['pred'], resampled_sample_preds['resampled_sample']['non-pretrained']['train']['truth'])
sample_sample_npt_train_acc = acc(resampled_sample_preds['resampled_sample']['non-pretrained']['train']['pred'], resampled_sample_preds['resampled_sample']['non-pretrained']['train']['truth'])

sample_sample_pt_train_f1 = f1(resampled_sample_preds['resampled_sample']['pretrained']['train']['pred'], resampled_sample_preds['resampled_sample']['pretrained']['train']['truth'])
sample_sample_pt_train_acc = acc(resampled_sample_preds['resampled_sample']['pretrained']['train']['pred'], resampled_sample_preds['resampled_sample']['pretrained']['train']['truth'])

sample_sample_npt_valid_f1 = f1(resampled_sample_preds['resampled_sample']['non-pretrained']['valid']['pred'], resampled_sample_preds['resampled_sample']['non-pretrained']['valid']['truth'])
sample_sample_npt_valid_acc = acc(resampled_sample_preds['resampled_sample']['non-pretrained']['valid']['pred'], resampled_sample_preds['resampled_sample']['non-pretrained']['valid']['truth'])

sample_sample_pt_valid_f1 = f1(resampled_sample_preds['resampled_sample']['pretrained']['valid']['pred'], resampled_sample_preds['resampled_sample']['pretrained']['valid']['truth'])
sample_sample_pt_valid_acc = acc(resampled_sample_preds['resampled_sample']['pretrained']['valid']['pred'], resampled_sample_preds['resampled_sample']['pretrained']['valid']['truth'])

In [None]:
# Sim model
# Bench data
sim_bench_npt_train_f1 = f1(sim_output_preds['bench_image']['non-pretrained']['train']['pred'], sim_output_preds['bench_image']['non-pretrained']['train']['truth'])
sim_bench_npt_train_acc = acc(sim_output_preds['bench_image']['non-pretrained']['train']['pred'], sim_output_preds['bench_image']['non-pretrained']['train']['truth'])

sim_bench_pt_train_f1 = f1(sim_output_preds['bench_image']['pretrained']['train']['pred'], sim_output_preds['bench_image']['pretrained']['train']['truth'])
sim_bench_pt_train_acc = acc(sim_output_preds['bench_image']['pretrained']['train']['pred'], sim_output_preds['bench_image']['pretrained']['train']['truth'])

sim_bench_npt_valid_f1 = f1(sim_output_preds['bench_image']['non-pretrained']['valid']['pred'], sim_output_preds['bench_image']['non-pretrained']['valid']['truth'])
sim_bench_npt_valid_acc = acc(sim_output_preds['bench_image']['non-pretrained']['valid']['pred'], sim_output_preds['bench_image']['non-pretrained']['valid']['truth'])

sim_bench_pt_valid_f1 = f1(sim_output_preds['bench_image']['pretrained']['valid']['pred'], sim_output_preds['bench_image']['pretrained']['valid']['truth'])
sim_bench_pt_valid_acc = acc(sim_output_preds['bench_image']['pretrained']['valid']['pred'], sim_output_preds['bench_image']['pretrained']['valid']['truth'])

# Sim data
sim_sim_npt_train_f1 = f1(sim_output_preds['sim_output']['non-pretrained']['train']['pred'], sim_output_preds['sim_output']['non-pretrained']['train']['truth'])
sim_sim_npt_train_acc = acc(sim_output_preds['sim_output']['non-pretrained']['train']['pred'], sim_output_preds['sim_output']['non-pretrained']['train']['truth'])

sim_sim_pt_train_f1 = f1(sim_output_preds['sim_output']['pretrained']['train']['pred'], sim_output_preds['sim_output']['pretrained']['train']['truth'])
sim_sim_pt_train_acc = acc(sim_output_preds['sim_output']['pretrained']['train']['pred'], sim_output_preds['sim_output']['pretrained']['train']['truth'])

sim_sim_npt_valid_f1 = f1(sim_output_preds['sim_output']['non-pretrained']['valid']['pred'], sim_output_preds['sim_output']['non-pretrained']['valid']['truth'])
sim_sim_npt_valid_acc = acc(sim_output_preds['sim_output']['non-pretrained']['valid']['pred'], sim_output_preds['sim_output']['non-pretrained']['valid']['truth'])

sim_sim_pt_valid_f1 = f1(sim_output_preds['sim_output']['pretrained']['valid']['pred'], sim_output_preds['sim_output']['pretrained']['valid']['truth'])
sim_sim_pt_valid_acc = acc(sim_output_preds['sim_output']['pretrained']['valid']['pred'], sim_output_preds['sim_output']['pretrained']['valid']['truth'])

# Sample data
sim_sample_npt_train_f1 = f1(sim_output_preds['resampled_sample']['non-pretrained']['train']['pred'], sim_output_preds['resampled_sample']['non-pretrained']['train']['truth'])
sim_sample_npt_train_acc = acc(sim_output_preds['resampled_sample']['non-pretrained']['train']['pred'], sim_output_preds['resampled_sample']['non-pretrained']['train']['truth'])

sim_sample_pt_train_f1 = f1(sim_output_preds['resampled_sample']['pretrained']['train']['pred'], sim_output_preds['resampled_sample']['pretrained']['train']['truth'])
sim_sample_pt_train_acc = acc(sim_output_preds['resampled_sample']['pretrained']['train']['pred'], sim_output_preds['resampled_sample']['pretrained']['train']['truth'])

sim_sample_npt_valid_f1 = f1(sim_output_preds['resampled_sample']['non-pretrained']['valid']['pred'], sim_output_preds['resampled_sample']['non-pretrained']['valid']['truth'])
sim_sample_npt_valid_acc = acc(sim_output_preds['resampled_sample']['non-pretrained']['valid']['pred'], sim_output_preds['resampled_sample']['non-pretrained']['valid']['truth'])

sim_sample_pt_valid_f1 = f1(sim_output_preds['resampled_sample']['pretrained']['valid']['pred'], sim_output_preds['resampled_sample']['pretrained']['valid']['truth'])
sim_sample_pt_valid_acc = acc(sim_output_preds['resampled_sample']['pretrained']['valid']['pred'], sim_output_preds['resampled_sample']['pretrained']['valid']['truth'])

In [None]:
# Training data f1 table npt
f1row0 = [sample_sample_npt_train_f1, sim_sample_npt_train_f1, bench_sample_npt_train_f1]
f1row1 = [sample_sim_npt_train_f1, sim_sim_npt_train_f1, bench_sim_npt_train_f1]
f1row2 = [sample_bench_npt_train_f1, sim_bench_npt_train_f1, bench_bench_npt_train_f1]

# Training data f1 table pt
f1row3 = [sample_sample_pt_train_f1, sim_sample_pt_train_f1, bench_sample_pt_train_f1]
f1row4 = [sample_sim_pt_train_f1, sim_sim_pt_train_f1, bench_sim_pt_train_f1]
f1row5 = [sample_bench_pt_train_f1, sim_bench_pt_train_f1, bench_bench_pt_train_f1]

# Validation data f1 table npt
f1row6 = [sample_sample_npt_valid_f1, sim_sample_npt_valid_f1, bench_sample_npt_valid_f1]
f1row7 = [sample_sim_npt_valid_f1, sim_sim_npt_valid_f1, bench_sim_npt_valid_f1]
f1row8 = [sample_bench_npt_valid_f1, sim_bench_npt_valid_f1, bench_bench_npt_valid_f1]

# Validation data f1 table pt
f1row9 = [sample_sample_pt_valid_f1, sim_sample_pt_valid_f1, bench_sample_pt_valid_f1]
f1row10 = [sample_sim_pt_valid_f1, sim_sim_pt_valid_f1, bench_sim_pt_valid_f1]
f1row11 = [sample_bench_pt_valid_f1, sim_bench_pt_valid_f1, bench_bench_pt_valid_f1]

In [None]:
# Training data acc table npt
accrow0 = [sample_sample_npt_train_acc, sim_sample_npt_train_acc, bench_sample_npt_train_acc]
accrow1 = [sample_sim_npt_train_acc, sim_sim_npt_train_acc, bench_sim_npt_train_acc]
accrow2 = [sample_bench_npt_train_acc, sim_bench_npt_train_acc, bench_bench_npt_train_acc]

# Training data acc table pt
accrow3 = [sample_sample_pt_train_acc, sim_sample_pt_train_acc, bench_sample_pt_train_acc]
accrow4 = [sample_sim_pt_train_acc, sim_sim_pt_train_acc, bench_sim_pt_train_acc]
accrow5 = [sample_bench_pt_train_acc, sim_bench_pt_train_acc, bench_bench_pt_train_acc]

# Validation data acc table npt
accrow6 = [sample_sample_npt_valid_acc, sim_sample_npt_valid_acc, bench_sample_npt_valid_acc]
accrow7 = [sample_sim_npt_valid_acc, sim_sim_npt_valid_acc, bench_sim_npt_valid_acc]
accrow8 = [sample_bench_npt_valid_acc, sim_bench_npt_valid_acc, bench_bench_npt_valid_acc]

# Validation data acc table pt
accrow9 = [sample_sample_pt_valid_acc, sim_sample_pt_valid_acc, bench_sample_pt_valid_acc]
accrow10 = [sample_sim_pt_valid_acc, sim_sim_pt_valid_acc, bench_sim_pt_valid_acc]
accrow11 = [sample_bench_pt_valid_acc, sim_bench_pt_valid_acc, bench_bench_pt_valid_acc]

In [None]:
row0 = np.concatenate([f1row3, f1row0])
row1 = np.concatenate([f1row4, f1row1])
row2 = np.concatenate([f1row5, f1row2])

In [None]:
print(row0)
print(row1)
print(row2)

In [None]:
row3 = np.concatenate([f1row9, f1row6])
row4 = np.concatenate([f1row10, f1row7])
row5 = np.concatenate([f1row11, f1row8])

In [None]:
print(row3)
print(row4)
print(row5)

In [None]:
row6 = np.concatenate([accrow3, accrow0])
row7 = np.concatenate([accrow4, accrow1])
row8 = np.concatenate([accrow5, accrow2])

In [None]:
print(row6)
print(row7)
print(row8)

In [None]:
row9 = np.concatenate([accrow9, accrow6])
row10 = np.concatenate([accrow10, accrow7])
row11 = np.concatenate([accrow11, accrow8])

In [None]:
print(row9)
print(row10)
print(row11)