In [1]:
import random
from random import sample
import numpy as np
import os
import pickle
from tqdm import tqdm
from collections import OrderedDict
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from scipy.spatial.distance import mahalanobis
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from preprocess_data import DataHandler
from torch.utils.data import DataLoader
from torchvision.models import wide_resnet50_2, resnet18
from fast_ixi import FAST_IXI
from utils import *
from eval import *

  from .autonotebook import tqdm as notebook_tqdm


[2024-02-01 18:34:06,720] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
# device setup
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('Device: {}'.format(device))

Device: cuda


In [3]:
# set random seed
seed = 1024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda:
    torch.cuda.manual_seed_all(seed)

In [4]:
config_path = 'config.yaml'
opt = read_config(config_path)
experiment_path = opt['dataset']['save_dir'] + '/' + opt['model']['backbone']
opt['model']['experiment_path'] = experiment_path

os.makedirs(os.path.join(experiment_path, 'temp_%s' % opt['model']['backbone']), exist_ok=True)
train_feature_filepath = os.path.join(experiment_path, 'temp_%s' % opt['model']['backbone'], 'train_%s.pkl' % 'brainmri')
opt['model']['train_feature_filepath'] = train_feature_filepath

pic_save_path = os.path.join(experiment_path, 'pictures')
os.makedirs(pic_save_path, exist_ok=True)

### Load pre-trained CNN

In [5]:
# Load the pretrained CNN
model, t_d, d = load_pretrained_CNN(opt)
model = model.to(device)
model.eval()

Backbone: wide_resnet50_2
Input dim size: 1792
Output dim size after reduced: 550


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid

In [6]:
# select randomly choosen dimension to reduce the dimensionality of the feature vector (like PCA)
idx = torch.tensor(sample(range(0, t_d), d))
# initialize the intermadiate outputs
outputs = []
def hook(module, input, output):
    outputs.append(output)
model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

<torch.utils.hooks.RemovableHandle at 0x7f84a68a3940>

### Learning Normal Class Representation

In [7]:
train_dataloader = load_train_dataset(opt).train_dataloader()
learned_representation = OrderedDict([('layer1', []), ('layer2', []), ('layer3', [])])

Using 581 IXI images and 130 fastMRI images for training. Using 15 images for validation.




#### Extract embeddings fromm the train dataset and save it as Multivariate Gaussian Distribution

In [8]:
if not os.path.exists(train_feature_filepath):    
# for each batch in the dataloader (use tqdm bar), train dataloader get item returns only x
    for batch_idx, img in tqdm(enumerate(train_dataloader), '| feature extraction | train | %s |' % 'brainmri'):
        img = img.to(device)
        with torch.no_grad():
            _ = model(img)
        for key, value in zip(learned_representation.keys(), outputs):
            learned_representation[key].append(value.cpu().detach())
        # initialize hook outputs
        outputs = []

    for key, value in learned_representation.items():
        learned_representation[key] = torch.cat(value, 0)

    print('first layer shape:', learned_representation['layer1'].shape)
    print('second layer shape:', learned_representation['layer2'].shape)
    print('third layer shape:', learned_representation['layer3'].shape)
    # Embedding concat
    embedding_vectors = learned_representation['layer1'] # get the maximum size of the embedding vectors
    """
    Rresearchers conceptually divide the input image into a grid based on the resolution of the largest activation map—typically
    the first layer of the pre-trained CNN. This way, each grid position, denoted as (i,j), 
    is associated with a unique embedding vector that represents the collective activation vectors for that particular image patch.
    """
    for layer_name in ['layer2', 'layer3']:
        embedding_vectors = embedding_concat(embedding_vectors, learned_representation[layer_name])

    # randomly select d dimension
    print('randomly select %d dimension' % opt['model']['output_dimension'])
    embedding_vectors = torch.index_select(embedding_vectors, 1, idx)

    B, C, H, W = embedding_vectors.size() # Get the shape of the embedding vectors which is same with the first layer of the pretrained model
    print('embedding_vectors shape:', embedding_vectors.shape)
    embedding_vectors = embedding_vectors.view(B, C, H * W)

    # calculate multivariate Gaussian distribution
    mean = torch.mean(embedding_vectors, dim=0).numpy()
    cov = torch.zeros(C, C, H * W).numpy()
    I = np.identity(C)

    # calculate mean, cov and inverse covariance matrix for each patch position at Xij 
    # (each patch position (i,j) is associated with a unique embedding vector)
    for i in range(H * W):
        # Xij = embedding_vectors[:, :, i].numpy()
        cov[:, :, i] = np.cov(embedding_vectors[:, :, i].numpy(), rowvar=False) + 0.01 * I

    # save learned distribution
    learned_representation = [mean, cov]
    with open(train_feature_filepath, 'wb') as f:
        pickle.dump(learned_representation, f)
