In [None]:
DEBUG = False
INPUT = '/kaggle/input/rsna-2023-abdominal-trauma-detection'
STAGE_2_MODELS = '/kaggle/input/rsna23-tsl-lrh-efb5'
BACKBONE_2 = 'tf_efficientnet_b5_ap'

In [None]:
import os
import sys
sys.path = [
    '../input/covn3d-same',
    '../input/timm20221011/pytorch-image-models-master',
    '../input/smp20210127/segmentation_models.pytorch-master/segmentation_models.pytorch-master',
    '../input/smp20210127/pretrained-models.pytorch-master/pretrained-models.pytorch-master',
    '../input/smp20210127/EfficientNet-PyTorch-master/EfficientNet-PyTorch-master',
] + sys.path

!pip -q install ../input/pylibjpeg140py3/pylibjpeg-1.4.0-py3-none-any.whl
!pip -q install ../input/pylibjpeg140py3/python_gdcm-3.0.17.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

!cp -r ../input/timm-20220211/pytorch-image-models-master/timm ./timm4smp

In [None]:
import gc
from collections import Counter
import psutil
import shutil
import ctypes
from fastcore.all import Path

from matplotlib import animation
from IPython.display import HTML

from collections import defaultdict
import ast
import cv2
import time
import timm
import timm4smp
import pickle
import random
import pydicom
import argparse
import warnings
import threading
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from glob import glob
import albumentations
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pylab import rcParams

%matplotlib inline
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = True

timm.__version__, timm4smp.__version__

In [None]:
def get_ram_usage(threshold: float=None): 
    "Returns ram usage in GB, garbage collecting if its over `threshold`."
    process = psutil.Process(os.getpid())
    memory_usage_bytes = process.memory_info().rss
    ram_usage = memory_usage_bytes / (1000 ** 3)
    if threshold and ram_usage > threshold:
        print(f'ram usage: {ram_usage}, garbage collecting')
        libc = ctypes.CDLL("libc.so.6")
        libc.malloc_trim(0)
        gc.collect()
        print(f'new ram usage: {process.memory_info().rss / (1000 ** 3)}')
    return ram_usage

In [None]:
gru = get_ram_usage

### Investigating studies per patient

In [None]:
if DEBUG: 
    met = pd.read_parquet('/kaggle/input/rsna-2023-abdominal-trauma-detection/train_dicom_tags.parquet')
    met = met.rename(columns={'PatientID': 'patient_id'})

    x =  met.SeriesInstanceUID.nunique()/ met.patient_id.nunique()
    print(x, 'scans per patient')
    print(f'therefore, there should be around {1100 * x} scans in the test set')


In [None]:
model_dir_seg = '/kaggle/input/rsna23-train-stage1/models'
image_size_seg = (128, 128, 128)
image_size = 224
msk_size = image_size_seg[0]
image_size_cls = 224
n_slice_per_c = 15
n_ch = 5

batch_size_seg = 1
num_workers = 4

### Make dataframe of studies with dicoms

In [None]:
df = pd.read_csv('/kaggle/input/rsna23-local-test-set/df_test.csv')

# Dataset

In [None]:
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)

    return pixel_array


