# Brain Tumor Classification with PyTorch⚡Lightning & EfficientNet

The goal of this challenge is to Predict the status of a genetic biomarker important for brain cancer treatment.

In [None]:
! pip install -qU pytorch_lightning pydicom rising torchsummary
! pip install -q https://github.com/shijianjian/EfficientNet-PyTorch-3D/archive/refs/heads/master.zip
! pip uninstall -q -y wandb
! ls -l /home/jovyan/work/rsna-miccai-brain-tumor
! nvidia-smi
! mkdir /home/jovyan/temp

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data exploration

These 3 cohorts are structured as follows: Each independent case has a dedicated folder identified by a five-digit number.
Within each of these “case” folders, there are four sub-folders, each of them corresponding to each of the structural multi-parametric MRI (mpMRI) scans, in DICOM format.
The exact mpMRI scans included are:

- **FLAIR**: Fluid Attenuated Inversion Recovery
- **T1w**: T1-weighted pre-contrast
- **T1Gd**: T1-weighted post-contrast
- **T2**: T2-weighted

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

PATH_DATASET = "/home/jovyan/work/rsna-miccai-brain-tumor"
PATH_TEMP = "/home/jovyan/temp"
SCAN_TYPES = ("FLAIR", "T1w", "T1CE", "T2w")

df_train = pd.read_csv(os.path.join(PATH_DATASET, "train_labels.csv"))
df_train["BraTS21ID"] = df_train["BraTS21ID"].apply(lambda i: "%05d" % i)
display(df_train.head())

In [None]:
_= df_train["MGMT_value"].value_counts().plot(kind="pie", title="label distribution")

In [None]:
import glob
import re
import torch
import pydicom
from torch import Tensor
import torch.nn.functional as F
from typing import Optional, Tuple
from pydicom.pixel_data_handlers import apply_voi_lut


def parse_name_index(dcm_path) -> int:
    res = re.match(r".*-(\d+)\.dcm", dcm_path).groups()
    assert len(res) == 1
    return int(res[0])


def load_dicom(path_file: str) -> Optional[np.ndarray]:
    dicom = pydicom.dcmread(path_file)
    # TODO: adjust spacing in particular dimension according DICOM meta
    try:
        img = apply_voi_lut(dicom.pixel_array, dicom).astype(np.float32)
    except RuntimeError as err:
        print(err)
        return None
    return img


def load_volume(path_volume: str, percentile: Optional[int] = 0.01) -> Tensor:
    path_slices = glob.glob(os.path.join(path_volume, '*.dcm'))
    path_slices = sorted(path_slices, key=parse_name_index)
    vol = []
    for p_slice in path_slices:
        img = load_dicom(p_slice)
        if img is None:
            continue
        vol.append(img.T)
    volume = torch.tensor(vol, dtype=torch.float32)
    if percentile is not None:
        # get extreme values
        p_low = np.quantile(volume, percentile) if percentile else volume.min()
        p_high = np.quantile(volume, 1 - percentile) if percentile else volume.max()
        # normalize
        volume = (volume.to(torch.float32) - p_low) / (p_high - p_low)
    return volume.T


def interpolate_volume(volume: Tensor) -> Tensor:
    vol_shape = volume.shape
    d_new = min(vol_shape[:2])
    # assert vol_shape[0] == vol_shape[1], f"mixed shape: {vol_shape}"
    if d_new == vol_shape[2]:
        return volume
    vol_size = (vol_shape[0], vol_shape[1], d_new)
    return F.interpolate(volume.unsqueeze(0).unsqueeze(0), size=vol_size, mode="trilinear", align_corners=False)[0, 0]


def _tuple_int(t: Tensor) -> tuple:
    return tuple(t.numpy().astype(int))

def resize_volume(volume: Tensor, size: int = 128) -> Tensor:
    shape_old = torch.tensor(volume.shape)
    shape_new = torch.tensor([size] * 3)
    scale = torch.max(shape_old.to(float) / shape_new)
    shape_scale = shape_old / scale
    # print(f"{shape_old} >> {shape_scale} >> {shape_new}")
    vol_ = F.interpolate(volume.unsqueeze(0).unsqueeze(0), size=_tuple_int(shape_scale), mode="trilinear", align_corners=False)[0, 0]
    offset = _tuple_int((shape_new - shape_scale) / 2)
    volume = torch.zeros(*_tuple_int(shape_new), dtype=torch.float32)
    shape_scale = _tuple_int(shape_scale)
    volume[offset[0]:offset[0]+shape_scale[0], offset[1]:offset[1]+shape_scale[1], offset[2]:offset[2]+shape_scale[2]] = vol_
    return volume


