In [None]:
from pathlib import Path
import sys
import time
from typing import Dict, List

import numpy as np
import pandas as pd
import torch as T
import torch.nn.functional as F
import torch.utils.data as tdata

sys.path.append('..')
from agent.config import load_config, override_config
from agent.state.model import Supervisor
from agent.state.data import Dataset

LOG_DIR = Path('../runs/2019-09-05T155134/')
CONFIG_FILE = Path(LOG_DIR / 'agent_train_conf.yaml')
CKPT_FILE = Path(LOG_DIR / 'encoder_ckpt/epoch_495.tar')
DATA_DIR = Path('../data')
N_EPOCHS = 3
BATCH_SIZE = 1024

In [None]:
T.set_grad_enabled(False)

conf = override_config(load_config(CONFIG_FILE), state={'num_workers': 4})
ckpt = T.load(CKPT_FILE)

if T.cuda.is_available() and T.backends.cudnn.is_available():
    device = T.device('cuda')
else:
    device = T.device('cpu')

model = Supervisor(conf.state, raw=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
model.to(device)

print(conf)
print('')
print(model)
print('')
print('total parameters:  ', sum(p.numel() for p in model.parameters()))
print('encoder parameters:', sum(p.numel() for p in model.encoder.parameters()))
print('convnet parameters:', sum(p.numel() for p in model.encoder._conv_net.parameters()))
print('')
print('using device:', device)

In [None]:
dataset = Dataset(DATA_DIR, conf)
collate_fn = Dataset.CollateFn(np.ones((conf.state.num_target,)))
loader = tdata.DataLoader(dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=conf.state.num_workers,
                          collate_fn=collate_fn,
                          pin_memory=True)

In [None]:
from sklearn.metrics import confusion_matrix

def get_decile_counts(class_out: np.ndarray) -> np.ndarray:
    counts = np.ndarray((10,), dtype=np.int)
    deciles = np.linspace(0, 1, 11)
    # Include 1 in last decile
    deciles[-1] += 0.01
    for i, (start, stop) in enumerate(zip(deciles[:1], deciles[1:])):
        counts[i] = np.count_nonzero(np.logical_and(start <= class_out, class_out < stop))
    return counts

def accuracy(cm: np.ndarray) -> float:
    # (tn + tp) / (tn + fn + fp + tp)
    return (cm[0,0] + cm[1,1]) / np.sum(cm)

def precision(cm: np.ndarray) -> float:
    # tp / (tp + fp)
    denom = (cm[1,1] + cm[0,1])
    if denom == 0:
        return 0
    else:
        return cm[1,1] / denom

def recall(cm: np.ndarray) -> float:
    # tp / (tp + fn)
    denom = cm[1,1] + cm[1,0]
    if denom == 0:
        return 0
    else:
        return cm[1,1] / denom

def f1_score(cm: np.ndarray) -> float:
    p = precision(cm)
    r = recall(cm)
    if p + r == 0:
        return 0
    else:
        return 2 * (p * r) / (p + r)

conf_matrices = np.zeros((conf.state.num_target, 2, 2), dtype=np.int64)
out_deciles = np.zeros((conf.state.num_target, 10), dtype=np.int64)
for epoch in range(1, N_EPOCHS + 1):
    epoch_start = time.perf_counter()
    for i, batch in enumerate(loader):
        batch.to(device, non_blocking=True)
        out = model(batch)
        mse_raw = F.mse_loss(out, batch.target, reduction='none')
        out_np = out.cpu().numpy()
        pred_np = (out_np >= 0.5).astype(np.int)
        target_np = batch.target.cpu().numpy()

        for j in range(conf.state.num_target):
            conf_matrices[j] += confusion_matrix(target_np[:,j], pred_np[:,j], labels=[0, 1])
            out_deciles[j] += get_decile_counts(out_np[:,j])

    d_time = round((time.perf_counter() - epoch_start) / 60, 2)
    print('Epoch {} complete in {} m'.format(epoch, d_time))

In [None]:
metrics = []
for i in range(conf.state.num_target):
    cm = conf_matrices[i]
    metrics.append([f1_score(cm), precision(cm), recall(cm), accuracy(cm)])

df = pd.DataFrame(metrics, columns=['f1', 'precision', 'recall', 'accuracy'], index=Dataset.TARGET_COLS)
df

In [None]:
df['f1'][df['f1'] > 0].median()