In [1]:
pip install '../input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl'

Processing /kaggle/input/rsna-monai-packages/monai-0.6.0-202107081903-py3-none-any.whl
Installing collected packages: monai
Successfully installed monai-0.6.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import glob
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [3]:
import albumentations as A
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, accuracy_score
from torch.optim import lr_scheduler
from tqdm import tqdm
import re

In [4]:
NUM_IMAGES_3D = 64
TRAINING_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
IMAGE_SIZE = 256
N_EPOCHS = 15
do_valid = True
n_workers = 4
type_ = "FLAIR"
MODEL_NAME = 'version2.FLAIR'

In [5]:
def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array

    if rotate > 0:
        rot_choices = [
            0,
            cv2.ROTATE_90_CLOCKWISE,
            cv2.ROTATE_90_COUNTERCLOCKWISE,
            cv2.ROTATE_180,
        ]
        data = cv2.rotate(data, rot_choices[rotate])

    data = cv2.resize(data, (img_size, img_size))
    return data

In [6]:
import random

import cv2
from torch.utils.data import Dataset


class BrainRSNADataset(Dataset):
    def __init__(
        self, data, transform=None, target="MGMT_value", mri_type="FLAIR", is_train=True
    ):
        self.target = target
        self.data = data
        self.type = mri_type

        self.transform = transform
        self.is_train = is_train
        self.folder = "train" #if self.is_train else "test"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.loc[index]
        case_id = int(row.BraTS21ID)
        target = int(row[self.target])
        _3d_images = self.load_dicom_images_3d(case_id)
        _3d_images = torch.tensor(_3d_images).float()
        if self.is_train:
            return {"image": _3d_images, "target": target, "case_id": case_id}
        else:
            return {"image": _3d_images, "case_id": case_id}

    def load_dicom_images_3d(
        self,
        case_id,
        num_imgs=NUM_IMAGES_3D,
        img_size=IMAGE_SIZE,
        rotate=0,
    ):
        case_id = str(case_id).zfill(5)

        path = f"../input/rsna-miccai-brain-tumor-radiogenomic-classification/{self.folder}/{case_id}/{self.type}/*.dcm"
        files = sorted(
            glob.glob(path),
            key=lambda var: [
                int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", var)
            ],
        )

        middle = len(files) // 2
        num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)
        
        image_stack = [load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]
        img3d = np.stack(image_stack).T
        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d, n_zero), axis=-1)

        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)

        return np.expand_dims(img3d, 0)



In [7]:
ls ../input/

