In [None]:
import pickle
import os
import numpy as np
import torch

In [None]:
with open('MRI_Image_TO_EDSS', 'rb') as f:
    MRI_Image_TO_EDSS = pickle.load(f)
print(len(MRI_Image_TO_EDSS))

In [None]:
images_T1_post = []
images_T1_pre = []
images_T2 = []
images_flair  = []

labels = [] # np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
MRIs = []

for i in range(len(MRI_Image_TO_EDSS)):
    MRI_Image = str(int(MRI_Image_TO_EDSS[i][0]))
    if MRI_Image not in os.listdir('./img_debias'):
        continue
    if len(os.listdir(os.sep.join(["img_debias", MRI_Image]))) != 5:
        continue
        
    ### T1 Post
    data_path_T1_post = os.sep.join(["img_debias", MRI_Image, "t1_post_reg_des_debis.nii.gz"])
    images_T1_post.append(data_path_T1_post)
    
    ### T1 Pre
    data_path_T1_pre = os.sep.join(["img_debias", MRI_Image, "t1_pre_reg_des_debis.nii.gz"])
    images_T1_pre.append(data_path_T1_pre)
    
    ### T2 
    data_path_T2 = os.sep.join(["img_debias", MRI_Image, "t2_reg_des_debis.nii.gz"])
    images_T2.append(data_path_T2)

    ### flair
    data_path_flair= os.sep.join(["img_debias", MRI_Image, "flair_reg_des_debis.nii.gz"])
    images_flair.append(data_path_flair)
    
    labels.append(int(MRI_Image_TO_EDSS[i][1] > 4.0))
    MRIs.append(MRI_Image)
print(len(images_T1_post), len(images_T1_pre),len(images_T2), len(images_flair), len(labels))
print(labels)
print(len(labels))
# data_path = os.sep.join(["../../", "workspace", "data", "medical", "ixi", "IXI-T1"])

## Augmentation the training set

In [None]:
##  shuffle first
randinds = np.arange(len(labels))
np.random.seed(42) 
np.random.shuffle(randinds)

MRIs = [MRIs[i] for i in randinds]
labels = [labels[i] for i in randinds]
images_T1_post = [images_T1_post[i] for i in randinds]
images_T1_pre = [images_T1_pre[i] for i in randinds]
images_T2 = [images_T2[i] for i in randinds]
images_flair = [images_flair[i] for i in randinds]

## split train /test
train_frac = 
Ntot = len(labels)

cut = int(train_frac * Ntot)

MRIs_train = MRIs[: cut]
MRIs_test = MRIs[cut:]  
labels_train = labels[: cut]
labels_test = labels[cut :]
images_T1_post_train = images_T1_post[: cut]
images_T1_post_test = images_T1_post[cut :]
images_T1_pre_train  = images_T1_pre[: cut]
images_T1_pre_test  = images_T1_pre[cut :]
images_T2_train  = images_T2[: cut]
images_T2_test  = images_T2[cut :]
images_flair_train = images_flair[: cut]
images_flair_test = images_flair[cut :]

## Augment train only
inds = [i for i, x in enumerate(labels_train) if x ==1]
MRIs_train.extend([MRIs_train[i] for i in inds]*10)
labels_train.extend([labels_train[i] for i in inds]*10)
images_T1_post_train.extend([images_T1_post_train[i] for i in inds]*10)
images_T1_pre_train.extend([images_T1_pre_train[i] for i in inds]*10)
images_T2_train.extend([images_T2_train[i] for i in inds]*10)
images_flair_train.extend([images_flair_train[i] for i in inds]*10)

##  shuffle train again
randinds = np.arange(len(labels_train))
np.random.seed(42) 
np.random.shuffle(randinds)
MRIs_train = [MRIs_train[i] for i in randinds]
labels_train = [labels_train[i] for i in randinds]
images_T1_post_train = [images_T1_post_train[i] for i in randinds]
images_T1_pre_train = [images_T1_pre_train[i] for i in randinds]
images_T2_train = [images_T2_train[i] for i in randinds]
images_flair_train = [images_flair_train[i] for i in randinds]

print('====>Train set size:')
print(len(images_T1_post_train), len(images_T1_pre_train),len(images_T2_train), 
      len(images_flair_train), len(labels_train))
