### Versions

* **v1**: First submission

# Training a fast.ai model


## Importing everything

In [1]:
!cp /kaggle/input/easy-load-the-image-with-nvjpeg2000/nvjpeg2k.so ./

In [2]:
#Installing libraries
!pip install /kaggle/input/rsna-2022-whl/pylibjpeg-1.4.0-py3-none-any.whl
!pip install /kaggle/input/rsna-2022-whl/python_gdcm-3.0.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install /kaggle/input/dicomsdl-offline-installer/dicomsdl-0.109.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
!unzip -q ../input/timm-with-dependencies/timm_all -d timm-with-dependencies
!pip install --no-index --find-links timm-with-dependencies timm

Processing /kaggle/input/rsna-2022-whl/pylibjpeg-1.4.0-py3-none-any.whl
Installing collected packages: pylibjpeg
Successfully installed pylibjpeg-1.4.0
[0mProcessing /kaggle/input/rsna-2022-whl/python_gdcm-3.0.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Installing collected packages: python-gdcm
Successfully installed python-gdcm-3.0.15
[0mProcessing /kaggle/input/dicomsdl-offline-installer/dicomsdl-0.109.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
Installing collected packages: dicomsdl
Successfully installed dicomsdl-0.109.1
[0mLooking in links: timm-with-dependencies
Processing ./timm-with-dependencies/timm-0.6.12-py3-none-any.whl
Installing collected packages: timm
Successfully installed timm-0.6.12
[0m

In [3]:
#Importing libraries
from fastai.vision.all import *
from sklearn.model_selection import StratifiedKFold
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import StratifiedGroupKFold
from fastai.metrics import ActivationType

import os, gc
import time

import numpy as np
import pandas as pd
import cv2
import pydicom
import dicomsdl as dicom
import glob
import re
import nvjpeg2k

# Progress bar library imports
from tqdm import tqdm

# Parallel processing library imports
from joblib import Parallel, delayed, cpu_count

#Prevents throwing error if image is truncated
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

## Model loading

In [4]:
#https://www.kaggle.com/competitions/rsna-breast-cancer-detection/discussion/369267  
def pfbeta_torch(preds, labels, beta=1):
    if preds.dim() != 2 or (preds.dim() == 2 and preds.shape[1] !=2): raise ValueError('Houston, we got a problem')
    preds = preds[:, 1]
    preds = preds.clip(0, 1)
    y_true_count = labels.sum()
    ctp = preds[labels==1].sum()
    cfp = preds[labels==0].sum()
    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0.0

# https://www.kaggle.com/competitions/rsna-breast-cancer-detection/discussion/369886    
def pfbeta_torch_thresh(preds, labels):
    optimized_preds = optimize_preds(preds, labels)
    return pfbeta_torch(optimized_preds, labels)

def optimize_preds(preds, labels=None, thresh=None, return_thresh=False, print_results=False):
    preds = preds.clone()
    if labels is not None: without_thresh = pfbeta_torch(preds, labels)
    
    if not thresh and labels is not None:
        threshs = np.linspace(0, 1, 101)
        f1s = [pfbeta_torch((preds > thr).float(), labels) for thr in threshs]
        idx = np.argmax(f1s)
        thresh, best_pfbeta = threshs[idx], f1s[idx]

    preds = (preds > thresh).float()

    if print_results:
        print(f'without optimization: {without_thresh}')
        pfbeta = pfbeta_torch(preds, labels)
        print(f'with optimization: {pfbeta}')
        print(f'best_thresh = {thresh}')
    if return_thresh:
        return thresh
    return preds

### Patch classifier

In [5]:
#Settings
arch = "resnet18"
path = Path("/kaggle/input/generating-patches-0-5-ratio")

#Callbacks
aug_transforms = [Flip(), Dihedral(p=1.0), Contrast(0.1), Brightness(0.1), Zoom(), Warp(magnitude=0.1)]

cbs = [EarlyStoppingCallback(min_delta=0.001, patience=4)]

#Loss functions and metrics
loss_func=LabelSmoothingCrossEntropy()
metrics = [error_rate, RocAucBinary()]

#Datablock

dls = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 get_y=lambda x: 1 if "MALIGNANT" in str(x) else 0,
                 batch_tfms=[*aug_transforms, Normalize.from_stats(*imagenet_stats)],
               ).dataloaders(path, bs=64, shuffle=True)



patch_classifier = vision_learner(dls, arch, metrics=metrics, loss_func=loss_func, pretrained=False).to_fp16()
patch_classifier.load("/kaggle/input/models/patch_finetuned")

<fastai.learner.Learner at 0x7f18bd5ac890>

### ROI classifier

In [6]:
arch =  "resnet18"
path = Path("/kaggle/input/roi-patches")

cbs = [EarlyStoppingCallback(min_delta=0.001, patience=4),
      SaveModelCallback(with_opt=True)]

loss_func = LabelSmoothingCrossEntropy()
metrics = [error_rate, RocAucBinary()]

# label_smoothing_weights = torch.tensor([1,50]).float()
# if torch.cuda.is_available():
#     label_smoothing_weights = label_smoothing_weights.cuda()

dls = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 get_y=lambda x: 1 if "MALIGNANT" in str(x) else 0,
                 batch_tfms=[*aug_transforms, Normalize.from_stats(*imagenet_stats)],
               ).dataloaders(path, bs=32, shuffle=True)

roi_classifier = vision_learner(dls, arch, metrics=metrics, loss_func=loss_func, pretrained=False).to_fp16()

roi_classifier = roi_classifier.load("/kaggle/input/models/roi_finetuned")

del dls
gc.collect()

92

## Optimization functions

## Data processing

In [7]:
#Processing functions

def _load_dicom(path: str):
    dcmfile = dicom.open(path)
    return dcmfile, dcmfile.pixelData()

def _windowing(scan, img):
    
    center = scan.WindowCenter
    width = scan.WindowWidth
    bits_stored = scan.BitsStored
    try:
        function = scan.VOILUTFunction
    except: 
        try:
            if scan.PixelIntensityRelationship == "LOG":
                function = "SIGMOID"
        except:
            function = "LINEAR"
    if type(scan.WindowWidth) in [list, pydicom.multival.MultiValue]:
        center = center[0]
        width = width[0] 
    y_range = float(2**bits_stored - 1)
    if function == 'SIGMOID':
        img = y_range / (1 + np.exp(-4 * (img - center) / width))
    else: # LINEAR
        center -= 0.5
        width -= 1
        below = img <= (center - width / 2)
        above = img > (center + width / 2)
        between = np.logical_and(~below, ~above)
        img[below] = 0
        img[above] = y_range
        img[between] = ((img[between] - center) / width + 0.5) * y_range
    return img

def _fix_photometric_inter(scan, img):

    if scan.PhotometricInterpretation == 'MONOCHROME1':
        return img.max() - img
    else:
        return img - img.min()

    return img
    
def _hist_eq(img):
    img = _convert_to_8bit(img)
    return cv2.equalizeHist(img)

def _convert_to_8bit(img):
    return (img / img.max()*255).astype(np.uint8)

def _padresize_to_width(img, size):
        
    h, w  = img.shape

    # If the width of the image is less than the desired width
    if w < size[1]:
        # Add padding to the right side of the image to reach the desired width
        img = cv2.copyMakeBorder(img, 0, 0, 0, size[1] - w, cv2.BORDER_CONSTANT, value=(0, 0, 0))

    # If the width of the image is greater than the desired width
    if w > size[1]:
        # Resize the image to the desired width
        img = cv2.resize(img, (size[1], size[0]))
        # Resize the mask if provided with interpolation set to nearest to keep pixel values
    return img

def _resize_to_height(img, size):

    h,w = img.shape
    r = h/w
    new_size = (int(size[0]/r), size[0])

    img = cv2.resize(img, new_size)

    return img

def _crop_roi(img):

    bin_img = _binarize(img)
    contour = _find_contours(bin_img)

    x1, x2 = np.min(contour[:, :, 0]), np.max(contour[:, :, 0])
    y1, y2 = np.min(contour[:, :, 1]), np.max(contour[:, :, 1])

    return img[y1:y2, x1:x2] 

def _remove_background(img):

    bin_img = _binarize(img)
    contour = _find_contours(bin_img)

    mask = np.zeros(bin_img.shape, np.uint8)
    cv2.drawContours(mask, [contour], -1, 255, cv2.FILLED)

    return img * mask

def _find_contours(bin_img):

    contours, _ = cv2.findContours(bin_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    contour = max(contours, key=cv2.contourArea)

    return contour

def _binarize(img):

    binarized = (img > (img.max()*0.05)).astype("uint8")

    return binarized

def _correct_side(img):

    col_sums_split = np.array_split(np.sum(img, axis=0), 2)
    left_col_sum = np.sum(col_sums_split[0])
    right_col_sum = np.sum(col_sums_split[1])

    if left_col_sum > right_col_sum: 
        return img
    else: 
        return np.fliplr(img)

In [8]:
def preprocess_image(path:str, size: tuple=(4096,2048), hist_eq: bool=True, 
                     window_size: tuple=(256,256), stride: tuple=(128,128), j2k=False):
    
    if j2k:
        scan = pydicom.dcmread(path)
        offset = scan.PixelData.find(b'\x00\x00\x00\x0C')
        jpeg_stream = bytearray(scan.PixelData[offset:])
        img = j2k_decoder.decode(jpeg_stream)
        
    else:
        scan, img = _load_dicom(path)

    img = _windowing(scan, img)
    img = _fix_photometric_inter(scan, img)
    img = _correct_side(img)
    img = _remove_background(img)
    img = _crop_roi(img)
    img = _resize_to_height(img, size)
    img = _padresize_to_width(img, size)
    if hist_eq:
        img = _hist_eq(img)
    else:
        img = _convert_to_8bit(img)

    name = _get_file_name(path)
    patches = _generate_patches(img, window_size, stride)

    return patches, name, img

In [9]:
def _generate_patches(img, window_size, stride):
        
    patches = []
    for y in range(0, img.shape[0]-window_size[0], stride[0]):
        for x in range(0, img.shape[1]-window_size[1], stride[1]):
            # Extract the patch from the image
            patch = img[y:y+window_size[0], x:x+window_size[1]]
            patches.append(patch)
            
    return patches      

def _get_file_name(path): 
    name = path.split("/")[-1].strip(".dcm")
    return name

In [10]:
def _get_center_of_mass(heatmap, threshold: float=0.5, num_of_centers: int=1):

#     thresh = ((heatmap>threshold)*255).astype(np.uint8)
#     contours, _ = cv2.findContours(
#         thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
#     contours = sorted(contours, key=cv2.contourArea, reverse=True)

#     ctx = []
#     if contours:
#         for c in contours[:num_of_centers]:
#             M = cv2.moments(c)
#             if M['m00'] != 0:
#                 cx = int(M['m10']/M['m00'])
#                 cy = int(M['m01']/M['m00'])
#             else:
#                 for point in c[:num_of_centers]:
#                     cx, cy = point[0]
            
#             ctx.append([cx,cy])
            
#     else: return None
    
#     ctx = np.array(ctx)
#     ctx = ctx / heatmap.shape[::-1]
        
#     return ctx

    if heatmap.max() <= threshold:
        return None
    
    max_coords = np.array(np.unravel_index(np.argmax(heatmap), heatmap.shape), dtype=int)[:2]
    return  max_coords / heatmap.shape
    
def _get_roi(center, size=512, img_shape=(4096,2048)):

    cy, cx = (center*img_shape).astype(int)

    start_x = cx - size//2
    start_y = cy - size//2
    end_x = cx + size//2
    end_y = cy + size//2

    if start_x < 0:
        end_x -= start_x
        start_x = 0
    if start_y < 0:
        end_y -= start_y
        start_y = 0
    if end_x > img_shape[1]:
        start_x -= end_x - img_shape[1]
        end_x = img_shape[1]
    if end_y > img_shape[0]:
        start_y -= end_y - img_shape[0]
        end_y = img_shape[0]

    return start_x, start_y, end_x, end_y

def _save_roi(image, name, start_x, start_y, end_x, end_y, i: int=0):
    
    roi = image[start_y:end_y, start_x:end_x]
    cv2.imwrite(f"/kaggle/tmp/output/{name}.png", roi)

def extract_and_save_roi(image, heatmap, name, threshold, num_of_centers: int=1, roi_size: int=512):
    
    img_shape = image.shape
    center = _get_center_of_mass(heatmap, threshold, num_of_centers)
    
    if center is None:
        return None
    
    if num_of_centers>1:
        for i, ct in enumerate(center):
            start_x, start_y, end_x, end_y = _get_roi(ct, roi_size, img_shape)
            _save_roi(image, name, start_x, start_y, end_x, end_y, i)
    else: 
        start_x, start_y, end_x, end_y = _get_roi(center, roi_size, img_shape)
        _save_roi(image, name, start_x, start_y, end_x, end_y)

In [11]:
def make_transfer_syntax_uid(df, dcm_dir):
    machine_id_to_transfer = {}
    machine_id = df.machine_id.unique()
    for i in machine_id:
        d = df[df.machine_id == i].iloc[0]
        f = f'{dcm_dir}/{d.patient_id}/{d.image_id}.dcm'
        dicom = pydicom.dcmread(f)
        machine_id_to_transfer[i] = dicom.file_meta.TransferSyntaxUID
    return machine_id_to_transfer

In [12]:
TEST = False

if TEST:
    dst = "train"
else: 
    dst = "test"

test_df = pd.read_csv(f"/kaggle/input/rsna-breast-cancer-detection/{dst}.csv")
dcm_dir = f"/kaggle/input/rsna-breast-cancer-detection/{dst}_images"

machine_id_to_transfer = make_transfer_syntax_uid(test_df, dcm_dir)
test_df.loc[:, 'i'] = np.arange(len(test_df))
test_df.loc[:, 'TransferSyntaxUID'] = test_df.machine_id.map(machine_id_to_transfer)

j2k_df = test_df[test_df.TransferSyntaxUID == '1.2.840.10008.1.2.4.90'].reset_index(drop=True)
non_j2k_df = test_df[test_df.TransferSyntaxUID != '1.2.840.10008.1.2.4.90'].reset_index(drop=True)

j2k_paths = j2k_df.apply(lambda d: f'{dcm_dir}/{d.patient_id}/{d.image_id}.dcm', axis=1).to_numpy()
non_j2k_paths = non_j2k_df.apply(lambda d: f'{dcm_dir}/{d.patient_id}/{d.image_id}.dcm', axis=1).to_numpy()

del j2k_df, non_j2k_df
gc.collect()

0

In [13]:
%%time
chunk_size = 32
window_size = (256, 256) 
stride = (128, 128)
image_size = (4096,2048)
threshold = 0.5
num_of_centers = 1 
roi_size = 1024
hist_eq = True

#test_images = glob.glob("/kaggle/input/rsna-breast-cancer-detection/test_images/*/*.dcm")
#chunks = [test_images[i:i+chunk_size] for i in range(0, len(test_images), chunk_size)]
#del test_images

patches_per_image = int(((image_size[0]-window_size[0]) / stride[0]) * ((image_size[1]-window_size[1]) / stride[1]))

all_patches = []
img_names = []
heatmaps = []
imgs = []

os.makedirs("/kaggle/tmp/output/", exist_ok=True)

CPU times: user 1.3 ms, sys: 64 µs, total: 1.36 ms
Wall time: 1.25 ms


In [14]:
%%time

chunks = [non_j2k_paths[i:i+chunk_size] for i in range(0, len(non_j2k_paths), chunk_size)]
#del test_images

for chunk in tqdm(chunks):
    
    patches, names, images = zip(*(Parallel(n_jobs=-1)(delayed(preprocess_image)(img_path, image_size, hist_eq, window_size, stride) for img_path in chunk)))
    
    patches = [patch for patches_part in patches for patch in patches_part]

    with patch_classifier.no_bar(): 
        prob_preds = patch_classifier.get_preds(dl=patch_classifier.dls.test_dl(patches, bs=512, device="cuda"))

    for i in range(0, len(patches), patches_per_image):

        heatmap = prob_preds[0][i:i+patches_per_image,1].view(30,-1,1)
#         heatmap = heatmap.permute(2,0,1)
#         heatmap = F.avg_pool2d(heatmap, (2,2), stride=1, padding=1)
#         heatmap = heatmap.permute(1,2,0)
        heatmaps.append(heatmap.numpy().squeeze())

    patches = []
    gc.collect()

    assert len(names) == len(heatmaps) == len(images)

    _ = Parallel(n_jobs=-1)(delayed(extract_and_save_roi)(image, heatmap, name, threshold, num_of_centers, roi_size) for heatmap, name, image in zip(heatmaps, names, images))

#     for heatmap, name, image in zip(heatmaps, img_names, images):
#         extract_and_save_roi(image, heatmap, name, threshold, num_of_centers, roi_size)

    heatmaps = []
    names = []
    images = []
    gc.collect()

0it [00:00, ?it/s]

CPU times: user 4.82 ms, sys: 4 µs, total: 4.82 ms
Wall time: 5.37 ms





In [15]:
%%time

all_patches = []
names = []
images = []
heatmaps = []

chunks = [j2k_paths[i:i+chunk_size] for i in range(0, len(j2k_paths), chunk_size)]
j2k = True
j2k_decoder = nvjpeg2k.Decoder()
#del test_images

for chunk in tqdm(chunks):
    
    for img_path in chunk:
        patches, name, image = preprocess_image(img_path, image_size, hist_eq, window_size, stride, j2k=True)
        all_patches.extend(patches)
        images.append(image)
        names.append(name)

    with patch_classifier.no_bar(): 
        prob_preds = patch_classifier.get_preds(dl=patch_classifier.dls.test_dl(all_patches, bs=512, device="cuda"))

    for i in range(0, len(all_patches), patches_per_image):

        heatmap = prob_preds[0][i:i+patches_per_image,1].view(30,-1,1)
#         heatmap = heatmap.permute(2,0,1)
#         heatmap = F.avg_pool2d(heatmap, (2,2), stride=1, padding=1)
#         heatmap = heatmap.permute(1,2,0)
        heatmaps.append(heatmap.numpy().squeeze())

    all_patches = []
    gc.collect()

    assert len(names) == len(heatmaps) == len(images)

    _ = Parallel(n_jobs=-1)(delayed(extract_and_save_roi)(image, heatmap, name, threshold, num_of_centers, roi_size) for heatmap, name, image in zip(heatmaps, names, images))

#     for heatmap, name, image in zip(heatmaps, img_names, images):
#         extract_and_save_roi(image, heatmap, name, threshold, num_of_centers, roi_size)

    heatmaps = []
    names = []
    images = []
    gc.collect()

100%|██████████| 1/1 [00:10<00:00, 10.10s/it]

CPU times: user 3.45 s, sys: 1.33 s, total: 4.78 s
Wall time: 10.1 s





## Predicting on test

In [16]:
%%time

preds_all = []
test_dl = roi_classifier.dls.test_dl(get_image_files(f'/kaggle/tmp/output/'), bs=64, device="cuda")

preds, _ = roi_classifier.get_preds(dl=test_dl)
preds_all.append(preds)

CPU times: user 92.9 ms, sys: 68.9 ms, total: 162 ms
Wall time: 314 ms


In [17]:
preds = torch.zeros_like(preds_all[0])
for pred in preds_all:
    preds += pred

# preds /= NUM_SPLITS
preds = (preds>0.8).int()

# #preds = optimize_preds(preds, thresh=threshold)
image_ids = [path.stem for path in test_dl.items]

image_id2pred = defaultdict(lambda: 0)
for image_id, pred in zip(image_ids, preds[:, 1]):
    image_id2pred[int(image_id)] = pred.item()

<a id="section-three"></a>
# Making a submission

In [18]:
test_csv = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/test.csv')

prediction_ids = []
preds = []

for _, row in test_csv.iterrows():
    prediction_ids.append(row.prediction_id)
    preds.append(image_id2pred.get(row.image_id, 0))

submission = pd.DataFrame(data={'prediction_id': prediction_ids, 'cancer': preds}).groupby('prediction_id').max().reset_index()
submission.head()

Unnamed: 0,prediction_id,cancer
0,10008_L,0
1,10008_R,0


In [19]:
submission.to_csv('submission.csv', index=False)