# Submission Notebook

In [1]:
import shutil

import polars as pl

import kaggle_evaluation.rsna_inference_server

Here we put all the code used in this notebook and do not import it from Python modules.

In [2]:
ID_COL = 'SeriesInstanceUID'

LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
    'Aneurysm Present',
]

# All tags (other than PixelData and SeriesInstanceUID) that may be in a test set dcm file
DICOM_TAG_ALLOWLIST = [
    'BitsAllocated',
    'BitsStored',
    'Columns',
    'FrameOfReferenceUID',
    'HighBit',
    'ImageOrientationPatient',
    'ImagePositionPatient',
    'InstanceNumber',
    'Modality',
    'PatientID',
    'PhotometricInterpretation',
    'PixelRepresentation',
    'PixelSpacing',
    'PlanarConfiguration',
    'RescaleIntercept',
    'RescaleSlope',
    'RescaleType',
    'Rows',
    'SOPClassUID',
    'SOPInstanceUID',
    'SamplesPerPixel',
    'SliceThickness',
    'SpacingBetweenSlices',
    'StudyInstanceUID',
    'TransferSyntaxUID',
]

I change some functions to make the code robust to changes in the test set from the provided training set. Like in `dicom_serie_load`, I recursively obtain the file paths in a given base path, and filter only for ".dcm".

In [3]:
import os
import pydicom

def dicom_serie_load(serie_path):

    instances_filename = os.listdir(serie_path)
    
    ds_l = []
    for root, _, filenames in os.walk(serie_path):
        for filename in filenames:
            if filename.endswith(".dcm"):
                ds = pydicom.dcmread(f"{root}/{filename}")
                ds_l.append(ds)

    return ds_l


def dicom_get_zposition(ds):
    
    if getattr(ds, "ImagePositionPatient", None) and len(ds.ImagePositionPatient) >= 3:
        z_position = ds.ImagePositionPatient[2]
    else:  # in case the tag is missing or does not contain the z axis position
        z_position = getattr(ds, "InstanceNumber", 0.0)

    return float(z_position)


def dicom_get_rescale_factors(ds):

    slope = getattr(ds, "RescaleSlope", 1.0)
    intercept = getattr(ds, "RescaleIntercept", 0.0)
    
    return float(slope), float(intercept)
    

import torch
from copy import deepcopy

def dicom_split_array_from_metadata(ds):  # process a single DICOM

    ds_copy = deepcopy(ds)  # copy the dicom object
    pixel_pt = torch.from_numpy(ds_copy.pixel_array)
    del ds_copy.PixelData  # remove pixel data from the dicom object copy
    
    return pixel_pt, ds_copy


def dicom_serie_process(ds_l):

    n_ds = len(ds_l)

    pixel_pt, ds_metadata = dicom_split_array_from_metadata(ds_l[0])  # Get first instance

    if n_ds == 1:  # one dicom with the whole volume
        volume = pixel_pt
        ds_metadata_l = [ds_metadata]
    else:  # each dicom with a slice
        volume = torch.zeros((n_ds, *pixel_pt.shape), dtype=torch.float32)
        ds_metadata_l = [None] * n_ds
        volume[0] = pixel_pt
        ds_metadata_l[0] = ds_metadata

        # To later sort the slices
        zpositions = torch.zeros((n_ds,), dtype=torch.float32)
        zpositions[0] = dicom_get_zposition(ds_l[0])
        
        for i, ds in enumerate(ds_l[1:], start=1):
            volume[i], ds_metadata_l[i] = dicom_split_array_from_metadata(ds)
            zpositions[i] = dicom_get_zposition(ds)

		# sort slices in the volume
        zpositions_argsort = torch.argsort(zpositions)
        volume = volume[zpositions_argsort]
	
    # rescale volume
    slope, intercept = dicom_get_rescale_factors(ds_l[0])
    volume = volume * slope + intercept
    
    return volume, ds_metadata_l


def dicom_get_spacing(ds):

    pixel_spacing = getattr(ds, "PixelSpacing", None)
    slice_thickness = getattr(ds, "SliceThickness", None)

    if (pixel_spacing is None) or (slice_thickness is None):
        shared_functional_groups_sequence = getattr(ds, "SharedFunctionalGroupsSequence", None)
        if shared_functional_groups_sequence is not None:
            pixel_measures_sequence = getattr(shared_functional_groups_sequence[0], "PixelMeasuresSequence", None)
            if pixel_measures_sequence is not None:
                if pixel_spacing is None:
                    pixel_spacing = getattr(pixel_measures_sequence[0], "PixelSpacing", None)
                if slice_thickness is None:
                    slice_thickness = getattr(pixel_measures_sequence[0], "SliceThickness", None)
    
    if pixel_spacing is None:
        pixel_spacing = [0.0, 0.0]
    else:
        pixel_spacing = [float(axis_spacing) for axis_spacing in pixel_spacing]
    
    if slice_thickness is None:
        slice_thickness = 0.0
    else:
        slice_thickness = float(slice_thickness)
    
    spacing = [*pixel_spacing, slice_thickness]

    return spacing


