In [1]:
import torch
import numpy as np
import cv2
import os
from functools import partial
import itertools
from collections import defaultdict

import mafat_radar_challenge.data_loader.augmentation as module_aug
import mafat_radar_challenge.data_loader.data_loaders as module_data
import mafat_radar_challenge.model.loss as module_loss
import mafat_radar_challenge.model.metric as module_metric
import mafat_radar_challenge.model.model as module_arch
from mafat_radar_challenge.trainer import Trainer, MAFATTrainer
from mafat_radar_challenge.utils import setup_logger
from mafat_radar_challenge.tester import MAFATTester
import mafat_radar_challenge.data_loader.data_splitter as module_splitter
import mafat_radar_challenge.data_loader.samplers as module_sampler
import mafat_radar_challenge.data_loader.mixers as module_mix

from mafat_radar_challenge.main import get_instance, setup_device
from mafat_radar_challenge.cli import load_config

import matplotlib.pyplot as plt
from sklearn import metrics
from scipy.stats import pearsonr
import pandas as pd

In [2]:
def plot_roc_curve(gt, score, suffix=""):
    fpr, tpr, th = metrics.roc_curve(gt, score)
    auc = round(metrics.auc(fpr, tpr), 4)
    plt.semilogx(fpr, tpr, "-", label="AUC = {}".format(auc) + " " + suffix)
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.grid(which="both")
    plt.legend()
    
def precision_recall_curve(gt, score, suffix=""):
    # fig = plt.figure(figsize=(15,10))
    prec, recall, th = metrics.precision_recall_curve(gt, score)
    # auc = metrics.auc(prec, recall)
    plt.plot(prec, recall, "-", label=suffix)#, label="AUC = {}".format(auc))
    plt.xlabel("Precision")
    plt.ylabel("Recall")
    plt.grid(which="both")
    plt.legend()
    
def plot_histogram(positive_preds, negative_preds):
    fig = plt.figure(figsize=(15,10))
    ax1 = fig.add_subplot(111)
    ax2 = ax1.twinx()
    ax2.hist(positive_preds, alpha=0.6, log=False, label='+', cumulative=False, bins=30, color='orange')
    ax1.hist(negative_preds, alpha=0.4, log=False, label='-', cumulative=False, bins=30, color='blue')
    ax1.legend(loc=1)
    ax2.legend(loc=2)
    plt.show()
    
def plot_stacked_histogram(df, grouping_col="snr_type"):
    group_values = df[grouping_col].unique()
    positive_df = df[df.target_type == 1]
    negative_df = df[df.target_type == 0]    
        
    fig = plt.figure(figsize=(15,10))
    ax1 = fig.add_subplot(111)
    ax2 = ax1.twinx()
    pos_colors = [y for x,y in zip(group_values, ["orange", "lime", "yellow"])]
    neg_colors = [y for x,y in zip(group_values, ["blue", "purple", "cyan"])]
    ax2.hist([positive_df.loc[positive_df[grouping_col] == x, "score"].values.tolist() for x in group_values], alpha=0.6, log=False, label=[x + ' +' for x in group_values], stacked=True, cumulative=False, bins=30, color=pos_colors)
    ax1.hist([negative_df.loc[negative_df[grouping_col] == x, "score"].values.tolist() for x in group_values], alpha=0.4, log=False, label=[x + ' -' for x in group_values], stacked=True, cumulative=False, bins=30, color=neg_colors)
    ax1.legend(loc=2)
    ax2.legend(loc=1)
    plt.show()
    
def plot_calibration_curves(gt, score, suffix=""):
    fpr, tpr, th = metrics.roc_curve(gt, score)
    fpr = np.clip(fpr, 0, 1)
    tpr = np.clip(tpr, 0, 1)
    th = np.clip(th, 0, 1)
    fnr = 1 - tpr
    plt.plot(th, fpr, label="FPR" + " " + suffix)
    plt.plot(th, fnr, label="FNR" + " " + suffix)
    #plt.yscale("log")
    plt.xlabel("TH")
    plt.ylabel("FPR & FNR")
    plt.grid(which="both")
    plt.legend()