def study_path_to_3D_image(path, image_size_seg=(128, 128, 128), plot_image=False, z_df=None):
    t_paths = sorted(glob(os.path.join(path, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    t_paths = [p for p in t_paths if '3124/5842/514.dcm' not in p] # corrupted file
    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_size_seg[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]

    imgs = {}
    pos_zs = []
    for filename in t_paths:
        dicom = pydicom.dcmread(filename)
        pos_z = dicom[(0x20, 0x32)].value[-1]  # to retrieve the order of frames
#         dicom_numbers.append(filename.split('/')[-1].split(".")[0])
        pos_zs.append(pos_z)
        img = standardize_pixel_array(dicom)
        img = cv2.resize(img, (image_size_seg[0], image_size_seg[1]), interpolation = cv2.INTER_AREA)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            img = 1 - img
        imgs[pos_z] = img
        
#     print(dicom_numbers, pos_zs)
    if len(pos_zs) > 1: 
        pos_z_ascending = pos_zs[-1] > pos_zs[0]
    else: 
        pos_z_ascending = True
    if z_df is not None: 
        study = path.split('/')[-1]
        z_df.loc[study, 'pos_z_ascending'] = pos_zs[-1] > pos_zs[0]
        
    images = []
    cnt = Counter(pos_zs)
    for i, k in enumerate(sorted(imgs.keys())):
        img = imgs[k]
#         images.append(img)
        images.extend([img] * cnt[k]) # to make sure we have same dimensions
        if not (i % 100) and plot_image:
            plt.figure(figsize=(5, 5))
            plt.imshow(img, cmap="gray")
            plt.title(f"Patient {patient} - Study {study} - Frame {i}/{len(imgs)}")
            plt.axis(False)
            plt.show()
    images = np.stack(images, -1)
    
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)
    return images

In [None]:
ORGANS = ['liver']
LABELS = [['liver_healthy', 'liver_low', 'liver_high']]
ABS = [[0, 15]]
N_ORGANS = len(ORGANS)

In [None]:
def load_sample(row, has_mask=True):
    image = study_path_to_3D_image(row.image_folder)
    if image.ndim < 4:
        image = np.expand_dims(image, 0) # to 3ch
    return image.repeat(3, 0) 


class SegTestDataset(Dataset):

    def __init__(self, df):
        self.df = df.reset_index()

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]

        image = load_sample(row, has_mask=False)
        image = image / 255.
#         gc.collect()
        return torch.tensor(image).float()


In [None]:
dataset_seg = SegTestDataset(df)
# display(df.head())
loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, shuffle=False, num_workers=num_workers)

