In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

In [None]:
import os
import cv2
import json
import glob
import torch
import pydicom
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

from collections import Counter
from tqdm.notebook import tqdm


pd.set_option('display.width', 5000)
pd.set_option('max_colwidth', 100)
pd.set_option('display.max_columns', 100)

In [None]:
from params import *


from data.dataset import *
from data.transforms import *
from data.preparation import *
from data.sagital_to_axial import get_axial_coords

from model_zoo.models import define_model
from util.torch import load_model_weights
from util.plots import *

from inference.lvl1 import Config, predict

## Sagittal

### Inference

In [None]:
EXP_FOLDER = "../logs/2024-08-28/13/"  # coatnet_rmlp_2_rw_384 50 ep new folds
EXP_FOLDER = "../logs/2024-08-28/24/"  # coatnet_rmlp_2_rw_384 aug 50 ep new folds
FOLD = 0

In [None]:
config = Config(json.load(open(EXP_FOLDER + "config.json", "r")))

df = prepare_coords_data(config.coords_folder, use_ext=config.use_ext)

folds = pd.read_csv(config.folds_file)
df = df.merge(folds, how="left")
df['fold'] = df['fold'].fillna(-1)

In [None]:
# models_sag = []
for fold in range(4):
    model = define_model(
        config.name,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        pooling=config.pooling,
        num_classes=config.num_classes,
        num_classes_aux=config.num_classes_aux,
        n_channels=config.n_channels,
        reduce_stride=config.reduce_stride,
        pretrained=False,
    )
    model = model.cuda().eval()

    weights = EXP_FOLDER + f"{config.name}_{fold}.pt"
    try:
        model = load_model_weights(model, weights, verbose=config.local_rank == 0)
    except FileNotFoundError:
        continue
    # models_sag.append(model)

    df_val = df[df['fold'] == fold].reset_index(drop=True)
    dataset = CoordsDataset(df_val, transforms=get_transfos(augment=False, resize=config.resize, use_keypoints=True))

    preds, _ = predict(model, dataset, config.loss_config, batch_size=32, use_fp16=True)

    np.save(EXP_FOLDER + f"pred_inf_{fold}.npy", preds)

### Eval

In [None]:
# cc = pd.read_csv('../input/train_label_coordinates.csv')

# cc = cc.merge(
#     cc[["series_id", "level"]].groupby("series_id").count().reset_index(),
#     how="left",
#     on="series_id",
#     suffixes=("", "_count"),
# )

# dfg = (
#     cc[(cc.condition != "Spinal Canal Stenosis") & ~(cc.level_count.isin([5, 10]))]
#     .groupby(["study_id", "series_id"])
#     .agg(list)
#     .reset_index()
# )
# dfg.shape

# dfg = (
#     cc[(cc.condition == "Spinal Canal Stenosis") & (cc.level_count != 5)]
#     .sort_values("level")
#     .groupby(["study_id", "series_id"])
#     .agg(list)
#     .reset_index()
# )
# dfg.shape

In [None]:
PLOT = True

In [None]:
df_gt = prepare_data()
df_sev = prepare_data_crop(DATA_PATH)
df_spinenet = pd.read_csv('../output/spinenet_kps.csv')

In [None]:
ds = []
df_ = prepare_data()

