## Introduction

This kernels uses the recent pip wheel of DALI for decoding dicoms using GPU. It works for all JPEG2000 and most of the JPEG-lossless formated images.

The decoding work strongly is based on the kernels of Theo Viel (@theoviel) and David Austin (@tivfrvqhs5)

***WARNING***: Allthough the GPU decoding works for all train images, a few of the JPEG-lossless formated DICOMS (TransferSyntaxUID == '1.2.840.10008.1.2.4.70') of the hidden test set cannot be decoded. So its crucial to have a CPU fallback in place so the notebook wont throw an exception in the submission re-run

## Requirements

We start with installing pip requirements.

In [1]:
!pip install -q timm==0.6.5 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements
!pip install -q albumentations==1.2.1 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements
!pip install -q pylibjpeg-libjpeg==1.3.1 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements
!pip install -q pydicom==2.0.0 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements
!pip install -q python-gdcm==3.0.20 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements
!pip install -q dicomsdl==0.109.1 --no-index --find-links=/kaggle/input/rsna-bc-pip-requirements

[0m

Then we install the latest DALI packaging which we will use for GPU decoding

In [2]:
!pip install -q /kaggle/input/nvidia-dali-nightly-cuda110-1230dev/nvidia_dali_nightly_cuda110-1.23.0.dev20230203-7187866-py3-none-manylinux2014_x86_64.whl

[0m

Next, we import all the packages we need and patch a function to allow for INT16 support

In [3]:
import timm
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
from copy import copy
import gc
import shutil 

import glob
from scipy.special import expit

import albumentations as A
import cv2
cv2.setNumThreads(0)

import dicomsdl
import pydicom
from pydicom.filebase import DicomBytesIO

from os.path import join

from tqdm import tqdm

from joblib import Parallel, delayed
import multiprocessing as mp

from types import SimpleNamespace
from typing import Any, Dict

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast


import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali import pipeline_def
from nvidia.dali.types import DALIDataType

In [4]:
#we need to patch DALI for Int16 support


from nvidia.dali.backend import TensorGPU, TensorListGPU
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
from nvidia.dali import types
from nvidia.dali.plugin.base_iterator import _DaliBaseIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
import torch
import torch.utils.dlpack as torch_dlpack
import ctypes
import numpy as np
import torch.nn.functional as F
import pydicom

to_torch_type = {
    types.DALIDataType.FLOAT:   torch.float32,
    types.DALIDataType.FLOAT64: torch.float64,
    types.DALIDataType.FLOAT16: torch.float16,
    types.DALIDataType.UINT8:   torch.uint8,
    types.DALIDataType.INT8:    torch.int8,
    types.DALIDataType.UINT16:  torch.int16,
    types.DALIDataType.INT16:   torch.int16,
    types.DALIDataType.INT32:   torch.int32,
    types.DALIDataType.INT64:   torch.int64
}


def feed_ndarray(dali_tensor, arr, cuda_stream=None):
    """
    Copy contents of DALI tensor to PyTorch's Tensor.

    Parameters
    ----------
    `dali_tensor` : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
                    Tensor from which to copy
    `arr` : torch.Tensor
            Destination of the copy
    `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
                    CUDA stream to be used for the copy
                    (if not provided, an internal user stream will be selected)
                    In most cases, using pytorch's current stream is expected (for example,
                    if we are copying to a tensor allocated with torch.zeros(...))
    """
    dali_type = to_torch_type[dali_tensor.dtype]

    assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList"
                                    " doesn't match the element type of the target PyTorch Tensor: "
                                    "{} vs {}".format(dali_type, arr.dtype))
    assert dali_tensor.shape() == list(arr.size()), \
        ("Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".
            format(dali_tensor.shape(), list(arr.size())))
    cuda_stream = types._raw_cuda_stream(cuda_stream)

    # turn raw int to a c void pointer
    c_type_pointer = ctypes.c_void_p(arr.data_ptr())
    if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
        stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
        dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
    else:
        dali_tensor.copy_to_external(c_type_pointer)
    return arr





Next I set major variables which handle the public run and the re-run on the hidden test set, and also allow for simulating the size of the hidden test set by setting RAM_CHECK = True

In [5]:
# Params

COMP_FOLDER = '/kaggle/input/rsna-breast-cancer-detection/'
DATA_FOLDER = COMP_FOLDER + 'test_images/'

sample_submission = pd.read_csv(COMP_FOLDER + 'sample_submission.csv')

PUBLIC_RUN = len(sample_submission) == 2

N_CORES = mp.cpu_count()
MIXED_PRECISION = False
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

RAM_CHECK = True
DEBUG = True

test_df = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/test.csv')
test_df['cancer'] = 0 #dummy value


if PUBLIC_RUN is False:
    RAM_CHECK = False
    DEBUG = False

if RAM_CHECK is True:
    test_df = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv')
    patient_filter = list(sorted((set(test_df.patient_id.unique()))))[:8000]
    test_df = test_df[test_df.patient_id.isin(patient_filter)]
    DATA_FOLDER = DATA_FOLDER.replace('test','train')

if DEBUG is True:
    test_df = test_df.head(500)

test_df

Unnamed: 0,site_id,patient_id,image_id,laterality,view,age,cancer,biopsy,invasive,BIRADS,implant,density,machine_id,difficult_negative_case
0,2,10006,462822612,L,CC,61.0,0,0,0,,0,,29,False
1,2,10006,1459541791,L,MLO,61.0,0,0,0,,0,,29,False
2,2,10006,1864590858,R,MLO,61.0,0,0,0,,0,,29,False
3,2,10006,1874946579,R,CC,61.0,0,0,0,,0,,29,False
4,2,10011,220375232,L,CC,55.0,0,0,0,0.0,0,,21,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,2,10512,457122509,R,CC,40.0,0,0,0,,0,,29,False
496,2,10512,474818568,R,MLO,40.0,0,0,0,,0,,29,False
497,2,10514,1196389201,L,CC,68.0,0,0,0,,0,,48,False
498,2,10514,1887797144,L,MLO,68.0,0,0,0,,0,,48,False


In [6]:
print(f'Len df : {len(test_df)}')
test_df['patient_id'].nunique()

Len df : 500


110

In [7]:
test_df["fns"] = test_df['patient_id'].astype(str) + '/' + test_df['image_id'].astype(str) + '.dcm'

Next, we define the function for GPU-based decoding using DALI and processing the dicom images

In [8]:
def convert_dicom_to_jpg(file, save_folder=""):
    patient = file.split('/')[-2]
    image = file.split('/')[-1][:-4]
    dcmfile = pydicom.dcmread(file)

    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.90':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\x00\x00\x00\x0C")  #<---- the jpeg2000 header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
            binary_file.write(hackedbitstream)
            
    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.70':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\xff\xd8\xff\xe0")  #<---- the jpeg lossless header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
            binary_file.write(hackedbitstream)

            
@pipeline_def
def jpg_decode_pipeline(jpgfiles):
    jpegs, _ = fn.readers.file(files=jpgfiles)
    images = fn.experimental.decoders.image(jpegs, device='mixed', output_type=types.ANY_DATA, dtype=DALIDataType.UINT16)
    return images

def parse_window_element(elem):
    if type(elem)==list:
        return float(elem[0])
    if type(elem)==str:
        return float(elem)
    if type(elem)==float:
        return elem
    if type(elem)==pydicom.dataelem.DataElement:
        try:
            return float(elem[0])
        except:
            return float(elem.value)
    return None

def linear_window(data, center, width):
    lower, upper = center - width // 2, center + width // 2
    data = torch.clamp(data, min=lower, max=upper)
    return data 

def process_dicom(img, dicom):
    try:
        invert = getattr(dicom, "PhotometricInterpretation", None) == "MONOCHROME1"
    except:
        invert = False
        
    center = parse_window_element(dicom["WindowCenter"]) 
    width = parse_window_element(dicom["WindowWidth"])
        
    if (center is not None) & (width is not None):
        img = linear_window(img, center, width)

    img = (img - img.min()) / (img.max() - img.min())
    if invert:
        img = 1 - img
    return img

In [9]:
cfg = SimpleNamespace(**{})
cfg.img_size = 1024
cfg.backbone = 'seresnext50_32x4d'
cfg.pretrained=False
cfg.in_channels = 1
cfg.classes = ['cancer']
cfg.batch_size = 16  #8
cfg.data_folder = "/tmp/output/"
cfg.val_aug = A.CenterCrop(always_apply=False, p=1.0, height=cfg.img_size, width=cfg.img_size)
cfg.device = DEVICE

We will process the dicoms in chunks so the disk space does not become an issue. 

In [10]:
SAVE_SIZE = int(cfg.img_size * 1.125)
SAVE_FOLDER = cfg.data_folder
os.makedirs(SAVE_FOLDER, exist_ok=True)
N_CHUNKS = len(test_df["fns"]) // 2000 if len(test_df["fns"]) > 2000 else 1
CHUNKS = [(len(test_df["fns"]) / N_CHUNKS * k, len(test_df["fns"]) / N_CHUNKS * (k + 1)) for k in range(N_CHUNKS)]
CHUNKS = np.array(CHUNKS).astype(int)
JPG_FOLDER = "/tmp/jpg/"

In [11]:


for ttt, chunk in enumerate(CHUNKS):
    print(f'chunk {ttt} of {len(CHUNKS)} chunks')
    os.makedirs(JPG_FOLDER, exist_ok=True)

    _ = Parallel(n_jobs=2)(
        delayed(convert_dicom_to_jpg)(f'{DATA_FOLDER}/{img}', save_folder=JPG_FOLDER)
        for img in test_df["fns"].tolist()[chunk[0]: chunk[1]]
    )
    
    jpgfiles = glob.glob(JPG_FOLDER + "*.jpg")


    pipe = jpg_decode_pipeline(jpgfiles, batch_size=1, num_threads=2, device_id=0)
    pipe.build()

    for i, f in enumerate(tqdm(jpgfiles)):
        
        patient, dicom_id = f.split('/')[-1][:-4].split('_')
        dicom = pydicom.dcmread(DATA_FOLDER + f"/{patient}/{dicom_id}.dcm")
        try:
            out = pipe.run()
            # Dali -> Torch
            img = out[0][0]
            img_torch = torch.empty(img.shape(), dtype=torch.int16, device="cuda")
            feed_ndarray(img, img_torch, cuda_stream=torch.cuda.current_stream(device=0))
            img = img_torch.float()

            


            #apply dicom preprocessing
            img = process_dicom(img, dicom)

            #resize the torch image
            img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

            img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy()
            out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
            cv2.imwrite(out_file_name, img)
    
        except Exception as e:
            print(i, e)
            pipe = jpg_decode_pipeline(jpgfiles[i+1:], batch_size=1, num_threads=2, device_id=0)
            pipe.build()
            continue

    shutil.rmtree(JPG_FOLDER)
print(f'DALI Raw image load complete')

chunk 0 of 1 chunks


100%|██████████| 500/500 [01:17<00:00,  6.47it/s]


DALI Raw image load complete


In [12]:
fns = glob.glob(f'{SAVE_FOLDER}/*.png')
n_saved = len(fns)
print(f'Image on disk count : {n_saved}')

Image on disk count : 500


A few hidden test set images might not be decoded via DALI, so we fallback to CPU for those

In [13]:
gpu_processed_files = [fn.split('/')[-1].replace('_','/').replace('png','dcm') for fn in fns]
to_process = [f for f in test_df["fns"].values if f not in gpu_processed_files]
len(gpu_processed_files), len(to_process)

(500, 0)

In [14]:

def process(f, save_folder=""):
    patient = f.split('/')[-2]
    dicom_id = f.split('/')[-1][:-4]
    
    dicom = dicomsdl.open(f)
    img = dicom.pixelData()
    img = torch.from_numpy(img)
    img = process_dicom(img, dicom)
    
    img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

    img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy()
    out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
    cv2.imwrite(out_file_name, img)
    return out_file_name

In [15]:
cpu_processed_filenames = Parallel(n_jobs=2)(
    delayed(process)(f'{DATA_FOLDER}/{img}', save_folder=SAVE_FOLDER)
    for img in tqdm(to_process)
)
cpu_processed_filenames = [f for f in cpu_processed_filenames if f]
print(f'CPU Raw image load complete with {len(cpu_processed_filenames)} loaded')

0it [00:00, ?it/s]

CPU Raw image load complete with 0 loaded





In [16]:
gc.collect()
torch.cuda.empty_cache()

In [17]:
n_saved = len(glob.glob(f'{SAVE_FOLDER}/*.png'))
print(f'Image on disk count : {n_saved}')

Image on disk count : 500


In [18]:
assert n_saved == len(test_df)

We finished with preprocessing all the dicoms to images. So next, we set-up the dataloading and model

In [19]:

def batch_to_device(batch, device):
    batch_dict = {key: batch[key].to(device) for key in batch}
    return batch_dict


class CustomDataset(Dataset):
    def __init__(self, df, cfg, aug):

        self.cfg = cfg
        self.df = df.copy()
        self.df = self.df[self.df['image_id'].astype(str) != '1942326353']
        self.labels = self.df[self.cfg.classes].values
        self.df["fns"] = self.df['patient_id'].astype(str) + '_' + self.df['image_id'].astype(str) + '.png'
        self.fns = self.df["fns"].astype(str).values
        self.aug = aug
        self.data_folder = cfg.data_folder

    def __getitem__(self, idx):

        label = self.labels[idx]
        img = self.load_one(idx)

        if self.aug:
            img = self.augment(img)

        img = self.normalize_img(img)
        torch_img = torch.tensor(img).float().permute(2,0,1)
        
        feature_dict = {
            "input": torch_img,
            "target": torch.tensor(label),
        }
        return feature_dict

    def __len__(self):
        return len(self.fns)

    def load_one(self, idx):
        path = self.data_folder + self.fns[idx]
        try:
            img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
            shape = img.shape
            if len(img.shape) == 2:
                img = img[:,:,None]
        except Exception as e:
            print(e)
        return img

    def augment(self, img):
        img = img.astype(np.float32)
        transformed = self.aug(image=img)
        trans_img = transformed["image"]
        return trans_img

    def normalize_img(self, img):
        img = img / 255
        return img


In [20]:
def gem(x, p=5, eps=1e-5):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)