else:
    with open(train_feature_filepath, 'rb') as f:
        learned_representation = pickle.load(f)   

### Evaluate on Test Dataset

In [9]:
from data_loader import get_all_test_dataloaders

In [10]:
test_dataloaders = get_all_test_dataloaders( opt['dataset']['ann_path'],  opt['dataset']['target_size'], opt['dataset']['batch_size'])

In [11]:
print('number of anomaly classes:', len(test_dataloaders))

number of anomaly classes: 17


In [12]:
all_images = {key : [] for key in test_dataloaders.keys()}
all_labels = {key : [] for key in test_dataloaders.keys()}
all_pos_masks = {key : [] for key in test_dataloaders.keys()}
all_neg_masks = {key : [] for key in test_dataloaders.keys()}
all_thresholds = {key : [] for key in test_dataloaders.keys()}
all_test_outputs = { key: OrderedDict([('layer1', []), ('layer2', []), ('layer3', [])]) for key in test_dataloaders.keys()}
all_embedding_vectors = {key : [] for key in test_dataloaders.keys()}
all_scores = {key : [] for key in test_dataloaders.keys()}

extend_images = []
extend_labels = []
extend_pos_masks = []
extend_neg_masks = []
extend_scores = []


In [13]:
for anomaly_class in test_dataloaders.keys():
    print('******************* DATASET: {} ****************'.format(anomaly_class)) 
    imgs = []
    labels = []
    pos_masks = []
    neg_masks = []
    for (img, label, pos_mask, neg_mask) in tqdm(test_dataloaders[anomaly_class], '| feature extraction | test | %s |' % anomaly_class):
        imgs.extend(img.cpu().detach().numpy())
        labels.extend(label.cpu().detach().numpy())
        pos_masks.extend(pos_mask.cpu().detach().numpy())
        neg_masks.extend(neg_mask.cpu().detach().numpy())
        extend_images.extend(img.cpu().detach().numpy())
        extend_labels.extend(label.cpu().detach().numpy())
        extend_pos_masks.extend(pos_mask.cpu().detach().numpy())
        extend_neg_masks.extend(neg_mask.cpu().detach().numpy())

        # get the model prediction
        with torch.no_grad():
            _ = model(img.to(device))
        # get intermediate outputs
        for key, value in zip(all_test_outputs[anomaly_class].keys(), outputs):
            all_test_outputs[anomaly_class][key].append(value.cpu().detach())
            
    # initialize hook outputs
    outputs = []
    for key, value in all_test_outputs[anomaly_class].items():
        all_test_outputs[anomaly_class][key] = torch.cat(value, 0)
    # Embedding concat
    embedding_vectors = all_test_outputs[anomaly_class]['layer1']
    for layer_name in ['layer2', 'layer3']:
        embedding_vectors = embedding_concat(embedding_vectors, all_test_outputs[anomaly_class][layer_name])
    
    # randomly select d dimension
    embedding_vectors = torch.index_select(embedding_vectors, 1, idx)
    print(embedding_vectors.shape)
    

    all_images[anomaly_class] = imgs
    all_labels[anomaly_class] = labels
    all_pos_masks[anomaly_class] = pos_masks
    all_neg_masks[anomaly_class] = neg_masks
    all_embedding_vectors[anomaly_class] = embedding_vectors

******************* DATASET: absent_septum ****************


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
| feature extraction | test | absent_septum |: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.93it/s]


torch.Size([1, 550, 32, 32])
******************* DATASET: artefacts ****************


| feature extraction | test | artefacts |: 100%|███████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 45.08it/s]


torch.Size([16, 550, 32, 32])
******************* DATASET: craniatomy ****************


| feature extraction | test | craniatomy |: 100%|██████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 47.69it/s]


torch.Size([15, 550, 32, 32])
******************* DATASET: dural ****************


| feature extraction | test | dural |: 100%|█████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 49.07it/s]


torch.Size([7, 550, 32, 32])
******************* DATASET: ea_mass ****************