def dicom_serie_get_spacing(ds_l):

    spacings = torch.zeros((len(ds_l), 3), requires_grad=False)
    for i, ds in enumerate(ds_l):
        spacings[i] = torch.tensor(dicom_get_spacing(ds))

    return spacings.mode(dim=0).values

In [4]:
import torch.nn.functional as F

class NormalizeSpacing:

    def __init__(self, interp_mode, domain_spacings_dict, get_metadata):
        self.interp_mode = interp_mode
        self.domain_spacings_dict = domain_spacings_dict
        self.get_metadata = get_metadata        

    def transform(self, volume, domain, spacing):

        volume = volume.to(torch.float32)

        if domain in self.domain_spacings_dict:
            domain_spacing = torch.tensor(self.domain_spacings_dict[domain])
        else:
            domain_spacing = spacing
        
        target_size = (torch.tensor(volume.shape[2:], dtype=torch.float32) / spacing * domain_spacing).to(torch.int32)
        volume = F.interpolate(volume, size=tuple(target_size), mode=self.interp_mode)
    
        return volume


class NormalizeSizeInterp:

    get_metadata = None
    
    def __init__(self, target_size, mode):
        self.target_size = target_size
        self.mode = mode    

    def transform(self, volume):
    
        volume = volume.to(torch.float32)
        
        volume = F.interpolate(volume, size=self.target_size, mode=self.mode)
        return volume

import numpy as np

class PercentileCropIntensity:

    get_metadata = None

    def __init__(self, percentiles):
        self.percentiles = percentiles
    
    def transform(self, volume):
        percentiles = np.percentile(volume.flatten().detach().cpu().numpy(), self.percentiles)
        volume = torch.clamp(volume, min=percentiles[0], max=percentiles[1])
        return volume


class StandardizeIntensity:

    get_metadata = None
    
    def transform(self, volume):

        volume = volume.to(torch.float32)
    
        volume = (volume - volume.mean()) / volume.std()
        
        return volume

Since the Kaggle server that performs inference does not have access to the Internet, we have to copy and paste all the ResNet50 code here:

In [5]:
from torch import nn
from torch.nn import functional as F

# -------------------- Copied from https://github.com/Warvito/MedicalNet-models/blob/main/medicalnet_models/models/resnet.py
def conv3x3x3(in_planes: int, out_planes: int, stride: int = 1, dilation: int = 1) -> nn.Conv3d:
    """3x3x3 convolution with padding"""
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        dilation=dilation,
        stride=stride,
        padding=dilation,
        bias=False,
    )