class GeM(nn.Module):
    def __init__(self, p=5, eps=1e-5, p_trainable=False):
        super(GeM, self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1) * p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        ret = gem(x, p=self.p, eps=self.eps)
        return ret

    def __repr__(self):
        return (self.__class__.__name__  + f"(p={self.p.data.tolist()[0]:.4f},eps={self.eps})")



class Net(nn.Module):

    def __init__(self, cfg: Any):
        super(Net, self).__init__()

        self.cfg = cfg
        self.n_classes = len(cfg.classes)
        self.backbone = timm.create_model(cfg.backbone, 
                                          pretrained=cfg.pretrained, 
                                          num_classes=0, 
                                          global_pool="", 
                                          in_chans=self.cfg.in_channels)
    
        backbone_out = self.backbone.feature_info[-1]['num_chs']

        self.global_pool = GeM(p_trainable=False)
        self.head = torch.nn.Linear(backbone_out, self.n_classes)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, batch):

        x = batch['input']

        x = self.backbone(x)
        x = self.global_pool(x)
        x = x[:,:,0,0]

        logits = self.head(x)
        
        
        outputs = {}
        
        
        if self.training:
            loss = self.loss_fn(logits,batch["target"].float())
            outputs['loss'] = loss
        else:
            outputs["logits"] = logits
        
 

        return outputs


