**About** : This notebook is used to infer models.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

### Imports

In [None]:
import os
import gc
import re
import sys
import cv2
import glob
import json
import torch
import shutil
import warnings
import numpy as np
import pandas as pd
import torch.nn.functional as F
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm
from joblib import Parallel, delayed
from scipy.special import softmax

In [None]:
from util.torch import load_model_weights
from util.plots import plot_mask, add_rect
from util.metrics import rsna_loss

from data.processing import process_and_save
from data.transforms import get_transfos
from data.dataset import CropDataset, CoordsDataset
from data.preparation import prepare_data_crop

from inference.seg import get_crops
from inference.dataset import ImageInfDataset, FeatureInfDataset, SafeDataset
from inference.lvl1 import predict, Config

from model_zoo.models import define_model
from model_zoo.models_lvl2 import define_model as define_model_2
from model_zoo.models_seg import define_model as define_model_seg
from model_zoo.models_seg import convert_3d

from params import CLASSES_SEG, MODES, LEVELS_, SEVERITIES, LEVELS, CLASSES_CROP

### Params

In [None]:
DEBUG = False
FOLD = 0

ROOT_DATA_DIR = "../input/seg_npy/imgs/"
SAVE_FOLDER = "../output/spider_pl/"

# ROOT_DATA_DIR = "../input/npy2/"
# SAVE_FOLDER = "../output/comp_pl/"


os.makedirs(SAVE_FOLDER, exist_ok=True)

os.makedirs(SAVE_FOLDER, exist_ok=True)
os.makedirs(SAVE_FOLDER + "npy/", exist_ok=True)
os.makedirs(SAVE_FOLDER + "mid/", exist_ok=True)
os.makedirs(SAVE_FOLDER + "csv/", exist_ok=True)

In [None]:
df_meta = pd.DataFrame({'img_path': glob.glob(ROOT_DATA_DIR + "*.npy")})
df_meta['orient'] = "Sagittal"
df_meta['study_series'] = df_meta["img_path"].apply(lambda x: x.split('/')[-1][:-4])

if DEBUG:
    df_meta = df_meta.head(3)

if ROOT_DATA_DIR == "../input/npy2/":  # comp data, filtering needed
    dfm = pd.read_csv('../input/train_series_descriptions.csv')
    dfm['study_series'] = dfm['study_id'].astype(str) + "_" + dfm['series_id'].astype(str)
    df_meta = df_meta[["study_series", "img_path"]].merge(dfm, how="left", on="study_series")
    df_meta['orient'] = df_meta['series_description'].apply(lambda x: x.split(' ')[0])

    folds = pd.read_csv('../input/train_folded_v1.csv')
    df_meta = df_meta.merge(folds)

    df_meta = df_meta[df_meta['fold'] != FOLD]
    df_meta = df_meta[df_meta['orient'] == "Sagittal"]
    df_meta.drop(['study_id', 'series_id', 'series_description', 'fold'], axis=1, inplace=True)

    df_meta = df_meta.reset_index(drop=True)

print(f'-> Pseudo labeling {len(df_meta)} series with fold {FOLD}')

In [None]:
BATCH_SIZE = 32
BATCH_SIZE_2 = 512
USE_FP16 = True

NUM_WORKERS = os.cpu_count()

In [None]:
COORDS_FOLDERS = {
    "sag": ("../logs/2024-08-29/0/", "fullfit_0"),  # fullfit for simplicity
}

CROP_EXP_FOLDERS = {
    "crop": ("../output/2024-09-13_7/", [FOLD], "crops_0.1"),
}

for k in CROP_EXP_FOLDERS:
    assert os.path.exists(CROP_EXP_FOLDERS[k][0]), f"Crop model not found: {k}"
for k in COORDS_FOLDERS:
    assert os.path.exists(COORDS_FOLDERS[k][0]), f"Coords model not found: {k}"

## Preparation

In [None]:
save_folder = SAVE_FOLDER
save_middle_frame = True

