In [1]:
import random
from random import sample
import numpy as np
import os
import pickle
from tqdm import tqdm
from collections import OrderedDict
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from scipy.spatial.distance import mahalanobis
from scipy.ndimage import gaussian_filter
import torch
import torch.nn.functional as F
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# device setup
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Device: {}'.format(device))

Device: cuda


In [3]:
# set random seed
seed = 1024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed_all(seed)

In [4]:
config_path = 'config.yaml'
opt = read_config(config_path)
experiment_path = opt['dataset']['save_dir'] + '/' + opt['model']['backbone']
opt['model']['experiment_path'] = experiment_path

os.makedirs(os.path.join(experiment_path, 'normal_embeddings'), exist_ok=True)
train_feature_filepath = os.path.join(experiment_path, 'normal_embeddings', 'train_%s.pkl' % opt['dataset']['name'])
opt['model']['train_feature_filepath'] = train_feature_filepath

pic_save_path = os.path.join(experiment_path, 'pictures')
os.makedirs(pic_save_path, exist_ok=True)

In [6]:
opt

{'dataset': {'ann_path': './data/splits',
  'path': './data',
  'save_dir': './results',
  'name': 'fast_ixi',
  'target_size': [128, 128],
  'batch_size': 1},
 'model': {'backbone': 'resnet18',
  'target_dimension': 448,
  'output_dimension': 180,
  'experiment_path': './results/resnet18',
  'train_feature_filepath': './results/resnet18/normal_embeddings/train_fast_ixi.pkl'}}

### Load pre-trained CNN

In [7]:
# Load the pretrained CNN
model, t_d, d = load_pretrained_CNN(opt)
model = model.to(device)
model.eval()



Backbone: resnet18
Input dim size: 448
Output dim size after reduced: 180


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [8]:
# select randomly choosen dimension to reduce the dimensionality of the feature vector (like PCA)
idx = torch.tensor(sample(range(0, t_d), d))
# initialize the intermadiate outputs
outputs = []
def hook(module, input, output):
    outputs.append(output)
model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

<torch.utils.hooks.RemovableHandle at 0x7fa7f43c6970>

### Learning Normal Class Representation

In [9]:
train_dataloader = load_train_dataset(opt).train_dataloader()
learned_representation = OrderedDict([('layer1', []), ('layer2', []), ('layer3', [])])

Using 581 IXI images and 130 fastMRI images for training. Using 15 images for validation.


#### Extract embeddings fromm the train dataset and save it as Multivariate Gaussian Distribution

In [10]:
if not os.path.exists(train_feature_filepath):    
# for each batch in the dataloader (use tqdm bar), train dataloader get item returns only x
    for batch_idx, img in tqdm(enumerate(train_dataloader), '| feature extraction | train | %s |' % opt['dataset']['name']):
        img = img.to(device)
        with torch.no_grad():
            _ = model(img)
        for key, value in zip(learned_representation.keys(), outputs):
            learned_representation[key].append(value.cpu().detach())
        # initialize hook outputs
        outputs = []

    for key, value in learned_representation.items():
        learned_representation[key] = torch.cat(value, 0)

    print('first layer shape:', learned_representation['layer1'].shape)
    print('second layer shape:', learned_representation['layer2'].shape)
    print('third layer shape:', learned_representation['layer3'].shape)
    # Embedding concat
    embedding_vectors = learned_representation['layer1'] # get the maximum size of the embedding vectors
    """
    Rresearchers conceptually divide the input image into a grid based on the resolution of the largest activation map—typically
    the first layer of the pre-trained CNN. This way, each grid position, denoted as (i,j), 
    is associated with a unique embedding vector that represents the collective activation vectors for that particular image patch.
    """
    for layer_name in ['layer2', 'layer3']:
        embedding_vectors = embedding_concat(embedding_vectors, learned_representation[layer_name])

    # randomly select d dimension
    print('randomly select %d dimension' % opt['model']['output_dimension'])
    embedding_vectors = torch.index_select(embedding_vectors, 1, idx)

    B, C, H, W = embedding_vectors.size() # Get the shape of the embedding vectors which is same with the first layer of the pretrained model
    print('embedding_vectors shape:', embedding_vectors.shape)
    embedding_vectors = embedding_vectors.view(B, C, H * W)

    # calculate multivariate Gaussian distribution
    mean = torch.mean(embedding_vectors, dim=0).numpy()
    cov = torch.zeros(C, C, H * W).numpy()
    I = np.identity(C)

    # calculate mean, cov and inverse covariance matrix for each patch position at Xij 
    # (each patch position (i,j) is associated with a unique embedding vector)
    for i in range(H * W):
        # Xij = embedding_vectors[:, :, i].numpy()
        cov[:, :, i] = np.cov(embedding_vectors[:, :, i].numpy(), rowvar=False) + 0.01 * I

    # save learned distribution
    learned_representation = [mean, cov]
    with open(train_feature_filepath, 'wb') as f:
        pickle.dump(learned_representation, f)