# Val dataset

In [3]:
models = [
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold0/0924-200924/checkpoints/",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold1/0924-222622/checkpoints/",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold2/0925-004336/checkpoints/",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold3/0925-030037/checkpoints/",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold4/0925-051739/checkpoints/"   
]

WITHOUT_SYNTH_NOISE = False
WITHOUT_AUX = True
ONLY_BEST_EPOCH = True

In [4]:
# cfg = load_config(CONFIG_PATH)
preds_list = list()
gt_list = list()
df_dicts = defaultdict(list)

for checkpoint_folder in models:
    model_paths = sorted(os.listdir(checkpoint_folder))
    model_paths = [os.path.join(checkpoint_folder, x) for x in model_paths]
    best_model = model_paths.pop(-1)
    model_paths.pop(0)
    model_paths.pop(-1)
    if ONLY_BEST_EPOCH:
        model_paths = [best_model]
    for model_path in model_paths:
        print(model_path)
        # Setup
        cfg = load_config(os.path.join(os.path.dirname(model_path), "config.yml"))
        transforms = get_instance(module_aug, "augmentation", cfg)
        if "sampler" in cfg:
            sampler = getattr(module_sampler, cfg["sampler"]["type"])
            sampler = partial(sampler, **cfg["sampler"]["args"])
        else:
            sampler = None
            
        if "mixer" in cfg:
            mixer = get_instance(module_mix, "mixer", cfg)
        else:
            mixer = None
        data_loader = get_instance(module_data, "data_loader", cfg, transforms, sampler, mixer)
        valid_data_loader = get_instance(module_data, "val_data_loader", cfg, transforms, sampler)
        validation_df = valid_data_loader.dataset.df.copy()
        model = get_instance(module_arch, "arch", cfg)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint["state_dict"])
        model, device = setup_device(model, cfg["target_devices"])
        model.eval()

        # Predict
        counter = 0
        preds = list()
        gt = list()
        with torch.no_grad():
            for idx, (image_batch, label_batch) in enumerate(valid_data_loader):
                if isinstance(image_batch, list):
                    for i, _ in enumerate(image_batch):
                        image_batch[i] = image_batch[i].to(device)
                    data = image_batch
                else:
                    data = image_batch.to(device)
                output = model(data).cpu()
                output = torch.sigmoid(output)
                preds.append(output.cpu().numpy())
                #for idx_2, (image, label) in enumerate(zip(image_batch[0], label_batch)):
                for idx_2, (image, label) in enumerate(zip(image_batch, label_batch)):
                    gt.append(label[0].cpu().numpy())

        preds = np.vstack(preds).reshape(-1).tolist()
        gt = np.vstack(gt).reshape(-1).tolist()
        
        validation_df["label"] = gt
        validation_df["score"] = preds
        if WITHOUT_SYNTH_NOISE:
            validation_df = validation_df[validation_df.source != "synth"]
        if WITHOUT_AUX:
            aux_segments = validation_df.loc[validation_df.source == "aux", "segment_id"]
            validation_df = validation_df[~validation_df.segment_id.isin(aux_segments)]
        df_dicts[checkpoint_folder].append(validation_df.copy())

/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold0/0924-200924/checkpoints/model_best.pth
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold1/0924-222622/checkpoints/model_best.pth
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold2/0925-004336/checkpoints/model_best.pth
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold3/0925-030037/checkpoints/model_best.pth
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold4/0925-051739/checkpoi

In [5]:
bce_list = list()
auc_list = list()
for checkpoint_model, val_dfs in df_dicts.items():
    print(checkpoint_model)
    val_df = val_dfs[0]
    bce_list.append(metrics.log_loss(val_df.label, val_df.score))
    fpr, tpr, th = metrics.roc_curve(val_df.label, val_df.score)
    auc_list.append(metrics.auc(fpr, tpr))
print("AUC: {}+-{}".format(np.mean(auc_list), np.std(auc_list)))
print("BCE: {}+-{}".format(np.mean(bce_list), np.std(bce_list)))

