In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, Compose

import shutil
import time
from datetime import datetime
import argparse

from tqdm import tqdm
from sklearn import metrics
from dotmap import DotMap

from tensorboardX import SummaryWriter

from PIL import Image
from torchvision import models, transforms

from torch.nn import functional as F
import cv2
import pandas as pd
import os
import ctypes

import shap

import numpy as np

from matplotlib.colors import LinearSegmentedColormap
from sklearn.linear_model import LogisticRegression
import gc

# Train

In [None]:
def train_model(model, train_loader, epoch, num_epochs, optimizer, writer, current_lr, log_every=100):
    _ = model.train()

    if torch.cuda.is_available():
        model.cuda()

    y_preds = []
    y_trues = []
    losses = []

    for i, (image, label, weight) in enumerate(train_loader):
        optimizer.zero_grad()

        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()
            weight = weight.cuda()

        label = label[0]
        weight = weight[0]

        prediction = model.forward(image.float())

        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)
        loss.backward()
        optimizer.step()

        loss_value = loss.item()
        losses.append(loss_value)

        probas = torch.sigmoid(prediction)

        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())

        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5

        writer.add_scalar('Train/Loss', loss_value,
                          epoch * len(train_loader) + i)
        writer.add_scalar('Train/AUC', auc, epoch * len(train_loader) + i)

        if (i % log_every == 0) & (i > 0):
            print(
                '''[Epoch: {0} / {1} |Single batch number : {2} / {3} ]| avg train loss {4} | train auc : {5} | lr : {6}'''.
                format(
                    epoch + 1,
                    num_epochs,
                    i,
                    len(train_loader),
                    np.round(np.mean(losses), 4),
                    np.round(auc, 4),
                    current_lr
                )
            )

    writer.add_scalar('Train/AUC_epoch', auc, epoch + i)

    train_loss_epoch = np.round(np.mean(losses), 4)
    train_auc_epoch = np.round(auc, 4)
    return train_loss_epoch, train_auc_epoch


class MRNet(nn.Module):
    def __init__(self, backbone="alexnet"):
        super().__init__()
        # self.pretrained_model = models.alexnet()
        self.pretrained_model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        self.classifer = nn.Linear(256, 2)

    def forward(self, x):
        if len(x.shape) == 6 or len(x.shape) == 5:
            # Batch processing
            if len(x.shape) == 6:
                batch_size, _, num_slices, c, h, w = x.size()
            else:
                batch_size, num_slices, c, h, w = x.size()
            x = x.view(-1, c, h, w)
            features = self.pretrained_model.features(x)
            pooled_features = self.pooling_layer(features)
            pooled_features = pooled_features.view(batch_size, num_slices, -1)
            flattened_features, _ = torch.max(pooled_features, dim=1)
        else:
            # Original single-sample behavior
            x = torch.squeeze(x, dim=0)
            features = self.pretrained_model.features(x)
            pooled_features = self.pooling_layer(features)
            pooled_features = pooled_features.view(pooled_features.size(0), -1)
            flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]

        output = self.classifer(flattened_features)  # Classifier
        return output


class MRDataset(data.Dataset):
    def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None):
        super().__init__()
        self.task = task
        self.plane = plane
        self.root_dir = root_dir
        self.train = train
        if self.train:
            self.folder_path = self.root_dir + 'train/{0}/'.format(plane)
            self.records = pd.read_csv(
                self.root_dir + 'train-{0}.csv'.format(task), header=None, names=['id', 'label'])
        else:
            transform = None
            self.folder_path = self.root_dir + 'valid/{0}/'.format(plane)
            self.records = pd.read_csv(
                self.root_dir + 'valid-{0}.csv'.format(task), header=None, names=['id', 'label'])

        self.records['id'] = self.records['id'].map(
            lambda i: '0' * (4 - len(str(i))) + str(i))
        self.paths = [self.folder_path + filename +
                      '.npy' for filename in self.records['id'].tolist()]
        self.labels = self.records['label'].tolist()

        self.transform = transform
        if weights is None:
            pos = np.sum(self.labels)
            neg = len(self.labels) - pos
            self.weights = torch.FloatTensor([1, neg / pos])
        else:
            self.weights = torch.FloatTensor(weights)
        # print(self.weights)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        array = np.load(self.paths[index])
        label = self.labels[index]
        if label == 1:
            label = torch.FloatTensor([[0, 1]])
        elif label == 0:
            label = torch.FloatTensor([[1, 0]])

        if self.transform:
            array = self.transform(array)
        else:
            array = np.stack((array,) * 3, axis=1)
            array = torch.FloatTensor(array)

        # if label.item() == 1:
        #     weight = np.array([self.weights[1]])
        #     weight = torch.FloatTensor(weight)
        # else:
        #     weight = np.array([self.weights[0]])
        #     weight = torch.FloatTensor(weight)

        return array, label, self.weights