In [None]:
for i in tqdm(range(len(df_meta))):
    imgs = np.load(df_meta['img_path'][i])
    study_series = df_meta['study_series'][i]

    imgs = imgs[:, ::-1].copy()

    np.save(save_folder + f"npy/{study_series}.npy", imgs)

    if save_middle_frame:
        img = imgs[len(imgs) // 2]
        img = np.clip(
            img, np.percentile(img.flatten(), 0), np.percentile(img.flatten(), 98)
        )
        max_, min_ = img.max(), img.min()
        if max_ != min_:
            img = (img - min_) / (max_ - min_)
        else:
            img = img - min_
        img = (img * 255).astype(np.uint8)
        cv2.imwrite(save_folder + f"mid/{study_series}.png", img)

        if DEBUG:
            plt.imshow(img, cmap="gray")
            plt.title(study_series)
            plt.axis(False)
            plt.show()


## Sagittal Coords

In [None]:
df_sag = df_meta[df_meta["orient"] == "Sagittal"].reset_index(drop=True)
df_sag = df_sag[df_sag.columns[:6]]

df_sag['img_path'] = SAVE_FOLDER + "mid/" + df_sag["study_series"] + ".png"
df_sag['target'] = [np.ones((5, 2)) for _ in range(len(df_sag))]

df_sag.head(3)

In [None]:
config_sag = Config(json.load(open(COORDS_FOLDERS['sag'][0] + "config.json", "r")))

model_sag = define_model(
    config_sag.name,
    drop_rate=config_sag.drop_rate,
    drop_path_rate=config_sag.drop_path_rate,
    pooling=config_sag.pooling,
    num_classes=config_sag.num_classes,
    num_classes_aux=config_sag.num_classes_aux,
    n_channels=config_sag.n_channels,
    reduce_stride=config_sag.reduce_stride,
    pretrained=False,
)
model_sag = model_sag.cuda().eval()

weights = COORDS_FOLDERS['sag'][0] + f"{config_sag.name}_{COORDS_FOLDERS['sag'][1]}.pt"
model_sag = load_model_weights(model_sag, weights, verbose=1)

In [None]:
%%time
transfos = get_transfos(augment=False, resize=config_sag.resize, use_keypoints=True)
dataset = CoordsDataset(df_sag, transforms=transfos)
dataset = SafeDataset(dataset)

preds_sag, _ = predict(model_sag, dataset, config_sag.loss_config, batch_size=32, use_fp16=True)

In [None]:
DELTAS = [0.1]  #, 0.15]

for delta in DELTAS:
    os.makedirs(SAVE_FOLDER + f"crops_{delta}", exist_ok=True)

In [None]:
for idx in tqdm(range(len(df_sag))):
    study_series = df_sag["study_series"][idx]
    imgs_path = SAVE_FOLDER + "npy/" + study_series + ".npy"

    imgs = np.load(imgs_path)

    preds = preds_sag[idx].reshape(-1, 2).copy()

    for delta in DELTAS:  # , 0.15
        crops = np.concatenate([preds, preds], -1)
        crops[:, [0, 1]] -= delta
        crops[:, [2, 3]] += delta
        crops = crops.clip(0, 1)

        crops[:, [0, 2]] *= imgs.shape[2]
        crops[:, [1, 3]] *= imgs.shape[1]
        crops = crops.astype(int)

        img_crops = []
        for i, (x0, y0, x1, y1) in enumerate(crops):

            crop = imgs[:, y0: y1, x0: x1].copy()
            # crop = np.zeros((3, 1, 1))
            try:
                assert crop.shape[2] >= 1 and crop.shape[1] >= 1
            except AssertionError:
                # print('!!')
                # pass
                crop = imgs.copy()

            np.save(SAVE_FOLDER + f"crops_{delta}/{study_series}_{LEVELS_[i]}.npy", crop)
            img_crops.append(crop[len(crop) // 2])

        if DEBUG:
            preds[:, 0] *= imgs.shape[2]
            preds[:, 1] *= imgs.shape[1]

            plt.figure(figsize=(8, 8))
            plt.imshow(imgs[len(imgs) // 2], cmap="gray")
            plt.scatter(preds[:, 0], preds[:, 1], marker="x", label="center")
            plt.title(study_series)
            plt.axis(False)
            plt.legend()
            plt.show()

            plt.figure(figsize=(20, 4))
            for i in range(5):
                plt.subplot(1, 5, i + 1)
                plt.imshow(img_crops[i], cmap="gray")
                plt.axis(False)
                plt.title(LEVELS[i])
            plt.show()

In [None]:
if DEBUG and not EVAL:
    ref_folder = DEBUG_DATA_DIR + "coords_crops_0.1/"
    df_ref = prepare_data_crop(ROOT_DATA_DIR, ref_folder).head(10)

    df_ref['img_path_2'] = df_ref['img_path'].apply(
        lambda x: re.sub(ref_folder, SAVE_FOLDER + f"crops_0.1/", x)
    )

    for i in range(len(df_ref)):
        cref = np.load(df_ref['img_path'][i])
        c = np.load(df_ref['img_path_2'][i])
        assert (cref == c).all()
        # plt.subplot(1, 2, 1)
        # plt.imshow(c[len(c) // 2], cmap="gray")
        # plt.subplot(1, 2, 2)
        # plt.imshow(cref[len(cref) // 2], cmap="gray")
        # plt.show()
        # break

## Crop models

In [None]:
df = df_meta.copy()

df["target"] = 0
df["coords"] = 0

df["level"] = [LEVELS for _ in range(len(df))]
df["level_"] = [LEVELS_ for _ in range(len(df))]
df = df.explode(["level", "level_"]).reset_index(drop=True)
df["img_path_"] = df["study_series"] + "_" + df["level_"] + ".npy"

In [None]:
crop_fts = {}
for mode in tqdm(CROP_EXP_FOLDERS, total=len(CROP_EXP_FOLDERS)):
    exp_folder, folds, crop_folder = CROP_EXP_FOLDERS[mode]
    print(f"- Model {mode} - {exp_folder}")

    config = Config(json.load(open(exp_folder + "config.json", "r")))

    if mode == "crop":
        df_mode = df[df['orient'] == "Sagittal"].reset_index(drop=True)
        df_mode["side"] = "Center"
    elif "scs" in mode:
        df_mode = df[df['orient'] == "Sagittal"]
        df_mode = df_mode[df_mode["weighting"] == "T2"].reset_index(drop=True)
        df_mode["side"] = "Center"
    elif "nfn" in mode:
        df_mode = df[df['orient'] == "Sagittal"]
        df_mode["side"] = ["Right", "Left"]
        df_mode = df_mode.explode("side").reset_index(drop=True)
        df_mode = df_mode.sort_values(
            ["study_id", "series_id", "side", "level"],
            ascending=[True, True, False, True],
            ignore_index=True
        )
    elif "ss" in mode:
        df_mode = df[df['orient'] == "Axial"]
        df_mode["side"] = ["Right", "Left"]
        df_mode = df_mode.explode("side").reset_index(drop=True)
        df_mode = df_mode.sort_values(
            ["study_id", "series_id", "side", "level"],
            ascending=[True, True, False, True],
            ignore_index=True
        )

    df_mode['img_path'] = SAVE_FOLDER + crop_folder + "/" + df_mode["img_path_"]

    transfos = get_transfos(augment=False, resize=config.resize, crop=config.crop)
    dataset = CropDataset(
        df_mode,
        targets="target",
        transforms=transfos,
        frames_chanel=config.frames_chanel,
        n_frames=config.n_frames,
        stride=config.stride,
        train=False,
        load_in_ram=False,
    )
    dataset = SafeDataset(dataset)

    model = define_model(
        config.name,
        drop_rate=config.drop_rate,
        drop_path_rate=config.drop_path_rate,
        pooling=config.pooling,
        head_3d=config.head_3d,
        n_frames=config.n_frames,
        num_classes=config.num_classes,
        num_classes_aux=config.num_classes_aux,
        n_channels=config.n_channels,
        reduce_stride=config.reduce_stride,
        pretrained=False,
    )
    model = model.cuda().eval()

    preds = []
    for fold in folds:
        weights = exp_folder + f"{config.name}_{fold}.pt"
        model = load_model_weights(model, weights, verbose=1)

        pred, _ = predict(
            model,
            dataset,
            config.loss_config,
            batch_size=BATCH_SIZE,
            use_fp16=USE_FP16,
            num_workers=NUM_WORKERS,
        )
        pred = softmax(pred, -1)
        preds.append(pred)

    preds = np.mean(preds, 0)

    for i, tgt in enumerate(CLASSES_CROP):
        for j in range(3):
            df_mode[f'pred_{tgt}_{j}'] = preds[:, i, j]

In [None]:
df_mode.to_csv(SAVE_FOLDER + f"preds_crop_{FOLD}.csv", index=False)
print("-> Saved PL preds to", SAVE_FOLDER + f"preds_crop_{FOLD}.csv")
df_mode.head(10)

In [None]:
df_gt = prepare_data_crop("../input/")
df_gt['study_series'] = df_gt['study_id'].astype(str) + "_" + df_gt['series_id'].astype(str)

folds = pd.read_csv('../input/train_folded_v1.csv')
df_gt = df_gt.merge(folds)
df_gt = df_gt[df_gt['fold'] != FOLD]

In [None]:
df_mode['pred'] = df_mode.apply(
    lambda r: r[df_mode.columns[-16:-1]].values.reshape(-1, 3), axis=1
)
df_gt = df_gt.merge(df_mode[['study_series', 'pred', 'level']], on=["study_series", "level"], how="left")

In [None]:
# for i, row in df_gt.iterrows():
#     print(row.target)
#     print(row.pred)
#     print(
#         log_loss(row.target, row.pred, labels=[0, 1, 2])
#     )
#     break

In [None]:
from sklearn.metrics import log_loss
df_gt['error'] = df_gt.apply(lambda row: log_loss(row.target, row.pred, labels=[0, 1, 2]), axis=1)

In [None]:
import seaborn as sns
sns.histplot(df_gt['error'])

In [None]:
(df_gt['target'].apply(lambda x : x[0])).value_counts()

In [None]:
df_gt[df_gt['study_id'] == 1879696087].sort_values('error', ascending=False)

In [None]:
df_gt.sort_values('error', ascending=False)

In [None]:
(df_gt['error'] > 2).mean()

Done ! 