/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold0/0924-200924/checkpoints/
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold1/0924-222622/checkpoints/
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold2/0925-004336/checkpoints/
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold3/0925-030037/checkpoints/
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold4/0925-051739/checkpoints/
AUC: 0.9783557597905593+-0.00761875175147112
BCE: 0.2690891935471226+-0.04848699502373529


# Public test dataset

In [58]:
# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_fold0/0922-202644/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_fold1/0923-175040/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_fold2/0923-200711/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_fold3/0923-222449/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_fold4/0924-004209/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold0/0924-200924/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold1/0924-222622/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold2/0925-004336/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold3/0925-030037/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam_fold4/0925-051739/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v9_adam_fold0/0925-151207/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v9_adam_fold1/0925-185430/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v9_adam_fold2/0925-223701/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v9_adam_fold3/0926-021930/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v9_adam_fold4/0926-060202/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_more_aug_v9_adam_centered0/0927-185118/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_more_aug_v9_adam_centered1/0927-210936/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_more_aug_v9_adam_centered2/0927-232928/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_more_aug_v9_adam_centered3/0928-014604/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_more_aug_v9_adam_centered4/0928-040237/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered0/0928-210739/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered1/0928-232816/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered2/0929-014620/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered3/0929-040430/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered4/0929-062240/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered_blend0/0929-225220/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered_blend1/0930-011037/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered_blend2/0930-032751/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered_blend3/0930-054506/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_less_aug_v9_adam_centered_blend4/0930-080225/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v9_adam_centered0/0930-213909/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v9_adam_centered1/0930-234106/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v9_adam_centered2/1001-014329/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v9_adam_centered3/1001-034550/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v9_adam_centered4/1001-054814/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_pw_0.9_fold0/1001-195710/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_pw_0.9_fold1/1001-221433/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_pw_0.9_fold2/1002-003156/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_pw_0.9_fold3/1002-024950/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_pw_0.9_fold4/1002-050747/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_focal_fold0/1002-072547/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_focal_fold1/1002-094356/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_focal_fold2/1002-120142/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_focal_fold3/1002-141913/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_focal_fold4/1002-163722/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold0/1003-101910/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold1/1003-175906/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold2/1004-013832/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold3/1004-091304/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold4/1004-164923/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_pruned_more_aux_more_synth_specaug_simple_aug_v9_adam_centered0/1005-005432/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_pruned_more_aux_more_synth_specaug_simple_aug_v9_adam_centered1/1005-021455/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_pruned_more_aux_more_synth_specaug_simple_aug_v9_adam_centered2/1005-033519/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_pruned_more_aux_more_synth_specaug_simple_aug_v9_adam_centered3/1005-045546/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_pruned_more_aux_more_synth_specaug_simple_aug_v9_adam_centered4/1005-061620/checkpoints/"   
# ]

# models = [
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b5_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold0/1005-184411/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b5_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold1/1006-082322/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b5_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold2/1006-220329/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b5_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold3/1007-114115/checkpoints/",
#     "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b5_more_aux_more_synth_specaug_simple_aug_v9_adam_centered_fold4/1008-012514/checkpoints/"   
# ]

models = [
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0905-131049/checkpoints/model_best.pth",
    # "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v5_adam_centered_rnd_crop_90/0905-231358/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-104511/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-160951/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0908-144825/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0912-121429/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b1_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0913-154517/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0912-121429/checkpoints/checkpoint-epoch11-13.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v9_adam_centered/0920-104137/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_timm_b2_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0912-190748/checkpoints/model_best.pth",
    "/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_resnest50_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0918-193440/checkpoints/checkpoint-epoch23.pth",
]

ONLY_BEST_EPOCH = True
FULL_PUBLIC = True

In [59]:
# cfg = load_config(CONFIG_PATH)
preds_list = list()
gt_list = list()
df_dicts = defaultdict(list)