def conv1x1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv3d:
    """1x1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: nn.Module = None,
        dilation: int = 1,
    ) -> None:
        super().__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: nn.Module = None,
        dilation: int = 1,
    ) -> None:
        super().__init__()
        self.conv1 = conv1x1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(
            planes,
            planes,
            stride=stride,
            dilation=dilation,
        )
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * 4)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(
        self,
        block,
        layers,
    ) -> None:
        super().__init__()

        self.inplanes = 64
        self.layers = layers

        self.conv1 = nn.Conv3d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)

    def _make_layer(
        self, block, planes: int, blocks: int, stride: int = 1, dilation: int = 1
    ):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1x1(
                    self.inplanes,
                    planes * block.expansion,
                    stride=stride,
                ),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride=stride,
                dilation=dilation,
                downsample=downsample,
            )
        )
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x
# --------------------

class MedResNet(nn.Module):


    def __init__(self, out_features):
        super().__init__()
        self.pretrained = ResNet(Bottleneck, [3, 4, 6, 3])
        self.global_avg_pooling = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.linear = nn.Linear(2048, out_features, bias=False)
        self.linear.weight = nn.init.xavier_normal_(self.linear.weight)

    def forward(self, x):
        x = self.pretrained.forward(x)
        x = self.global_avg_pooling(x)
        x = x.squeeze(dim=(2,3,4))
        x = self.linear(x)
        if not self.training:  # training uses BCEWithLogitsLoss
            x = F.sigmoid(x)

        return x

Now we write the prediction function:

In [6]:
# Constants

# -- Transforms
volume_domain_median_spacing_dict = {
    "CT": (0.46875, 0.46875, 0.8),
    "MR": (0.410156, 0.410156, 0.6),
}

def get_metadata_dicom(ds_metadata_l):
    modality = ds_metadata_l[0].Modality
    spacing = dicom_serie_get_spacing(ds_metadata_l)
    return modality, spacing

transforms = [
    NormalizeSpacing("trilinear", volume_domain_median_spacing_dict, get_metadata_dicom),
    PercentileCropIntensity(percentiles=(0.5, 99.5)),
    StandardizeIntensity(), 
    NormalizeSizeInterp((32, 224, 224), "nearest")
]

# -- Device
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# -- Model
model_resnet50 = MedResNet(14)
model_resnet50.load_state_dict(torch.load("/kaggle/input/med3d-resnet50-rsna-iad/pytorch/default/1/resnet50.pth", weights_only=True))
model_resnet50.eval()
model_resnet50 = model_resnet50.to(device)

def _predict(series_path: str):
    
    # -- Load and Transform
    ds_l = dicom_serie_load(series_path)
    volume, ds_metadata_l = dicom_serie_process(ds_l)

    ndim = len(volume.shape)
    if ndim < 5:
        for _ in range(5 - ndim):
            volume = volume.unsqueeze(0)
    
    for transform in transforms:
        if transform.get_metadata is not None:
            transform_metadata = transform.get_metadata(ds_metadata_l)
            volume = transform.transform(volume, *transform_metadata)
        else:
            volume = transform.transform(volume)
    
    # -- Predict
    with torch.no_grad():
        volume = volume.to(device)
        volume = volume.nan_to_num(posinf=0.0, neginf=0.0)
        scores = model_resnet50.forward(volume)
        
        if (scores[:, :-1] > 0.5).any().item():
            scores[:, -1] = 1.0
        else:
            scores[:, -1] = 0.0

    # -- Make polars DataFrame
    series_id = os.path.basename(series_path)
    predictions = pl.DataFrame(
        data=[[series_id] + scores[0].tolist()],
        schema=[ID_COL, *LABEL_COLS],
        orient='row',
    )

    return predictions

For the script to run until the end we will use a `try: ... except: ...` block with a fallback function in the case the prediction function fails for some reason: 

In [7]:
def _fallback(series_path: str):

    series_id = os.path.basename(series_path)
    predictions = pl.DataFrame(
        data=[[series_id] + [0.5] * len(LABEL_COLS)],
        schema=[ID_COL, *LABEL_COLS],
        orient='row',
    )

    return predictions

In [8]:
# Replace this function with your inference code.
# You can return either a Pandas or Polars dataframe, though Polars is recommended.
# Each prediction (except the very first) must be returned within 30 minutes of the series being provided.
def predict(series_path: str) -> pl.DataFrame:
    """Make a prediction."""
    
    try:
        predictions = _predict(series_path)
    except Exception:
        predictions = _fallback(series_path)

    if isinstance(predictions, pl.DataFrame):
        assert predictions.columns == [ID_COL, *LABEL_COLS]
    elif isinstance(predictions, pd.DataFrame):
        assert (predictions.columns == [ID_COL, *LABEL_COLS]).all()
    else:
        raise TypeError('The predict function must return a DataFrame')

    # ----------------------------- IMPORTANT ------------------------------
    # You MUST have the following code in your `predict` function
    # to prevent "out of disk space" errors. This is a temporary workaround
    # as we implement improvements to our evaluation system.
    shutil.rmtree('/kaggle/shared', ignore_errors=True)
    # ----------------------------------------------------------------------
    
    return predictions.drop(ID_COL)

In [9]:
# Small test
#predict("/kaggle/input/rsna-intracranial-aneurysm-detection/series/1.2.826.0.1.3680043.8.498.10012790035410518400400834395242853657")

In [10]:
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway()
    display(pl.read_parquet('/kaggle/working/submission.parquet'))

SeriesInstanceUID,Left Infraclinoid Internal Carotid Artery,Right Infraclinoid Internal Carotid Artery,Left Supraclinoid Internal Carotid Artery,Right Supraclinoid Internal Carotid Artery,Left Middle Cerebral Artery,Right Middle Cerebral Artery,Anterior Communicating Artery,Left Anterior Cerebral Artery,Right Anterior Cerebral Artery,Left Posterior Communicating Artery,Right Posterior Communicating Artery,Basilar Tip,Other Posterior Circulation,Aneurysm Present
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""1.2.826.0.1.3680043.8.498.1007…",0.006142,0.005146,0.037685,0.023692,0.038315,0.025038,0.106118,0.001882,0.007146,0.008368,0.006331,0.008048,0.008292,0.0
"""1.2.826.0.1.3680043.8.498.1002…",0.016863,0.011341,0.054065,0.034428,0.023518,0.023615,0.038864,0.002557,0.003989,0.008769,0.010296,0.007462,0.018473,0.0
"""1.2.826.0.1.3680043.8.498.1005…",0.018186,0.024549,0.05438,0.037644,0.056714,0.054527,0.086452,0.01043,0.009046,0.015601,0.033935,0.021262,0.029377,0.0
