In [None]:
from datasets.rg_data import AstroDataLoaders
from pathlib import Path
from utils.logging import Logger
import torch
import torch.nn as nn
import time
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from models import tiramisu
from scipy import ndimage
import utils.training as train_utils
import numpy as np

## Constants

In [None]:
resume = 'latest.pth'
DATA_PATH = Path('data_reduced')
WEIGHTS_PATH = Path('weights')
results_dir = 'results'
log_file = 'log.txt'
batch_size = 20
n_classes = 4
device = torch.device('cuda:0')

## Data Loading

In [None]:
data_loader = AstroDataLoaders(DATA_PATH, batch_size)
test_loader = data_loader.get_test_loader()
if device == 'cuda':
    torch.cuda.manual_seed(0)
next(iter(test_loader))

## Model Definition

In [None]:
model = tiramisu.FCDenseNet67(n_classes=n_classes).to(device)
criterion = nn.NLLLoss(weight=data_loader.class_weight.cuda()).cuda()
train_utils.load_weights(model, str(WEIGHTS_PATH)+'/' + resume)
model

## Logger instance

In [None]:
writer = SummaryWriter()
logger = Logger(log_file, test_loader.dataset.classes, writer)

In [None]:
metric_values = ['union', 'tp', 'fp', 'fn', 'obj_tp', 'obj_fp', 'obj_fn']
metric_names = ['accuracy', 'iou', 'precision', 'recall', 'dice', 'obj_precision', 'obj_recall']
classes = ['Void', 'Sidelobe', 'Source', 'Galaxy']

## Helper functions

In [None]:
def get_predictions(output_batch):
    bs,c,h,w = output_batch.size()
    tensor = output_batch.data
    values, indices = tensor.cpu().max(1)
    indices = indices.view(bs,h,w)
    return indices

## Metrics functions

In [None]:
def compute_union(preds, targets, class_id):
    total_union = {}
    current_class = torch.where(preds == class_id, 1.,0.) # isolates the class of interest
    gt = torch.where(targets == class_id, 1., 0.)
    union = torch.where(torch.logical_or(current_class, gt), 1., 0.)

    total_union = union.sum().item()
    
    return total_union

In [None]:
def compute_confusion_matrix(preds, targets, class_id):

    assert preds.size() == targets.size()
    current_class = preds == class_id # isolates the class of interest
    gt = targets == class_id

    tp = gt.mul(current_class).eq(1).sum().item()
    fp = gt.eq(0).long().mul(current_class).eq(1).sum().item()
    fn = current_class.eq(0).long().mul(gt).eq(1).sum().item()
    tn = current_class.eq(0).long().mul(gt).eq(0).sum().item()

    return tp, fp, fn, tn

## Metrics for Object Detection comparison

In [None]:
def compute_object_confusion_matrix(preds, targets, class_id, threshold=0.5):

    tp = 0
    fp = 0
    fn = 0

    for pred, target in zip(preds, targets):

        gt = torch.where(target == class_id, 1., 0.)
        current_class = torch.where(pred == class_id, 1., 0.) # isolates the class of interest
        pred_objects, nr_pred_objects = ndimage.label(current_class)
        target_objects, nr_target_objects = ndimage.label(gt)

        for pred_idx in range(nr_pred_objects):
            current_obj_pred = torch.where(torch.from_numpy(pred_objects == pred_idx), 1., 0.)

            obj_iou = get_obj_iou(nr_target_objects, target_objects, current_obj_pred)
            if nr_target_objects != 0:
                if obj_iou >= threshold:
                    tp += 1
                else: 
                    fp += 1

        if nr_target_objects > nr_pred_objects:
            fn += (nr_target_objects - nr_pred_objects)
    
    return tp, fp, fn

In [None]:
def get_obj_iou(nr_target_objects, target_objects, current_obj_pred):
    obj_ious = []
    for target_idx in range(nr_target_objects):
        current_obj_target = torch.from_numpy(target_objects == target_idx)
        intersection = torch.where(torch.logical_and(current_obj_pred, current_obj_target), 1., 0.)
        union = torch.where(torch.logical_or(current_obj_pred, current_obj_target), 1., 0.)

        obj_ious.append(intersection.sum() / union.sum())
    if len(obj_ious) > 0:
        return np.nanmax(obj_ious).item()
    else:
        return 0 

## Aggregate metrics for each batch

In [None]:
def division(x,y):
    return x / y if y else 0
    
def compute_batch_metrics(union, tp, fp, fn, tn):

    # TODO IoU and Dice are the same metric, remove?

    accuracy       =   division(tp + tn, tp + fp + tn + fn)
    iou            =   division(tp, union)
    precision      =   division(tp, tp + fp)
    recall         =   division(tp, tp + fn)
    dice           =   division(tp, tp + fp + fn)

    return accuracy, iou, precision, recall, dice

def compute_batch_obj_metrics(obj_tp, obj_fp, obj_fn):

    obj_precision  =   division(obj_tp, obj_tp + obj_fp)
    obj_recall     =   division(obj_tp, obj_tp + obj_fn)

    return obj_precision, obj_recall


## Run Test

In [None]:
model.eval()
test_loss = 0
test_metrics = {class_name: {metric_name: 0. for metric_name in metric_names} for class_name in classes}
batch_metrics = {class_name: {metric_name: [] for metric_name in metric_names} for class_name in classes}

since = time.time()


for data, target in tqdm(test_loader, desc="Testing"):
    with torch.no_grad():
        data = data.to(device)
        targets = target.to(device)
        output = model(data)
        test_loss += criterion(output, targets).item()
        preds = get_predictions(output)

        # Skipping Background class in metric computation (i + 1)
        for i, class_name in enumerate(classes[1:]): 
            union = compute_union(preds, targets.data.cpu(), i + 1) 
            if union == 0:
                # There is no object with that class, skipping...
                continue

            tp, fp, fn, tn = compute_confusion_matrix(preds, targets.data.cpu(), i + 1)
            obj_tp, obj_fp, obj_fn = compute_object_confusion_matrix(preds, targets.data.cpu(), i + 1)

            accuracy, iou, precision, recall, dice = compute_batch_metrics(union, tp, fp, fn, tn)
            obj_precision, obj_recall = compute_batch_obj_metrics(obj_tp, obj_fp, obj_fn)

            batch_metrics[class_name]['accuracy'].append(accuracy)
            batch_metrics[class_name]['iou'].append(iou)
            batch_metrics[class_name]['precision'].append(precision)
            batch_metrics[class_name]['recall'].append(recall)
            batch_metrics[class_name]['dice'].append(dice)
            batch_metrics[class_name]['obj_precision'].append(obj_precision)
            batch_metrics[class_name]['obj_recall'].append(obj_recall)
        
test_loss /= len(test_loader)

time_elapsed = time.time() - since


for class_name in classes[1:]:
    test_metrics[class_name] = {metric_name: np.mean(batch_metrics[class_name][metric_name]) for metric_name in metric_names}

In [None]:
logger.log_metrics('Test', 1, test_loss, test_metrics, time_elapsed)
logger.wandb_plot_metrics(test_metrics, 'test')
train_utils.view_sample_predictions(model, test_loader, 1, 100, None)