In [1]:
import pandas as pd
import numpy as np
import os

import pydicom
from pydicom.filebase import DicomBytesIO
from pydicom.pixel_data_handlers import apply_windowing

import torch

import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali import pipeline_def
from nvidia.dali.types import DALIDataType
from nvidia.dali.plugin.pytorch import DALIGenericIterator

import matplotlib
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = '1'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
IMG_DIR = "/home/data4/share/rsna-breast-cancer-detection/train_images"

In [3]:
df = pd.read_csv("/home/data4/share/rsna-breast-cancer-detection/train.csv")
df['dcm'] = IMG_DIR + '/' + df.patient_id.astype(str) + '/' + df.image_id.astype(str) + '.dcm'
df_j2k, df_jll = df[df["site_id"]==2].reset_index(drop=True), df[df["site_id"]==1].reset_index(drop=True)
print('j2k:', len(df_j2k), 'jll:', len(df_jll))

j2k: 25187 jll: 29519


In [30]:
class J2KIterator(object):
    def __init__(self, df, img_dir, batch_size):
        self.df = df
        self.img_dir = img_dir
        self.batch_size = batch_size

    def dicom_to_j2k(self, img_dir, patient_id, image_id):
        dcmfile = pydicom.dcmread(f'{img_dir}/{patient_id}/{image_id}.dcm')
        
        if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.90':
            with open(f'{img_dir}/{patient_id}/{image_id}.dcm', 'rb') as fp:
                raw = DicomBytesIO(fp.read())
                ds = pydicom.dcmread(raw)
            offset = ds.PixelData.find(b"\x00\x00\x00\x0C")

            return np.frombuffer(ds.PixelData[offset:], dtype=np.uint8), dcmfile.PhotometricInterpretation == 'MONOCHROME1'

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __iter__(self):
        self.i = 0
        return self

    def __next__(self):
        #if self.i > len(self.df):
        if self.i > 32:
            raise StopIteration

        compressed_imgs, is_monochrome_imgs, patient_ids, img_ids = [], [], [], []

        batch_start, batch_finish = self.i, min(self.i + self.batch_size, len(self.df))
        df_batch = self.df.iloc[range(batch_start, batch_finish)]

        for patient_id, img_id in df_batch[["patient_id", "image_id"]].values:
            compressed_img, is_monochrome_img = self.dicom_to_j2k(self.img_dir, patient_id, img_id)

            compressed_imgs.append(compressed_img)
            is_monochrome_imgs.append(np.array([is_monochrome_img], dtype = np.bool_))

            patient_ids.append(np.array([patient_id], dtype = np.int64))
            img_ids.append(np.array([img_id], dtype = np.int64))

        self.i += self.batch_size

        return compressed_imgs, is_monochrome_imgs, patient_ids, img_ids

In [31]:
class JLLIterator(object):
    def __init__(self, df, batch_size, img_dir):
        self.df = df
        self.batch_size = batch_size
        self.img_dir = img_dir
       
    @staticmethod
    def decompress_jll(dm):
        info = dm.getPixelDataInfo()
        img = np.empty((info['Rows'], info['Cols']), dtype = info['dtype'])
        dm.copyFrameData(0, img)
        return img
    
    def process_img(self, img_dir, patient_id, image_id):
        dm = dicomsdl.open(f'{img_dir}/{patient_id}/{image_id}.dcm')
        return self.decompress_jll(dm).astype(np.uint16), dm.PhotometricInterpretation == 'MONOCHROME1'
        
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __iter__(self):
        self.i = 0
        return self

    def __next__(self):
        if self.i > len(self.df):
            raise StopIteration
        
        decompressed_imgs, is_monochrome_imgs, patient_ids, img_ids = [], [], [], []

        batch_start, batch_finish = self.i, min(self.i + self.batch_size, len(self.df))
        df_batch = self.df.iloc[range(batch_start, batch_finish)]
        
        for patient_id, img_id in df_batch[['patient_id', 'image_id']].values:
            decompressed_img, is_monochrome_img = self.process_img(self.img_dir, p_id, i_id)
            
            decompressed_imgs.append(np.expand_dims(decompressed_img, axis = 2))
            is_monochrome_imgs.append(np.array([is_monochrome_img], dtype = np.bool_))
            patient_ids.append(np.array([patient_id], dtype = np.int64))
            img_ids.append(np.array([img_id], dtype = np.int64))
            labels.append(np.array([cancer], dtype = np.int64))
            
        self.i += self.batch_size
                    
        return compressed_imgs, is_monochrome_imgs, patient_ids, img_ids