else:
    with open(train_feature_filepath, 'rb') as f:
        learned_representation = pickle.load(f)   

| feature extraction | train | fast_ixi |: 711it [00:04, 174.88it/s]


first layer shape: torch.Size([711, 64, 32, 32])
second layer shape: torch.Size([711, 128, 16, 16])
third layer shape: torch.Size([711, 256, 8, 8])
randomly select 180 dimension
embedding_vectors shape: torch.Size([711, 180, 32, 32])


### Evaluate on Test Dataset

In [11]:
from data_loader import get_all_test_dataloaders

In [12]:
test_dataloaders = get_all_test_dataloaders( opt['dataset']['ann_path'],  opt['dataset']['target_size'], opt['dataset']['batch_size'])

In [13]:
print('number of anomaly classes:', len(test_dataloaders))

number of anomaly classes: 17


In [14]:
all_images = {key : [] for key in test_dataloaders.keys()}
all_labels = {key : [] for key in test_dataloaders.keys()}
all_pos_masks = {key : [] for key in test_dataloaders.keys()}
all_neg_masks = {key : [] for key in test_dataloaders.keys()}
all_thresholds = {key : [] for key in test_dataloaders.keys()}
all_test_outputs = { key: OrderedDict([('layer1', []), ('layer2', []), ('layer3', [])]) for key in test_dataloaders.keys()}
all_embedding_vectors = {key : [] for key in test_dataloaders.keys()}
all_scores = {key : [] for key in test_dataloaders.keys()}

In [15]:
for anomaly_class in test_dataloaders.keys():
    print('******************* DATASET: {} ****************'.format(anomaly_class)) 
    imgs = []
    labels = []
    pos_masks = []
    neg_masks = []
    for (img, label, pos_mask, neg_mask) in tqdm(test_dataloaders[anomaly_class], '| feature extraction | test | %s |' % anomaly_class):
        imgs.extend(img.cpu().detach().numpy())
        labels.extend(label.cpu().detach().numpy())
        pos_masks.extend(pos_mask.cpu().detach().numpy())
        neg_masks.extend(neg_mask.cpu().detach().numpy())

        # get the model prediction
        with torch.no_grad():
            _ = model(img.to(device))
        # get intermediate outputs
        for key, value in zip(all_test_outputs[anomaly_class].keys(), outputs):
            all_test_outputs[anomaly_class][key].append(value.cpu().detach())
        # initialize hook outputs
        outputs = []

    for key, value in all_test_outputs[anomaly_class].items():
        all_test_outputs[anomaly_class][key] = torch.cat(value, 0)
    # Embedding concat
    embedding_vectors = all_test_outputs[anomaly_class]['layer1']
    for layer_name in ['layer2', 'layer3']:
        embedding_vectors = embedding_concat(embedding_vectors, all_test_outputs[anomaly_class][layer_name])
    
    # randomly select d dimension
    embedding_vectors = torch.index_select(embedding_vectors, 1, idx)
    #print(embedding_vectors.shape)
    

    all_images[anomaly_class] = imgs
    all_labels[anomaly_class] = labels
    all_pos_masks[anomaly_class] = pos_masks
    all_neg_masks[anomaly_class] = neg_masks
    all_embedding_vectors[anomaly_class] = embedding_vectors

******************* DATASET: absent_septum ****************


| feature extraction | test | absent_septum |: 100%|██████████| 1/1 [00:00<00:00, 42.30it/s]


******************* DATASET: artefacts ****************


| feature extraction | test | artefacts |: 100%|██████████| 16/16 [00:00<00:00, 117.90it/s]


******************* DATASET: craniatomy ****************