print('=======> # positives in Train')
print(np.sum(labels_train))
print('====>Test set size:')
print(len(images_T1_post_test), len(images_T1_pre_test),len(images_T2_test), 
      len(images_flair_test), len(labels_test))
print('=======> # positives in Test')
print(np.sum(labels_test))


In [None]:
len(set(MRIs_train)), len(set(MRIs_test)), \
len(set(MRIs_train).intersection(MRIs_test))

## set label weight  in train

In [None]:
num_classes = len(set(labels_train))
class_count = [np.sum(np.array(labels_train) == i ) for i in range(num_classes)]
print(num_classes, class_count)
class_weights = 1./torch.tensor(class_count, dtype=torch.float) 
print(class_weights)

In [None]:
np.sum(labels_train), len(labels_train)

# Util

In [None]:
from sklearn.metrics import roc_curve, confusion_matrix , auc, precision_recall_curve, average_precision_score
from sklearn import metrics

def print_metrics(y_true, y_pred):
    import matplotlib.pyplot as plt
    
    false_positive_rate, recall, thresholds = roc_curve(y_true, y_pred)
    roc_auc = auc(false_positive_rate, recall)
    auprc = average_precision_score(y_true, y_pred)
    print('AUC: ',roc_auc, "AUPRC: ", auprc)


    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)

    # calculate the g-mean for each threshold
    gmeans = np.sqrt(tpr * (1-fpr))
    ix = np.argmax(gmeans)
    print('Best Threshold=%f, G-Mean=%.3f' % (thresholds[ix], gmeans[ix]))


    tn, fp, fn, tp = confusion_matrix(y_true, y_pred > thresholds[ix]).ravel()
    print('TN: ', tn, ", FP: ",fp, ", FN:", fn, ", TP:", tp)
    print("==> Sensitivity (Recall, TPR): %.3f"%(tp/(tp+fn)))
    print("==> Specifity: %.3f"%(tn/(tn+fp)))
    print("==> Positive Predictive Value (PPV) (Precision): %.3f"%(tp / (tp + fp)))
    print("==> Negative Predictive Value (NPV): %.3f"%(tn / (tn + fn)))
    print("==> Accuracy: %.3f"%((tp+tn)/(tn+ fp+ fn+tp)))
    print("==> F1 score: %.3f"%((2*tp)/(2*tp + fp + fn)))
    
    ns_probs = [0 for _ in range(len(y_true))]
    
    # calculate roc curves
    ns_fpr, ns_tpr, _ = roc_curve(y_true, ns_probs)
    lr_fpr, lr_tpr, _ = roc_curve(y_true, y_pred)
    # plot the roc curve for the model
    plt.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill, AUC = %0.2f' % 0.5)
    plt.plot(lr_fpr, lr_tpr, marker='.', label = 'Our model: AUC = %0.2f' % roc_auc)
    # axis labels
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    # show the legend
    plt.legend()
    # show the plot
    plt.show()
    
    lr_precision, lr_recall, _ = precision_recall_curve(y_true, y_pred)
    lr_auc = auc(lr_recall, lr_precision)
    # summarize scores
#     print('Logistic: f1=%.3f auc=%.3f' % (lr_f1, lr_auc))
    # plot the precision-recall curves
    y_true = np.array(y_true)
    no_skill = len(y_true[y_true==1]) / len(y_true)
    plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No skill, AUPRC = %0.2f' % no_skill)
    plt.plot(lr_recall, lr_precision, marker='.',label = 'Our model: AUPRC = %0.2f' % auprc)
    # axis labels
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    # show the legend
    plt.legend(loc = 'upper right')
    # show the plot
    plt.show()


# model

In [None]:
import torch.nn as nn
from typing import Callable, Sequence, Type, Union
import torch
from collections import OrderedDict
from monai.networks.layers.factories import Conv, Dropout, Pool
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils.module import look_up_option