for checkpoint_folder in models:
    # model_paths = sorted(os.listdir(checkpoint_folder))
    # model_paths = [os.path.join(checkpoint_folder, x) for x in model_paths]
    # best_model = model_paths.pop(-1)
    # model_paths.pop(0)
    # model_paths.pop(-1)
    if ONLY_BEST_EPOCH:
        # model_paths = [best_model]
        model_paths = [checkpoint_folder]
    for model_path in model_paths:
        print(model_path)
        # Setup
        cfg = load_config(os.path.join(os.path.dirname(model_path), "config.yml"))
        val_data_dir = cfg["val_data_loader"]["args"]["data_dir"]
        cfg["val_data_loader"]["args"]["data_dir"] = val_data_dir.replace("mafat_val_", "mafat_full_public_test_")
        cfg["val_data_loader"]["args"]["csv_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_set.csv"
        # cfg["val_data_loader"]["args"]["data_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_v9_spectrogram.npy"
        # cfg["val_data_loader"]["args"]["csv_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_set.csv"
        print(cfg["val_data_loader"]["args"]["data_dir"])
        #cfg["val_data_loader"] = {
        #    "type": "MAFATValDataLoader", 
        #    "args": {
        #        "shuffle": False,
        #        "batch_size": 32, 
        #        "data_dir": "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_v9_spectrogram.npy",
        #        "csv_dir": "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_set.csv",
        #        "nworkers": 2,
        #        "use_metadata": False
        #    }
        #}
        transforms = get_instance(module_aug, "augmentation", cfg)
        if "sampler" in cfg:
            sampler = getattr(module_sampler, cfg["sampler"]["type"])
            sampler = partial(sampler, **cfg["sampler"]["args"])
        else:
            sampler = None
            
        if "mixer" in cfg:
            mixer = get_instance(module_mix, "mixer", cfg)
        else:
            mixer = None
        # data_loader = get_instance(module_data, "data_loader", cfg, transforms, sampler, mixer)
        valid_data_loader = get_instance(module_data, "val_data_loader", cfg, transforms, sampler)
        validation_df = valid_data_loader.dataset.df.copy()
        model = get_instance(module_arch, "arch", cfg)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint["state_dict"])
        model, device = setup_device(model, cfg["target_devices"])
        model.eval()

        # Predict
        counter = 0
        preds = list()
        gt = list()
        with torch.no_grad():
            for idx, (image_batch, label_batch) in enumerate(valid_data_loader):
                if isinstance(image_batch, list):
                    for i, _ in enumerate(image_batch):
                        image_batch[i] = image_batch[i].to(device)
                    data = image_batch
                else:
                    data = image_batch.to(device)
                output = model(data).cpu()
                output = torch.sigmoid(output)
                preds.append(output.cpu().numpy())
                #for idx_2, (image, label) in enumerate(zip(image_batch[0], label_batch)):
                for idx_2, (image, label) in enumerate(zip(image_batch, label_batch)):
                    gt.append(label[0].cpu().numpy())

        preds = np.vstack(preds).reshape(-1).tolist()
        gt = np.vstack(gt).reshape(-1).tolist()
        
        validation_df["label"] = gt
        validation_df["score"] = preds
        if not FULL_PUBLIC:
            validation_df = validation_df[validation_df.source=="public"]
        df_dicts[checkpoint_folder].append(validation_df.copy())

/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0905-131049/checkpoints/model_best.pth
/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_v7_spectrogram.npy
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-104511/checkpoints/model_best.pth
/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_v8_spectrogram.npy
Loaded pretrained weights for efficientnet-b2
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-160951/checkpoints/model_best.pth
/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_full_public_test_v8_spectrogram.npy
Loaded pretrained weights for efficientnet-b4
/mnt/agarcia_HDD/mafat-radar-challeng

Using cache found in /home/agarcia/.cache/torch/hub/zhanghang1989_ResNeSt_master


In [60]:
bce_list = list()
auc_list = list()
# OOF results
for checkpoint_model, val_dfs in df_dicts.items():
    print(checkpoint_model)
    val_df = val_dfs[0]
    bce_list.append(metrics.log_loss(val_df.label, val_df.score))
    print("BCE: {}".format(metrics.log_loss(val_df.label, val_df.score)))
    fpr, tpr, th = metrics.roc_curve(val_df.label, val_df.score)
    auc_list.append(metrics.auc(fpr, tpr))
    print("AUC: {}".format(metrics.auc(fpr, tpr)))
