In [1]:
!pip3 install pylibjpeg pylibjpeg-libjpeg pydicom python-gdcm

Collecting pylibjpeg
  Downloading pylibjpeg-1.4.0-py3-none-any.whl (28 kB)
Collecting pylibjpeg-libjpeg
  Downloading pylibjpeg_libjpeg-1.3.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
Collecting python-gdcm
  Downloading python_gdcm-3.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m62.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: python-gdcm, pylibjpeg-libjpeg, pylibjpeg
Successfully installed pylibjpeg-1.4.0 pylibjpeg-libjpeg-1.3.4 python-gdcm-3.0.21
[0m

In [2]:
import os, cv2 as cv
import json
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import albumentations as album

from scipy.sparse import csr_matrix, save_npz, load_npz

from pathlib import Path
import pydicom as dicom
import nibabel as nib
from time import time
from sklearn.model_selection import train_test_split
import wandb

In [3]:
!pip install -q -U segmentation-models-pytorch albumentations > /dev/null
import segmentation_models_pytorch as smp

[0m

In [4]:
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [5]:
dicom_path = Path('/kaggle/input/rsna-2022-cervical-spine-fracture-detection/train_images/')
segm_path = Path('/kaggle/input/rsna-2022-cervical-spine-fracture-detection/segmentations/')
bounds_path = Path('/kaggle/input/ct-lowhigh-bounds/ct_lowhigh_bounds.csv')
model_path = Path('/kaggle/input/effnetv2s-slicet-and-num')
checkpoint_path = Path('./checkpoints/')

In [6]:
ENCODER = 'tu-tf_efficientnetv2_m'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'others']
ACTIVATION = 'softmax2d' # could be None for logits or 'softmax2d' for multiclass segmentation

# create segmentation model with pretrained encoder
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_m-cc09e0cd.pth


In [7]:
model.to('cuda')
weights = torch.load('/kaggle/input/best-segm-model/best_model.pth')#, map_location=torch.device('cpu'))
model = nn.DataParallel(model)
model.load_state_dict(weights.state_dict())
model.eval()

for param in model.parameters():
    param.requires_grad = False

In [8]:
BATCH_SIZE = 16

In [9]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callable): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    """   
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor))
        
    return album.Compose(_transform)

In [10]:
wront_resolution_list = []
for patient in tqdm(dicom_path.iterdir()):
    try:
        md = dicom.dcmread(patient / '10.dcm')
    except:
        print(patient)
        continue
    rows, columns = md.Rows, md.Columns
    if rows != 512 or columns != 512:
        wront_resolution_list.append(patient)

387it [00:05, 77.12it/s]

/kaggle/input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.17166


2019it [00:28, 69.75it/s]


In [11]:
wront_resolution_list = [str(wrong_uid).split('/')[-1] for wrong_uid in wront_resolution_list]

In [12]:
wront_resolution_list.append('1.2.826.0.1.3680043.17166')
wront_resolution_list

['1.2.826.0.1.3680043.22678',
 '1.2.826.0.1.3680043.23400',
 '1.2.826.0.1.3680043.8858',
 '1.2.826.0.1.3680043.17166']

In [13]:
bounds_df = pd.read_csv('/kaggle/input/ct-lowhigh-bounds/ct_lowhigh_bounds.csv', index_col=0)
bounds_df.head()

Unnamed: 0,StudyInstanceUID,TotalSlices,C1_lb,C1_hb,C2_lb,C2_hb,C3_lb,C3_hb,C4_lb,C4_hb,C5_lb,C5_hb,C6_lb,C6_hb,C7_lb,C7_hb
0,1.2.826.0.1.3680043.17625,239,30,96,35,122,72,137,90,158,109,180,123,198,143,220
1,1.2.826.0.1.3680043.3850,688,49,235,66,314,165,351,194,405,238,465,284,513,321,578
2,1.2.826.0.1.3680043.2286,333,34,125,41,162,92,182,112,210,136,240,157,264,180,295
3,1.2.826.0.1.3680043.14435,618,46,213,61,284,151,318,178,367,218,421,259,464,293,522
4,1.2.826.0.1.3680043.3442,336,34,126,41,164,93,183,113,212,137,242,158,266,181,297


In [14]:
bounds_df = bounds_df.iloc[890:]

In [15]:
segm_uids = [uid.replace('.nii', '') for uid in os.listdir(segm_path)]

In [16]:
def clear_mask(mask, patient_record, vertebra_number):
    labels, counts = np.unique(mask, return_counts=True)
    labels, counts = labels[1:], counts[1:]
    
    match_dict = {}
    
    for label, count in zip(labels, counts):
        if count < 50:
            mask[mask == label] = 0
            continue
        
        if label > 0 and label < 8:
            lb = patient_record[f'C{label}_lb']
            hb = patient_record[f'C{label}_hb']
        else:
            continue
        
        new_label = label
        
        while lb > vertebra_number or hb < vertebra_number:
            if lb < vertebra_number:
                new_label -= 1
            elif hb < vertebra_number:
                new_label += 1
            else:
                break
            
            if new_label > 0 and new_label < 8:
                lb = patient_record[f'C{new_label}_lb']
                hb = patient_record[f'C{new_label}_hb']
            else:
                break
        
        if new_label != label:
            mask[mask == label] == new_label
    
    return mask

def get_slice_n_from_path(path):
    return int(str(path).split('/')[-1].split('.')[-2])

def load_dicom(path):
    # Source: https://www.kaggle.com/code/vslaykovsky/pytorch-effnetv2-vertebrae-detection-acc-0-95
    img=dicom.dcmread(path)
    img.PhotometricInterpretation = 'YBR_FULL'
    data = img.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data=(data * 255).astype(np.uint8)
    return cv.cvtColor(data, cv.COLOR_GRAY2RGB), img

def save_matrix(mask, save_path):
    sparse_matrix = csr_matrix(mask.reshape(-1, mask.shape[-1]))
    save_npz(save_path, sparse_matrix)

try:
    os.mkdir('/kaggle/working/dataset')
except:
    pass

json_dict_name = '/kaggle/working/dataset/metadata.json'
json_dict = {}

In [17]:
BATCH_SIZE = 32
preprocess = get_preprocessing(preprocessing_fn)

model_time = 0

start_time = time()
for _, patient_record in tqdm(bounds_df.iterrows()):
    patient = patient_record[0]
    
    if patient in wront_resolution_list or patient in segm_uids:
        continue
        
    patient_path = dicom_path / patient
    
    list_slices = sorted([get_slice_n_from_path(file) for file in os.listdir(patient_path)])
    num_slices = len(list_slices)
    
    mask = np.zeros((512, 512, num_slices))
    batch_counter = 0
    
    while len(list_slices):
        slices = list_slices[:BATCH_SIZE]
        list_slices = list_slices[BATCH_SIZE:]
        imgs = np.zeros((len(slices), 3, 512, 512))
        for ind, slice_ in enumerate(slices):
            img, _ = load_dicom(patient_path / (str(slice_) + '.dcm'))
            img = preprocess(image = img)['image']
            imgs[ind, :, :, :] = img
        
        imgs = torch.FloatTensor(imgs).to('cuda')
        
        model_start_time = time()
        with torch.cuda.amp.autocast():
            predictions = model(imgs)
        predictions = np.argmax(predictions.cpu(), axis=1)
        mask[:, :, (batch_counter * BATCH_SIZE) : ((batch_counter * BATCH_SIZE) + len(slices))] = \
            np.transpose(predictions, axes = [1, 2, 0])
        model_time += time() - model_start_time
        batch_counter += 1
        
    mask = mask.astype(np.uint8)
    for i in range(mask.shape[2]):
        mask[:, :, i] = clear_mask(mask[:, :, i], patient_record, i + 1)
        
    save_matrix(mask, f'/kaggle/working/dataset/{patient}.npz')
    json_dict[patient] = {}
    
    for i in range(mask.shape[2]):
        vertebrae = np.unique(mask[:, :, i])[1:].tolist()
        for vertebra in vertebrae:
            if vertebra in json_dict[patient]:
                json_dict[patient][vertebra].append(i)
            else:
                json_dict[patient][vertebra] = []
                json_dict[patient][vertebra].append(i)

with open(json_dict_name, 'w') as json_file:
    json.dump(json_dict, json_file)
    
print(time() - start_time)

1129it [9:29:24, 30.26s/it]


34165.02242875099


In [18]:
import zipfile
    
def zipdir(path, ziph):
    for root, dirs, files in os.walk(path):
        for file in files:
            ziph.write(os.path.join(root, file), 
                       os.path.relpath(os.path.join(root, file), 
                                       os.path.join(path, '..')))

with zipfile.ZipFile('/kaggle/working/dataset.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipdir('/kaggle/working/dataset', zipf)

In [19]:
# mask_orig = nib.load('/kaggle/input/rsna-2022-cervical-spine-fracture-detection/segmentations/1.2.826.0.1.3680043.5783.nii')
# mask_orig_data = mask_orig.get_fdata()[:, :, ::-1]
# mask_orig_data.shape
# mask_orig_data[mask_orig_data > 7] = 8
# mask_orig_data = mask_orig_data.astype(np.uint8)

In [20]:
# for i in range(mask_orig_data.shape[2]):
#     a = np.unique(mask_orig_data[:, :, i])
#     b = np.unique(mask[:, :, i])
#     if not np.array_equal(a, b):
#         diff_r = np.setdiff1d(b, a, assume_unique=True)
#         diff_l = np.setdiff1d(a, b, assume_unique=True)
#         error_r = 0
#         error_l = 0
#         for val_r in diff_r:
#             error_r += np.count_nonzero(mask[:, :, i] == val_r)
        
#         for val_l in diff_l:
#             error_l += np.count_nonzero(mask_orig_data[:, :, i] == val_l)
            
#         print(i, diff_r, diff_l, error_r, error_l)

In [21]:
# for i in range(mask_orig_data.shape[2]):
#     a = np.unique(mask_orig_data[:, :, i])
#     b = np.unique(mask[:, :, i])
#     if np.any(a != b):
#         print(i+1, ' : ', a, b)