In [21]:
def get_dl(test_df, cfg):

    test_ds = CustomDataset(test_df, cfg, cfg.val_aug)
    test_dl = DataLoader(test_ds, shuffle=False, batch_size=cfg.batch_size, num_workers=N_CORES, pin_memory=True)

    return test_dl, batch_to_device

def get_state_dict(sd_fp):
    sd = torch.load(sd_fp, map_location="cpu")['model']
    sd = {k.replace("module.", ""):v for k,v in sd.items()}
    return sd

def get_nets(cfg,state_dicts):

    nets = []

    for i,state_dict in enumerate(state_dicts):
        net = Net(cfg).eval().to(DEVICE)
        print("loading dict")
        sd = get_state_dict(state_dict)
        net.load_state_dict(sd, strict=True)
        nets += [net]
        del sd
        gc.collect()
    return nets

In [22]:
sub_dl, batch_to_device = get_dl(test_df, cfg)

In [23]:
state_dicts = sorted(glob.glob('/kaggle/input/rsna-seresnext50-5fold/check*.pth'))
print(state_dicts)

nets = get_nets(cfg,state_dicts)

['/kaggle/input/rsna-seresnext50-5fold/checkpoint_last_seed298515.pth', '/kaggle/input/rsna-seresnext50-5fold/checkpoint_last_seed334760.pth', '/kaggle/input/rsna-seresnext50-5fold/checkpoint_last_seed607282.pth', '/kaggle/input/rsna-seresnext50-5fold/checkpoint_last_seed758935.pth', '/kaggle/input/rsna-seresnext50-5fold/checkpoint_last_seed779477.pth']
loading dict
loading dict
loading dict
loading dict
loading dict