class _Transition(nn.Sequential):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        act: Union[str, tuple] = ("relu", {"inplace": False}),
        norm: Union[str, tuple] = "batch",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            out_channels: number of the output classes.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        """
        super().__init__()

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        pool_type: Callable = Pool[Pool.AVG, spatial_dims]

        self.add_module("norm", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels))
        self.add_module("relu", get_act_layer(name=act))
        self.add_module("conv", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
        self.add_module("pool", pool_type(kernel_size=2, stride=2))

class _DenseLayer(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        growth_rate: int,
        bn_size: int,
        dropout_prob: float,
        act: Union[str, tuple] = ("relu", {"inplace": False}),
        norm: Union[str, tuple] = "batch",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            growth_rate: how many filters to add each layer (k in paper).
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        """
        super().__init__()

        out_channels = bn_size * growth_rate
        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

        self.layers = nn.Sequential()

        self.layers.add_module("norm1", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels))
        self.layers.add_module("relu1", get_act_layer(name=act))
        self.layers.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

        self.layers.add_module("norm2", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels))
        self.layers.add_module("relu2", get_act_layer(name=act))
        self.layers.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

        if dropout_prob > 0:
            self.layers.add_module("dropout", dropout_type(dropout_prob))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        new_features = self.layers(x)
        return torch.cat([x, new_features], 1)
    
class _DenseBlock(nn.Sequential):
    def __init__(
        self,
        spatial_dims: int,
        layers: int,
        in_channels: int,
        bn_size: int,
        growth_rate: int,
        dropout_prob: float,
        act: Union[str, tuple] = ("relu", {"inplace": False}),
        norm: Union[str, tuple] = "batch",
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            layers: number of layers in the block.
            in_channels: number of the input channel.
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            growth_rate: how many filters to add each layer (k in paper).
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        """
        super().__init__()
        for i in range(layers):
            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob, act=act, norm=norm)
            in_channels += growth_rate
            self.add_module("denselayer%d" % (i + 1), layer)
            
class DenseNet(nn.Module):
    """
    Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
            (i.e. bn_size * k features in the bottleneck layer)
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        act: Union[str, tuple] = ("relu", {"inplace": False}),
        norm: Union[str, tuple] = "batch",
        dropout_prob: float = 0.0,
    ) -> None:

        super().__init__()

        conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv[Conv.CONV, spatial_dims]
        pool_type: Type[Union[nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]] = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: Type[Union[nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]] = Pool[
            Pool.ADAPTIVEAVG, spatial_dims
        ]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)),
                    ("relu0", get_act_layer(name=act)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
                act=act,
                norm=norm,
            )
            self.features.add_module(f"denseblock{i + 1}", block)
            in_channels += num_layers * growth_rate
#             print('in_channels is ', in_channels)
            if i == len(block_config) - 1:
                self.features.add_module(
                    "norm5", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
                )
            else:
                _out_channels = in_channels // 2
                trans = _Transition(
                    spatial_dims, in_channels=in_channels, out_channels=_out_channels, act=act, norm=norm
                )
                self.features.add_module(f"transition{i + 1}", trans)
                in_channels = _out_channels

        # pooling and classification
        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", get_act_layer(name=act)),
                    ("pool", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)),
                    ("out", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        ### get them out
        self.class_layers_1 = get_act_layer(name=act)
        self.class_layers_2 = avg_pool_type(1)
        self.class_layers_3 = nn.Flatten(1)
        MODAL = 1
        self.class_layers_4 = nn.Linear(in_channels * MODAL, out_channels)
        
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, conv_type):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight))
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)

        
    def forward(self, x_T1_pre: torch.Tensor
                   ) -> torch.Tensor:   # [2, 1, 96, 96, 96]
        ####  ==================== T1_pre ==================================
        x_T1_pre = self.features(x_T1_pre)    # [2, 1024, 3, 3, 3]
        x_T1_pre = self.class_layers_1(x_T1_pre)  # output: [2, 1024, 3, 3, 3]
        x_T1_pre = self.class_layers_2(x_T1_pre) # output: [2, 1024, 1, 1, 1]
        x_T1_pre = self.class_layers_3(x_T1_pre) # output: [2, 1024]
        x_join = x_T1_pre
        linear_y = self.class_layers_4(x_join) # output: [2, 2]
        out = self.sigmoid(linear_y)
        return x_join, linear_y , out

    