In [32]:
J2Ki = J2KIterator(df_j2k, batch_size=32, img_dir="/home/data4/share/rsna-breast-cancer-detection/train_images")
JLLi = JLLIterator(df_jll, batch_size=32, img_dir="/home/data4/share/rsna-breast-cancer-detection/train_images")

In [33]:
@pipeline_def
def j2k_decode_pipeline(width, height):
    imgs, is_monochromes, patient_ids, img_ids, labels = fn.external_source(
        source=J2Ki, num_outputs=4, device="cpu", dtype = [types.UINT8, types.BOOL, types.INT64, types.INT64],
    )
    
    imgs = fn.decoders.image(
        imgs, device='mixed', output_type=types.ANY_DATA, dtype=DALIDataType.UINT8
    )

    imgs = fn.resize(imgs, size=[width, height], device="gpu")
    
    return imgs, is_monochromes, patient_ids, img_ids, labels

In [34]:
@pipeline_def
def jll_process_pipeline(width, height):
    imgs, is_monochromes, patient_ids, img_ids, labels = fn.external_source(
        source=JLLi, num_outputs=4, device="gpu", dtype = [types.UINT8, types.BOOL, types.INT64, types.INT64],
    )
    
    imgs = fn.reinterpret(imgs, layout = "HWC")
    
    imgs = fn.resize(imgs, size=[width, height], device="gpu")
    
    return imgs, is_monochromes, patient_ids, img_ids, labels

In [None]:
def process(batch, roi_model):
    processed_imgs = []
    imgs, is_monochromes = batch[0]['decompressed_imgs'], batch[0]['is_monochrome_imgs']
    patient_ids, img_ids = batch[0]['patient_ids'], batch[0]['img_ids']
    for i in range(len(imgs)):
        is_monochrome = is_monochromes.numpy().tolist()[i][0]
        patient_id = patient_ids.numpy().tolist()[i][0]
        


    return img, img_window, img_id

In [40]:
class CustomDALIGenericIterator(DALIGenericIterator):
    def __init__(self, length, pipelines, **argw):
        #self.model = yolo_model
        self._len = length 
        super().__init__(pipelines, ['imgs', 'is_monochrome_imgs', 'patient_ids', 'img_ids'])
    
    def __next__(self):
        batch = super().__next__()
        #batch_imgs = process(batch, self.model)
        return batch_imgs
    
    def __len__(self):
        return self._len

In [46]:
j2k_pipe = j2k_decode_pipeline(width=768, height=768, batch_size=32, num_threads=2, device_id=0, 
                               py_num_workers=1, exec_async=False, exec_pipelined=False)
j2k_iter = CustomDALIGenericIterator(length=len(J2Ki), pipelines=[j2k_pipe])

In [52]:
for batch in j2k_iter:
    p_imgs = batch[0]['imgs']
    p_is_monochrome_imgs = batch[0]['is_monochrome_imgs']
    p_patient_ids = batch[0]['patient_ids']
    p_img_ids = batch[0]['img_ids']
    break

In [68]:
for i in range(32):
    print(type(p_imgs.numpy().tolist()[i][0]))
    print(type(p_is_monochrome_imgs.numpy().tolist()[i][0]))
    print(type(p_patient_ids.numpy().tolist()[i][0]))
    print(type(p_img_ids.numpy().tolist()[i][0]))
    break

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.