In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os.path import dirname, join, exists, splitext, isdir
from copy import deepcopy
from typing import List
import multiprocessing as mp
from glob import glob
import torch
import numpy as np
import pandas as pd
from scipy.special import softmax
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap
import seaborn as sns
from tqdm import tqdm
from IPython.display import display, HTML, clear_output, Markdown, Audio
from ipywidgets import HBox, Label, VBox, Dropdown, Layout, Output, Image

from cac.config import Config, DATA_ROOT
from cac.utils.logger import set_logger, color
from cac.utils.metrics import PrecisionAtRecall
from cac.data.dataloader import get_dataloader
from cac.analysis.classification import ClassificationAnalyzer

In [None]:
import warnings
warnings.simplefilter('ignore')

### Define inputs

In [None]:
VERSION = 'experiments/covid-detection/v9_4_cough_adam_1e-4.yml'
USER = 'piyush'
BEST_EPOCH = 99

In [None]:
BATCH_SIZE = 10
NUM_WORKERS = 10

### Define config

In [None]:
config = Config(VERSION, USER)

### Load data

In [None]:
val_dataloader, _ = get_dataloader(
    config.data, 'val',
    BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False,
    drop_last=False
)

### Initialize the analyzer module

In [None]:
analyzer = ClassificationAnalyzer(config, checkpoint=BEST_EPOCH, load_best=False, debug=True)

### Load epochwise logs

In [None]:
logs = analyzer.load_epochwise_logs(mode='val', get_metrics=False)

In [None]:
logs.keys()

#### Instance loss

In [None]:
instance_losses = logs['instance_loss']

In [None]:
batch_losses = logs['batch_loss']

#### Predictions: Labels

In [None]:
predict_labels = logs['predict_labels']

#### Predictions: Probabilities

In [None]:
predict_probs = logs['predict_probs']

### Estimate predict_labels if it doesn't exist

In [None]:
recall = 0.9

if len(predict_labels.columns) == 3:
    targets = torch.tensor(predict_labels['targets'])
    epoch_columns = [col for col in predict_probs.columns if 'epoch' in col]
    for epoch_column in tqdm(epoch_columns):
        predict_proba = torch.from_numpy(np.stack(predict_probs[epoch_column].values))
        # only for binary
        predict_proba = predict_proba[:, 1]
        _, _, threshold = PrecisionAtRecall(recall=recall)(targets, predict_proba)
        predict_labels[epoch_column] = predict_proba.ge(threshold).int().tolist()

## Model Prediction Grid

In [None]:
epochs = [x for x in predict_probs.columns if 'epoch' in x]

In [None]:
COLOR_CODES = {
    0: 'red',
    1: 'blue'
}

In [None]:
epoch_predictions = predict_labels[epochs]

In [None]:
targets = predict_labels['targets']

In [None]:
prediction_correctness = predict_labels[epochs].copy()

for epoch in epochs:
    prediction_correctness[epoch] = 1 * (predict_labels[epoch] == targets)

In [None]:
fig, ax = plt.subplots(figsize=(25,15))

for row_idx in prediction_correctness.index:
    row = prediction_correctness.loc[row_idx]
    right = [int(x.split('_')[-1]) for x in list(row[row == 1].index)]
    yarray = [row_idx for _ in range(len(right))]
    plt.scatter(right, yarray, c='blue', s=0.4)

    wrong = [int(x.split('_')[-1]) for x in list(row[row == 0].index)]
    yarray = [row_idx for _ in range(len(wrong))]
    plt.scatter(wrong, yarray, c='red', s=0.4)

plt.title('Model prediction grid: $(t, x): (t, P(x | t))$')
plt.xlabel('Epochs')
plt.ylabel('Samples')
ax.set_xlim([0, prediction_correctness.shape[1]])
ax.set_ylim([0, prediction_correctness.shape[0]])
plt.grid()
plt.show()