class DenseNet121(DenseNet):
    """DenseNet121 with optional pretrained support when `spatial_dims` is 2."""

    def __init__(
        self,
        init_features: int = 32, # 64,
        growth_rate: int = 16, # 32,
        block_config: Sequence[int] = (3,6,12,8),#(6, 12, 24, 16),
        pretrained: bool = False,
        progress: bool = True,
        **kwargs,
    ) -> None:
        super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
        if pretrained:
            if kwargs["spatial_dims"] > 2:
                raise NotImplementedError(
                    "Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
                    "provide pretrained models for more than two spatial dimensions."
                )
            _load_state_dict(self, "densenet121", progress)

# Train

In [None]:
import logging
import os
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate, Resize, ScaleIntensity, EnsureType



# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((256,256,44)), RandRotate(range_x=0.02, range_y=0.02, range_z=0.02, prob=0.5), EnsureType()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((256,256,44)), EnsureType()])

#### ======================================= T1 pre ===============================================
# Define image dataset, data loader
check_ds_T1_pre = ImageDataset(image_files=images_T1_pre_train, labels=labels_train, transform=train_transforms)
check_loader_T1_pre = DataLoader(check_ds_T1_pre, batch_size=10, num_workers=2)# , pin_memory=torch.cuda.is_available())
# create a training data loader
train_ds_T1_pre = ImageDataset(image_files=images_T1_pre_train, labels=labels_train, transform=train_transforms)
train_loader_T1_pre = DataLoader(train_ds_T1_pre, batch_size=10, shuffle=False, num_workers=2)#, pin_memory=torch.cuda.is_available())
#create a validation data loader
val_ds_T1_pre = ImageDataset(image_files=images_T1_pre_test, labels=labels_test, transform=val_transforms)
val_loader_T1_pre = DataLoader(val_ds_T1_pre, batch_size=10, num_workers=2)#, pin_memory=torch.cuda.is_available())


# Create DenseNet121, CrossEntropyLoss and Adam optimizer
device = torch.device(14 if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.BCEWithLogitsLoss(weight=class_weights.to(device))
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

triplet_loss = \
    nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance(), margin=1.5)



# start a typical PyTorch training
EPOCH = 400
val_interval = 2
best_metric = -1
epoch_loss_values = list()
metric_values = list()
best_AUC=0

for epoch in range(EPOCH):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{EPOCH}")
    model.train()
    epoch_loss = 0
    correct = 0
    step = 0
    total = 0
    
    iter_T1_pre = iter(train_loader_T1_pre)
           
    while step < len(train_loader_T1_pre):
        step += 1
        inputs_T1_pre, labels = next(iter_T1_pre)
        inputs_T1_pre = inputs_T1_pre.to(device)
        one_hot_label = np.eye(2)[np.array(labels,dtype="int")]
        one_hot_label = torch.tensor(one_hot_label)
        optimizer.zero_grad()
        joint_embedding, out, predict = model(inputs_T1_pre)
        one_hot_label = one_hot_label.type_as(out).to(device)
        loss_1 = loss_function(out, one_hot_label)

        anchor = 0 
        positive = [i for i, x in enumerate(labels) if x == 1]
        negative = [i for i, x in enumerate(labels) if x == 0]
        if not positive or not negative:
            continue
        loss = loss_1
        for i in positive:
            for j in negative:
                loss += triplet_loss(0*joint_embedding[0,:], joint_embedding[i,:], joint_embedding[j,:])
        loss.backward()
        _, predicted = torch.max(predict.detach().cpu(), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds_T1_pre) // train_loader_T1_pre.batch_size
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()

        with torch.no_grad():
            num_correct = 0.0
            metric_count = 0
            pred_list = []
            true_list = []
            predicted_list=[]
            prediction_probablity=[]
            label_list = []
            val_total=0
            val_correct = 0
            
            iter_T1_pre = iter(val_loader_T1_pre)
            step = 0
            while step < len(val_loader_T1_pre):
                step += 1
                val_images_T1_pre, val_labels = next(iter_T1_pre)
                val_images_T1_pre = val_images_T1_pre.to(device)
                _, out,predict = model(val_images_T1_pre)
                _, predicted = torch.max(predict.detach().cpu(), 1)
                predicted_list.append(predicted.cpu().numpy())
                predicted_2 = predict.detach().cpu().numpy()
                prediction_prob = predicted_2[:,1].tolist()
                prediction_probablity.extend(prediction_prob)
                label_list.extend(val_labels.cpu().numpy().tolist())
                val_total += val_labels.size(0)
                val_correct += (predicted == val_labels).sum().item()
            
            
            Accuracy = val_correct/val_total
            y=np.array(label_list)
            false_positive_rate, recall, thresholds = roc_curve(y.flatten(), np.array(prediction_probablity).flatten())
            roc_auc = auc(false_positive_rate, recall)
            auprc = average_precision_score(y.flatten(), np.array(prediction_probablity).flatten())
            
            if roc_auc > best_AUC:
                best_AUC = roc_auc
                torch.save(model.state_dict(), "Models_saved-copy3-pre-Medcam/State_checkpoints_{}.thr".format(epoch))
                torch.save(model, "Models_saved-copy3-pre-Medcam/Model_checkpoints_{}.thr".format(epoch))
            print("AUC: {} , Accuracy: {}, AUPRC: {}".format(roc_auc, Accuracy, auprc))

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

