## Step 0. Imports and Constants

In [None]:
# install packages
!pip3 install -U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
!pip3 install python-gdcm pylibjpeg pylibjpeg-libjpeg pydicom
!pip3 install ensemble-boxes

In [None]:
# import libraries
import os
import gc
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import torch
import torchvision as tv
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data.sampler import SequentialSampler
from ensemble_boxes import *
from tqdm import tqdm
import time
import warnings
warnings.filterwarnings('ignore')

In [None]:
# specify paths and parameters
PATH_MAIN = '../input/rsna-2022-cervical-spine-fracture-detection'
PATH_IMAGES = f'{PATH_MAIN}/train_images'
PATH_FRACTURE = f'{PATH_MAIN}/train.csv'
PATH_BOXES = f'{PATH_MAIN}/train_bounding_boxes.csv'
PATH_SEGMENTATION = '../input/rsna-2022-spine-fracture-detection-metadata/train_segmented.csv'  # to be changed

IMAGE_SHAPE = (512, 512)
NUM_VERTEBRAE = 7
NUM_FOLDS = 5
SEED = 0
BATCH_SIZE = 4

In [None]:
# use PyTorch CUDA for computation
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

## Step 1. Load Data

In [None]:
def load_image(image_path):
    '''
    Load image data from given path. Return the image pixel array and the file metadata.
    
    Parameters:
        image_path: str
            Path to the dicom file of the target image.
    
    Return:
        image: numpy.ndarray
            The image pixel array with shape (H, W, C).
            Pixel values range from 0 to 1.
        metadata: pydicom.dataset.FileDataset
            The metadata of the dicom file.
    '''
    # load metadata and alter photometric interpretation of image pixels
    metadata = pydicom.dcmread(image_path)
    # extract image pixel array and rescale to [0, 255]
    image = apply_voi_lut(metadata.pixel_array, metadata)
    if image.shape != IMAGE_SHAPE:
        image = cv2.resize(image, dsize=IMAGE_SHAPE, interpolation=cv2.INTER_CUBIC)
    image = (image - np.min(image)) / (np.max(image) - np.min(image))
    image = np.stack([image] * 3, axis=-1)
    return image, metadata

In [None]:
# load an image
image, metadata = load_image(f'{PATH_IMAGES}/1.2.826.0.1.3680043.17625/150.dcm')
print(f'image shape: {image.shape}')
print(f'pixel range: [{np.min(image)}, {np.max(image)}]')
plt.imshow(image, cmap='gray')

# x, y, w, h = 214.0,113.0,150.0,136.0
# rect = patches.Rectangle(
#     (x, y), w, h, 
#     linewidth=1, edgecolor='r', facecolor='none'
# )
# plt.gca().add_patch(rect)
plt.show()

In [None]:
df_boxes = pd.read_csv(PATH_BOXES)
df_boxes.head()

In [None]:
# split the dataset into folds by Group K-Fold mechanism
# ensure that the slices from the same patient do not appear in training and test set simultaneously
np.random.seed(SEED)
group_kfold = GroupKFold(NUM_FOLDS)
folds = group_kfold.split(df_boxes, groups=df_boxes.StudyInstanceUID)
for fold, (_, test_indices) in enumerate(folds):
    df_boxes.loc[test_indices, 'Fold'] = fold
df_boxes.iloc[:,-1] = df_boxes.iloc[:,-1].astype(np.uint8)
print(f'fold indices: {sorted(df_boxes.Fold.unique())}')
df_boxes.sample(5)

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    '''
    An image dataset to extract slice images and their bounding boxes.
    '''
    def __init__(self, df, image_dir):
        '''
        Initialize the image dataset.
        
        Parameters:
            df: pandas.core.frame.DataFrame
                The dataframe containing segmentation information of each slice.
            image_dir: str
                The path to the image directory.
        '''
        super().__init__()
        self.df = df
        self.image_dir = image_dir
    
    def __len__(self):
        '''
        Length of the dataset
        
        Return:
            length: int
                Total number of slices in the dataset.
        '''
        length = len(self.df)
        return length
    
    def __getitem__(self, idx):
        '''
        Retrieve the idx-th slice of the dataset and its corresponding labels.
        
        Parameters:
            idx: int
                The index of the slice in the dataset to retrieve.
        
        Return:
            ...
        '''
        # get path to the slice image
        patient_uid = self.df.iloc[idx].StudyInstanceUID
        slice_number = self.df.iloc[idx].slice_number
        image_path = os.path.join(self.image_dir, patient_uid, f'{slice_number}.dcm')
        # load the slice image
        image, _ = load_image(image_path)  # shape (H, W, C)
        image = np.transpose(image, (2, 0, 1))  # shape (C, H, W)
        image = torch.as_tensor(image).type(torch.FloatTensor)
        
        # get bounding box position
        x, y = self.df.iloc[idx].x, self.df.iloc[idx].y 
        w, h = self.df.iloc[idx].width, self.df.iloc[idx].height
        if patient_uid == '1.2.826.0.1.3680043.22678':
            # image 22678 has shape 768x768 (need to resize the bounding box)
            ratio = 512 / 768
            x *= ratio
            y *= ratio
            w *= ratio
            h *= ratio
        boxes = [[x, y, x+w, y+h]]
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        # object labels
        labels = torch.ones((1,), dtype=torch.int64)
        
        # image unique id
        image_id = torch.tensor([idx])
        
        # bounding box area
        area = torch.tensor([w * h], dtype=torch.float32)
        
        # suppose all instances are not crowd
        iscrowd = torch.zeros((1,), dtype=torch.int64)
        
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        return image, target

