**About** : This notebook is used to train models.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

In [None]:
import os
import torch

print(torch.__version__)
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.cuda.get_device_name(0)
print(device)

## Initialization

### Imports

In [None]:
import os
import re
import cv2
import sys
import glob
import yaml
import shutil
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from tqdm import tqdm

pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', 30)
pd.set_option('max_colwidth', 100)

In [None]:
from params import *
from util.plots import *
from data.preparation import *
from util.metrics import compute_metrics
from inference.det import *
from util.wbf import fusion, iou
from util.boxes import Boxes, expand_boxes

## Data

In [None]:
# df_patient, df_img = prepare_data(DATA_PATH)

# df = pd.read_csv('../input/active_extravasation_bounding_boxes.csv')
# df = df.rename(columns={"pid": "patient_id", "series_id": "series", "instance_number": "instance"})
# df = df.merge(df_img)

In [None]:
# plt.figure(figsize=(20, 5))

# for i in range(1, 5):
#     plt.subplot(1, 4, i)

#     idx = np.random.choice(len(df))
#     img = cv2.imread(df['path'].values[idx])
#     boxes = df[["x1", "y1", "x2", "y2"]].values[idx]

#     plot_boxes(img, boxes, bbox_format="pascal_voc")

# plt.show()

## Inference

In [None]:
class Config:
    selected_model = "yolo"
    bbox_format = "yolo"
    pred_format = "pascal_voc"

    fold = 1
#     fold = "fullfit"
    version = "v2"
    exp = 2
    
    if fold != "fullfit":
        name = f"rsna_{version}_fold{fold}_{exp}"
    else:
        name = f"rsna_{version}_fullfit_{exp}"

    data_dir = f"../input/yolo/v1/{fold}_train/"
    cfg = f"../yolox/exps/{name}.py"
#     ckpt = f"../yolox/YOLOX_outputs/{name}/best_ckpt.pth"
    ckpt = f"../yolox/YOLOX_outputs/{name}/last_epoch_ckpt.pth"
#     ckpt = f"../yolox/YOLOX_outputs/{name}/epoch_5_ckpt.pth"
    labels = ["extravasation"]

    size = (384, 384)

    # NMS
    conf_thresh = 0.1
    iou_thresh = 0.5
    max_per_img = 1

    num_workers = 8
    val_bs = 64
    device = "cuda"

In [None]:
model_marker = retrieve_yolox_model(Config.cfg, Config.ckpt, size=Config.size)
model_marker = YoloXWrapper(model_marker, Config)

In [None]:
df = pd.DataFrame({"path": glob.glob(Config.data_dir + "images/valid/*")})
df['gt_path'] = df['path'].apply(lambda x: re.sub("images", "labels", x))
df['gt_path'] = df['gt_path'].apply(lambda x: re.sub(".png", ".txt", x))

df['patient_id'] = df['path'].apply(lambda x: x.split('/')[-1].split('_')[0])
df['series'] = df['path'].apply(lambda x: x.split('_')[-2])
df['frame'] = df['path'].apply(lambda x: x.split('_')[-1][:-4])

df = df.sort_values(['patient_id', 'series', 'frame'], ignore_index=True)
# df = df.head(100)
df.head()

In [None]:
print('- Predict')
transforms = get_transfos(size=Config.size)
dataset = InferenceDataset(df, transforms)
meter = predict(model_marker, dataset, Config, disable_tqdm=False)

print('\n- Update shapes')
dataset = InferenceDataset(df, None)
for i in range(len(dataset)):
    shape = dataset[i][2]
    meter.preds[i].update_shape(shape)

In [None]:
PLOT = False

In [None]:
print('- Evaluate')

recalls = {}
for idx in range(len(dataset)):
    img, gt, shape = dataset[idx] 

    gt = Boxes(gt, (shape[0], shape[1]), bbox_format="yolo")

    pred = Boxes(meter.preds[idx]['pascal_voc'][meter.labels[idx] == 0], (shape[0], shape[1]), bbox_format="pascal_voc")
    scores = meter.confidences[idx][meter.labels[idx] == 0]

    for r in [1, 3, 5]:
        metrics = compute_metrics(pred['pascal_voc'][:r], gt['pascal_voc'])
        try:
            recalls[Config.labels[0] + f"@{r}"].append(metrics['recall'])
        except:
            recalls[Config.labels[0] + f"@{r}"] = [metrics['recall']]

    if PLOT or not (idx % 500):