In [24]:
print(f'Dataloader length : {len(sub_dl.dataset)}')

Dataloader length : 500


In [25]:
with torch.inference_mode():

    preds = [[] for i in range(len(nets))]
    for batch in tqdm(sub_dl):
        batch = batch_to_device(batch, cfg.device)
        for i, net in enumerate(nets):
            logits = net(batch)['logits'].sigmoid().float().detach().cpu().numpy()
            preds[i] += [logits]
            
preds = np.array([np.concatenate(p, axis=0) for p in preds])
preds = preds.mean(0) #average fold predictions
preds = preds[:,0]
preds.shape

100%|██████████| 32/32 [02:40<00:00,  5.01s/it]


(500,)

In [26]:
preds.shape

(500,)

In [27]:
patient_id = sub_dl.dataset.df['patient_id'].values
laterality = sub_dl.dataset.df['laterality'].values

prediction_id = [f'{i}_{j}' for i,j in  zip(patient_id, laterality)]

pred_df = pd.DataFrame({'prediction_id': prediction_id, 'cancer_raw': preds})

#aggregate by prediction_id , i.e. by patient_laterality
sub = pred_df.groupby('prediction_id')[['cancer_raw']].agg('mean')

# binarize predictions
th = np.quantile(sub['cancer_raw'].values,0.97935)
sub['cancer'] = (sub['cancer_raw'].values > th).astype(int)