## Step 2. Model Construction

In [None]:
ds_train = ImageDataset(df_boxes[df_boxes.Fold != 0], PATH_IMAGES)
ds_test = ImageDataset(df_boxes[df_boxes.Fold == 0], PATH_IMAGES)
collate_fn = lambda batch: tuple(zip(*batch))

dl_train = torch.utils.data.DataLoader(
    ds_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=os.cpu_count(),
    collate_fn=collate_fn
)

dl_test = torch.utils.data.DataLoader(
    ds_test,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=os.cpu_count(),
    collate_fn=collate_fn
)

In [None]:
def compute_iou(true, pred):
    '''
    Compute the Intersection over Union (IoU).

    Parameters:
        true: numpy.ndarray
            Coordinates of the ground-truth box in [xmin, ymin, xmax, ymax] format.
        pred: numpy.ndarray
            Coordinates of the predicted box in [xmin, ymin, xmax, ymax] format.
            
    Return:
        iou: float
            Intersection over Union.
    '''
    # Calculate overlap area
    dx = min(true[2], pred[2]) - max(true[0], pred[0]) + 1
    if dx < 0:
        return 0.0
    dy = min(true[3], pred[3]) - max(true[1], pred[1]) + 1
    if dy < 0:
        return 0.0

    intersection = dx * dy
    # Calculate union area
    union = (
        (true[2] - true[0] + 1) * (true[3] - true[1] + 1) +
        (pred[2] - pred[0] + 1) * (pred[3] - pred[1] + 1) -
        intersection
    )
    iou = intersection / union
    return iou

def find_best_match(gts, pred, pred_idx, threshold = 0.5, form = 'pascal_voc', ious=None):
    """Returns the index of the 'best match' between the
    ground-truth boxes and the prediction. The 'best match'
    is the highest IoU. (0.0 IoUs are ignored).

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        pred: (List[Union[int, float]]) Coordinates of the predicted box
        pred_idx: (int) Index of the current predicted box
        threshold: (float) Threshold
        form: (str) Format of the coordinates
        ious: (np.ndarray) len(gts) x len(preds) matrix for storing calculated ious.

    Return:
        (int) Index of the best match GT box (-1 if no match above threshold)
    """
    best_match_iou = -np.inf
    best_match_idx = -1

    for gt_idx in range(len(gts)):
        
        if gts[gt_idx][0] < 0:
            # Already matched GT-box
            continue
        
        iou = -1 if ious is None else ious[gt_idx][pred_idx]

        if iou < 0:
            iou = compute_iou(gts[gt_idx], pred)
            
            if ious is not None:
                ious[gt_idx][pred_idx] = iou

        if iou < threshold:
            continue

        if iou > best_match_iou:
            best_match_iou = iou
            best_match_idx = gt_idx

    return best_match_idx

def calculate_precision(gts, preds, threshold = 0.5, form = 'coco', ious=None):
    """Calculates precision for GT - prediction pairs at one threshold.

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        preds: (List[List[Union[int, float]]]) Coordinates of the predicted boxes,
               sorted by confidence value (descending)
        threshold: (float) Threshold
        form: (str) Format of the coordinates
        ious: (np.ndarray) len(gts) x len(preds) matrix for storing calculated ious.

    Return:
        (float) Precision
    """
    n = len(preds)
    tp = 0
    fp = 0
    
    # for pred_idx, pred in enumerate(preds_sorted):
    for pred_idx in range(n):

        best_match_gt_idx = find_best_match(gts, preds[pred_idx], pred_idx,
                                            threshold=threshold, form=form, ious=ious)

        if best_match_gt_idx >= 0:
            # True positive: The predicted box matches a gt box with an IoU above the threshold.
            tp += 1
            # Remove the matched GT box
            gts[best_match_gt_idx] = -1

        else:
            # No match
            # False positive: indicates a predicted box had no associated gt box.
            fp += 1

    # False negative: indicates a gt box had no associated predicted box.
    fn = (gts.sum(axis=1) > 0).sum()

    return tp / (tp + fp + fn)