#     if metrics['recall'] == 0:
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plot_boxes(img, gt['pascal_voc'], "pascal_voc")
        plt.title(f'{idx} - Truth - {df.patient_id[idx], df.series[idx], df.frame[idx]}')

        plt.subplot(1, 2, 2)
        plot_boxes(img, pred['pascal_voc'], "pascal_voc")
        plt.title(f'Pred - conf={scores[0]:.3f}')
        plt.show()
        
#     if df.patient_id[idx] != "10217":
#         break
#     if PLOT and idx > 20:
#         break

print('\n')
for k, v in recalls.items():
    print(f'Recall {k}: {np.mean(recalls[k]):.3f}')
#         break
#     break

## Fusion & Crops

In [None]:
def restrict_imgs(img_paths, max_len=600, margin=10):
    n_imgs = len(img_paths)
    
    if n_imgs > 400:
        img_paths = img_paths[n_imgs // 6 - margin:]
    else:
        img_paths = img_paths[max(n_imgs // 8 - margin, 0):]
            
    img_paths = img_paths[- max_len - margin :]
    
    return img_paths, n_imgs

In [None]:
df_patient, df_img = prepare_data(DATA_PATH)

if "fold" not in df_patient.columns:
    folds = pd.read_csv(DATA_PATH + "folds_4.csv")
    df_img = df_img.merge(folds)
    df_patient = df_patient.merge(folds)
    
df_img = df_img[["patient_id", "series", "instance", "frame", "path", "extravasation_injury", "fold"]]

df_img = df_img[df_img["fold"] == Config.fold]
df_img['pred_extravasation'] = (
    0.5 * np.load(f"../logs/2023-09-20/36_r/pred_val_{Config.fold}.npy")[:, 1] +  0.5 * 
    np.load(f"../logs/2023-10-05/13/pred_val_{Config.fold}.npy")[:, 1]
)

df_img_max = df_img[['patient_id', 'series', 'extravasation_injury', 'pred_extravasation']].groupby(['patient_id', 'series']).max().reset_index()
df_img = df_img.merge(df_img_max, on=['patient_id', 'series'], suffixes=("", "_agg"))

df_img_cum = df_img.copy()
df_img_cum['pred_extravasation'] = (df_img_cum['pred_extravasation'] > 0.2).astype(int)
df_img_cum = df_img_cum[['patient_id', 'series', 'extravasation_injury', 'pred_extravasation']].groupby(['patient_id', 'series']).sum().reset_index()
df_img = df_img.merge(df_img_cum, on=['patient_id', 'series'], suffixes=("", "_cum"))

In [None]:
df_img_f = df_img[(df_img['pred_extravasation_agg'] > 0.2) & (df_img['pred_extravasation_cum'] > 3)]
len(df_img_f[df_img_f['extravasation_injury'] == 1].series.unique()), len(df_img_f.series.unique())

In [None]:
df[['patient_id', 'series', "frame"]] = df[['patient_id', 'series', "frame"]].astype(int)
df_img = df_img.merge(df, on=['patient_id', 'series', "frame"], how="left", suffixes=('', '_yolo'))

# df_img.dropna(inplace=True)
df_img['gt_path'].fillna('', inplace=True)

df_img = df_img.reset_index(drop=True)

In [None]:
df_img_g = df_img_f[["series", "extravasation_injury", "pred_extravasation"]].groupby('series').max()

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}")
    config_df = save_config(Config, log_folder + "config.json")
    create_logger(directory=log_folder, name="logs.txt")

preds, preds_aux = k_fold(Config, df_patient, df_img, log_folder=log_folder, run=None)

In [None]:
roc_auc_score(df_img_g['extravasation_injury'], df_img_g['pred_extravasation'])

In [None]:
# errors = [
# #     48517, 17447,
#     3750, 24136, 39205, 29661, 53257, 3532
# ]  # No crop
# errors_idk = [4654, 14668, 20619, 60792, 16080, 54917, 16080, 27089, 58540, 15610, 15909, 12151]  # Wrong crop

# # removed = [29661, 53257, 63618, 20619, 60961, 63205, 48977, 58540, 48517, 41840]
# # removed = [29661, 53257, 63618, 20619, 60961, 63205, 48977, 58540, 48517, 41840]

# removed = [24136, 39205, 29661, 5104, 58540, 39222, 48517, 15610, 15786, 41840]
# [True, False, False, False, True, False, True, False, True, False]
# # [True, False, False, False, False, False, False, True, False, False]

# # [True, False, False, False, False, False, False, True, False, False]
# # [29661, 53257, 63618, 20619, 60961, 63205, 48977, 58540, 48517, 41840]

In [None]:
was_hit = np.copy(np.array(hits))

In [None]:
CROP_SIZE = 96

PLOT = False

SAVE_FOLDER = "../input/crops_extrav/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

SAVE = True

TH_HITS = 0.2
MIN_HITS = 3
MIN_PRED = 0.2

In [None]:
(crops[:, :, :, 0] == crops[:, :, :, 1]).all()

In [None]:
hits = []
ct = -1
paths, labels = [], []
for study_idx, ((patient_id, series), dfg) in tqdm(enumerate(df_img.groupby(['patient_id', 'series'])), total=len(df_img['series'].unique())):
#     if patient_id != 386:
#         continue
#     if series not in errors + errors_idk:
#         continue
#     if series != 24136:
#         continue

    dfg = dfg.reset_index(drop=True)
    has_extrav_ = dfg.extravasation_injury.max()

    dfg = dfg.iloc[restrict_imgs(np.arange(len(dfg)))[0]].reset_index(drop=True)
    has_extrav = dfg.extravasation_injury.max()

#     if not has_extrav:
#         continue
        
    print(f'\n- {patient_id, series} - Has extrav : {has_extrav}')

    if has_extrav_ and not has_extrav:
        print('Warning, removed extrav when restricting')
        
    if dfg['pred_extravasation'].max() < MIN_PRED:
        if not has_extrav:
            print('- Skip low conf')
            continue
        else:
            print(f'- Warning, MIN_PRED={MIN_PRED} would skip extrav as maximum conf is {dfg.pred_extravasation.max()}')

    if (dfg['pred_extravasation'] > TH_HITS).sum() < MIN_HITS:
        if not has_extrav:
            print('- Skip low hits')
            continue
        else:
            print(f'- Warning, MIN_HITS={MIN_HITS} would skip extrav as hits@{TH_HITS} is {(dfg.pred_extravasation > TH_HITS).sum()}', end="")
            print(f'\t hits@{TH_HITS / 2} is {(dfg.pred_extravasation > TH_HITS / 2).sum()}')
        
    # Truncate to high score area
    prev = len(dfg)
    start_frame = np.argmax(dfg['pred_extravasation'].values > 0.05)
    start_frame = max(0, start_frame - 5)
    end_frame = len(dfg) - np.argmax(dfg['pred_extravasation'].values[::-1] > 0.05)
    end_frame = min(len(dfg), end_frame + 5)
    
    if has_extrav and not dfg.iloc[np.arange(start_frame, end_frame)].extravasation_injury.max():  # Lower threshold
        print('- Use th=0.01 to remove frames')
        start_frame = np.argmax(dfg['pred_extravasation'].values > 0.01)
        start_frame = max(0, start_frame - 5)
        end_frame = len(dfg) - np.argmax(dfg['pred_extravasation'].values[::-1] > 0.01)
        end_frame = min(len(dfg), end_frame + 5)

    dfg = dfg.iloc[np.arange(start_frame, end_frame)].reset_index(drop=True)

    has_extrav_ = dfg.extravasation_injury.max()
    if has_extrav and not has_extrav_:
        print('- Warning, removed extrav when truncating frames. Skipping.')
        continue

    
    print(f'- Use {len(dfg)} frames out of {prev}')
    
#     ct += 1
#     if was_hit[ct]:
#         continue
    
    # Inference
    transforms = get_transfos(size=Config.size)
    dataset = InferenceDataset(dfg, transforms)
    meter = predict(model_marker, dataset, Config, disable_tqdm=True)

    transforms = albu.Compose([
        albu.PadIfNeeded(always_apply=False, p=1.0, min_height=384, min_width=384),
        albu.CenterCrop(always_apply=False, p=1.0, height=384, width=384),
    ])
    dataset = InferenceDataset(dfg, transforms)
    for i in range(len(dataset)):
        shape = dataset[i][2]
        meter.preds[i].update_shape(shape)
    
    preds, confs, gts = [], [], []
    for idx in range(len(dataset)):
        img, gt, shape = dataset[idx] 
        pred = Boxes(meter.preds[idx]['pascal_voc'][meter.labels[idx] == 0], (shape[0], shape[1]), bbox_format="pascal_voc")
        pred = expand_boxes(pred, min_size=64, max_size=64)
        scores = meter.confidences[idx][meter.labels[idx] == 0]

        if len(gt):
            gts.append(Boxes(gt, (shape[0], shape[1]), bbox_format="yolo"))
        else:
            gts.append(Boxes(np.zeros((1, 4)), (shape[0], shape[1]), bbox_format="yolo"))
            
        preds.append(pred)
        confs.append(scores)
        
    # Fuse boxes
    boxes_fusion, confs_fusion, hits_fusion, frames = fusion(
        preds,
        confs,
        iou_threshold=0.1,
        skip_box_thr=0.05,
        conf_threshold=0.1,
        hits_threshold=2 if len(dataset) <= 100 else 3,
        merge=True,
        max_det=1,
    )
    if not len(boxes_fusion):
        print('-> No predictions')
#         continue
        
#     plt.plot(np.concatenate(confs))
#     plt.axhline(0.05, c="salmon")
#     plt.show()
        
    if len(boxes_fusion):
        boxes_fusion = expand_boxes(boxes_fusion, min_size=CROP_SIZE, max_size=CROP_SIZE)

    # GT
    if has_extrav:
        boxes_fusion_gt, _, _, frames_gt = fusion(
            gts,
            np.ones((len(gts), 1)),
            iou_threshold=0.1,
            skip_box_thr=0.01,
            conf_threshold=0.1,
            hits_threshold=1,
            merge=True,
        )
        boxes_fusion_gt = expand_boxes(boxes_fusion_gt, min_size=CROP_SIZE, max_size=CROP_SIZE)
        
        if isinstance(boxes_fusion_gt, list):
            print('-> Warning, No GT found ')
            continue

    # Get hits
    scores = []
    for i, conf in enumerate(confs_fusion):
        start, end = frames[i, 0], frames[i, 1]    
        for idx in range(start, end):
            gt = []
            if has_extrav:
                gt = [b for j, b in enumerate(boxes_fusion_gt['pascal_voc']) if (idx >= frames_gt[j, 0] and idx <= frames_gt[j, 1])]
                gt = np.array(gt)

            for box in gt:
                iou_score = iou(box, boxes_fusion['pascal_voc'][i])
                scores.append(iou_score)

    hit = np.max(scores) > 0.1 if len(scores) else False
    hits.append(hit)
    if hit:
        print('-> Correct prediction')
#         continue

    # Get crops & save
    if SAVE:
        if has_extrav:
            for i, frame in enumerate(frames_gt):
                x0, y0, x1, y1 = boxes_fusion_gt['pascal_voc'][i]
                
                if frame[1] - frame[0] < 5:
                    print(f'- Extend crop of size {frame[1] - frame[0] + 1}')
                    mid = (frame[1] + frame[0]) // 2
                    delta = 3
                    frame[0], frame[1] = mid - delta - 1, mid + delta
                frame[0] = max(frame[0], 0)
                frame[1] = min(frame[1], len(dataset) - 1)
                    
                crops = []
                for idx in range(frame[0], frame[1] + 1):
                    img, _, shape = dataset[idx] 
                    crop = img[y0: y1, x0: x1]
                    crops.append(crop)

    #                 plt.figure(figsize=(5, 5))
    #                 plt.imshow(crop)
    #                 plt.show()

                if SAVE:
                    name = f"{patient_id}_{series}_{i}_extrav.npy"
                    crops = np.array(crops).astype(np.uint8)[..., 0]
                    np.save(SAVE_FOLDER + name, crops)
                    print(f"- Saved crop {name} of frames {np.arange(frame[0], frame[1] + 1, dtype=int)}")
        else:
            for i, frame in enumerate(frames):
                x0, y0, x1, y1 = boxes_fusion['pascal_voc'][i]

                if frame[1] - frame[0] < 5:
                    print(f'- Extend crop of size {frame[1] - frame[0] + 1}')
                    mid = (frame[1] + frame[0]) // 2
                    delta = 3
                    frame[0], frame[1] = mid - delta - 1, mid + delta
                frame[0] = max(frame[0], 0)
                frame[1] = min(frame[1], len(dataset) - 1)
                    
                crops = []
                for idx in range(frame[0], frame[1] + 1):
                    img, _, shape = dataset[idx] 
                    crop = img[y0: y1, x0: x1]
                    crops.append(crop)

    #                 plt.figure(figsize=(5, 5))
    #                 plt.imshow(crop)
    #                 plt.show()

                if SAVE:
                    name = f"{patient_id}_{series}_{i}_pred.npy"
                    crops = np.array(crops).astype(np.uint8)[..., 0]
                    np.save(SAVE_FOLDER + name, crops)
                    print(f"- Saved crop {name} of frames {np.arange(frame[0], frame[1] + 1, dtype=int)}")
        
    # Plot
    if PLOT and not hit:
        for i, conf in enumerate(confs_fusion):
            to_plot = np.linspace(frames[i, 0], frames[i, 1], 5, dtype=int)

            plt.figure(figsize=(25, 5))
            for plot_idx, idx in enumerate(to_plot):
    #             plt.figure(figsize=(5, 5))
                img, _, shape = dataset[idx] 

                if has_extrav:
                    gt = [b for i, b in enumerate(boxes_fusion_gt['pascal_voc']) if (idx >= frames_gt[i, 0] and idx <= frames_gt[i, 1])]
                    gt = np.array(gt)
                else:
                    gt = np.array([])

                merged_boxes = [b for i, b in enumerate(boxes_fusion['pascal_voc']) if ((frames[i, 0] <= idx) and (frames[i, 1] >= idx))]
    #             merged_boxes = preds[idx]['pascal_voc']



                plt.subplot(1, len(to_plot), plot_idx + 1)
                plot_boxes(img, gt, "pascal_voc", merged_boxes=merged_boxes)
                plt.title(f'Frame {idx} - conf {conf:.3f}')
            plt.show()


        if has_extrav and (not hit):
            for i, frame in enumerate(frames_gt):
                to_plot = np.linspace(frames_gt[i, 0], frames_gt[i, 1], 5, dtype=int)

                plt.figure(figsize=(25, 5))
                for plot_idx, idx in enumerate(to_plot):
                    img, _, shape = dataset[idx] 

                    gt = [b for i, b in enumerate(boxes_fusion_gt['pascal_voc']) if (idx >= frames_gt[i, 0] and idx <= frames_gt[i, 1])]
                    gt = np.array(gt)

                    try:
                        merged_boxes = [b for i, b in enumerate(boxes_fusion['pascal_voc']) if ((frames[i, 0] <= idx) and (frames[i, 1] >= idx))]
                    except:
                        merged_boxes = None
                        
                    merged_boxes = preds[idx]['pascal_voc']
                    

                    plt.subplot(1, len(to_plot), plot_idx + 1)
                    plot_boxes(img, gt, "pascal_voc", merged_boxes=merged_boxes)
                    plt.title(f'Frame {idx} - Pred conf : {confs[idx][0] :.3f}')
                plt.show()

#     if study_idx > 100:
#         break

In [None]:
len(os.listdir(SAVE_FOLDER))

### Full Inference

In [None]:
df_patient, df_img = prepare_data(DATA_PATH)

if "fold" not in df_patient.columns:
    folds = pd.read_csv(DATA_PATH + "folds_4.csv")
    df_img = df_img.merge(folds)
    df_patient = df_patient.merge(folds)

In [None]:
PLOT = False

In [None]:
for fold in [0]:
    print(f'\n- Fold {fold}\n')
    Config.fold = fold
    Config.name = f"rsna_{Config.version}_fold{fold}_{Config.exp}"
    Config.data_dir = f"../input/yolo/v1/{fold}_train/"
    Config.cfg = f"../yolox/exps/{Config.name}.py"
    Config.ckpt = f"../yolox/YOLOX_outputs/{Config.name}/last_epoch_ckpt.pth"

    model_marker = retrieve_yolox_model(Config.cfg, Config.ckpt, size=Config.size)
    model_marker = YoloXWrapper(model_marker, Config)
    
    print('\n- Predict')
    
    df_val = df_img[df_img['fold'] == fold].reset_index(drop=True)

#     df['path'] = df['path'].apply(lambda x: "../input/imgs/" + x.split('/')[-1])
#     df_val = df_val[df_val['path'].isin(df['path'].values)].reset_index(drop=True)
#     df_val['gt_path'] = df_val['path'].apply(
#         lambda x: "../input/yolo/v1/0_train/labels/valid/" + re.sub(".png", ".txt", x.split('/')[-1])
#     )

    transforms = get_transfos(size=Config.size)
    dataset = InferenceDataset(df_val, transforms)
    meter = predict(model_marker, dataset, Config, disable_tqdm=False)

    preds = meter.preds
        
    print('\n- Save & viz')

    boxes = []
    scores = []
    for idx in range(len(dataset)):
        pred = meter.preds[idx]['pascal_voc'][0]
        score = meter.confidences[idx][0]
        
        boxes.append(pred)
        scores.append(score)

        if PLOT or not (idx % 10000):
            img, gt, shape = dataset[idx] 
            if isinstance(img, torch.Tensor):
                img = img.cpu().numpy().transpose(1, 2, 0)
            plt.figure(figsize=(5, 5))
            plot_boxes(img, pred[None], "pascal_voc")
            plt.title(f'Pred - conf={score:.3f}')
            plt.show()
    
    np.save(f'../output/boxes_{Config.name}.npy', np.array(boxes))
    np.save(f'../output/confs_{Config.name}.npy', np.array(scores))
    
#     break

Done ! 