In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
import os
import csv
from sklearn.metrics import confusion_matrix
from tqdm.notebook import tqdm
import pandas as pd
import seaborn as sns
import sys
from torchmetrics import ConfusionMatrix
from torchmetrics.classification import F1Score, Accuracy, Precision, Recall
sys.path.append('../')
from datamodule.datamodule import select_data
from models.models import Classifier

In [None]:
print(plt.style.available)
plt.style.use('seaborn-v0_8-dark-palette')

In [None]:
path_root = '../../'
path_results = os.path.join(path_root, 'results/pt_classifier_analysis')
config = yaml.load(open(os.path.join(path_root, 'config.yaml')), Loader=yaml.FullLoader)
config['paths']['path_root'] = path_root
os.makedirs(path_results, exist_ok=True)

In [None]:
csv_filename = '../../results/coop_bench_alpha_0.0_beta_0.0_gamma_0.0_delta_1.0/version_4/logs/metrics.csv'
metrics = {}
with open(csv_filename) as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for i,row in enumerate(reader):
        if i == 0:
            for header in row:
                metrics[header] = []
            key_list = list(metrics.keys())
        else:
            for j,value in enumerate(row):
                metrics[key_list[j]].append(value)

In [None]:
print(metrics.keys())

In [None]:
# Clean up the metrics
metrics['epoch'] = np.unique(np.asarray(metrics['epoch'], dtype=int))
metrics['loss_train'] = np.asarray([float(i) for i in metrics['loss_train'] if i != ''])
metrics['loss_val'] = np.asarray([float(i) for i in metrics['loss_val'] if i != ''])

In [None]:
fig , ax = plt.subplots(1,1, figsize=(8,5))

ax.plot(metrics['epoch'], metrics['loss_train'], label = "Train loss")
ax.plot(metrics['epoch'], metrics['loss_val'], label = "Validation loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Cross Entropy Loss")
ax.set_xticks([i for i in range(0,21,2)], [i for i in range(0,21,2)])
ax.legend(loc='upper left')

In [None]:
np.where(metrics['loss_train'] == metrics['loss_train'].min())

In [None]:
config['paths']['path_data'] = 'data/post_training'
pt_filenames = [os.path.join(config['paths']['path_root'], config['paths']['path_data'], i) for i in os.listdir(os.path.join(config['paths']['path_root'], config['paths']['path_data']))]
pt_filenames.sort()
pt_bench_images = [torch.load(f, weights_only=True)['bench_image'].squeeze().detach() for f in tqdm(pt_filenames)]
pt_sim_images = [torch.load(f, weights_only=True)['sim_output'].squeeze().detach() for f in tqdm(pt_filenames)]
pt_ideal_images = [torch.load(f, weights_only=True)['resampled_sample'].squeeze().detach() for f in tqdm(pt_filenames)]

pt_targets = [torch.argmax(torch.load(f, weights_only=True)['target']) for f in tqdm(pt_filenames)]
pt_targets = pt_targets.squeeze()
pt_unique_targets = np.unique(pt_targets)

In [None]:
pt_targets = torch.from_numpy(pt_targets)

In [None]:
pt_filenames[800:]

In [None]:
train_indices = []
valid_indices = []
for i,fil

In [None]:
checkpoint_path = '../../results/classifier_baseline_bench_resampled_sample/version_0/checkpoints/last.ckpt'
classifier = Classifier.load_from_checkpoint(checkpoint_path).double().cuda()

In [None]:
bench_results = []
sim_results = []
ideal_results = []
pbar = tqdm(total=len(pt_sim_images))
for sim_image, bench_image, ideal_image, target in zip(pt_sim_images, pt_bench_images, pt_ideal_images, pt_targets):
    sim_image = sim_image.unsqueeze(0).unsqueeze(0)
    sim_sample = torch.cat((sim_image, sim_image, sim_image), dim=1).cuda()
    
    sim_pred = classifier(sim_sample)
    sim_pred = torch.argmax(sim_pred, dim=-1).cpu()

    bench_image = bench_image.unsqueeze(0).unsqueeze(0)
    bench_sample = torch.cat((bench_image, bench_image, bench_image), dim=1).cuda()
    bench_pred = classifier(bench_sample)
    bench_pred = torch.argmax(bench_pred, dim=-1).cpu()


    ideal_image = ideal_image.unsqueeze(0).unsqueeze(0)
    ideal_sample = torch.cat((ideal_image, ideal_image, ideal_image), dim=1).cuda()
    ideal_pred = classifier(ideal_sample)
    ideal_pred = torch.argmax(ideal_pred, dim=-1).cpu()
        
    bench_results.append([bench_pred, target])
    sim_results.append([sim_pred, target])
    ideal_results.append([ideal_pred, target])
    pbar.update(1)