| feature extraction | test | ea_mass |: 100%|███████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 49.91it/s]


torch.Size([4, 550, 32, 32])
******************* DATASET: edema ****************


| feature extraction | test | edema |: 100%|███████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 48.19it/s]


torch.Size([18, 550, 32, 32])
******************* DATASET: encephalomalacia ****************


| feature extraction | test | encephalomalacia |: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 47.45it/s]


torch.Size([1, 550, 32, 32])
******************* DATASET: enlarged_ventricles ****************


| feature extraction | test | enlarged_ventricles |: 100%|█████████████████████████████████████████████████████████| 19/19 [00:00<00:00, 49.42it/s]


torch.Size([19, 550, 32, 32])
******************* DATASET: intraventricular ****************


| feature extraction | test | intraventricular |: 100%|██████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.94it/s]


torch.Size([1, 550, 32, 32])
******************* DATASET: lesions ****************


| feature extraction | test | lesions |: 100%|█████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 49.66it/s]


torch.Size([22, 550, 32, 32])
******************* DATASET: mass ****************


| feature extraction | test | mass |: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 50.35it/s]


torch.Size([22, 550, 32, 32])
******************* DATASET: posttreatment ****************


| feature extraction | test | posttreatment |: 100%|███████████████████████████████████████████████████████████████| 44/44 [00:00<00:00, 49.96it/s]


torch.Size([44, 550, 32, 32])
******************* DATASET: resection ****************


| feature extraction | test | resection |: 100%|███████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 51.40it/s]


torch.Size([10, 550, 32, 32])
******************* DATASET: sinus ****************


| feature extraction | test | sinus |: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 51.16it/s]


torch.Size([2, 550, 32, 32])
******************* DATASET: wml ****************


| feature extraction | test | wml |: 100%|███████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 48.70it/s]


torch.Size([5, 550, 32, 32])
******************* DATASET: other ****************


| feature extraction | test | other |: 100%|█████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 53.12it/s]


torch.Size([5, 550, 32, 32])
******************* DATASET: normal ****************


| feature extraction | test | normal |: 100%|██████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 50.93it/s]


torch.Size([30, 550, 32, 32])


#### Get Anomaly Map for each dataset

In [14]:
for anomaly_class in test_dataloaders.keys():
    print('******************* DATASET: {} ****************'.format(anomaly_class)) 
    # calculate mahalanobis distance between learned_representation to give anomaly score to each patch position of the test images
    embedding_vectors = all_embedding_vectors[anomaly_class]
    B, C, H, W = embedding_vectors.size()
    print('embedding_vectors shape:', embedding_vectors.shape)
    
    embedding_vectors = embedding_vectors.view(B, C, H * W).numpy()
    dist_list = []
    for i in range(H * W):
        mean = learned_representation[0][:, i]
        conv_inv = np.linalg.inv(learned_representation[1][:, :, i])
        dist = [mahalanobis(sample[:, i], mean, conv_inv) for sample in embedding_vectors]
        dist_list.append(dist)
    dist_list = np.array(dist_list).transpose(1, 0).reshape(B, H, W)
    
    # upsample to image size to get anomaly score map
    dist_list = torch.tensor(dist_list)
    score_map = F.interpolate(dist_list.unsqueeze(1), size=all_images[anomaly_class][0].shape[2], mode='bilinear',
                                align_corners=False).squeeze().numpy()
    
    # apply gaussian smoothing on the score map
    for i in range(score_map.shape[0]):
        score_map[i] = gaussian_filter(score_map[i], sigma=4)
    
    # Normalize the score map
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)
    all_scores[anomaly_class] = scores