def find_dim_min(vec: list, thr: float) -> int:
    high = np.array(vec) >= thr
    return np.argmax(high)


def find_dim_max(vec: list, thr: float) -> int:
    high = np.array(vec) >= thr
    return len(high) - np.argmax(high[::-1])


def crop_volume(volume: Tensor, thr: float = 1e-6) -> Tensor:
    dims_x = torch.sum(torch.sum(volume, 1), -1) / np.prod(volume.shape)
    dims_y = torch.sum(torch.sum(volume, 0), -1) / np.prod(volume.shape)
    dims_z = torch.sum(torch.sum(volume, 0), 0) / np.prod(volume.shape)
    return volume[
        find_dim_min(dims_x, thr):find_dim_max(dims_x, thr),
        find_dim_min(dims_y, thr):find_dim_max(dims_y, thr),
        find_dim_min(dims_z, thr):find_dim_max(dims_z, thr)
    ]


def show_volume_slice(axarr_, vol_slice, ax_name: str, v_min_max: tuple = (0., 1.)):
    axarr_[0].set_title(f"axis: {ax_name}")
    axarr_[0].imshow(vol_slice, cmap="gray", vmin=v_min_max[0], vmax=v_min_max[1])
    axarr_[1].plot(torch.sum(vol_slice, 1), list(range(vol_slice.shape[0]))[::-1])
    axarr_[1].plot(list(range(vol_slice.shape[1])), torch.sum(vol_slice, 0))
    axarr_[1].set_aspect('equal')
    axarr_[1].grid()


def idx_middle_if_none(volume: Tensor, *xyz: Optional[int]):
    xyz = list(xyz)
    vol_shape = volume.shape
    for i, d in enumerate(xyz):
        if d is None:
            xyz[i] = int(vol_shape[i] / 2)
        assert 0 <= xyz[i] < vol_shape[i]
    return xyz


def show_volume(volume: Tensor, x: Optional[int] = None, y: Optional[int] = None, z: Optional[int] = None, fig_size: Tuple[int, int] = (14, 9), v_min_max: tuple = (0., 1.),):
    x, y, z = idx_middle_if_none(volume, x, y, z)
    fig, axarr = plt.subplots(nrows=2, ncols=3, figsize=fig_size)
    print(f"share: {volume.shape}, x={x}, y={y}, z={y}  >> {volume.dtype}")
    show_volume_slice(axarr[:, 0], volume[x, :, :], "X", v_min_max)
    show_volume_slice(axarr[:, 1], volume[:, y, :], "Y", v_min_max)
    show_volume_slice(axarr[:, 2], volume[:, :, z], "Z", v_min_max)
    # plt.show(fig)
    return fig

In [None]:
from ipywidgets import interact, IntSlider

def interactive_show(volume_path: str, crop_thr: float):
    print(f"loading: {volume_path}")
    volume = load_volume(volume_path, percentile=0)
    print(f"sample shape: {volume.shape} >> {volume.dtype}")
    volume = interpolate_volume(volume)
    print(f"interp shape: {volume.shape} >> {volume.dtype}")
    volume = crop_volume(volume, crop_thr)
    print(f"crop shape: {volume.shape} >> {volume.dtype}")
    vol_shape = volume.shape
    interact(
        lambda x, y, z: plt.show(show_volume(volume, x, y, z)),
        x=IntSlider(min=0, max=vol_shape[0], step=5, value=int(vol_shape[0] / 2)),
        y=IntSlider(min=0, max=vol_shape[1], step=5, value=int(vol_shape[1] / 2)),
        z=IntSlider(min=0, max=vol_shape[2], step=5, value=int(vol_shape[2] / 2)),
    )


PATH_SAMPLE_VOLUME = os.path.join(PATH_DATASET, "train", "00005", "FLAIR")

interactive_show(PATH_SAMPLE_VOLUME, crop_thr=1e-6)

## Prepare dataset

In [None]:
import os
from typing import Union, Sequence, Optional

import pandas as pd
import torch
from torch.utils.data import Dataset