for fold in range(4):
    preds = np.load(EXP_FOLDER + f"pred_inf_{fold}.npy")
    df_val = df[df['fold'] == fold].reset_index(drop=True)
    
    dataset = CoordsDataset(df_val, transforms=get_transfos(augment=False, use_keypoints=True))

    for idx in tqdm(range(len(dataset))):
        study = df_val['study_id'][idx]
        series = df_val['series_id'][idx]

        # if df_val['series_description'][idx] != "Sagittal T2/STIR":
        #     continue

        # if not series in SERIES:  # 2433314690:
        #     continue
        # if series != 2433314690:
        #     continue

        img, y, _ = dataset[idx]
        labels = np.vstack(df_sev[df_sev['series_id'] == series].sort_values('level')['target'].values)

        gt = df_gt[df_gt['series_id'] == series]
        imgs = np.load(f'../input/npy/{study}_{series}.npy')

        # print(gt['coords'].values[0][:, 0])
        # frame = int(np.round(gt['coords'].values[0][:, 0].mean()))
        frame = len(imgs) // 2
        # frame = len(imgs) // 4
        # frame = int(gt['coords'].values[0][-1, 0])

        img = imgs[frame]
        img = (img - img.min()) / (img.max() - img.min())

        try:
            spinenet_coords = df_spinenet[df_spinenet['series_id'] == series]
            spinenet_coords = spinenet_coords.values[len(spinenet_coords)  // 2, -10:].reshape(2, 5).T
            # p = spinenet_coords.copy()
            # p[:, 0] /= img.shape[1]
            # p[:, 1] /= img.shape[0]
        except:
            spinenet_coords = None

        p_ = preds[idx].reshape(-1, 2)

        p = preds[idx].reshape(-1, 2)

        d = np.abs(p - y.numpy()) * 100
        d = d[y.sum(-1) > 0].mean()

        # if (y.sum(-1) > 0 ).sum() != 5:
        #     print(y)
        ds.append(d)

        # if d < 4:
        #     continue
        # print(df_val["series_id"][idx])

        if PLOT:
            if d > 5:
                y = y[y.sum(-1) > 0]
                # if len(y) == 5:
                #     continue

                print(study, series)
                # print('SCS / L-NFN / R-NFN / L-SS / R-SS')
                # print(labels)

                # cv2.imwrite(f'../output/fix/{study}_{series}.png', (img * 255).astype(np.uint8))

                plt.figure(figsize=(8, 8))
                plt.imshow(img, cmap="gray")
                plt.scatter(y[:, 0] * img.shape[1], y[:, 1] * img.shape[0], marker="x", label="truth")
                plt.scatter(p_[:, 0] * img.shape[1], p_[:, 1] * img.shape[0], marker="x", label="pred")
                if spinenet_coords is not None:
                    plt.scatter(spinenet_coords[:, 0], spinenet_coords[:, 1], marker="x", label="spinenet")
                plt.title(f'Dist = {d:.2f} - series {series}')
                plt.axis(False)
                plt.legend()
                plt.show()

                print(p_[:, 0] * img.shape[1], p_[:, 1] * img.shape[0])

            # if idx > 50:
            #     break
    # if PLOT:
    break


In [None]:
print('Images with error > 5%:', (np.array(ds) > 5).sum())
print('Images with error > 4%:', (np.array(ds) > 4).sum())
print('Images with error > 3%:', (np.array(ds) > 3).sum())

In [None]:
sns.histplot(ds)
plt.axvline(np.mean(ds), c="salmon")
plt.text(np.mean(ds), 100, f"   mean={np.mean(ds):.3f}", color="salmon")
plt.show()

### Crop

In [None]:
DELTA = 0.1

SAVE = True
PLOT = False

SAVE_FOLDER = f"../input/coords_crops_{DELTA}_/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

In [None]:
for fold in range(4):
    pred_val = np.load(EXP_FOLDER + f"pred_inf_{fold}.npy")
    df_val = df[df['fold'] == fold].reset_index(drop=True)

    for idx in tqdm(range(len(df_val))):
        study_series = df_val["img_path"][idx].split('/')[-1][:-4]
        imgs_path = DATA_PATH + "npy/" + study_series + ".npy"
        imgs = np.load(imgs_path)

        preds = pred_val[idx].reshape(-1, 2)

        crops = np.concatenate([preds, preds], -1)
        crops[:, [0, 1]] -= DELTA
        crops[:, [2, 3]] += DELTA
        crops[:, [0, 2]] *= imgs.shape[2]
        crops[:, [1, 3]] *= imgs.shape[1]
        crops = crops.astype(int)


        # print(df_val["series_id"][idx])

        if SAVE:
            for i, (x0, y0, x1, y1) in enumerate(crops):
                crop = imgs[:, y0: y1, x0: x1].copy()
                np.save(SAVE_FOLDER + f'{study_series}_{LEVELS_[i]}.npy', crop)

                # cc = np.load(SAVE_FOLDER + study_series + "_" + LEVELS_[i] + ".npy")
                # plt.imshow(cc[len(cc) // 2], cmap="gray")
                # plt.show()

        if PLOT:
            preds[:, 0] *= imgs.shape[2]
            preds[:, 1] *= imgs.shape[1]

            plt.figure(figsize=(8, 8))
            plt.imshow(imgs[len(imgs) // 2], cmap="gray")
            plt.scatter(preds[:, 0], preds[:, 1], marker="x", label="center")
            plt.scatter(crops[:, 0], crops[:, 1], marker="x", label="top-left")
            plt.scatter(crops[:, 2], crops[:, 3], marker="x", label="bot-right")
            plt.title(study_series)
            plt.axis(False)
            plt.legend()
            plt.show()
            break

Done ! 

## Axial

In [None]:
ref_studies = [
    # 113758629,
    # 13317052, 60612428, 74294498, 142991438, 
    # 168833126, 189360935, 58813022, 1115952008, 959290081,
    2388577668  # bugged
]

PLOT = False
SAVE = True

SIZE = 0.15
SAVE_FOLDER = f"../input/crops_ax_{SIZE}/"

os.makedirs(SAVE_FOLDER, exist_ok=True)

In [None]:
EXP_FOLDER_AX = "../logs/2024-08-26/3/"

config = Config(json.load(open(EXP_FOLDER_AX + "config.json", "r")))

models_ax = []
for fold in range(4):
    model = define_model(
        config.name,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        pooling=config.pooling,
        num_classes=config.num_classes,
        num_classes_aux=config.num_classes_aux,
        n_channels=config.n_channels,
        reduce_stride=config.reduce_stride,
        pretrained=False,
    )
    model = model.cuda().eval()

    weights = EXP_FOLDER_AX + f"{config.name}_{fold}.pt"
    model = load_model_weights(model, weights, verbose=config.local_rank == 0)
    models_ax.append(model)

    # break

In [None]:
df = prepare_coords_data()

folds = pd.read_csv(config.folds_file)
df = df.merge(folds, how="left")
df['fold'] = df['fold'].fillna(-1)

In [None]:
df_ = prepare_data()
df_coords = pd.read_csv(DATA_PATH + "train_label_coordinates.csv")

In [None]:
coords = []
axial_coords = []

for fold in range(4):
    preds_coords = np.load(EXP_FOLDER + f"pred_val_{fold}.npy")
    df_val = df[df["fold"] == fold].reset_index(drop=True)

    for idx in tqdm(range(len(df_val))):
        study = df_val["study_id"][idx]
        series = df_val["series_id"][idx]

        # if not study in ref_studies:
        #     continue

        # Get axial projection
        p = preds_coords[idx].reshape(-1, 2)

        img = cv2.imread(df_val["img_path"][idx])
        h, w, _ = img.shape

        world_point, assigned_level, closest_z, df_axial = get_axial_coords(
            study,
            series,
            p.copy(),
            h,
            w,
            df_,
            "../input/train_images/",
        )

        if closest_z.max() == 0:  # Fix
            world_point[:, -1] -= (world_point[:, -1].mean() - df_axial.projection.mean())
            world_point, assigned_level, closest_z, df_axial = get_axial_coords(
                study,
                series,
                p.copy(),
                h,
                w,
                df_,
                "../input/train_images/",
                world_point=world_point
            )

        # Evaluate
        series_ax = df_axial["series_id"].values[0]
        df_gt = df_coords[df_coords["series_id"] == series_ax].reset_index(drop=True)
        df_gt = df_gt[["instance_number", "level", "x", "y"]].groupby("level").mean().sort_index()
        gt = df_gt["instance_number"].values.flatten()
        preds = df_axial["instance_number"].values[closest_z]

        if len(df_gt) == 5:
            mae = np.abs(gt - preds).mean()
            # df_gt["pred"] = preds
        else:
            mae = 0

        # Locate disk
        imgs = np.load(f'../input/npy2/{study}_{series_ax}.npy')

        with torch.no_grad():
            x = torch.tensor(imgs[closest_z].astype(np.float32)).cuda()

            min_ = x.amin((-1, -2), keepdim=True)
            max_ = x.amax((-1, -2), keepdim=True)
            x = (x - min_) / (max_ - min_)
            x =  F.interpolate(
                x.unsqueeze(1).repeat(1, 3, 1, 1),
                config.resize,
                mode="bilinear",
            )

            preds_ax = models_ax[fold](x)[0].sigmoid().detach().cpu().numpy().reshape(x.size(0), 2, 2)

        preds_ax[:, :, 0] *= imgs.shape[2]
        preds_ax[:, :, 1] *= imgs.shape[1]
        
        # Crop
        crop_imgs = []
        for i in range(5):
            f = closest_z[i]
            fs = max(closest_z[i] - 3, 0)
            fe = min(closest_z[i] + 3, len(imgs))

            xc, yc = preds_ax[i].mean(0).astype(int)
            dx, dy = int(imgs.shape[2] * SIZE), int(imgs.shape[1] * SIZE)
            x0, x1 = max(xc - dx, 0), min(xc + dx, imgs.shape[2])
            y0, y1 = max(yc - dy, 0), min(yc + dy, imgs.shape[1])

            d = SIZE // 2
            crop = imgs[fs: fe, y0: y1, x0: x1]
            crop_imgs.append(crop[len(crop) // 2])

            if SAVE:
                np.save(SAVE_FOLDER + f"{study}_{series_ax}_{LEVELS_[i]}.npy", crop.copy())

        # Plot
        if PLOT:
            plt.figure(figsize=(25, 5))
            for i in range(5):
                plt.subplot(1, 5, i + 1)
                plt.imshow(crop_imgs[i], cmap="gray")
            plt.show()
            
            # display(df_gt)
            # plot_coords(
            #     world_point,
            #     assigned_level,
            #     closest_z,
            #     h,
            #     w,
            #     df_axial,
            #     title=f"Study {study} - Series {series_ax}",
            # )

            df_lvl = df_coords[df_coords["series_id"] == series_ax].reset_index(drop=True)
            plt.figure(figsize=(25, 5))
            for i in range(5):
                plt.subplot(1, 5, i + 1)

                img = imgs[closest_z[i]]

                df_lvl_ = df_lvl[df_lvl["level"] == LEVELS[i]]
                x, y = df_lvl_['x'].values, df_lvl_["y"].values

                plt.scatter(preds_ax[i, :, 0], preds_ax[i, :, 1], label="pred")
                plt.scatter(x, y, label="truth", marker="x")

                plt.imshow(img, cmap="gray")
                plt.legend()
                plt.axis(False)
                plt.title(str(df_axial["instance_number"][closest_z[i]]))

            plt.show()
            if idx > 5:
                break
    if PLOT:
        break