def calculate_image_precision(gts, preds, thresholds = (0.5, ), form = 'coco'):
    """Calculates image precision.

    Args:
        gts: (List[List[Union[int, float]]]) Coordinates of the available ground-truth boxes
        preds: (List[List[Union[int, float]]]) Coordinates of the predicted boxes,
               sorted by confidence value (descending)
        thresholds: (float) Different thresholds
        form: (str) Format of the coordinates

    Return:
        (float) Precision
    """
    n_threshold = len(thresholds)
    image_precision = 0.0
    
    ious = np.ones((len(gts), len(preds))) * -1
    # ious = None

    for threshold in thresholds:
        precision_at_threshold = calculate_precision(gts.copy(), preds, threshold=threshold,
                                                     form=form, ious=ious)
        image_precision += precision_at_threshold / n_threshold

    return image_precision

In [None]:
def run_wbf(predictions, image_index, image_size=512, iou_thr=0.55, skip_box_thr=0.5, weights=None):
    boxes = [prediction[image_index]['boxes'].data.cpu().numpy()/(image_size-1) for prediction in predictions]
    scores = [prediction[image_index]['scores'].data.cpu().numpy() for prediction in predictions]
    labels = [np.ones(prediction[image_index]['scores'].shape[0]) for prediction in predictions]
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels, weights=None, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    boxes = boxes*(image_size-1)
    return boxes, scores, labels

class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [None]:
model = tv.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model = model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
lr_scheduler = None

num_epochs = 10

In [None]:
loss_hist = Averager()
best_val = None
patience = 2 # early stop patience

for epoch in range(num_epochs):
    start_time = time.time() # start timer
    loss_hist.reset() # init averager
    model.train() # train mode
    for itr, (images, targets) in enumerate(tqdm(dl_train)):
        images = [image.to(device) for image in images]
        targets = [
            {
                k: v.to(device) if k == 'labels' else v.double().to(device) 
                for k, v in t.items()
            } for t in targets
        ]
        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        loss_hist.send(loss_value)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

#         if (itr+1) % 50 == 0:
#             print(f"Iteration #{itr+1} loss: {loss_value}")
            
    # update the learning rate
    if lr_scheduler is not None:
        lr_scheduler.step()

    # At every epoch we will also calculate the validation IOU
    validation_image_precisions = []
    iou_thresholds = [x for x in np.arange(0.5, 0.76, 0.05)]
    model.eval()
    
    for images, targets in tqdm(dl_test):
        images = [image.to(device) for image in images]
        targets = [
            {
                k: v.to(device) if k == 'labels' else v.double().to(device) 
                for k, v in t.items()
            } for t in targets
        ]
        
        #outputs = model(images)
        predictions = [model(images)]
        
        for i, image in enumerate(images):
            boxes, scores, labels = run_wbf(predictions, image_index=i)
            boxes = boxes.astype(np.int32).clip(min=0, max=511)
            
            preds = boxes
            #outputs[i]['boxes'].data.cpu().numpy()
            #scores = outputs[i]['scores'].data.cpu().numpy()
            preds_sorted_idx = np.argsort(scores)[::-1]
            preds_sorted = preds[preds_sorted_idx]
            gt_boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
            image_precision = calculate_image_precision(
                preds_sorted,
                gt_boxes,
                thresholds=iou_thresholds,
                form='coco'
            )
            validation_image_precisions.append(image_precision)
    
    val_iou = np.mean(validation_image_precisions)
    print(
        f"Epoch #{epoch+1} loss: {loss_hist.value}",
        "Validation IOU: {0:.4f}".format(val_iou)
    )
    model_path = 'best_model.pth'
    if not best_val:
        best_val = val_iou  # So any validation roc_auc we have is the best one for now
        print("Saving model")
        torch.save(model, model_path)  # Saving the model
        #continue
    if val_iou >= best_val:
        print("Saving model as IOU is increased from", best_val, "to", val_iou)
        best_val = val_iou
        patience = 2  # Resetting patience since we have new best validation accuracy
        torch.save(model, model_path)  # Saving current best model torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn.pth')
    else:
        patience -= 1
        if patience == 0:
            print('Early stopping. Best Validation IOU: {:.3f}'.format(best_val))
            break