print()
print("AUC: {}+-{}".format(np.mean(auc_list), np.std(auc_list)))
print("BCE: {}+-{}".format(np.mean(bce_list), np.std(bce_list)))
print()
# Results of mean of predictions
scores = list()
for checkpoint_model, val_dfs in df_dicts.items():
    val_df = val_dfs[0]
    scores.append(val_df.score.values)
val_df.score = np.mean(scores, axis=0)
bce = metrics.log_loss(val_df.label, val_df.score)
fpr, tpr, th = metrics.roc_curve(val_df.label, val_df.score)
auc = metrics.auc(fpr, tpr)
print("BCE: {}".format(bce))
print("AUC: {}".format(auc))


/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0905-131049/checkpoints/model_best.pth
BCE: 0.37340993950797163
AUC: 0.944758784991066
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-104511/checkpoints/model_best.pth
BCE: 0.5296185924152703
AUC: 0.9304645622394282
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v8_adam/0906-160951/checkpoints/model_best.pth
BCE: 0.5471652164676182
AUC: 0.921878102044868
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b4_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0908-144825/checkpoints/model_best.pth
BCE: 0.62097043412678
AUC: 0.9162696049235656
/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b3_more_aux_more_synth_specaug_simple_aug_v5_adam_centered/0912-121429/checkp

In [74]:
np.mean(scores, axis=0)

(284,)

In [None]:
fig = plt.figure(figsize=(15,10))
for gt, preds, model in zip(gt_list, preds_list, models):
    plot_roc_curve(gt, preds, model)

In [None]:
fig = plt.figure(figsize=(15,10))
for gt, preds, model in zip(gt_list, preds_list, models):
    plot_calibration_curves(gt, preds, model)

In [None]:
for gt, preds, model in zip(gt_list, preds_list, models):
    positive_preds = [p for (g,p) in zip(gt, preds) if g == 1]
    negative_preds = [p for (g,p) in zip(gt, preds) if g == 0]
    print(model)
    plot_histogram(positive_preds, negative_preds)

In [None]:
for gt, preds, model in zip(gt_list, preds_list, models):
    if WITHOUT_SYNTH_NOISE:
        preds = preds[:309]
        gt = gt[:309]
        aux_validation_df = validation_df.iloc[:309, :].copy()
        assert len(preds) == 309
    else:
        aux_validation_df = validation_df.copy()
    aux_validation_df.loc[:, "score"] = preds
    print(model)
    plot_stacked_histogram(aux_validation_df)

In [None]:
fig = plt.figure(figsize=(15,10))
for gt, preds, model in zip(gt_list, preds_list, models):
    precision_recall_curve(gt, preds, model)

In [None]:
validation_df.groupby(["target_type", "snr_type"]).count()

In [None]:
# False positives
for gt, preds, model in zip(gt_list, preds_list, models):
    print(model)
    aux_df = validation_df.copy()    
    if WITHOUT_SYNTH_NOISE:
        preds = preds[:309]
        gt = gt[:309]
        aux_df = aux_df.iloc[:309, :]
        assert len(preds) == 309
    aux_df["score"] = preds
    display(aux_df[(aux_df.score > 0.4) & (aux_df.target_type == 0)])

In [None]:
# False negatives
for gt, preds, model in zip(gt_list, preds_list, models):
    print(model)
    aux_df = validation_df.copy()    
    if WITHOUT_SYNTH_NOISE:
        preds = preds[:309]
        gt = gt[:309]
        aux_df = aux_df.iloc[:309, :]
        assert len(preds) == 309
    aux_df["score"] = preds
    display(aux_df[(aux_df.score < 0.6) & (aux_df.target_type == 1)])

## T-SNE