[0m[01;34mmonai-v060-deep-learning-in-healthcare-imaging[0m/
[01;34mresnet10rsna[0m/
[01;34mrsna-miccai-brain-tumor-radiogenomic-classification[0m/
[01;34mrsna-monai-packages[0m/


In [8]:
import monai

# model 
model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
device = torch.device("cuda")
model.to(device);
all_weights = os.listdir("../input/resnet10rsna")
fold_files = [f for f in all_weights if type_ in f]
criterion = nn.BCEWithLogitsLoss()

In [9]:
import argparse

import pandas as pd
from sklearn.model_selection import StratifiedKFold

# parser = argparse.ArgumentParser()
# parser.add_argument("--n_folds", default=5, type=int)
# args = parser.parse_args()

train = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=518)
oof = []
targets = []
target = "MGMT_value"

for fold, (trn_idx, val_idx) in enumerate(
    skf.split(train, train[target])
):
    train.loc[val_idx, "fold"] = int(fold)


train.to_csv("train.csv", index=False)

In [10]:
train.head()

Unnamed: 0,BraTS21ID,MGMT_value,fold
0,0,1,1.0
1,2,1,2.0
2,3,0,0.0
3,5,1,1.0
4,6,1,2.0


In [11]:
data = pd.read_csv("./train.csv")
curr_fold = 2
train_df = data[data.fold != curr_fold].reset_index(drop=False)
val_df = data[data.fold == curr_fold].reset_index(drop=False)

In [12]:
train_df.head()

Unnamed: 0,index,BraTS21ID,MGMT_value,fold
0,0,0,1,1.0
1,2,3,0,0.0
2,3,5,1,1.0
3,5,8,1,1.0
4,6,9,0,4.0


In [13]:
train_dataset = BrainRSNADataset(data=train_df, mri_type=type_, is_train=True)

valid_dataset = BrainRSNADataset(data=val_df, mri_type=type_, is_train=True)


train_dl = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAINING_BATCH_SIZE,
    shuffle=True,
    num_workers=n_workers,
    drop_last=True,
    pin_memory=True,
)

validation_dl = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=n_workers,
    pin_memory=True,
)

In [14]:
model = monai.networks.nets.resnet10(spatial_dims=3, n_input_channels=1, n_classes=1)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.5, last_epoch=-1, verbose=True)

model.zero_grad()
model.to(device)
best_loss = 9999
best_auc = 0
criterion = nn.BCEWithLogitsLoss()
best_model = None
final_thresh = 0.5
for counter in range(N_EPOCHS):

    epoch_iterator_train = tqdm(train_dl)
    tr_loss = 0.0
    for step, batch in enumerate(epoch_iterator_train):
        model.train()
        images, targets = batch["image"].to(device), batch["target"].to(device)

        outputs = model(images)
        targets = targets  # .view(-1, 1)
        loss = criterion(outputs.squeeze(1), targets.float())

        loss.backward()
        optimizer.step()
        model.zero_grad()
        optimizer.zero_grad()

        tr_loss += loss.item()
        epoch_iterator_train.set_postfix(
            batch_loss=(loss.item()), loss=(tr_loss / (step + 1))
        )
    scheduler.step()  # Update learning rate schedule

    if do_valid:
        with torch.no_grad():
            val_loss = 0.0
            preds = []
            true_labels = []
            case_ids = []
            epoch_iterator_val = tqdm(validation_dl)
            for step, batch in enumerate(epoch_iterator_val):
                model.eval()
                images, targets = batch["image"].to(device), batch["target"].to(device)

                outputs = model(images)
                targets = targets  # .view(-1, 1)
                loss = criterion(outputs.squeeze(1), targets.float())
                val_loss += loss.item()
                epoch_iterator_val.set_postfix(
                    batch_loss=(loss.item()), loss=(val_loss / (step + 1))
                )
                preds.append(outputs.sigmoid().detach().cpu().numpy())
                true_labels.append(targets.cpu().numpy())
                case_ids.append(batch["case_id"])
        preds = np.vstack(preds).T[0].tolist()
        true_labels = np.hstack(true_labels).tolist()
        case_ids = np.hstack(case_ids).tolist()
        auc_score = roc_auc_score(true_labels, preds)
        auc_score_adj_best = 0
        for thresh in np.linspace(0, 1, 50):
            auc_score_adj = roc_auc_score(true_labels, list(np.array(preds) > thresh))
            if auc_score_adj > auc_score_adj_best:
                best_thresh = thresh
                auc_score_adj_best = auc_score_adj
                best_acc = accuracy_score(true_labels, list(np.array(preds) > thresh))

        print(
            f"EPOCH {counter}/{N_EPOCHS}: Validation average loss: {val_loss/(step+1)} + AUC SCORE = {auc_score} + AUC SCORE THRESH {best_thresh} = {auc_score_adj_best}"
        )
        print(f'Best Accuracy: {best_acc}')
        if auc_score > best_auc:
            print("Saving the model...")
            final_thresh = best_thresh
            all_files = os.listdir("./")

            for f in all_files:
                if f"{MODEL_NAME}_{type_}_fold{curr_fold}" in f:
                    os.remove(f"./{f}")

            best_auc = auc_score
            best_model = f"./3d-{MODEL_NAME}_{type_}_fold{curr_fold}_{round(best_auc,3)}.pth"
            torch.save(
                model.state_dict(),
                f"./3d-{MODEL_NAME}_{type_}_fold{curr_fold}_{round(best_auc,3)}.pth",
            )

print(best_auc)

  0%|          | 0/58 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 58/58 [02:48<00:00,  2.91s/it, batch_loss=0.69, loss=0.716]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:43<00:00,  2.92s/it, batch_loss=0.958, loss=0.729]


EPOCH 0/15: Validation average loss: 0.7286055366198222 + AUC SCORE = 0.48448477751756447 + AUC SCORE THRESH 0.42857142857142855 = 0.5103922716627635
Best Accuracy: 0.5299145299145299
Saving the model...


100%|██████████| 58/58 [02:37<00:00,  2.71s/it, batch_loss=0.552, loss=0.708]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.59s/it, batch_loss=0.825, loss=0.707]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 1/15: Validation average loss: 0.7065736373265584 + AUC SCORE = 0.5043911007025761 + AUC SCORE THRESH 0.44897959183673464 = 0.5461065573770492
Best Accuracy: 0.5641025641025641
Saving the model...


100%|██████████| 58/58 [02:32<00:00,  2.63s/it, batch_loss=0.65, loss=0.695]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.58s/it, batch_loss=0.484, loss=0.794]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 2/15: Validation average loss: 0.79442444841067 + AUC SCORE = 0.4803864168618267 + AUC SCORE THRESH 0.5306122448979591 = 0.5408372365339579
Best Accuracy: 0.5555555555555556


100%|██████████| 58/58 [02:32<00:00,  2.64s/it, batch_loss=0.779, loss=0.691]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:39<00:00,  2.65s/it, batch_loss=0.861, loss=0.708]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 3/15: Validation average loss: 0.7079754869143168 + AUC SCORE = 0.4973653395784543 + AUC SCORE THRESH 0.36734693877551017 = 0.5275175644028103
Best Accuracy: 0.5470085470085471


100%|██████████| 58/58 [02:32<00:00,  2.62s/it, batch_loss=0.7, loss=0.677]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.56s/it, batch_loss=0.547, loss=0.712]


