Normalization parameters:

for real images: mean: 17.181818181818183 ; std: 15.545454545454545

for semi-synthetic: mean: 61.66156689772726 ; std: 36.4139976810122

for synthetic lines: mean: 51.227415953125 ; std: 31.554029192471624

real images, new annotations: mean: 595.7211111111111 ; std: 1159.925

In [1]:
from pathlib import Path

import os
import pylab as plt
from skimage import io
import numpy as np
import json
import pandas as pd
import torch
import argparse
import albumentations as A
from tracenet.utils.loader import get_loaders
from tracenet.utils.plot import plot_keypoints, plot_traces, show_imgs
from tracenet.models.detr import DETR
from tracenet.losses.accuracy_metric import Metric

In [2]:
def show_imgs(imgs, s=4, titles=None):
    if titles is None:
        titles = [''] * len(imgs)
    fig, ax = plt.subplots(1, len(imgs), figsize=(len(imgs) * s, s))
    for i, im in enumerate(imgs):
        plt.sca(ax[i])
        io.imshow(im)
        plt.sca(ax[i])
        plt.title(titles[i])


In [3]:
path = '/research/sharedresources/cbi/data_exchange/hangrp/Development/mt_detection/'


data_dir = path + 'data_synth/lines_2D_400_high_contrast'
img_dir = 'img'
gt_dir = 'csv'
mean_std = (21.38, 23.5)
model_dir = './models/tracenet_01_lines_2p'


# data_dir = path + 'data_synth/lines_2D_400_high_contrast'
# img_dir = 'img'
# gt_dir = 'csv'
# mean_std = (21.38, 23.5)
# model_dir = './models/tracenet_02_lines_3p'

# data_dir = path + 'data_synth/filaments_2D_high_contrast'
# img_dir = 'img'
# gt_dir = 'csv'
# mean_std = (21.42, 20.46)
# model_dir = './models/tracenet_03_filaments_3p'


# data_dir = path + 'data_synth/filaments_2D_high_contrast'
# img_dir = 'img'
# gt_dir = 'csv'
# mean_std = (21.42, 20.46)
# model_dir = './models/tracenet_04_filaments_5p'


# data_dir = path + 'data_synth/filaments_2D_high_contrast'
# img_dir = 'img'
# gt_dir = 'csv'
# mean_std = (21.42, 20.46)
# model_dir = './models/tracenet_05_filaments_7p'


# data_dir = path + 'data_synth/filament_2D_low_contrast'
# img_dir = 'img'
# gt_dir = 'csv'
# mean_std = (53.9, 32.33)


# data_dir = path + 'data/training_data'
# img_dir = 'img'
# gt_dir = 'traces'
# mean_std = (595.72, 1159.925)

# data_dir = path + 'data/training_data'
# img_dir = 'img_synth'
# mean_std = (57.3, 32.7)
# gt_dir = 'traces'
train_dir = 'train'
val_dir = 'val'
bs = 2

## Prediction

In [4]:
models = os.listdir(model_dir)
models

['fresh-star-9',
 'crisp-resonance-12',
 'swept-bird-7',
 'playful-meadow-9',
 'spring-violet-15',
 'driven-moon-16',
 'prime-elevator-14',
 'misty-moon-8',
 'usual-yogurt-13',
 'devoted-darkness-11']

In [None]:
outputs = []
imgs_gt = []
metrics = []
for model in models:
    output = []
    gt = []
    model_path = rf'{model_dir}/{model}/best_model.pth'

    # load the model
    with open(os.path.join(os.path.dirname(model_path), 'config.json')) as f:
        config = json.load(f)
        
    n_points=config['n_points']
    net = DETR(n_points=n_points, n_classes=1).cuda()
    net.load_state_dict(torch.load(model_path))
    net.eval();
    train_dl, val_dl = get_loaders(data_dir, train_dir=train_dir, val_dir=val_dir, 
                               img_dir=img_dir, gt_dir=gt_dir, batch_size=1, 
                               mean_std=mean_std, shuffle=False, n_points=config['n_points'])
    metric = Metric()
    
    # predict
    for i, (imgs, _, targets) in enumerate(val_dl):
        gt.append(plot_traces(imgs[0][0], targets['trace'][0], return_image=True, n_points=config['n_points']))
        with torch.no_grad():
            out = net(imgs.cuda())
            for key in ['trace', 'trace_class']:
                targets[key] = [t.cuda() for t in targets[key]]
            metric(out, targets)
            probas = out['pred_logits'].softmax(-1)[0, :, 1:].cpu()
            keep = probas.max(-1).values > 0.7
            output.append(plot_traces(imgs[0][0], out['pred_traces'][0, keep].cpu(), 
                                      return_image=True, n_points=config['n_points']))
    outputs.append(output)
    imgs_gt.append(gt)
    metric.aggregate()
    metrics.append(metric)

In [6]:
config['n_points']

2

In [7]:
df = pd.DataFrame()
for metric, model in zip(metrics, models):
    metric.mean['model'] = model
    df = pd.concat([df, pd.Series(metric.mean).to_frame().transpose()], ignore_index=True)

In [8]:
df

Unnamed: 0,cardinality error,relative cardinality error,Precision,Recall,F1 Score,end distance,model
0,1.066667,0.09338,0.925576,0.993888,0.9561,0.03849,fresh-star-9
1,0.78,0.069322,0.939362,0.992508,0.963479,0.01904,crisp-resonance-12
2,0.64,0.055392,0.956722,0.989405,0.970834,0.015154,swept-bird-7
3,0.853333,0.072986,0.936584,0.996578,0.964042,0.022843,playful-meadow-9
4,6.08,0.974448,0.571291,0.994851,0.70745,0.089473,spring-violet-15
5,6.58,0.973005,0.555121,0.994157,0.697876,0.136416,driven-moon-16
6,3.4,0.458805,0.72704,0.98416,0.823335,0.109459,prime-elevator-14
7,0.786667,0.066228,0.948239,0.993503,0.968535,0.01679,misty-moon-8
8,2.28,0.346946,0.786052,0.982665,0.861059,0.194221,usual-yogurt-13
9,0.88,0.068098,0.943109,0.996166,0.967252,0.016974,devoted-darkness-11


In [9]:
ind = 2

In [10]:
len(metrics[ind].buffer['cardinality error'])

150

In [11]:
sum(metrics[ind].buffer['cardinality error'] > 0)

tensor(62, device='cuda:0')

In [12]:
sum(metrics[ind].buffer['relative cardinality error'] > 0.1)

tensor(35, device='cuda:0')

In [13]:
sum(metrics[ind].buffer['end distance'] > 5 / 512)

tensor(81)

In [None]:
# plot predictions
for i, (imgs, _, _) in enumerate(val_dl):
    if metrics[ind].buffer['relative cardinality error'][i] > 0.1:
        img_in = imgs[0][0].numpy()
        img_in = img_in - np.min(img_in)
        img_in = img_in / np.max(img_in)
        per = np.percentile(img_in, 100)
        img_in = np.clip(img_in, 0, per)
        prediction = outputs[ind][i]
        show_imgs([img_in, imgs_gt[0][i], prediction], s=10, 
                  titles=['input', 'ground_truth', models[ind]])