In [None]:
best_AUC

## Gradcam

In [None]:
import logging
import os
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate, Resize, ScaleIntensity, EnsureType


# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((256,256,44)), RandRotate(range_x=0.02, range_y=0.02, range_z=0.02, prob=0.5), EnsureType()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((256,256,44)), EnsureType()])

#### ======================================= T1 pre ===============================================
# Define image dataset, data loader
check_ds_T1_pre = ImageDataset(image_files=images_T1_pre_train, labels=labels_train, transform=train_transforms)
check_loader_T1_pre = DataLoader(check_ds_T1_pre, batch_size=1, num_workers=2)# , pin_memory=torch.cuda.is_available())
# create a training data loader
train_ds_T1_pre = ImageDataset(image_files=images_T1_pre_train, labels=labels_train, transform=train_transforms)
train_loader_T1_pre = DataLoader(train_ds_T1_pre, batch_size=1, shuffle=False, num_workers=2)#, pin_memory=torch.cuda.is_available())
#create a validation data loader
val_ds_T1_pre = ImageDataset(image_files=images_T1_pre_test, labels=labels_test, transform=val_transforms)
val_loader_T1_pre = DataLoader(val_ds_T1_pre, batch_size=1, num_workers=2)#, pin_memory=torch.cuda.is_available())


# Create DenseNet121, CrossEntropyLoss and Adam optimizer
device = torch.device(14 if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=2)# .to(device)
model.eval()
iter_T1_pre = iter(train_loader_T1_pre)
step = 0
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

while step < len(train_loader_T1_pre):
    step += 1
    inputs_T1_pre, labels = next(iter_T1_pre)
    inputs_T1_pre = inputs_T1_pre# .to(device)
    one_hot_label = np.eye(2)[np.array(labels,dtype="int")]
    one_hot_label = torch.tensor(one_hot_label)
    optimizer.zero_grad()
    joint_embedding, out, predict = model(inputs_T1_pre)
    print(inputs_T1_pre.shape)
    
    gcv2 = GradCam(model, target_layer='norm5')
    
    # Generate cam mask
    target_class = labels[0]
    
    cam = gcv2.generate_cam(inputs_T1_pre, target_class)  # cam: (224, 224)
    print('Grad cam completed')

    input()


In [None]:
# for module_pos, module in model.features._modules.items():
#     print(module_pos, module)
#     if module_pos == 'norm5':
#         print('fefe')

# utils

In [None]:
from torch.nn import ReLU