| feature extraction | test | craniatomy |: 100%|██████████| 15/15 [00:00<00:00, 131.69it/s]


******************* DATASET: dural ****************


| feature extraction | test | dural |: 100%|██████████| 7/7 [00:00<00:00, 128.81it/s]


******************* DATASET: ea_mass ****************


| feature extraction | test | ea_mass |: 100%|██████████| 4/4 [00:00<00:00, 121.16it/s]


******************* DATASET: edema ****************


| feature extraction | test | edema |: 100%|██████████| 18/18 [00:00<00:00, 134.78it/s]


******************* DATASET: encephalomalacia ****************


| feature extraction | test | encephalomalacia |: 100%|██████████| 1/1 [00:00<00:00, 141.20it/s]


******************* DATASET: enlarged_ventricles ****************


| feature extraction | test | enlarged_ventricles |: 100%|██████████| 19/19 [00:00<00:00, 140.94it/s]


******************* DATASET: intraventricular ****************


| feature extraction | test | intraventricular |: 100%|██████████| 1/1 [00:00<00:00, 139.09it/s]


******************* DATASET: lesions ****************


| feature extraction | test | lesions |: 100%|██████████| 22/22 [00:00<00:00, 141.54it/s]


******************* DATASET: mass ****************


| feature extraction | test | mass |: 100%|██████████| 22/22 [00:00<00:00, 143.80it/s]


******************* DATASET: posttreatment ****************


| feature extraction | test | posttreatment |: 100%|██████████| 44/44 [00:00<00:00, 137.79it/s]


******************* DATASET: resection ****************


| feature extraction | test | resection |: 100%|██████████| 10/10 [00:00<00:00, 147.89it/s]


******************* DATASET: sinus ****************


| feature extraction | test | sinus |: 100%|██████████| 2/2 [00:00<00:00, 143.42it/s]


******************* DATASET: wml ****************


| feature extraction | test | wml |: 100%|██████████| 5/5 [00:00<00:00, 128.51it/s]


******************* DATASET: other ****************


| feature extraction | test | other |: 100%|██████████| 5/5 [00:00<00:00, 124.79it/s]


******************* DATASET: normal ****************


| feature extraction | test | normal |: 100%|██████████| 30/30 [00:00<00:00, 136.20it/s]


#### Get Anomaly Map for each dataset

In [16]:
for anomaly_class in test_dataloaders.keys():
    #print('******************* DATASET: {} ****************'.format(anomaly_class)) 
    # calculate mahalanobis distance between learned_representation to give anomaly score to each patch position of the test images
    embedding_vectors = all_embedding_vectors[anomaly_class]
    B, C, H, W = embedding_vectors.size()
    #print('embedding_vectors shape:', embedding_vectors.shape)
    
    embedding_vectors = embedding_vectors.view(B, C, H * W).numpy()
    dist_list = []
    for i in range(H * W):
        mean = learned_representation[0][:, i]
        conv_inv = np.linalg.inv(learned_representation[1][:, :, i])
        dist = [mahalanobis(sample[:, i], mean, conv_inv) for sample in embedding_vectors]
        dist_list.append(dist)
    dist_list = np.array(dist_list).transpose(1, 0).reshape(B, H, W)
    
    # upsample to image size to get anomaly score map
    dist_list = torch.tensor(dist_list)
    score_map = F.interpolate(dist_list.unsqueeze(1), size=all_images[anomaly_class][0].shape[2], mode='bilinear',
                                align_corners=False).squeeze().numpy()
    
    # apply gaussian smoothing on the score map
    for i in range(score_map.shape[0]):
        score_map[i] = gaussian_filter(score_map[i], sigma=4)
    
    # Normalize the score map
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)
    all_scores[anomaly_class] = scores

### Calculate Metrics

