In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

In [None]:
import sys
sys.path.append("../../SpineNet/")

In [None]:
import os
import re
import cv2
import json
import glob
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

from pathlib import Path
from tqdm.notebook import tqdm
from collections import defaultdict
from matplotlib.patches import Polygon

import spinenet
from spinenet import SpineNet, download_example_scan
from spinenet.io import load_dicoms_from_folder

In [None]:
from params import *

from data.preparation import prepare_data

### Model

In [None]:
# !cp -r ../../SpineNet/spinenet/weights/grading/  ../output/

In [None]:
spinenet.download_weights(verbose=True, force=False)

spnt = SpineNet(device='cuda:0', verbose=True)

In [None]:
DATA_PATH = "../input/"

In [None]:
df = prepare_data(DATA_PATH)

In [None]:
df.head(1)

In [None]:
PLOT = True

In [None]:
SPINENET_CLASSES = [
    "Pfirrmann", "Narrowing", "CentralCanalStenosis", "Spondylolisthesis", 
    "UpperEndplateDefect", "LowerEndplateDefect", "UpperMarrow", "LowerMarrow",
    "ForaminalStenosisLeft", "ForaminalStenosisRight", "Herniation"
]

In [None]:
all_preds = []
for idx in tqdm(range(len(df))):
    study = df['study_id'][idx]
    series = df['series_id'][idx]

    if df['orient'][idx] == "Axial":
        continue

    scan = load_dicoms_from_folder(
        DATA_PATH + f"train_images/{study}/{series}/",
        require_extensions=False,
    )

    vert_dicts = spnt.detect_vb(scan.volume, scan.pixel_spacing)


    if PLOT:
        fig = plt.figure(figsize=(8, 8))
        slice_idx = scan.volume.shape[2] // 2

        ax = fig.add_subplot(1, 1, 1)
        ax.imshow(scan.volume[:, :, slice_idx], cmap="gray")
        ax.set_title(f"Slice {slice_idx+1}")
        ax.axis("off")
        for vert_dict in vert_dicts:
            if slice_idx in vert_dict["slice_nos"]:
                poly_idx = int(vert_dict["slice_nos"].index(slice_idx))
                poly = np.array(vert_dict["polys"][poly_idx])
                ax.add_patch(Polygon(poly, ec="y", fc="none"))
                ax.text(
                    np.mean(poly[:, 0]),
                    np.mean(poly[:, 1]),
                    vert_dict["predicted_label"],
                    c="y",
                    ha="center",
                    va="center",
                )

        fig.suptitle("Detected Vertebrae (all slices)")
        plt.show()

    ivd_dicts = spnt.get_ivds_from_vert_dicts(vert_dicts, scan.volume)
    ivd_dicts = [ivd_dict for ivd_dict in ivd_dicts if "T" not in ivd_dict['level_name']]

    preds = defaultdict(dict)
    for ivd in ivd_dicts:
        print(ivd['volume'].shape)
        with torch.inference_mode():
            image = torch.tensor(ivd['volume'])[None, None, :, :, :].float().cuda()
            net_output = spnt.grading_model(image)
            preds[ivd['level_name']] = {c: net_output[i].cpu().numpy()[0] for i, c in enumerate(SPINENET_CLASSES)}
    all_preds.append(preds)

    if PLOT:
        for k in preds:
            print(k)
            for c_idx in [1, 2, 8, 9]:
                print(f" - {SPINENET_CLASSES[c_idx]}", np.argmax(preds[k][SPINENET_CLASSES[c_idx]]))

    if idx > 10:
        break

In [None]:
df_save = pd.DataFrame(all_preds)

df_ = df[df['orient'] != "Axial"].reset_index()
df_save["study_id"] = df_['study_id'].astype(str)
df_save["series_id"] = df_['series_id'].astype(str)

df_save = df_save[["study_id", "series_id", "L1-L2", "L2-L3", "L3-L4", "L4-L5", "L5-S1"]]
df_save.columns = ["study_id", "series_id", "l1_l2", "l2_l3", "l3_l4", "l4_l5", "l5_s1"]
df_save.head()

In [None]:
for col in df_save.columns[2:]:
    df_save[col] = df_save[col].apply(lambda x: {k: x[k].tolist() for k in x} if isinstance(x, dict) else x)

In [None]:
df_save.to_csv('../output/spinenet_preds.csv', index=False)

### Eval

In [None]:
from data.preparation import *
from scipy.special import softmax
from sklearn.metrics import roc_auc_score

In [None]:
y = prepare_data_lvl2(DATA_PATH)

In [None]:
ddf = pd.read_csv('../output/spinenet_preds.csv')
for level in LEVELS_:
    ddf[level] = ddf[level].fillna('()').apply(eval)
ddf = ddf.drop_duplicates(subset='study_id', keep="first").reset_index(drop=True)

In [None]:
if "fold" not in df.columns:
    folds = pd.read_csv("../input/train_folded_v1.csv")
    ddf = ddf.merge(folds, how="left")

In [None]:
ddf.head()

In [None]:
tgts = ["spinal_canal_stenosis", "left_neural_foraminal_narrowing", "right_neural_foraminal_narrowing", "right_subarticular_stenosis", "left_subarticular_stenosis"]

In [None]:
from sklearn.linear_model import *
from util.metrics import disk_auc

In [None]:

for tgt_ in tgts:
    print(f'\n-> {tgt_}\n')
    for level in LEVELS_:
        print()
        
        for c in SPINENET_CLASSES:
            fts = ddf[level].apply(lambda x: x[c] if isinstance(x, dict) else [0 for _ in range(len(ddf[level].values[0][c]))])
            fts = np.array(fts.values.tolist())
            fts = softmax(fts, -1)

            tgt = y[f"{tgt_}_{level}"]
            

            for i in range(min(3, fts.shape[1])):
                s = roc_auc_score(tgt == i, fts[:, i])
                if s > 0.8:
                    print(c, level, i, s)

In [None]:
for tgt_ in tgts:
    print(f'\n-> {tgt_}\n')
    for level in LEVELS_:

        fts_ = []
        for c in SPINENET_CLASSES:
            fts = ddf[level].apply(lambda x: x[c] if isinstance(x, dict) else [0 for _ in range(len(ddf[level].values[0][c]))])
            fts = np.array(fts.values.tolist())
            fts = softmax(fts, -1)
            fts_.append(fts)

        fts = np.concatenate(fts_, -1)

        tgt = y[f"{tgt_}_{level}"].values
        pred_oof = np.zeros((len(y), 3))

        for fold in range(4):
            model = LogisticRegression(C=0.5)
            train_idx = ddf[ddf['fold'] != fold].index.values
            val_idx = ddf[ddf['fold'] == fold].index.values

            y_train = tgt[train_idx]
            model.fit(fts[train_idx][y_train >= 0], y_train[y_train >= 0])
            pred_oof[val_idx] = model.predict_proba(fts[val_idx])

        s = disk_auc(tgt, pred_oof)
        print(tgt_, level, s)
        # break

Done !