class GuidedBackprop():
    """
       Produces gradients generated with guided back propagation from the given image
    """
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.forward_relu_outputs = []
        # Put model in evaluation mode
        self.model.eval()
        self.update_relus()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
        # Register hook to the first layer
        first_layer = list(self.model.features._modules.items())[0][1]
        first_layer.register_backward_hook(hook_function)

    def update_relus(self):
        """
            Updates relu activation functions so that
                1- stores output in forward pass
                2- imputes zero for gradient values that are less than zero
        """
        def relu_backward_hook_function(module, grad_in, grad_out):
            """
            If there is a negative gradient, change it to zero
            """
            # Get last forward output
            corresponding_forward_output = self.forward_relu_outputs[-1]
            corresponding_forward_output[corresponding_forward_output > 0] = 1
            modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
            del self.forward_relu_outputs[-1]  # Remove last forward output
            return (modified_grad_out,)

        def relu_forward_hook_function(module, ten_in, ten_out):
            """
            Store results of forward pass
            """
            self.forward_relu_outputs.append(ten_out)

        # Loop through layers, hook up ReLUs
        for pos, module in self.model.features._modules.items():
            if isinstance(module, ReLU):
                module.register_backward_hook(relu_backward_hook_function)
                module.register_forward_hook(relu_forward_hook_function)

    def generate_gradients(self, input_image, target_class):
        # Forward pass
        model_output = self.model(input_image)
        # Zero gradients
        self.model.zero_grad()
        # Target for backprop
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
        one_hot_output[0][target_class] = 1
        # Backward pass
        model_output.backward(gradient=one_hot_output)
        # Convert Pytorch variable to numpy array
        # [0] to get rid of the first channel (1,3,224,224)
        gradients_as_arr = self.gradients.data.numpy()[0]
        return gradients_as_arr


def guided_grad_cam(grad_cam_mask, guided_backprop_mask):
    """
        Guided grad cam is just pointwise multiplication of cam mask and
        guided backprop mask
    Args:
        grad_cam_mask (np_arr): Class activation map mask
        guided_backprop_mask (np_arr):Guided backprop mask
    """
    cam_gb = np.multiply(grad_cam_mask, guided_backprop_mask)
    return cam_gb

class CamExtractor():
    """
        Extracts cam features from the model
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None

    def save_gradient(self, grad):
        self.gradients = grad

    def forward_pass_on_convolutions(self, x):  # x: torch.Size([1, 3, 224, 224])
        """
            Does a forward pass on convolutions, hooks the function at given layer
        """
        print('x shape is :', x.shape)
        conv_output = None
        for module_pos, module in self.model.features._modules.items():
            print('module_pos, module', module_pos, module)
            x = module(x)  # Forward
#             if int(module_pos) == self.target_layer:
            if module_pos == self.target_layer:
                print('got our interested layer:', module_pos, module)
                x.register_hook(self.save_gradient)
                conv_output = x  # Save the convolution output on that layer
        print('conv_output, x shape: ', conv_output.shape, x.shape)
        return conv_output, x   #  torch.Size([1, 256, 13, 13]) torch.Size([1, 256, 6, 6])

    def forward_pass(self, x):
        """
            Does a full forward pass on the model
        """
        # Forward pass on the convolutions
        conv_output, x = self.forward_pass_on_convolutions(x)
        print('conv_output, x  shape: ', conv_output.shape, x.shape)
#         x = x.view(x.size(0), -1)  # Flatten  #  after: torch.Size([1, 9216]), 9126=256x6x6
        print('x  shape: ',  x.shape)
        # Forward pass on the classifier
        #### theirs .....
#         x = self.model.classifier(x)   # after: torch.Size([1, 1000])  
        #### mine .....
        for module_pos, module in self.model.class_layers._modules.items():
            x = module(x)
        print('final output x: ', x.shape)
        return conv_output, x


class GradCam():
    """
        Produces class activation map
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.model.eval()
        # Define extractor
        self.extractor = CamExtractor(self.model, target_layer)

    def generate_cam(self, input_image, target_class=None):
        # Full forward pass
        # conv_output is the output of convolutions at specified layer
        # model_output is the final output of the model (1, 1000)
        conv_output, model_output = self.extractor.forward_pass(input_image)  # torch.Size([1, 256, 13, 13]) torch.Size([1, 1000])
        if target_class is None:
            target_class = np.argmax(model_output.data.numpy())
        # Target for backprop
        print('conv_output, model_output shape: ', conv_output.shape, model_output.shape) 
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()  #  torch.Size([1, 1000])
        print('one_hot_output shape: ', one_hot_output.shape)
        one_hot_output[0][target_class] = 1
        print('target class is:', target_class)  # target_class = 56
        # Zero grads
        self.model.features.zero_grad()
#         self.model.classifier.zero_grad()
        self.model.class_layers.zero_grad()

        # Backward pass with specified target
        model_output.backward(gradient=one_hot_output, retain_graph=True)
        # Get hooked gradients
        guided_gradients = self.extractor.gradients.data.numpy()[0]
        # Get convolution outputs
        print('conv_output, guided_gradients are ',conv_output.shape, guided_gradients.shape)
        target = conv_output.data.numpy()[0]
        # Get weights from gradients
        weights = np.mean(guided_gradients, axis=(1, 2))  # Take averages for each gradient
        # Create empty numpy array for cam