In [None]:
torch.save(bench_results, 'pt_bench_results.pt')
torch.save(sim_results, 'pt_sim_results.pt')
torch.save(ideal_results, 'pt_ideal_results.pt')

In [None]:
bench_results = torch.load('pt_bench_results.pt', weights_only=True)
sim_results = torch.load('pt_sim_results.pt', weights_only=True)
ideal_results = torch.load('pt_ideal_results.pt', weights_only=True)

In [None]:
bench_results = torch.stack([torch.stack((i[0].squeeze(), i[1].squeeze())) for i in bench_results])
sim_results = torch.stack([torch.stack([i[0].squeeze(), i[1].squeeze()]) for i in sim_results])
ideal_results = torch.stack([torch.stack([i[0].squeeze(), i[1].squeeze()]) for i in ideal_results])

In [None]:
bench_results_train = bench_results[:-200]
bench_results_valid = bench_results[800:]

sim_results_train = sim_results[:-200]
sim_results_valid = sim_results[800:]

ideal_results_train = ideal_results[:-200]
ideal_results_valid = ideal_results[800:]

In [None]:
confmat = ConfusionMatrix(task="multiclass", num_classes=10)

In [None]:
bench_cfm_train = confmat(bench_results_train[:,0], bench_results_train[:,1])
sim_cfm_train = confmat(sim_results_train[:,0], sim_results_train[:,1])
ideal_cfm_train = confmat(ideal_results_train[:,0], ideal_results_train[:,1])

bench_cfm_valid = confmat(bench_results_valid[:,0], bench_results_valid[:,1])
sim_cfm_valid = confmat(sim_results_valid[:,0], sim_results_valid[:,1])
ideal_cfm_valid = confmat(ideal_results_valid[:,0], ideal_results_valid[:,1])

In [None]:
bench_df_train = pd.DataFrame(bench_cfm_train, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
sim_df_train = pd.DataFrame(sim_cfm_train, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
ideal_df_train = pd.DataFrame(ideal_cfm_train, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

bench_df_valid = pd.DataFrame(bench_cfm_valid, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
sim_df_valid = pd.DataFrame(sim_cfm_valid, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])
ideal_df_valid = pd.DataFrame(ideal_cfm_valid, index = [i for i in range(0,10)], columns = [i for i in range(0,10)])

In [None]:
fig, ax = plt.subplots(2,3, figsize=(15,10))
sns.heatmap(ideal_df_train, annot=True, ax=ax[0][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(sim_df_train, annot=True, ax=ax[0][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(bench_df_train, annot=True, ax=ax[0][2], square=True, cbar=False, cmap='Blues')

sns.heatmap(ideal_df_valid, annot=True, ax=ax[1][0], square=True, cbar=False, cmap='Blues')
sns.heatmap(sim_df_valid, annot=True, ax=ax[1][1], square=True, cbar=False, cmap='Blues')
sns.heatmap(bench_df_valid, annot=True, ax=ax[1][2], square=True, cbar=False, cmap='Blues')

for a in ax.flatten():
    a.set_ylabel("Truth")
    a.set_xlabel("Prediction")
plt.tight_layout()
fig.savefig('pt_cfm.pdf')
fig.savefig('pt_cfm_png')

In [None]:
f1 = F1Score(task='multiclass', num_classes=10)

In [None]:
bench_f1_train = f1(bench_results_train[:,0], bench_results_train[:,1])
bench_f1_valid = f1(bench_results_valid[:,0], bench_results_valid[:,1])


In [None]:
print(bench_f1_train)
print(bench_f1_valid)