In [None]:
# cfg = load_config(CONFIG_PATH)
preds_list = list()
gt_list = list()
output_list = list()
for model_path in models:
    # Setup
    cfg = load_config(os.path.join(os.path.dirname(model_path), "config.yml"))
    transforms = get_instance(module_aug, "augmentation", cfg)
    if "sampler" in cfg:
        sampler = getattr(module_sampler, cfg["sampler"]["type"])
        sampler = partial(sampler, **cfg["sampler"]["args"])
    else:
        sampler = None
        
    # Without synth noise
    #if WITHOUT_SYNTH_NOISE: 
    #    cfg["val_data_loader"]["args"]["data_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_val_v4_spectrogram.npy"
    #    cfg["val_data_loader"]["args"]["csv_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_val_v4.csv"
    
        
    if "mixer" in cfg:
        mixer = get_instance(module_mix, "mixer", cfg)
    else:
        mixer = None

    data_loader = get_instance(module_data, "data_loader", cfg, transforms, sampler, mixer)
    valid_data_loader = get_instance(module_data, "val_data_loader", cfg, transforms, sampler)
    validation_df = valid_data_loader.dataset.df.copy()
    model = get_instance(module_arch, "arch", cfg)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint["state_dict"])
    model, device = setup_device(model, cfg["target_devices"])
    model.eval()

    # Predict
    preds = list()
    gt = list()
    outputs = list()
    with torch.no_grad():
        for idx, (image_batch, label_batch) in enumerate(valid_data_loader):
            data = image_batch.to(device)
            output = model.model.extract_features(data).cpu()
            output = model.model._avg_pooling(output)
            output = output.flatten(1)
            outputs.append(output.numpy())
            for idx_2, (image, label) in enumerate(zip(image_batch, label_batch)):
                gt.append(label[0].cpu().numpy())
                  
    outputs = np.vstack(outputs)                
    gt = np.vstack(gt).reshape(-1).tolist()
    
    if WITHOUT_SYNTH_NOISE:
        outputs = outputs[:309]
        gt = gt[:309]
        assert len(outputs) == 309
    
    output_list.append(outputs)
    gt_list.append(gt)

In [None]:
from sklearn.manifold import TSNE
tsne_embeddings = list()

for outputs in zip(output_list):
    outputs = outputs[0]
    X_embedded = TSNE(n_components=2).fit_transform(outputs)
    tsne_embeddings.append(X_embedded)

In [None]:
for model, tsne_embedding in zip(models, tsne_embeddings):
    print(model)
    aux_df = validation_df.copy()
    if WITHOUT_SYNTH_NOISE:
        aux_df = aux_df.iloc[:309, :]
        assert len(aux_df) == 309
    plt.figure(figsize=(20, 20))
    y = aux_df.target_type
    colors = ['r', 'g']
    text_column = "track_id"
    
    target_ids = aux_df.target_type.unique().tolist()
    names = list(aux_df.target_type.unique())
    for i, c, label in zip(target_ids, colors, names):
        text_values = aux_df.loc[y == i, text_column]
        plt.scatter(tsne_embedding[y == i, 0], tsne_embedding[y == i, 1], c=c, label=label)
        for text, coord_x, coord_y in zip(text_values, tsne_embedding[y == i, 0], tsne_embedding[y == i, 1]):
            plt.text(coord_x+.03, coord_y+.03, text, fontsize=13)
    plt.legend()
    plt.show()

In [None]:
for model, tsne_embedding in zip(models, tsne_embeddings):
    print(model)
    plt.figure(figsize=(10, 10))
    aux_df = validation_df.copy()
    if WITHOUT_SYNTH_NOISE:
        aux_df = aux_df.iloc[:309, :]
        assert len(aux_df) == 309
    y = aux_df.snr_type
    colors = ['r', 'g', 'b']
    target_ids = aux_df.snr_type.unique().tolist()
    names = list(aux_df.snr_type.unique())
    for i, c, label in zip(target_ids, colors, names):
        plt.scatter(tsne_embedding[y == i, 0], tsne_embedding[y == i, 1], c=c, label=label)
    plt.legend()
    plt.show()