#         cam = np.ones(target.shape[1:], dtype=np.float32)
        cam = np.ones(target.shape[1:-1], dtype=np.float32)
        # Multiply each weight with its conv output and then, sum
        for i, w in enumerate(weights):
            cam += w * target[i, :, :]
        cam = np.maximum(cam, 0)
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize between 0-1
        cam = np.uint8(cam * 255)  # Scale between 0-255 to visualize
        print('cam is:', cam.shape, cam.dtype)
        cam = np.uint8(Image.fromarray(cam).resize((input_image.shape[2],
                       input_image.shape[3]), Image.ANTIALIAS))
        # ^ I am extremely unhappy with this line. Originally resizing was done in cv2 which
        # supports resizing numpy matrices, however, when I moved the repository to PIL, this
        # option is out of the window. So, in order to use resizing with ANTIALIAS feature of PIL,
        # I briefly convert matrix to PIL image and then back.
        # If there is a more beautiful way, send a PR.
        return cam


import os
import copy
import numpy as np
from PIL import Image
import matplotlib.cm as mpl_color_map

import torch
from torchvision import models


def convert_to_grayscale(im_as_arr):
    """
        Converts 3d image to grayscale
    Args:
        im_as_arr (numpy arr): RGB image with shape (D,W,H)
    returns:
        grayscale_im (numpy_arr): Grayscale image with shape (1,W,D)
    """
    grayscale_im = np.sum(np.abs(im_as_arr), axis=0)
    im_max = np.percentile(grayscale_im, 99)
    im_min = np.min(grayscale_im)
    grayscale_im = (np.clip((grayscale_im - im_min) / (im_max - im_min), 0, 1))
    grayscale_im = np.expand_dims(grayscale_im, axis=0)
    return grayscale_im


def save_gradient_images(gradient, file_name):
    """
        Exports the original gradient image
    Args:
        gradient (np arr): Numpy array of the gradient with shape (3, 224, 224)
        file_name (str): File name to be exported
    """
    if not os.path.exists('results'):
        os.makedirs('results')
    # Normalize
    gradient = gradient - gradient.min()
    gradient /= gradient.max()
    # Save image
    path_to_file = os.path.join('results', file_name + '.jpg')
    save_image(gradient, path_to_file)


def save_class_activation_images(org_img, activation_map, file_name):
    """
        Saves cam activation map and activation map on the original image
    Args:
        org_img (PIL img): Original image
        activation_map (numpy arr): Activation map (grayscale) 0-255
        file_name (str): File name of the exported image
    """
    if not os.path.exists('../results'):
        os.makedirs('../results')
    # Grayscale activation map
    heatmap, heatmap_on_image = apply_colormap_on_image(org_img, activation_map, 'hsv')
    # Save colored heatmap
    path_to_file = os.path.join('results', file_name+'_Cam_Heatmap.png')
    save_image(heatmap, path_to_file)
    # Save heatmap on iamge
    path_to_file = os.path.join('results', file_name+'_Cam_On_Image.png')
    save_image(heatmap_on_image, path_to_file)
    # SAve grayscale heatmap
    path_to_file = os.path.join('results', file_name+'_Cam_Grayscale.png')
    save_image(activation_map, path_to_file)


def apply_colormap_on_image(org_im, activation, colormap_name):
    """
        Apply heatmap on image
    Args:
        org_img (PIL img): Original image
        activation_map (numpy arr): Activation map (grayscale) 0-255
        colormap_name (str): Name of the colormap
    """
    # Get colormap
    color_map = mpl_color_map.get_cmap(colormap_name)
    no_trans_heatmap = color_map(activation)
    # Change alpha channel in colormap to make sure original image is displayed
    heatmap = copy.copy(no_trans_heatmap)
    heatmap[:, :, 3] = 0.4
    heatmap = Image.fromarray((heatmap*255).astype(np.uint8))
    no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8))

    # Apply heatmap on iamge
    heatmap_on_image = Image.new("RGBA", org_im.size)
    heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA'))
    heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap)
    return no_trans_heatmap, heatmap_on_image