In [None]:
print('Getting the raw images and the processed images which are ready \n\
to input into the segmentation model')
idx = 0
processed_images = dataset_seg[idx]
path = df.iloc[idx].image_folder
t_paths = sorted(glob(os.path.join(path, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
raw_imgs = [pydicom.dcmread(f).pixel_array for f in t_paths]

In [None]:
# Original animation code from https://www.kaggle.com/code/franklinshih0617/rsna-abdominal-trauma-detect-eda-animation
def animate_images(images, figsize=(6, 6), cmap=None):
    "Run HTML({returned_object}) to see animation"
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(images[0], cmap=cmap) 
    def update(i): 
        im.set_array(images[i])
    ani = animation.FuncAnimation(fig, update, frames=range(len(images)), repeat=True)
    return ani.to_jshtml()

In [None]:
HTML(animate_images(raw_imgs[::2], cmap='gray'))

In [None]:
x = (processed_images[0].permute(2, 0, 1).numpy() * 255).astype(int)
print(x.min(), x.max())
x = list(x)
HTML(animate_images(x, cmap='gray'))

### See the raw input data: a slideshow of the ct scan
   

# Model

In [None]:
drop_rate = 0.
drop_path_rate = 0.
out_dim = 5
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output


# backbone = 'resnet18d'
# n_blocks = 4
# model = TimmSegModel(backbone)
# model = convert_3d(model)
# model(torch.rand(1, 3, 128,128,128)).shape
    

In [None]:
    
class TimmModel(nn.Module):
    def __init__(self, backbone, pretrained=False, out_dim=3, h=image_size, w=image_size):
        super(TimmModel, self).__init__()
        self.h = h
        self.w = w

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=out_dim,
            features_only=False,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()


        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(256, out_dim), # chacnged
        )

    def forward(self, x):  # (bs, nslice, ch, sz, sz)
        bs = x.shape[0]
        x = x.view(bs * n_slice_per_c, in_chans, self.h, self.w)
        feat = self.encoder(x)
        feat = feat.view(bs, n_slice_per_c, -1)
        feat, _ = self.lstm(feat)
        feat = feat.contiguous().view(bs * n_slice_per_c, -1)
        feat = self.head(feat)
        feat = feat.view(bs, n_slice_per_c, -1).contiguous()

        return feat

# Load Models

In [None]:
models_seg = []

kernel_type = 'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
backbone = 'resnet18d'
n_blocks = 4
for fold in range(1):
    model = TimmSegModel(backbone, pretrained=False)
    model = convert_3d(model)
    load_model_file = '/kaggle/input/rsna23-train-stage1/models/timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep_fold0_best.pth'
    sd = torch.load(load_model_file, map_location=torch.device('cpu'))
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model = model.to(device)
    model.eval()
    models_seg.append(model)
len(models_seg)

In [None]:
kernel_type = '0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep'
backbone = BACKBONE_2
model_dir_cls = F'{STAGE_2_MODELS}/models'
in_chans = 6
drop_rate_last = 0.3
models = [TimmModel(backbone, pretrained=False), 
#              TimmModel(backbone, pretrained=False), 
#              TimmModel(backbone, pretrained=False, h=image_size*2), 
#              TimmModel(backbone, pretrained=False, out_dim=2), 
         ]

for model, organ in zip(models, ORGANS):
    print(organ)
    load_model_file = f'{STAGE_2_MODELS}/models/{organ}_0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep_fold0_last.pth'
    sd = torch.load(load_model_file, map_location=torch.device('cpu'))
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model = model.to(device)
    model.eval()
len(models)

In [None]:
# timm.list_models()

In [None]:
!nvidia-smi

In [None]:
def load_bone(msk, cid, t_paths, cropped_images):
    n_scans = len(t_paths)
    bone = []
    try:
        msk_b = msk[cid] > 0.2
        msk_c = msk[cid] > 0.05

        x = np.where(msk_b.sum(1).sum(1) > 0)[0]
        y = np.where(msk_b.sum(0).sum(1) > 0)[0]
        z = np.where(msk_b.sum(0).sum(0) > 0)[0]

        if len(x) == 0 or len(y) == 0 or len(z) == 0:
            x = np.where(msk_c.sum(1).sum(1) > 0)[0]
            y = np.where(msk_c.sum(0).sum(1) > 0)[0]
            z = np.where(msk_c.sum(0).sum(0) > 0)[0]

        x1, x2 = max(0, x[0] - 1), min(msk.shape[1], x[-1] + 1)
        y1, y2 = max(0, y[0] - 1), min(msk.shape[2], y[-1] + 1)
        z1, z2 = max(0, z[0] - 1), min(msk.shape[3], z[-1] + 1)
        zz1, zz2 = int(z1 / msk_size * n_scans), int(z2 / msk_size * n_scans)

        inds = np.linspace(zz1 ,zz2-1 ,n_slice_per_c).astype(int)
        inds_ = np.linspace(z1 ,z2-1 ,n_slice_per_c).astype(int)
        for sid, (ind, ind_) in enumerate(zip(inds, inds_)):

            msk_this = msk[cid, :, :, ind_]

            images = []
            for i in range(-n_ch//2+1, n_ch//2+1):
                try:
                    dicom = pydicom.read_file(t_paths[ind+i])
                    img = standardize_pixel_array(dicom)
                    if dicom.PhotometricInterpretation == "MONOCHROME1":
                        img = 1 - img
                    images.append(img)
                except:
                    images.append(np.zeros((512, 512)))

            data = np.stack(images, -1)
            data = data - np.min(data)
            data = data / (np.max(data) + 1e-4)
            data = (data * 255).astype(np.uint8)
            msk_this = msk_this[x1:x2, y1:y2]
            xx1 = int(x1 / msk_size * data.shape[0])
            xx2 = int(x2 / msk_size * data.shape[0])
            yy1 = int(y1 / msk_size * data.shape[1])
            yy2 = int(y2 / msk_size * data.shape[1])
            data = data[xx1:xx2, yy1:yy2]
            data = np.stack([cv2.resize(data[:, :, i], (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR) for i in range(n_ch)], -1)
            msk_this = (msk_this * 255).astype(np.uint8)
            msk_this = cv2.resize(msk_this, (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR)

            data = np.concatenate([data, msk_this[:, :, np.newaxis]], -1)

            bone.append(torch.tensor(data))
#             gc.collect()

    except:
        for sid in range(n_slice_per_c):
            bone.append(torch.ones((image_size_cls, image_size_cls, n_ch+1)).int())

    cropped_images[cid] = torch.stack(bone, 0)


def load_cropped_images(msk, image_folder, n_ch=n_ch):

    pos_z = []
    t_paths = sorted(glob(os.path.join(image_folder, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    for filename in t_paths[:2]:
        dicom = pydicom.dcmread(filename)
        pos_z.append(dicom[(0x20, 0x32)].value[-1])  # to retrieve the order of frames
    if len(pos_z) > 1: z_ascending = pos_z[1] > pos_z[0] 
    else: z_ascending = True
    if not z_ascending: t_paths.reverse()
    for cid in range(5):
        threads[cid] = threading.Thread(target=load_bone, args=(msk, cid, t_paths, cropped_images))
        threads[cid].start()
    for cid in range(5):
        threads[cid].join()
#     gc.collect()

    return torch.cat(cropped_images, 0)


# Predict

In [None]:
# # For exceptions: 
# train = pd.read_csv('/kaggle/input/rsna-2023-abdominal-trauma-detection/train.csv')
# tar_means = train.mean()
# n_exceptions = 0

In [None]:
dataset_seg = SegTestDataset(df)
display(df.head())
loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, 
                                         shuffle=False, num_workers=num_workers)

In [None]:
%%time
preds = defaultdict(list)
bar = tqdm(loader_seg)
with torch.no_grad():
    for batch_id, (images) in enumerate(bar):
        try:
            pred_masks = []
            images = images.to(device)
            pred_masks = []
            for model in models_seg:
                pmask = model(images).sigmoid()
                pred_masks.append(pmask)
            pred_masks = torch.stack(pred_masks, 0).mean(0).cpu().numpy()

            # Build cls input
            cls_inp = []
            threads = [None] * 5
            cropped_images = [None] * 5
            for i in range(pred_masks.shape[0]):
                row = df.iloc[batch_id*batch_size_seg+i]
                cropped_images = load_cropped_images(pred_masks[i], row.image_folder)
                cls_inp.append(cropped_images.permute(0, 3, 1, 2).float() / 255.)
            cls_inp = torch.stack(cls_inp, 0).to(device)  # (1, 75, 6, 224, 224)

            image_full = cls_inp[0]
            out = defaultdict(dict)
            for organ, cols, (a, b) in zip(ORGANS, LABELS, ABS):
                images = []
                for image in image_full[a: b]: 
                    images.append(image.cpu())
                images = np.stack(images, 0)
                if organ == 'kidney': 
                    images = np.concatenate((images[:15, :, :, :], images[15:, :, :, :]), 2)
                out[organ]['images'] = torch.tensor(images).float().to(device)

            for organ, model, label_cols in zip(ORGANS, models, LABELS): 
                try: 
                    logits = model(out[organ]['images'].unsqueeze(0)).squeeze()
                    data = logits.sigmoid()
                    min_first_column = torch.min(data[:, 0])
                    max_last_two_columns = torch.max(data[:, 1:], dim=0).values
                    result = torch.cat((min_first_column.unsqueeze(0), max_last_two_columns), dim=0)
#                     print('***********', result.shape) ###################
                    preds[organ].append(result.cpu())
                except: 
                    print('XXXXXXXX')
                    preds[organ].append(torch.tensor(tar_means[label_cols]))
        except: 
            print('problem in loop')
            n_exceptions += 1
            for organ, label_cols in zip(ORGANS, LABELS):
                preds[organ].append(torch.tensor(tar_means[label_cols]))
        get_ram_usage(1)
        get_ram_usage(1)
        if batch_id % 100 == 0:
            get_ram_usage(1)
            time.sleep(1)

In [None]:
pred_cols = ['pred_liver_healthy', 'pred_liver_low', 'pred_liver_high']
df.loc[:, pred_cols] = torch.stack(preds['liver']).numpy()

In [None]:
df.to_csv('df_with_preds.csv', index=False)