In [None]:
for model, tsne_embedding in zip(models, tsne_embeddings):
    print(model)
    plt.figure(figsize=(10, 10))
    y = validation_df.source
    colors = ['r', 'g', 'b']
    target_ids = validation_df.source.unique().tolist()
    names = list(validation_df.source.unique())
    for i, c, label in zip(target_ids, colors, names):
        plt.scatter(tsne_embedding[y == i, 0], tsne_embedding[y == i, 1], c=c, label=label)
    plt.legend()
    plt.show()

# Test dataset

In [None]:
preds_list = list()
for model_path in models:
    print(model_path)
    cfg_path = os.path.join(os.path.dirname(model_path), "config.yml")
    cfg = load_config(cfg_path)
    cfg["data_loader"] = {
        "type": "MAFATTestDataLoader", 
        "args": {
            "batch_size": 64, 
            "data_dir": "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_test_v7_spectrogram.npy",
            "csv_dir": "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/MAFAT RADAR Challenge - Public Test Set V1.csv",
            "nworkers": 2
        }
    }
    transforms = get_instance(module_aug, "augmentation", cfg)
    if "sampler" in cfg:
        sampler = getattr(module_sampler, cfg["sampler"]["type"])
        sampler = partial(sampler, **cfg["sampler"]["args"])
    else:
        sampler = None
    # cfg["data_loader"]["args"]["sampler"] = sampler
    data_loader = get_instance(module_data, "data_loader", cfg, transforms)
    model = get_instance(module_arch, "arch", cfg)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint["state_dict"])
    model, device = setup_device(model, cfg["target_devices"])
    model.eval()

    test_df = data_loader.dataset.df.copy()

    counter = 0
    preds = list()
    gt = list()
    with torch.no_grad():
        for idx, image_batch in enumerate(data_loader):
            data = image_batch.to(device)
            output = model(data).cpu()
            output = torch.sigmoid(output)
            preds.append(output.cpu().numpy())

    preds = np.vstack(preds).reshape(-1).tolist()
    preds_list.append(preds)

In [None]:
couples = list(itertools.combinations(range(len(models)), 2))
unique_couples = set(couples)
for unique_couple in unique_couples:
    fig, ax = plt.subplots(figsize=(10,10))
    print(models[unique_couple[0]])
    print(models[unique_couple[1]])
    print("\t PC: ", pearsonr(preds_list[unique_couple[0]], preds_list[unique_couple[1]]))
    ax.plot([0.0, 1.0], [0.0, 1.0])
    ax.loglog(preds_list[unique_couple[0]], preds_list[unique_couple[1]], 'bo')
    for i, p in enumerate(preds_list[0]):
        ax.annotate(test_df.loc[i, "segment_id"], (preds_list[unique_couple[0]][i], preds_list[unique_couple[1]][i]))
    plt.show()

In [None]:
couples = list(itertools.combinations(range(len(models)), 2))
unique_couples = set(couples)
for unique_couple in unique_couples:
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111)
    print(models[unique_couple[0]])
    print(models[unique_couple[1]])
    preds_one = np.array(preds_list[unique_couple[0]])
    preds_two = np.array(preds_list[unique_couple[1]])
    print(np.sum(preds_one - preds_two))
    ax.hist(preds_one - preds_two)
    plt.show()

In [None]:
max_diff_indexes = np.argsort(np.abs(np.array(preds_list[0]) - np.array(preds_list[1])))
max_diff_indexes = max_diff_indexes[::-1]
for max_diff_idx in max_diff_indexes[:5]:
    print(preds_list[0][max_diff_idx])
    print(preds_list[1][max_diff_idx])
    print(test_df.iloc[max_diff_idx, :])

In [None]:
sorted(np.abs(np.array(preds_list[0]) - np.array(preds_list[1])))[::-1]

In [None]:
max_diff_indexes = np.argsort(np.abs(np.array(preds_list[0]) - np.array(preds_list[2])))
max_diff_indexes = max_diff_indexes[::-1]
for max_diff_idx in max_diff_indexes[:5]:
    print(preds_list[0][max_diff_idx])
    print(preds_list[2][max_diff_idx])
    print(test_df.iloc[max_diff_idx, :])