class BrainScansDataset(Dataset):

    def __init__(
        self,
        image_dir: str = 'train',
        df_table: Union[str, pd.DataFrame] = 'train_labels.csv',
        scan_types: Sequence[str] = ("FLAIR", "T2w"),
        cache_dir: Optional[str] = None,
        crop_thr: float = 1e-6,
        mode: str = 'train',
        split: float = 0.8,
        in_memory: bool = False,
        random_state=42,
    ):
        self.image_dir = image_dir
        self.scan_types = scan_types
        self.cache_dir = cache_dir
        self.crop_thr = crop_thr
        self.mode = mode
        self.in_memory = in_memory

        # set or load the config table
        if isinstance(df_table, pd.DataFrame):
            assert all(c in df_table.columns for c in ["BraTS21ID", "MGMT_value"])
            self.table = df_table
        elif isinstance(df_table, str):
            assert os.path.isfile(df_table), f"missing file: {df_table}"
            self.table = pd.read_csv(df_table)
        else:
            raise ValueError(f'unrecognised input for DataFrame/CSV: {df_table}')

        # shuffle data
        self.table = self.table.sample(frac=1, random_state=random_state).reset_index(drop=True)

        # split dataset
        assert 0.0 <= split <= 1.0, f"split {split} is out of range"
        frac = int(split * len(self.table))
        self.table = self.table[:frac] if mode == 'train' else self.table[frac:]

        # populate images/labels
        self.images = []
        self.labels = []
        for _, row in self.table.iterrows():
            id_ = row["BraTS21ID"]
            name = id_ if isinstance(id_, str) else "%05d" % id_
            imgs = [os.path.join(name, tp) for tp in self.scan_types]
            imgs = [p for p in imgs if os.path.isdir(os.path.join(self.image_dir, p))]
            self.images += imgs
            self.labels += [row["MGMT_value"]] * len(imgs)
        assert len(self.images) == len(self.labels)

    @staticmethod
    def load_image(rltv_path: str, image_dir: str, cache_dir: str, crop_thr: float):
        vol_path = os.path.join(cache_dir or "", f"{rltv_path}.pt")
        if os.path.isfile(vol_path):
            try:
                return torch.load(vol_path)
            except EOFError:
                print(f"failed loading: {vol_path}")
        img_path = os.path.join(image_dir, rltv_path)
        assert os.path.isdir(img_path)
        img = load_volume(img_path)
        img = interpolate_volume(img)
        if crop_thr is not None:
            img = crop_volume(img, thr=crop_thr)
        if cache_dir:
            os.makedirs(os.path.dirname(vol_path), exist_ok=True)
            torch.save(img, vol_path)
        return img

    def _load_image(self, rltv_path: str):
        return BrainScansDataset.load_image(rltv_path, self.image_dir, self.cache_dir, self.crop_thr)

    def __getitem__(self, idx: int) -> dict:
        label = self.labels[idx]
        img = self.images[idx]
        if isinstance(img, str):
            img = self._load_image(img)
        if self.in_memory:
            self.images[idx] = img
        # in case of predictions, return image name as label
        label = label if label is not None else img_name
        return {"data": img.unsqueeze(0), "label": label}

    def __len__(self) -> int:
        return len(self.images)


# ==============================
from tqdm.auto import tqdm

ds = BrainScansDataset(
    image_dir=os.path.join(PATH_DATASET, "train"),
    df_table=os.path.join(PATH_DATASET, "train_labels.csv"),
    crop_thr=None, cache_dir=PATH_TEMP,
)
for i in tqdm(range(2)):
    img = ds[i * 10]["data"]
    img = resize_volume(img[0])
    show_volume(img, fig_size=(12, 8))

In [None]:
import logging
from multiprocessing import Pool
from functools import partial
from typing import Optional, Sequence
from pytorch_lightning import LightningDataModule
import rising.transforms as rtr
from rising.loading import DataLoader, default_transform_call
from rising.random import DiscreteParameter, UniformParameter

# define transformations
VAL_TRANSFORMS = [
    rtr.NormZeroMeanUnitStd(keys=["data"]),
]
TRAIN_TRANSFORMS = [
    rtr.NormZeroMeanUnitStd(keys=["data"]),
    # rtr.Rot90((0, 1, 2), keys=["data"], p=0.5),
    # rtr.Mirror(dims=DiscreteParameter([0, 1, 2]), keys=["data"]),
    # rtr.Rotate(UniformParameter(0, 180), degree=True),
]