EPOCH 4/15: Validation average loss: 0.712087853749593 + AUC SCORE = 0.5752341920374707 + AUC SCORE THRESH 0.44897959183673464 = 0.6001170960187353
Best Accuracy: 0.5982905982905983
Saving the model...


100%|██████████| 58/58 [02:29<00:00,  2.58s/it, batch_loss=0.526, loss=0.682]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:40<00:00,  2.73s/it, batch_loss=0.721, loss=0.717]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 5/15: Validation average loss: 0.7169462283452351 + AUC SCORE = 0.5231264637002342 + AUC SCORE THRESH 0.6122448979591836 = 0.5292740046838407
Best Accuracy: 0.5213675213675214


100%|██████████| 58/58 [02:31<00:00,  2.61s/it, batch_loss=0.634, loss=0.685]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.60s/it, batch_loss=0.645, loss=0.699]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 6/15: Validation average loss: 0.6987573822339376 + AUC SCORE = 0.5281030444964872 + AUC SCORE THRESH 0.5102040816326531 = 0.5564988290398127
Best Accuracy: 0.5726495726495726


100%|██████████| 58/58 [02:30<00:00,  2.59s/it, batch_loss=0.754, loss=0.667]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.57s/it, batch_loss=0.532, loss=0.766]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 7/15: Validation average loss: 0.7661856452624003 + AUC SCORE = 0.5687939110070258 + AUC SCORE THRESH 0.3877551020408163 = 0.5898711943793911
Best Accuracy: 0.5982905982905983


100%|██████████| 58/58 [02:30<00:00,  2.59s/it, batch_loss=0.625, loss=0.673]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-04.


100%|██████████| 15/15 [00:38<00:00,  2.57s/it, batch_loss=0.348, loss=0.816]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 8/15: Validation average loss: 0.8158074736595153 + AUC SCORE = 0.5032201405152226 + AUC SCORE THRESH 0.32653061224489793 = 0.5379098360655737
Best Accuracy: 0.5555555555555556


100%|██████████| 58/58 [02:31<00:00,  2.61s/it, batch_loss=0.696, loss=0.66]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.53s/it, batch_loss=0.85, loss=0.698]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 9/15: Validation average loss: 0.6977291266123454 + AUC SCORE = 0.5708430913348946 + AUC SCORE THRESH 0.5918367346938775 = 0.5642564402810304
Best Accuracy: 0.5555555555555556