def save_image(im, path):
    """
        Saves a numpy matrix of shape D(1 or 3) x W x H as an image
    Args:
        im_as_arr (Numpy array): Matrix of shape DxWxH
        path (str): Path to the image
    """
    if isinstance(im, np.ndarray):
        if len(im.shape) == 2:
            im = np.expand_dims(im, axis=0)
        if im.shape[0] == 1:
            # Converting an image with depth = 1 to depth = 3, repeating the same values
            # For some reason PIL complains when I want to save channel image as jpg without
            # additional format in the .save()
            im = np.repeat(im, 3, axis=0)
            # Convert to values to range 1-255 and W,H, D
        if im.shape[0] == 3:
            im = im.transpose(1, 2, 0) * 255
        im = Image.fromarray(im.astype(np.uint8))
    im.save(path)


def preprocess_image(pil_im, resize_im=True):
    """
        Processes image for CNNs
    Args:
        PIL_img (PIL_img): Image to process
        resize_im (bool): Resize to 224 or not
    returns:
        im_as_var (torch variable): Variable that contains processed float tensor
    """
    # mean and std list for channels (Imagenet)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    # Resize image
    if resize_im:
        pil_im.thumbnail((512, 512))
    im_as_arr = np.float32(pil_im)
    im_as_arr = im_as_arr.transpose(2, 0, 1)  # Convert array to D,W,H
    # Normalize the channels
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
    # Convert to float tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    # Add one more channel to the beginning. Tensor shape = 1,3,224,224
    im_as_ten.unsqueeze_(0)
    # Convert to Pytorch variable
    im_as_ten.requires_grad = True
    im_as_var = im_as_ten
    return im_as_var


def recreate_image(im_as_var):
    """
        Recreates images from a torch variable, sort of reverse preprocessing
    Args:
        im_as_var (torch variable): Image to recreate
    returns:
        recreated_im (numpy arr): Recreated image in array
    """
    reverse_mean = [-0.485, -0.456, -0.406]
    reverse_std = [1/0.229, 1/0.224, 1/0.225]
    recreated_im = copy.copy(im_as_var.data.numpy()[0])
    for c in range(3):
        recreated_im[c] /= reverse_std[c]
        recreated_im[c] -= reverse_mean[c]
    recreated_im[recreated_im > 1] = 1
    recreated_im[recreated_im < 0] = 0
    recreated_im = np.round(recreated_im * 255)

    recreated_im = np.uint8(recreated_im).transpose(1, 2, 0)
    return recreated_im


def get_positive_negative_saliency(gradient):
    """
        Generates positive and negative saliency maps based on the gradient
    Args:
        gradient (numpy arr): Gradient of the operation to visualize
    returns:
        pos_saliency ( )
    """
    pos_saliency = (np.maximum(0, gradient) / gradient.max())
    neg_saliency = (np.maximum(0, -gradient) / -gradient.min())
    return pos_saliency, neg_saliency


def get_example_params(example_index):
    """
        Gets used variables for almost all visualizations, like the image, model etc.
    Args:
        example_index (int): Image id to use from examples
    returns:
        original_image (numpy arr): Original image read from the file
        prep_img (numpy_arr): Processed image
        target_class (int): Target class for the image
        file_name_to_export (string): File name to export the visualizations
        pretrained_model(Pytorch model): Model to use for the operations
    """
    # Pick one of the examples
    example_list = (('input_images/snake.jpg', 56),
                    ('input_images/cat_dog.png', 243),
                    ('input_images/spider.png', 72))
    img_path = example_list[example_index][0]
    target_class = example_list[example_index][1]
    file_name_to_export = img_path[img_path.rfind('/')+1:img_path.rfind('.')]
    # Read image
    original_image = Image.open(img_path).convert('RGB')
    # Process image
    prep_img = preprocess_image(original_image)

    print("program flow")

    # Define model
    pretrained_model = models.alexnet(pretrained=True)
    return (original_image,
            prep_img,
            target_class,
            file_name_to_export,
            pretrained_model)


def normalize_gradient_image(gradient):
    gradient = gradient - gradient.min()
    gradient /= gradient.max()
    
    return gradient