In [28]:
sub[['cancer']].to_csv('submission.csv')

for debugging purpose we can calculate the pF1 score if we infered on the train data by setting RAM_CHECK=True in the beginning

In [29]:
if RAM_CHECK:

    def pfbeta(labels, predictions, beta):
        #official implementation
        y_true_count = 0
        ctp = 0
        cfp = 0

        for idx in range(len(labels)):
            prediction = min(max(predictions[idx], 0), 1)
            if (labels[idx]):
                y_true_count += 1
                ctp += prediction
    #             cfp += 1 - prediction #bugfix
            else:
                cfp += prediction

        beta_squared = beta * beta
        c_precision = ctp / (ctp + cfp)
        c_recall = ctp / y_true_count
        if (c_precision > 0 and c_recall > 0):
            result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
            return result
        else:
            return 0

    #aggregate by prediction_id , i.e. by patient_laterality
    test_df['prediction_id'] = test_df.apply(lambda x: f'{x.patient_id}_{x.laterality}', 1)
    test_df_gr = test_df.groupby('prediction_id')[['cancer']].agg('mean')

    # Sort both the same
    test_df_gr = test_df_gr.loc[sub.index]

    y = test_df_gr['cancer'].values#.astype(np.float32)
    y_pred = sub['cancer'].values

#     print(y.shape, y_pred.shape)

    score = pfbeta(y, y_pred, 1)
    print(th, score)

0.1133195487327874 0.6666666666666665