******************* DATASET: absent_septum ****************
embedding_vectors shape: torch.Size([1, 550, 32, 32])
******************* DATASET: artefacts ****************
embedding_vectors shape: torch.Size([16, 550, 32, 32])
******************* DATASET: craniatomy ****************
embedding_vectors shape: torch.Size([15, 550, 32, 32])
******************* DATASET: dural ****************
embedding_vectors shape: torch.Size([7, 550, 32, 32])
******************* DATASET: ea_mass ****************
embedding_vectors shape: torch.Size([4, 550, 32, 32])
******************* DATASET: edema ****************
embedding_vectors shape: torch.Size([18, 550, 32, 32])
******************* DATASET: encephalomalacia ****************
embedding_vectors shape: torch.Size([1, 550, 32, 32])
******************* DATASET: enlarged_ventricles ****************
embedding_vectors shape: torch.Size([19, 550, 32, 32])
******************* DATASET: intraventricular ****************
embedding_vectors shape: torch.Size([1, 5

### Calculate Metrics

In [15]:
total_pixel_rocauc = []
total_auprc = []
total_dice_score =[]
for anomaly_class in test_dataloaders.keys():
    if anomaly_class == 'normal':
        pass
    else:
        print('******************* DATASET: {} ****************'.format(anomaly_class))
        gt_mask = np.asarray(all_pos_masks[anomaly_class])
        gt_mask = gt_mask.astype(int)
        gt_mask = gt_mask.astype(np.float32)
        scores = all_scores[anomaly_class]
        
        roc_auc = roc_auc_score(gt_mask.flatten(), scores.flatten())
        precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), scores.flatten())
        total_pixel_rocauc.append(roc_auc)


        ########## calculate optimal threshold #############

        a = 2 * precision * recall
        b = precision + recall
        f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
        threshold = thresholds[np.argmax(f1)]
        all_thresholds[anomaly_class] = threshold\
        
        print('threshold for masks: %.3f' % (threshold))
        pred_mask = scores.copy()
        pred_mask[pred_mask > threshold] = 1
        pred_mask[pred_mask <= threshold] = 0

        ####################################################
        
        precision, recall, _ = precision_recall_curve(gt_mask.flatten(), pred_mask.flatten())
        pr_auc = auc(recall, precision)
        total_auprc.append(pr_auc)
        
        intersection = np.sum(scores.flatten() * pred_mask.flatten())
        dice_score = (2. * intersection) / (np.sum(scores.flatten()) + np.sum(pred_mask.flatten()))
        total_dice_score.append(dice_score)
        
        print('pixel ROCAUC: %.3f' % (roc_auc))
        print('pixel AUPRC: %.3f' % (pr_auc))
        print('pixel DICE: %.3f' % (dice_score))

******************* DATASET: absent_septum ****************
threshold for masks: 0.512
pixel ROCAUC: 0.810
pixel AUPRC: 0.456
pixel DICE: 0.511
******************* DATASET: artefacts ****************
threshold for masks: 0.294
pixel ROCAUC: 0.713
pixel AUPRC: 0.363
pixel DICE: 0.455
******************* DATASET: craniatomy ****************
threshold for masks: 0.446
pixel ROCAUC: 0.587
pixel AUPRC: 0.447
pixel DICE: 0.756
******************* DATASET: dural ****************
threshold for masks: 0.471
pixel ROCAUC: 0.740
pixel AUPRC: 0.476
pixel DICE: 0.581
******************* DATASET: ea_mass ****************
threshold for masks: 0.331
pixel ROCAUC: 0.523
pixel AUPRC: 0.479
pixel DICE: 0.717
******************* DATASET: edema ****************
threshold for masks: 0.478
pixel ROCAUC: 0.727
pixel AUPRC: 0.524
pixel DICE: 0.642
******************* DATASET: encephalomalacia ****************
threshold for masks: 0.753
pixel ROCAUC: 0.950
pixel AUPRC: 0.596
pixel DICE: 0.252
******************

In [16]:
print('total AUROC:', np.mean(total_pixel_rocauc))

total AUROC: 0.7537521874199522


In [17]:
print('total AUPRC:', np.mean(total_auprc))

total AUPRC: 0.4056331415630455


In [18]:
print('total DICE:', np.mean(total_dice_score))

total DICE: 0.47759564603874627


### Save Qualitatives

In [19]:
from eval import *

In [21]:
# for anomaly_class in test_dataloaders.keys():
#     if anomaly_class == 'normal':
#         pass
#     elif len(all_images[anomaly_class]) == 1:
#         pass
#     else:
#         print('******************* DATASET: {} ****************'.format(anomaly_class))
#         images = all_images[anomaly_class]
#         scores = all_scores[anomaly_class]
#         masks = all_pos_masks[anomaly_class]
#         neg_masks = all_neg_masks[anomaly_class]
#         threshold = all_thresholds[anomaly_class]
#         save_dir = os.path.join(experiment_path, anomaly_class)
#         os.makedirs(save_dir, exist_ok=True)
#         plot_fig(images, scores, neg_masks, masks, threshold, save_dir, anomaly_class)