In [None]:
max_diff_indexes = np.argsort(np.abs(np.array(preds_list[1]) - np.array(preds_list[2])))
max_diff_indexes = max_diff_indexes[::-1]
for max_diff_idx in max_diff_indexes[:5]:
    print(preds_list[0][max_diff_idx])
    print(preds_list[2][max_diff_idx])
    print(test_df.iloc[max_diff_idx, :])

In [None]:
for model, preds in zip(models, preds_list):
    print(model)
    fig = plt.figure(figsize=(15,10))
    ax1 = fig.add_subplot(111)
    ax1.hist(preds, log=False, cumulative=False, bins=30, color='blue')
    plt.show()

In [None]:
df_list = []
idx = 0
concat_df = pd.DataFrame()
for model, preds in zip(models, preds_list):
    aux_df = test_df.copy()
    model_alias = model.split(os.sep)[-4]
    print(model_alias)
    aux_df["score"] = preds
    aux_df["segment_id"] = list(range(aux_df.shape[0]))
    aux_df = aux_df.sort_values("score")    
    concat_df["score" + str(idx)] = aux_df.score
    concat_df["segment_id" + str(idx)] = aux_df.copy().segment_id.values
    df_list.append(aux_df.copy())
    idx += 1
concat_df

## T-SNE

In [None]:
cfg

In [None]:
# cfg = load_config(CONFIG_PATH)
output_list = list()
for model_path in models:
    # Setup
    cfg = load_config(os.path.join(os.path.dirname(model_path), "config.yml"))
    transforms = get_instance(module_aug, "augmentation", cfg)
    
    # Adapt for testing      
    del cfg["data_loader"]["args"]["shuffle"]
    if "use_metadata" in cfg["data_loader"]["args"]:
        del cfg["data_loader"]["args"]["use_metadata"]
    cfg["data_loader"]["type"] = "MAFATTestDataLoader"
    cfg["data_loader"]["args"]["data_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/mafat_test_v7_spectrogram.npy"
    cfg["data_loader"]["args"]["csv_dir"] = "/home/agarcia/repos/mafat-radar-challenge/mafat_radar_challenge/data/MAFAT RADAR Challenge - Public Test Set V1.csv"
    cfg["data_loader"]["args"]["train"] = False
    
    data_loader = get_instance(module_data, "data_loader", cfg, transforms)
    model = get_instance(module_arch, "arch", cfg)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint["state_dict"])
    model, device = setup_device(model, cfg["target_devices"])
    model.eval()

    # Predict
    preds = list()
    outputs = list()
    preds = list()
    with torch.no_grad():
        for idx, image_batch in enumerate(data_loader):
            data = image_batch.to(device)
            output = model.model.extract_features(data).cpu()
            output = model.model._avg_pooling(output)
            output = output.flatten(1)
            outputs.append(output.numpy())
    
    outputs = np.vstack(outputs)                
    
    output_list.append(outputs)

In [None]:
from sklearn.manifold import TSNE
tsne_embeddings = list()

for outputs in zip(output_list):
    outputs = outputs[0]
    X_embedded = TSNE(n_components=2).fit_transform(outputs)
    tsne_embeddings.append(X_embedded)

In [None]:
test_df = data_loader.dataset.df

In [None]:
# TODO: Print with colors, using 0.5
for model, tsne_embedding, outputs in zip(models, tsne_embeddings, preds_list):
    print(model)
    plt.figure(figsize=(10, 10))
    y = np.array([1 if x > 0.5 else 0 for x in outputs])
    colors = ['r', 'g']
    text_column = "segment_id"
    names = ["0", "1"]
    target_ids = [0, 1]       
    
    for i, c, label in zip(target_ids, colors, names):
        text_values = test_df.loc[y == i, text_column] 
        plt.scatter(tsne_embedding[y == i, 0], tsne_embedding[y == i, 1], c=c, label=label)
        for text, coord_x, coord_y in zip(text_values, tsne_embedding[y == i, 0], tsne_embedding[y == i, 1]):
            plt.text(coord_x+.03, coord_y+.03, text, fontsize=13)
    plt.legend()
    plt.show()

In [None]:
preds_list