def rising_resize(size: int = 64, **batch):
    img = batch["data"]
    assert len(img.shape) == 4
    img_ = []
    for i in range(img.shape[0]):
        img_.append(resize_volume(img[i], size))
    batch.update({"data": torch.stack(img_, dim=0)})
    return batch

# ==============================


class BrainScansDM(LightningDataModule):

    def __init__(
        self,
        data_dir: str = '.',
        path_csv: str = 'train_labels.csv',
        cache_dir: str = '.',
        scan_types: Sequence[str] = ("FLAIR", "T2w"),
        crop_thr: float = 1e-6,
        in_memory: bool = False,
        input_size: int = 64,
        batch_size: int = 4,
        num_workers: int = None,
        train_transforms=None,
        valid_transforms=None,
        split: float = 0.8,
    ):
        super().__init__()
        # path configurations
        assert os.path.isdir(data_dir), f"missing folder: {data_dir}"
        self.train_dir = os.path.join(data_dir, 'train')
        self.test_dir = os.path.join(data_dir, 'test')
        self.cache_dir = cache_dir

        if not os.path.isfile(path_csv):
            path_csv = os.path.join(data_dir, path_csv)
        assert os.path.isfile(path_csv), f"missing table: {path_csv}"
        self.path_csv = path_csv

        # other configs
        self.scan_types = scan_types
        self.crop_thr = crop_thr
        self.input_size = input_size
        self.batch_size = batch_size
        self.split = split
        self.in_memory = in_memory
        self.num_workers = num_workers if num_workers is not None else os.cpu_count()

        # need to be filled in setup()
        self.train_dataset = None
        self.valid_dataset = None
        self.test_table = []
        self.test_dataset = None
        self.train_transforms = train_transforms
        self.valid_transforms = valid_transforms
            
    def prepare_data(self, num_proc: int = 0):
        if not self.cache_dir:
            return
        ds = BrainScansDataset(
            image_dir=self.train_dir,
            df_table=self.path_csv,
            scan_types=self.scan_types,
            split=1.0,
            cache_dir=self.cache_dir,
            crop_thr=self.crop_thr,
            in_memory=False,
        )
        # for im in ds.images:
        #     ds._load_image(im)

        if num_proc > 1:
            pool = Pool(processes=num_proc)
            mapping = pool.imap_unordered
        else:
            pool = None
            mapping = map

        pbar = tqdm(desc=f"preparing/caching scans @{num_proc} jobs", total=len(ds))
        _cache_img = partial(BrainScansDataset.load_image, image_dir=ds.image_dir, cache_dir=ds.cache_dir, crop_thr=ds.crop_thr)
        for _ in mapping(_cache_img, ds.images):
            pbar.update()

        if pool:
            pool.close()
            pool.join()

    def setup(self, *_, **__) -> None:
        """Prepare datasets"""
        ds_defaults = dict(
            image_dir=self.train_dir,
            df_table=self.path_csv,
            scan_types=self.scan_types,
            cache_dir=self.cache_dir,
            crop_thr=self.crop_thr,
            split=self.split,
            in_memory=self.in_memory,
        )
        self.train_dataset = BrainScansDataset(**ds_defaults, mode='train')
        logging.info(f"training dataset: {len(self.train_dataset)}")
        self.valid_dataset = BrainScansDataset(**ds_defaults, mode='valid')
        logging.info(f"validation dataset: {len(self.valid_dataset)}")

        if not os.path.isdir(self.test_dir):
            return
        ls_cases = [os.path.basename(p) for p in glob.glob(os.path.join(self.test_dir, '*'))]
        self.test_table = [dict(BraTS21ID=n, MGMT_value=0.5) for n in ls_cases]
        self.test_dataset = BrainScansDataset(
            image_dir=self.test_dir,
            df_table=pd.DataFrame(self.test_table),
            scan_types=self.scan_types,
            cache_dir=self.cache_dir,
            crop_thr=self.crop_thr,
            split=0,
            mode='test'
        )
        logging.info(f"test dataset: {len(self.test_dataset)}")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            sample_transforms=partial(rising_resize, size=self.input_size),  # todo: resize to fix size
            batch_transforms=self.train_transforms,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            sample_transforms=partial(rising_resize, size=self.input_size),  # todo: resize to fix size
            batch_transforms=self.valid_transforms,
        )

    def test_dataloader(self) -> Optional[DataLoader]:
        if not self.test_dataset:
            logging.warning('no testing images found')
            return None
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=0,
            shuffle=False,
            sample_transforms=partial(rising_resize, size=self.input_size),  # todo: resize to fix size
            batch_transforms=self.valid_transforms,
        )