100%|██████████| 58/58 [02:29<00:00,  2.57s/it, batch_loss=0.581, loss=0.647]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.59s/it, batch_loss=0.218, loss=0.779]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 10/15: Validation average loss: 0.7790408005317052 + AUC SCORE = 0.5518149882903982 + AUC SCORE THRESH 0.4081632653061224 = 0.5541569086651054
Best Accuracy: 0.5641025641025641


100%|██████████| 58/58 [02:30<00:00,  2.60s/it, batch_loss=0.622, loss=0.645]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.56s/it, batch_loss=0.355, loss=0.822]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 11/15: Validation average loss: 0.8221994698047638 + AUC SCORE = 0.5547423887587822 + AUC SCORE THRESH 0.5306122448979591 = 0.5525468384074941
Best Accuracy: 0.5555555555555556


100%|██████████| 58/58 [02:29<00:00,  2.58s/it, batch_loss=0.587, loss=0.647]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.59s/it, batch_loss=0.619, loss=0.734]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 12/15: Validation average loss: 0.7335605184237163 + AUC SCORE = 0.5357142857142857 + AUC SCORE THRESH 0.5306122448979591 = 0.5749414519906324
Best Accuracy: 0.5811965811965812


100%|██████████| 58/58 [02:33<00:00,  2.64s/it, batch_loss=0.59, loss=0.647]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.55s/it, batch_loss=0.238, loss=0.976]
  0%|          | 0/58 [00:00<?, ?it/s]

EPOCH 13/15: Validation average loss: 0.9759975999593735 + AUC SCORE = 0.5055620608899297 + AUC SCORE THRESH 0.4897959183673469 = 0.5623536299765808
Best Accuracy: 0.5726495726495726


100%|██████████| 58/58 [02:31<00:00,  2.62s/it, batch_loss=0.652, loss=0.652]
  0%|          | 0/15 [00:00<?, ?it/s]

Adjusting learning rate of group 0 to 5.0000e-05.


100%|██████████| 15/15 [00:38<00:00,  2.58s/it, batch_loss=0.671, loss=0.711]

EPOCH 14/15: Validation average loss: 0.7108534932136535 + AUC SCORE = 0.5377634660421545 + AUC SCORE THRESH 0.2857142857142857 = 0.5668911007025761
Best Accuracy: 0.5811965811965812
0.5752341920374707





In [15]:
tta_true_labels = []
tta_preds = []
test_dataset = BrainRSNADataset(data=data, mri_type=type_, is_train=True)
test_dl = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, num_workers=4
    )

# preds_f = np.zeros(len(sample))
# for fold in range(5):
image_ids = []
model.load_state_dict(torch.load(best_model))
preds = []
labels = []
epoch_iterator_test = tqdm(test_dl)
with torch.no_grad():
    for  step, batch in enumerate(epoch_iterator_test):
        model.eval()
        images = batch["image"].to(device)

        outputs = model(images)
        preds.append(outputs.sigmoid().detach().cpu().numpy())
        image_ids.append(batch["case_id"].detach().cpu().numpy())
        labels.append(batch['target'].detach().cpu().numpy())

#     preds_f += np.vstack(preds).T[0]/5

#     ids_f = np.hstack(image_ids)

100%|██████████| 585/585 [02:49<00:00,  3.45it/s]


In [16]:
all_preds = []
for batch in preds:
    for pred in batch:
        all_preds.append(pred[0])

In [17]:
all_labels = []
for batch in labels:
    for lab in batch:
        all_labels.append(lab)

In [18]:
assert len(all_preds) == len(all_labels)

In [19]:
from sklearn.metrics import accuracy_score, roc_auc_score

all_preds_thresh = [val >= final_thresh for val in all_preds]
print(accuracy_score(all_labels, all_preds_thresh))
print(roc_auc_score(all_labels, all_preds_thresh))

0.6188034188034188
0.6179492887774471