def evaluate_model(model, val_loader, epoch, num_epochs, writer, current_lr, log_every=20):
    _ = model.eval()

    if torch.cuda.is_available():
        model.cuda()

    y_trues = []
    y_preds = []
    losses = []

    for i, (image, label, weight) in enumerate(val_loader):

        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()
            weight = weight.cuda()

        label = label[0]
        weight = weight[0]

        prediction = model.forward(image.float())

        loss = torch.nn.BCEWithLogitsLoss(weight=weight)(prediction, label)

        loss_value = loss.item()
        losses.append(loss_value)

        probas = torch.sigmoid(prediction)

        y_trues.append(int(label[0][1]))
        y_preds.append(probas[0][1].item())

        try:
            auc = metrics.roc_auc_score(y_trues, y_preds)
        except:
            auc = 0.5

        writer.add_scalar('Val/Loss', loss_value, epoch * len(val_loader) + i)
        writer.add_scalar('Val/AUC', auc, epoch * len(val_loader) + i)

        if (i % log_every == 0) & (i > 0):
            print(
                '''[Epoch: {0} / {1} |Single batch number : {2} / {3} ] | avg val loss {4} | val auc : {5} | lr : {6}'''.
                format(
                    epoch + 1,
                    num_epochs,
                    i,
                    len(val_loader),
                    np.round(np.mean(losses), 4),
                    np.round(auc, 4),
                    current_lr
                )
            )

    writer.add_scalar('Val/AUC_epoch', auc, epoch + i)

    val_loss_epoch = np.round(np.mean(losses), 4)
    val_auc_epoch = np.round(auc, 4)
    return val_loss_epoch, val_auc_epoch

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def run(args):
    log_root_folder = "./logs/{0}/{1}/".format(args.task, args.plane)
    if args.flush_history == 1:
        objects = os.listdir(log_root_folder)
        for f in objects:
            if os.path.isdir(log_root_folder + f):
                shutil.rmtree(log_root_folder + f)

    now = datetime.now()
    logdir = log_root_folder + now.strftime("%Y%m%d-%H%M%S") + "/"
    os.makedirs(logdir)

    writer = SummaryWriter(logdir)

    augmentor = Compose([
        transforms.Lambda(lambda x: torch.Tensor(x)),
        RandomRotate(25),
        RandomTranslate([0.11, 0.11]),
        RandomFlip(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
    ])

    train_dataset = MRDataset('./data/', args.task,
                              args.plane, transform=augmentor, train=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True, drop_last=False)

    validation_dataset = MRDataset(
        './data/', args.task, args.plane, train=False)
    validation_loader = torch.utils.data.DataLoader(
        validation_dataset, batch_size=1, shuffle=-True, drop_last=False)

    mrnet = MRNet()
    # print("load dict start")
    # mrnet = torch.load("models/model_prefix_acl_axial_val_auc_0.8173_train_auc_0.8291_epoch_11.pth")
    # print("load dict done")

    if torch.cuda.is_available():
        mrnet = mrnet.cuda()

    optimizer = optim.Adam(mrnet.parameters(), lr=args.lr, weight_decay=0.1)

    if args.lr_scheduler == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=3, factor=.3, threshold=1e-4, verbose=True)
    elif args.lr_scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=3, gamma=args.gamma)

    best_val_loss = float('inf')
    best_val_auc = float(0)

    num_epochs = args.epochs
    iteration_change_loss = 0
    patience = args.patience
    log_every = args.log_every

    t_start_training = time.time()

    for epoch in range(num_epochs):
        current_lr = get_lr(optimizer)

        t_start = time.time()

        train_loss, train_auc = train_model(
            mrnet, train_loader, epoch, num_epochs, optimizer, writer, current_lr, log_every)
        val_loss, val_auc = evaluate_model(
            mrnet, validation_loader, epoch, num_epochs, writer, current_lr)

        if args.lr_scheduler == 'plateau':
            scheduler.step(val_loss)
        elif args.lr_scheduler == 'step':
            scheduler.step()

        t_end = time.time()
        delta = t_end - t_start

        print("train loss : {0} | train auc {1} | val loss {2} | val auc {3} | elapsed time {4} s".format(
            train_loss, train_auc, val_loss, val_auc, delta))

        iteration_change_loss += 1
        print('-' * 30)

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            if bool(args.save_model):
                file_name = f'model_{args.prefix_name}_{args.task}_{args.plane}_val_auc_{val_auc:0.4f}_train_auc_{train_auc:0.4f}_epoch_{epoch + 1}.pth'

                os.makedirs('./models/', exist_ok=True)
                for f in os.listdir('./models/'):
                    if (args.task in f) and (args.plane in f) and (args.prefix_name in f):
                        os.remove(f'./models/{f}')

                torch.save(mrnet, f'./models/{file_name}')
                print(f"Saving new best epoch with {val_auc=}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            iteration_change_loss = 0

        if iteration_change_loss == patience:
            print('Early stopping after {0} iterations without the decrease of the val loss'.
                  format(iteration_change_loss))
            break

    t_end_training = time.time()
    print(f'training took {t_end_training - t_start_training} s')


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--task', type=str, required=True,
                        choices=['abnormal', 'acl', 'meniscus'])

    parser.add_argument('-p', '--plane', type=str, required=True,
                        choices=['sagittal', 'coronal', 'axial'])

    parser.add_argument('--prefix_name', type=str, required=True)
    parser.add_argument('--augment', type=int, choices=[0, 1], default=1)
    parser.add_argument('--lr_scheduler', type=str,
                        default='plateau', choices=['plateau', 'step'])

    parser.add_argument('--gamma', type=float, default=0.5)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--flush_history', type=int, choices=[0, 1], default=0)
    parser.add_argument('--save_model', type=int, choices=[0, 1], default=1)
    parser.add_argument('--patience', type=int, default=5)
    parser.add_argument('--log_every', type=int, default=100)

    args = parser.parse_args()
    return args