In [17]:
total_pixel_rocauc = []
total_auprc = []
total_dice_score =[]
for anomaly_class in test_dataloaders.keys():
    if anomaly_class == 'normal':
        pass
    else:
        print('******************* DATASET: {} ****************'.format(anomaly_class))
        gt_mask = np.asarray(all_pos_masks[anomaly_class])
        gt_mask = gt_mask.astype(int)
        gt_mask = gt_mask.astype(np.float32)
        scores = all_scores[anomaly_class]
        
        roc_auc = roc_auc_score(gt_mask.flatten(), scores.flatten())
        precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), scores.flatten())
        total_pixel_rocauc.append(roc_auc)


        ########## calculate optimal threshold #############

        a = 2 * precision * recall
        b = precision + recall
        f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
        threshold = thresholds[np.argmax(f1)]
        all_thresholds[anomaly_class] = threshold\
        
        #print('threshold for masks: %.3f' % (threshold))
        pred_mask = scores.copy()
        pred_mask[scores > threshold] = 1
        pred_mask[scores <= threshold] = 0

        ####################################################
        
        precision, recall, _ = precision_recall_curve(gt_mask.flatten(), pred_mask.flatten())
        pr_auc = auc(recall, precision)
        total_auprc.append(pr_auc)
        
        intersection = np.sum(scores.flatten() * pred_mask.flatten())
        dice_score = (2. * intersection) / (np.sum(scores.flatten()) + np.sum(pred_mask.flatten()))
        total_dice_score.append(dice_score)
        
        print('ROCAUC: %.3f' % (roc_auc))
        print('AUPRC: %.3f' % (pr_auc))
        print('DICE: %.3f' % (dice_score))

******************* DATASET: absent_septum ****************
ROCAUC: 0.810
AUPRC: 0.488
DICE: 0.495
******************* DATASET: artefacts ****************
ROCAUC: 0.902
AUPRC: 0.277
DICE: 0.132
******************* DATASET: craniatomy ****************
ROCAUC: 0.814
AUPRC: 0.168
DICE: 0.212
******************* DATASET: dural ****************
ROCAUC: 0.758
AUPRC: 0.547
DICE: 0.449
******************* DATASET: ea_mass ****************
ROCAUC: 0.956
AUPRC: 0.437
DICE: 0.189
******************* DATASET: edema ****************
ROCAUC: 0.892
AUPRC: 0.486
DICE: 0.359
******************* DATASET: encephalomalacia ****************
ROCAUC: 0.956
AUPRC: 0.586
DICE: 0.267
******************* DATASET: enlarged_ventricles ****************
ROCAUC: 0.875
AUPRC: 0.491
DICE: 0.368
******************* DATASET: intraventricular ****************
ROCAUC: 0.979
AUPRC: 0.539
DICE: 0.117
******************* DATASET: lesions ****************
ROCAUC: 0.909
AUPRC: 0.286
DICE: 0.129
******************* DATASET: mass

In [18]:
print('total AUROC:', np.mean(total_pixel_rocauc))

total AUROC: 0.8644207298490829


In [19]:
print('total AUPRC:', np.mean(total_auprc))

total AUPRC: 0.3953308950012978


In [20]:
print('total DICE:', np.mean(total_dice_score))

total DICE: 0.2968046507141038


### Save Qualitatives

In [21]:
for anomaly_class in test_dataloaders.keys():
    if anomaly_class == 'normal':
        pass
    else:        
        print('******************* DATASET: {} ****************'.format(anomaly_class))
        images = all_images[anomaly_class]
        scores = all_scores[anomaly_class]
        masks = all_pos_masks[anomaly_class]
        neg_masks = all_neg_masks[anomaly_class]
        threshold = all_thresholds[anomaly_class]
        save_dir = os.path.join(pic_save_path, anomaly_class)
        os.makedirs(save_dir, exist_ok=True)
        print('saving images to:', save_dir)    
        visualize_images(images, scores, neg_masks, masks, threshold, save_dir, anomaly_class)

******************* DATASET: absent_septum ****************
saving images to: ./results/resnet18/pictures/absent_septum
******************* DATASET: artefacts ****************
saving images to: ./results/resnet18/pictures/artefacts
******************* DATASET: craniatomy ****************
saving images to: ./results/resnet18/pictures/craniatomy
******************* DATASET: dural ****************
saving images to: ./results/resnet18/pictures/dural
******************* DATASET: ea_mass ****************
saving images to: ./results/resnet18/pictures/ea_mass
******************* DATASET: edema ****************
saving images to: ./results/resnet18/pictures/edema
******************* DATASET: encephalomalacia ****************
saving images to: ./results/resnet18/pictures/encephalomalacia
******************* DATASET: enlarged_ventricles ****************
saving images to: ./results/resnet18/pictures/enlarged_ventricles
******************* DATASET: intraventricular ****************
saving images to: