In [35]:
import os
import sys
import json

import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from tdigest import TDigest
from torch.utils.data import DataLoader

sys.path.append(os.path.join('..', '..', 'tta_uia_segmentation', 'src'))

from dataset.dataset_in_memory import get_datasets
from tta_uia_segmentation.src.models.io import load_norm_from_configs_and_cpt
from tta_uia_segmentation.src.utils.io import load_config
from tta_uia_segmentation.src.utils.loss import dice_score
from tta_uia_segmentation.src.dataset.utils import onehot_to_class, class_to_onehot

# Load trained model and dataset

In [22]:
# Load Normalizer and Segmentation model
trained_model_dir   = '/scratch_net/biwidl319/jbermeo/data/models/brain/segmentation/hcp_t1/no_bg_supp_norm_w_3x3_conv'
dataset_name        = 'hcp_t1'
split               = 'train'
image_size          = (1, 256, 256)
device              = 'cuda' if torch.cuda.is_available() else 'cpu' 

params              = load_config(os.path.join(trained_model_dir, 'params.yaml'))
dataset_params      = params['dataset'][dataset_name]
model_params        = params['model']
train_params        = params['training']['segmentation']


In [23]:
# Load dataset with original preprocessed images (val)
(ds, )  = get_datasets(
        splits          = [split],
        paths           = dataset_params['paths_processed'],
        paths_original  = dataset_params['paths_original'], 
        image_size      = train_params['image_size'],
        resolution_proc = dataset_params['resolution_proc'],
        dim_proc        = dataset_params['dim'],
        n_classes       = dataset_params['n_classes'],
        aug_params      = train_params['augmentation'],
        deformation     = None,
        load_original   = True,
        bg_suppression_opts = train_params['bg_suppression_opts']
    )


MemoryError: Unable to allocate 1.29 GiB for an array with shape (20, 256, 260, 260) and data type float32

In [None]:
norm = load_norm_from_configs_and_cpt(
    model_params_norm=model_params['normalization_2D'],
    cpt_fp=os.path.join(trained_model_dir, 'checkpoint_best.pth'),
    device=device
)
norm.eval()

Normalization(
  (layers): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=reflect)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): RBF()
    (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=reflect)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): RBF()
    (6): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=reflect)
  )
)

# Calculate quantiles and first two moments and save them as a dictionary

In [42]:
train_params

{'augmentation': {'alpha': 1000,
  'brightness_max': 0.1,
  'brightness_min': 0.0,
  'da_ratio': 0.25,
  'gamma_max': 2.0,
  'gamma_min': 0.5,
  'noise_mean': 0.0,
  'noise_std': 0.1,
  'rot_max': 10,
  'rot_min': -10,
  'scale_max': 1.1,
  'scale_min': 0.9,
  'sigma': 20,
  'trans_max': 10,
  'trans_min': -10},
 'batch_size': 16,
 'bg_suppression_opts': {'bg_value': -0.5,
  'bg_value_max': 1,
  'bg_value_min': -0.5,
  'hole_filling': True,
  'mask_source': 'thresholding',
  'thresholding': 'otsu',
  'type': 'none'},
 'dataset': 'hcp_t1',
 'epochs': 150,
 'image_size': [1, 256, 256],
 'learning_rate': 0.001,
 'logdir': '/scratch_net/biwidl319/jbermeo/data/models/brain/segmentation/hcp_t1/no_bg_supp_norm_w_3x3_conv',
 'num_workers': 3,
 'validate_every': 3,
 'wandb_project': 'mt-segmentation_models',
 'with_bg_supression': False}

In [26]:
dl = DataLoader(ds, batch_size=16, shuffle=False)

In [27]:
n_epochs = 1
sum = 0
sum_sq = 0
n = 0
digest = TDigest()

for img, *_ in tqdm(dl):
    img = img.to(device)
    
    # Get normalized image
    with torch.no_grad():
        img = norm(img)
    
    n += torch.prod(torch.tensor(img.shape))
    sum += img.sum().item()
    sum_sq += (img ** 2).sum().item()
    
    digest.batch_update(img.cpu().numpy().flatten())
            
mean = sum / n
std = np.sqrt(sum_sq / n - mean ** 2)
print(f'Mean: {mean}, Std: {std}')
print(f'Deciles: {[digest.percentile(i) for i in np.linspace(0, 1, 11)]}')
        

  1%|          | 3/320 [01:44<3:04:13, 34.87s/it]


KeyboardInterrupt: 

In [37]:
mean = sum / n
std = np.sqrt(sum_sq / n - mean ** 2)

In [40]:
quantiles = np.concatenate([
    np.array([0.001, 0.01, 0.025, 0.05]),
    np.arange(.1, 0.9, 0.05),
    np.array([0.95, 0.975, 0.99, 0.999])
])

statistics = {
    'moments': {
        'mean': mean,   
        'std': std  
    }, 
    'quantiles': {
        q: digest.percentile(q * 100) for q in quantiles
    }
}

In [41]:
statistics

{'moments': {'mean': tensor(-0.5634), 'std': tensor(582.3605)},
 'quantiles': {0.001: -0.9945433342916964,
  0.01: -0.9021149695185093,
  0.025: -0.854265213743379,
  0.05: -0.809861784110772,
  0.1: -0.7521351086885315,
  0.15000000000000002: -0.7335615792542854,
  0.20000000000000004: -0.7335615575900846,
  0.25000000000000006: -0.7335615359258834,
  0.30000000000000004: -0.7335607424755598,
  0.3500000000000001: -0.7335499086800722,
  0.40000000000000013: -0.7335390748845846,
  0.45000000000000007: -0.733528241089097,
  0.5000000000000001: -0.714010575237264,
  0.5500000000000002: -0.7035516091813828,
  0.6000000000000002: -0.6783800194401608,
  0.6500000000000001: -0.6635622541876772,
  0.7000000000000002: -0.6441343705656857,
  0.7500000000000002: -0.6422032084829853,
  0.8000000000000002: -0.5911896446707012,
  0.8500000000000002: -0.2608513039432001,
  0.95: 0.09741911641490902,
  0.975: 0.1882849318456169,
  0.99: 0.24772406815528256,
  0.999: 0.6185989379882812}}

In [36]:
json.dumps(digest.to_dict())

'{"n": 3207722, "delta": 0.01, "K": 25, "centroids": [{"m": -1.2559771537780762, "c": 1.0}, {"m": -1.1844812631607056, "c": 1.0}, {"m": -1.1696264743804932, "c": 1.0}, {"m": -1.1674108505249023, "c": 1.0}, {"m": -1.1660230159759521, "c": 1.0}, {"m": -1.1632336378097534, "c": 1.0}, {"m": -1.1623378992080688, "c": 1.0}, {"m": -1.1572933197021484, "c": 1.0}, {"m": -1.15138840675354, "c": 1.0}, {"m": -1.150769591331482, "c": 1.0}, {"m": -1.1501342058181763, "c": 1.0}, {"m": -1.1498668193817139, "c": 1.0}, {"m": -1.1497288942337036, "c": 1.0}, {"m": -1.1497209072113037, "c": 1.0}, {"m": -1.1487590074539185, "c": 1.0}, {"m": -1.1452170610427856, "c": 1.0}, {"m": -1.1438095569610596, "c": 1.0}, {"m": -1.14359450340271, "c": 1.0}, {"m": -1.1434494256973267, "c": 1.0}, {"m": -1.1413179636001587, "c": 1.0}, {"m": -1.1407928466796875, "c": 1.0}, {"m": -1.1398078203201294, "c": 1.0}, {"m": -1.1379530429840088, "c": 1.0}, {"m": -1.1377865076065063, "c": 1.0}, {"m": -1.137774109840393, "c": 1.0}, {"