def get_args(**kwargs):
    args = {}

    args["task"] = "abnormal"
    args["plane"] = "sagittal"
    args["prefix_name"] = "prefix"
    args["augment"] = 1
    args["lr_scheduler"] = "plateau"
    args["gamma"] = 0.5
    args["epochs"] = 50
    args["lr"] = 1e-5
    args["flush_history"] = 0
    args["save_model"] = 1
    args["patience"] = 5
    args["log_every"] = 100

    for key, val in kwargs.items():
        print(f"override {key} from {args[key]} -> {val}")
        args[key] = val

    return args

## Run

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
# task 'abnormal', 'acl', 'meniscus'

# plane 'sagittal', 'coronal', 'axial'

# args = get_args(epochs=30, task="acl", plane = 'sagittal', , patience=7)
# args = get_args(epochs=30, task="acl", plane = 'coronal', , patience=7)
args = get_args(epochs=35, task="acl", plane='axial', patience=7)
mapped = DotMap(args)
run(mapped)

# Global

In [None]:
def extract_predictions(task, plane, train=True):
    assert task in ['acl', 'meniscus', 'abnormal']
    assert plane in ['axial', 'coronal', 'sagittal']

    models = os.listdir('models/')

    model_name = list(filter(lambda name: task in name and plane in name, models))[0]
    model_path = f'models/{model_name}'

    mrnet = torch.load(model_path)
    _ = mrnet.eval()

    train_dataset = MRDataset('data/',
                              task,
                              plane,
                              transform=None,
                              train=train)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               drop_last=False)
    predictions = []
    labels = []
    with torch.no_grad():
        # for image, label, _ in tqdm_notebook(train_loader):
        for image, label, _ in tqdm(train_loader):
            logit = mrnet(image.cuda())
            print(logit)
            prediction = torch.sigmoid(logit)
            print(prediction)
            predictions.append(prediction[0][1].item())
            print(prediction[0][1].item())
            print("-----------------")
            labels.append(label[0][0][1].item())

    return predictions, labels

In [None]:
predictions, labels = extract_predictions("acl", "axial", train=False)

In [None]:
task = 'acl'
results = {}

for plane in ['axial', 'coronal', 'sagittal']:
    predictions, labels = extract_predictions(task, plane)
    results['labels'] = labels
    results[plane] = predictions

X = np.zeros((len(predictions), 3))
X[:, 0] = results['axial']
X[:, 1] = results['coronal']
X[:, 2] = results['sagittal']

y = np.array(labels)

logreg = LogisticRegression(solver='lbfgs')
logreg.fit(X, y)

In [None]:
task = 'acl'
results_val = {}

for plane in ['axial', 'coronal', 'sagittal']:
    predictions, labels = extract_predictions(task, plane, train=False)
    results_val['labels'] = labels
    results_val[plane] = predictions

In [None]:
X_val = np.zeros((len(results_val['axial']), 3))
X_val[:, 0] = results_val['axial']
X_val[:, 1] = results_val['coronal']
X_val[:, 2] = results_val['sagittal']

y_val = np.array(results_val['labels'])

In [None]:
y_pred = logreg.predict_proba(X_val)[:, 1]
metrics.roc_auc_score(y_val, y_pred)

# XAI mrnet

In [None]:
task = 'acl'
plane = 'sagittal'  # 'axial', 'coronal', 'sagittal'
prefix = ''

model_name = [name for name in os.listdir('models/')
              if (task in name) and
              (plane in name) and
              (prefix in name)][0]
print(f"loading model {model_name}")

is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")

mrnet = torch.load(f'models/{model_name}')
mrnet = mrnet.to(device)

_ = mrnet.eval()

dataset = MRDataset(
    'data/',
    task,
    plane,
    transform=None,
    train=False
)
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    drop_last=False
)


def returnCAM(feature_conv, weight_softmax, class_idx):
    size_upsample = (256, 256)
    bz, nc, h, w = feature_conv.shape
    slice_cams = []
    for s in range(bz):
        for idx in class_idx:
            cam = weight_softmax[idx].dot(feature_conv[s].reshape((nc, h * w)))
            cam = cam.reshape(h, w)
            cam = cam - np.min(cam)
            cam_img = cam / np.max(cam)
            cam_img = np.uint8(255 * cam_img)
            slice_cams.append(cv2.resize(cam_img, size_upsample))
    return slice_cams

In [None]:
def find_last_conv_layer(model):
    last_conv_layer = None
    last_conv_layer_name = None

    # Traverse through all modules
    for name, module in model.named_modules():
        print(f"{name=}, {module}")
        if isinstance(module, nn.Conv2d):
            last_conv_layer = module
            last_conv_layer_name = name

    return last_conv_layer, last_conv_layer_name


res = find_last_conv_layer(mrnet)
res

In [None]:
# finalconv_name = "pretrained_model"
# 
def hook_feature(module, input, output):
    feature_blobs.append(output.data.cpu().numpy())

#     
# mrnet._modules.get(finalconv_name).register_forward_hook(hook_feature);

In [None]:
feature_blobs = []
mrnet._modules.get("pretrained_model")._modules.get("features")._modules.get("10").register_forward_hook(hook_feature);

In [None]:
patients = []

for i, (image, label, _) in tqdm(enumerate(loader), total=len(loader)):
    patient_data = {}
    patient_data['mri'] = image
    patient_data['label'] = label[0][0][1].item()
    patient_data['id'] = '0' * (4 - len(str(i))) + str(i)
    patients.append(patient_data)

acl = list(filter(lambda d: d['label'] == 1, patients))
no_acl = list(filter(lambda d: d['label'] == 0, patients))

len(acl), len(no_acl)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mrnet = MRNet()
mrnet = mrnet.to(device)

_ = mrnet.eval()

case = patients[94]
mri = case['mri']
mri = mri.to(device)

case = patients[102]
mri2 = case['mri']
mri2 = mri2.to(device)

batched_mri = torch.stack([mri, mri2], dim=0)
batched_mri.shape

In [None]:
mrnet(batched_mri)

In [None]:
for i in range(len(patients)):
    print(np.shape(patients[i]['mri']), i)

In [None]:
np.shape(patients[0]["mri"])

In [None]:
def create_patiens_cam(case, plane):
    patient_id = case['id']
    mri = case['mri']

    folder_path = f'./CAMS/{plane}/{patient_id}/'
    if os.path.isdir(folder_path):
        shutil.rmtree(folder_path)
    os.makedirs(folder_path)
    os.makedirs(folder_path + 'slices/')
    os.makedirs(folder_path + 'cams/')

    params = list(mrnet.parameters())
    weight_softmax = np.squeeze(params[-2].cpu().data.numpy())

    num_slices = mri.shape[1]
    global feature_blobs
    feature_blobs = []
    mri = mri.to(device)
    logit = mrnet(mri)
    size_upsample = (256, 256)
    feature_conv = feature_blobs[0]

    h_x = F.softmax(logit, dim=1).data.squeeze(0)
    probs, idx = h_x.sort(0, True)
    probs = probs.cpu().numpy()
    idx = idx.cpu().numpy()
    slice_cams = returnCAM(feature_blobs[-1], weight_softmax, idx[:1])

    # for s in tqdm.tqdm(range(num_slices), leave=False):
    for s in range(num_slices):
        slice_pil = (transforms
                     .ToPILImage()(mri.cpu()[0][s] / 255))
        slice_pil.save(folder_path + f'slices/{s}.png',
                       dpi=(300, 300))

        img = mri[0][s].cpu().numpy()
        img = img.transpose(1, 2, 0)
        heatmap = (cv2
        .cvtColor(cv2.applyColorMap(
            cv2.resize(slice_cams[s], (256, 256)),
            cv2.COLORMAP_JET),
            cv2.COLOR_BGR2RGB)
        )
        result = heatmap * 0.3 + img * 0.6

        pil_img_cam = Image.fromarray(np.uint8(result))
        pil_img_cam.save(folder_path + f'cams/{s}.png', dpi=(300, 300))

In [None]:
for person in tqdm(patients):
    create_patiens_cam(person, plane)

In [None]:
font = cv2.FONT_HERSHEY_SIMPLEX
black = (0, 0, 0)
red = (0, 0, 255)
green = (0, 255, 0)


def center_window(w_, h_):
    user32 = ctypes.windll.user32
    w_s, h_s = user32.GetSystemMetrics(0), user32.GetSystemMetrics(1)
    cv2.moveWindow(win_name, int((w_s - w_) / 2), int((h_s - h_) / 2) - 30)


def get_images(source_pth, patient):
    pth_cams = os.path.join(source_pth, patient, "cams")
    pth_slices = os.path.join(source_pth, patient, "slices")

    imgs_cam, imgs_slice = [], []

    names_cam = os.listdir(pth_cams)
    names_slise = os.listdir(pth_slices)

    names_cam.sort(key=lambda x: int(x[:-4]))
    names_slise.sort(key=lambda x: int(x[:-4]))

    for cam_, slice_ in zip(names_cam, names_slise):
        cam_pth = os.path.join(pth_cams, cam_)
        imgs_cam.append(cv2.imread(cam_pth, cv2.IMREAD_COLOR))

        slice_pth = os.path.join(pth_slices, slice_)
        imgs_slice.append(cv2.imread(slice_pth, cv2.IMREAD_COLOR))

    return imgs_cam, imgs_slice


def get_shap_images(source_pth, patient):
    patient = str(int(patient))
    pth_shap = os.path.join(source_pth, patient)
    imgs_shap = []

    names_shap = os.listdir(pth_shap)
    names_shap.sort(key=lambda x: int(x[:-4]))

    for shap_ in names_shap:
        cam_pth = os.path.join(pth_shap, shap_)
        imgs_shap.append(cv2.imread(cam_pth, cv2.IMREAD_COLOR))

    return imgs_shap


def paste(bg_, img_, row, col):
    y1 = pad * row + 256 * (row - 1)
    y2 = pad * row + 256 * row
    x1 = pad * col + 256 * (col - 1)
    x2 = pad * col + 256 * col

    bg_[y1:y2, x1:x2] = img_


def add_text(header_, txt, row, show_plane: str = None):
    pos_x = pad * row + 90 + 256 * (row - 1)
    pox_y = 55
    cv2.putText(header_, txt, (pos_x, pox_y), font, 1, black, 1)

    if show_plane:
        val = df_results.iloc[int(case)][show_plane]
        pred = str(round(val, 2))
        size, _ = cv2.getTextSize(f"{show_plane}:{pred}", font, 1, 1)
        text_width, text_height = size
        cv2.putText(
            header_, f"{show_plane}:{pred}", (int(pos_x + 40 - text_width / 2), pox_y + 30), font, 1, black, 1
        )


def get_header():
    header_ = np.full((100, win_w, 3), 255, dtype=np.uint8)

    add_text(header_, f"{show_a_id}/{len(images_cam_a) - 1}", row=1, show_plane="axial")
    add_text(header_, f"{show_c_id}/{len(images_cam_c) - 1}", row=2, show_plane='coronal')
    add_text(header_, f"{show_s_id}/{len(images_cam_s) - 1}", row=3, show_plane='sagittal')

    pos_y1 = 25

    size, _ = cv2.getTextSize(f"case:{case}", font, 1, 1)
    text_width, text_height = size
    cv2.putText(header_, f"case:{case}", (int(win_w / 2 - text_width / 2), pos_y1), font, 1, black, 1)

    ground_truth = str(df.iloc[int(case)]["label"])
    size, _ = cv2.getTextSize(f"label:{ground_truth}", font, 1, 1)
    text_width, text_height = size
    cv2.putText(header_, f"label:{ground_truth}", (int(win_w / 2 - text_width / 2 - 256), pos_y1), font, 1, black, 1)

    val = df_results.iloc[int(case)]["preds"]
    color = red if val > 0.5 else green
    pred = str(round(val, 2))
    size, _ = cv2.getTextSize(f"pred:{pred}", font, 1, 1)
    text_width, text_height = size
    cv2.putText(header_, f"pred:{pred}", (int(win_w / 2 - text_width / 2 + 256 + 10), pos_y1), font, 1, color, 1)

    status = np.full((30, win_w, 3), 255, dtype=np.uint8)
    pred_lbl = 1 if val > 0.5 else 0
    ground_truth = int(ground_truth)
    if ground_truth == 0:
        if ground_truth == pred_lbl:
            txt = "True Negative"
            color = green
        else:
            txt = "False Positive"
            color = red
    else:  # == 1
        if ground_truth == pred_lbl:
            txt = "True Positive"
            color = green
        else:
            txt = "False Negative"
            color = red

    size, _ = cv2.getTextSize(f"{txt}", font, 1, 1)
    text_width, text_height = size
    cv2.putText(status, f"{txt}", (int(win_w / 2 - text_width / 2), 25), font, 1, color, 1)
    merged = np.vstack([status, header_])

    return merged


def add_sidebar(img_merge):
    sidebar = np.full((img_merge.shape[0], 110, 3), 255, dtype=np.uint8)

    cv2.putText(
        sidebar, f"CAMs", (int(15), int(270)), font, 1, black, 2
    )
    cv2.putText(
        sidebar, f"IMG", (int(25), int(530)), font, 1, black, 2
    )
    cv2.putText(
        sidebar, f"SHAPs", (int(10), int(790)), font, 1, black, 2
    )
    cv2.putText(
        sidebar, f"V:{source_shap}", (int(10), int(830)), font, 0.4, red, 1
    )

    return np.hstack([sidebar, img_merge])


def fill_str(txt):
    if type(txt) != str:
        txt = str(txt)

    while len(txt) < 4:
        txt = "0" + txt
    return txt


df = pd.read_csv('data/valid-acl.csv', header=None, names=['id', 'label'])
df["id"] = df["id"] - 1130

df_results = pd.read_csv(
    'data/results.csv', header=0, names=['id', 'labels', 'axial', 'coronal', 'sagittal', 'preds']
)

win_name = "preview"
cv2.namedWindow(win_name)

pad = 10
img_h, img_w = 256, 256
n_rows, n_cols = 3, 3
win_h = img_h * n_rows + (n_rows + 1) * pad
win_w = img_w * n_cols + (n_cols + 1) * pad

# bg = np.zeros((win_h, win_w, 3), dtype=np.uint8)
bg = np.full((win_h, win_w, 3), 255, dtype=np.uint8)

# -------------------------------- CAMS
source = "CAMS"
source_a = os.path.join(source, "axial")
source_c = os.path.join(source, "coronal")
source_s = os.path.join(source, "sagittal")

cases_a = os.listdir(source_a)
cases_c = os.listdir(source_c)
cases_s = os.listdir(source_s)

cases_a.sort(key=lambda x: int(x))
cases_c.sort(key=lambda x: int(x))
cases_s.sort(key=lambda x: int(x))

scale = 1.0
case = "0010"
new_case = case

images_cam_a, images_slice_a = get_images(source_a, case)
images_cam_c, images_slice_c = get_images(source_c, case)
images_cam_s, images_slice_s = get_images(source_s, case)
show_a_id, show_c_id, show_s_id = (
    int(len(images_cam_a) / 2),
    int(len(images_cam_c) / 2),
    int(len(images_cam_s) / 2)
)

# -------------------------------- SHAPS

shap_sources = ["SHAPS", "SHAPS_V2"]
shap_sources_id = 0
source_shap = shap_sources[shap_sources_id]
source_shap_a = os.path.join(source_shap, "axial")
source_shap_c = os.path.join(source_shap, "coronal")
source_shap_s = os.path.join(source_shap, "sagittal")

cases_shap_a = os.listdir(source_shap_a)
cases_shap_c = os.listdir(source_shap_c)
cases_shap_s = os.listdir(source_shap_s)

cases_shap_a.sort(key=lambda x: int(x))
cases_shap_c.sort(key=lambda x: int(x))
cases_shap_s.sort(key=lambda x: int(x))

images_shap_a = get_shap_images(source_shap_a, case)
images_shap_c = get_shap_images(source_shap_c, case)
images_shap_s = get_shap_images(source_shap_s, case)

# -------------------------------- SHAPS

header = get_header()

while True:
    key = cv2.waitKey(100) & 0xFF
    if key == ord("q") or key == 27:
        break

    if key == ord("d"):  # d = 100
        print("use debug mode")

    if key == ord("v") and len(shap_sources) > 1:  # v = 118
        if shap_sources_id + 1 < len(shap_sources):
            shap_sources_id += 1
            print("next", shap_sources_id, len(shap_sources))
        else:
            shap_sources_id = 0
            print("first")
        
        source_shap = shap_sources[shap_sources_id]
        source_shap_a = os.path.join(source_shap, "axial")
        source_shap_c = os.path.join(source_shap, "coronal")
        source_shap_s = os.path.join(source_shap, "sagittal")
        
        cases_shap_a = os.listdir(source_shap_a)
        cases_shap_c = os.listdir(source_shap_c)
        cases_shap_s = os.listdir(source_shap_s)
        
        cases_shap_a.sort(key=lambda x: int(x))
        cases_shap_c.sort(key=lambda x: int(x))
        cases_shap_s.sort(key=lambda x: int(x))
        
        images_shap_a = get_shap_images(source_shap_a, case)
        images_shap_c = get_shap_images(source_shap_c, case)
        images_shap_s = get_shap_images(source_shap_s, case)

    # down
    if key == 49 and show_a_id - 1 >= 0:
        show_a_id -= 1
    if key == 50 and show_c_id - 1 >= 0:
        show_c_id -= 1
    if key == 51 and show_s_id - 1 >= 0:
        show_s_id -= 1

    # up
    if key == 55 and show_a_id + 1 < len(images_cam_a):
        show_a_id += 1
    if key == 56 and show_c_id + 1 < len(images_cam_c):
        show_c_id += 1
    if key == 57 and show_s_id + 1 < len(images_cam_s):
        show_s_id += 1

    # case switch key:[ENTER]
    if key == 13:
        min_case_id, max_case_id = "0000", fill_str(len(cases_a) - 1)
        case_tmp = input(f"enter case id in range [{min_case_id}]-[{max_case_id}]")
        if case_tmp and case_tmp.isnumeric():
            if 0 <= int(case_tmp) < len(cases_a):
                new_case = fill_str(case_tmp)

    # left right
    if key == 52 and int(case) - 1 >= 0:
        new_case = str(int(case) - 1)
        new_case = fill_str(new_case)
    if key == 54 and int(case) + 1 < len(cases_a):
        new_case = str(int(case) + 1)
        new_case = fill_str(new_case)

    if case != new_case:
        case = new_case

        images_cam_a, images_slice_a = get_images(source_a, case)
        images_cam_c, images_slice_c = get_images(source_c, case)
        images_cam_s, images_slice_s = get_images(source_s, case)

        images_shap_a = get_shap_images(source_shap_a, case)
        images_shap_c = get_shap_images(source_shap_c, case)
        images_shap_s = get_shap_images(source_shap_s, case)

        show_a_id, show_c_id, show_s_id = (
            int(len(images_cam_a) / 2),
            int(len(images_cam_c) / 2),
            int(len(images_cam_s) / 2)
        )

    paste(bg, images_cam_a[show_a_id], 1, 1)
    paste(bg, images_slice_a[show_a_id], 2, 1)
    paste(bg, images_shap_a[show_a_id], 3, 1)

    paste(bg, images_cam_c[show_c_id], 1, 2)
    paste(bg, images_slice_c[show_c_id], 2, 2)
    paste(bg, images_shap_c[show_c_id], 3, 2)

    paste(bg, images_cam_s[show_s_id], 1, 3)
    paste(bg, images_slice_s[show_s_id], 2, 3)
    paste(bg, images_shap_s[show_s_id], 3, 3)

    header = get_header()
    img_merge = np.vstack([header, bg])
    img_merge = add_sidebar(img_merge)

    # scale
    if key == 61 and scale + .1 < 3:
        scale += .1
    if key == 45 and scale - .1 > .5:
        scale -= .1
    if key == 48:
        scale = 1.0

    if scale != 1.0:
        width = int(img_merge.shape[1] * scale)
        height = int(img_merge.shape[0] * scale)
        img_merge = cv2.resize(img_merge, (width, height), interpolation=cv2.INTER_CUBIC)
        center_window(width, height)

    cv2.imshow(win_name, img_merge)

cv2.destroyAllWindows()

# Load/save

In [None]:
df['label'].value_counts()

In [None]:
df_train = pd.read_csv('data/train-acl.csv', header=None, names=['id', 'label'])

df_train['label'].value_counts()

In [None]:
tmp_results_val = {
    'labels': np.array(results_val['labels']),
    'axial': np.array(results_val['axial']),
    'coronal': np.array(results_val['coronal']),
    'sagittal': np.array(results_val['sagittal']),
    'preds': np.array(y_pred),
}
df_results = pd.DataFrame(tmp_results_val)
df_results

In [None]:
# df_results.to_csv("data/results.csv", index_label="id")

# df_results = pd.read_csv('data/results.csv', header=0, names=['id', 'labels', 'axial', 'coronal', 'sagittal'])
df_results

# Create shap images

In [None]:
gc.collect()

prefix = ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def clear(additional: list = None):
    variables_to_delete = ["background_samples", "background_tensor", "shap_values", "indexes", "e"]
    if additional:
        variables_to_delete.extend(additional)
    for var in variables_to_delete:
        try:
            del globals()[var]
        except KeyError:
            pass  # Variable was not defined, so just pass
    gc.collect()


def create_patiens_shaps(task="acl"):
    planes = ['sagittal', 'coronal', 'axial']
    pbar = tqdm(total=len(planes) * 120, position=0, leave=True)
    for plane in planes:
        model_name = [
            name for name in os.listdir('models/')
            if (task in name) and
               (plane in name) and
               (prefix in name)
        ][0]
        print(f"Loading model for {plane}:{model_name}")

        dataset = MRDataset(
            'data/',
            task,
            plane,
            transform=None,
            train=False
        )

        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            drop_last=False
        )

        patients = []

        # for i, (image, label, _) in tqdm.tqdm(enumerate(loader), total=len(loader)):
        for i, (image, label, _) in enumerate(loader):
            patient_data = {}
            patient_data['mri'] = image
            patient_data['label'] = label[0][0][1].item()
            patient_data['id'] = '0' * (4 - len(str(i))) + str(i)
            patients.append(patient_data)

        model = torch.load(f'models/{model_name}').eval()
        model = model.cuda()

        slices = {}
        for i in range(len(patients)):
            slice = np.shape(patients[i]['mri'])[1]
            if slice in slices:
                slices[slice] += 1
            else:
                slices[slice] = 1

        slices = sorted(slices.items(), key=lambda x: x[1], reverse=True)

        for images_in_set, sets_count in slices:
            all_images_exist_for_all_patients = all(
                os.path.exists(os.path.join(f"SHAPS_V2/{plane}/{tensor_to_id_map[set_idx]}/", f"{img_idx}.png"))
                for set_idx in range(sets_count)
                for img_idx in range(images_in_set)
            )
            if all_images_exist_for_all_patients:
                pbar.update(sets_count)  # Update progress bar by the number of patients in the current set
                continue
            
            background_samples = []
            tensor_to_id_map, n = {}, 0
            for patient_idx in range(len(patients)):
                if np.shape(patients[patient_idx]['mri'])[1] == images_in_set:
                    tensor_to_id_map[n] = patient_idx
                    background_samples.append(patients[patient_idx]['mri'])
                    n += 1

            background_tensor = torch.cat(background_samples, axis=0)
            background_tensor = background_tensor.to(device)
            e = shap.GradientExplainer(model, background_tensor)

            # for set_idx in tqdm.tqdm(range(sets_count)):
            for set_idx in range(sets_count):
                real_set_id = tensor_to_id_map[set_idx]
                
                # Check if all images for this set already exist
                all_images_exist = all(
                    os.path.exists(os.path.join(f"SHAPS_V2/{plane}/{real_set_id}/", f"{img_idx}.png"))
                    for img_idx in range(images_in_set)
                )
                if all_images_exist:
                    pbar.update(1)
                    continue
                
                to_explain = background_tensor[set_idx:set_idx + 1]
                # shap_values, indexes = e.shap_values(to_explain, nsamples=1)
                shap_values, indexes = e.shap_values(to_explain, nsamples=images_in_set)

                data = shap_values[0]
                for img_idx, img in enumerate(data):
                    a = np.sum(img, axis=0)
                    summed_data = np.zeros((a.shape[0] // region_size, a.shape[1] // region_size))

                    summed_img = np.sum(img, axis=0)
                    for x in range(0, summed_img.shape[0], region_size):
                        for y in range(0, summed_img.shape[1], region_size):
                            region_sum = np.sum(summed_img[x:x + region_size, y:y + region_size])
                            summed_data[x // region_size, y // region_size] = region_sum

                    alpha = 0.3

                    v = to_explain.cpu().numpy()[0][img_idx]
                    v_transposed = np.transpose(v, (1, 2, 0))
                    v_transposed = v_transposed.astype(np.uint8)

                    resized_mask = cv2.resize(summed_data, (v_transposed.shape[1], v_transposed.shape[0]))
                    norm_mask = cv2.normalize(
                        resized_mask, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U
                    )
                    colored_mask = cv2.applyColorMap(norm_mask, cv2.COLORMAP_JET)
                    blended = cv2.addWeighted(v_transposed, 1 - alpha, colored_mask, alpha, 0)
                    
                    path = f"SHAPS_V3/{plane}/{real_set_id}/"
                    os.makedirs(path, exist_ok=True)
                    img_path = os.path.join(path, f"{img_idx}.png")
                    cv2.imwrite(img_path, blended)
                pbar.update(1)

            # clear memory
            clear()
    pbar.close()


region_size = 32
create_patiens_shaps()