# ==============================
from tqdm.auto import tqdm

dm = BrainScansDM(
    data_dir=PATH_DATASET,
    scan_types=["T2w"],
    input_size=224,
    crop_thr=1e-6,
    batch_size=8,
    cache_dir=PATH_TEMP,
    # in_memory=True,
    num_workers=2,
    train_transforms=rtr.Compose(TRAIN_TRANSFORMS, transform_call=default_transform_call),
    valid_transforms=rtr.Compose(VAL_TRANSFORMS, transform_call=default_transform_call),
)
# dm.prepare_data(2)
dm.setup()

# Quick view
for batch in dm.train_dataloader():
    for i in range(3):
        show_volume(batch["data"][i][0], fig_size=(9, 6), v_min_max=(-1., 1.))
    break

## Prepare 3D model

In [None]:
from torchmetrics import Accuracy, F1, Precision
from pytorch_lightning import LightningModule
from efficientnet_pytorch_3d import EfficientNet3D


class LitBrainMRI(LightningModule):

    def __init__(
        self,
        model_name: str = "efficientnet-b0",
        lr: float = 1e-4,
    ):
        super().__init__()
        self.model_name = model_name
        self.model = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=1)
        self.learning_rate = lr

        self.train_accuracy = Accuracy()
        self.train_precision = Precision()
        self.train_f1_score = F1()
        self.val_accuracy = Accuracy()
        self.val_precision = Precision()
        self.val_f1_score = F1()

    def forward(self, x: Tensor) -> Tensor:
        return F.softmax(self.model(x))

    def compute_loss(self, y_hat: Tensor, y: Tensor):
        return F.cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx):
        img, y = batch["data"], batch["label"]
        y_hat = self(img)
        loss = self.compute_loss(y_hat, y)
        self.log("train_loss", loss, prog_bar=False)
        self.log("train_acc", self.train_accuracy(y_hat, y), prog_bar=False)
        self.log("train_prec", self.train_precision(y_hat, y), prog_bar=False)
        self.log("train_f1", self.train_f1_score(y_hat, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, y = batch["data"], batch["label"]
        y_hat = self(img)
        loss = self.compute_loss(y_hat, y)
        self.log("valid_loss", loss, prog_bar=False)
        self.log("valid_acc", self.val_accuracy(y_hat, y), prog_bar=True)
        self.log("valid_prec", self.val_precision(y_hat, y), prog_bar=True)
        self.log("valid_f1", self.val_f1_score(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0)
        return [optimizer], [scheduler]


# ==============================
from torchsummary import summary

model = LitBrainMRI()
# summary(model, input_size=(1, 128, 128, 128))

## Train a model

In [None]:
import pytorch_lightning as pl

logger = pl.loggers.CSVLogger(save_dir='logs/', name=model.model_name)
# swa = pl.callbacks.StochasticWeightAveraging(swa_epoch_start=0.6)
ckpt = pl.callbacks.ModelCheckpoint(
    monitor='valid_f1',
    save_top_k=1,
    save_last=True,
    # save_weights_only=True,
    filename='checkpoint/{epoch:02d}-{valid_acc:.4f}-{valid_f1:.4f}',
    # verbose=False,
    mode='max',
)

# ==============================

trainer = pl.Trainer(
    # fast_dev_run=True,
    gpus=1,
    callbacks=[ckpt],  # , swa
    logger=logger,
    max_epochs=3,
    precision=16,
    #overfit_batches=5,
    accumulate_grad_batches=24,
    val_check_interval=0.5,
    progress_bar_refresh_rate=1,
    weights_summary='top',
)

# ==============================

trainer.fit(model=model, datamodule=dm)

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
print(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[['train_loss', 'valid_loss']].plot(grid=True, legend=True, xlabel=agg_col)
df_metrics[['valid_f1', 'valid_acc', 'valid_prec', 'train_acc']].plot(grid=True, legend=True, xlabel=agg_col)