In [None]:
import os
import json
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from rsna2024.utils import natural_sort
import pydicom
import torch
import random
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
import logging
logging.getLogger('albumentations').setLevel(logging.WARNING)
from albumentations.pytorch import ToTensorV2

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

from rsna2024.runner import Runner
from rsna2024.preproc.generate_tiles import get_tile
from rsna2024 import model as module_model

root_dir = '/media/latlab/MR/projects/kaggle-rsna-2024'
data_dir = os.path.join(root_dir, 'data', 'raw')
df_series = pd.read_csv(
    os.path.join(data_dir, 'train_series_descriptions.csv'),
    dtype={'study_id': 'str', 'series_id': 'str'},
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def load_config(config_path):
    with open(config_path) as f:
        return json.load(f)
    
def get_random_series_id(df_series, study_id, series_description):
    series_list = df_series[
        (df_series['study_id'] == study_id)
        & (df_series['series_description'] == series_description)
    ]['series_id'].tolist()

    if len(series_list) == 0:
        return None

    return random.sample(series_list, 1)[0]

In [None]:
# Baseline model
model_name = 'comfy-galaxy-426'
full_model_name = 'rsna-2024-' + model_name
model_dir = os.path.join(root_dir, 'models', full_model_name)

baseline_cfg = load_config(os.path.join(model_dir, 'config.json'))

baseline_preds, baseline_ys, data = Runner(baseline_cfg, model_name=full_model_name).predict()
baseline_preds = torch.unflatten(torch.tensor(baseline_preds).to(device), 1, [3, -1])
orig_baseline_preds = baseline_preds.clone()
baseline_ys = torch.tensor(baseline_ys).to(device)

In [None]:
# Keypoint detection
model_name = 'eager-voice-350'
full_model_name = 'rsna-2024-' + model_name
model_dir = os.path.join(root_dir, 'models', full_model_name)

kp_cfg = load_config(os.path.join(model_dir, 'config.json'))

kp_preds, kp_ys, kp_data = Runner(kp_cfg, model_name=full_model_name).predict()
assert data.equals(kp_data)

In [None]:
# Spinal canal stenosis (Sagittal T2/STIR) tile prediction
model_name = 'honest-thunder-425'
full_model_name = 'rsna-2024-' + model_name
tiles_sagt2_model_dir = os.path.join(root_dir, 'models', full_model_name)
tiles_sagt2_cfg = load_config(os.path.join(tiles_sagt2_model_dir, 'config.json'))
tiles_sagt2_model_path_list = sorted([os.path.join(tiles_sagt2_model_dir, x) for x in os.listdir(tiles_sagt2_model_dir) if x.endswith('_best.pt')])

model_list = []
for path in tiles_sagt2_model_path_list:
    model = getattr(module_model, tiles_sagt2_cfg['model']['type'])(**tiles_sagt2_cfg['model']['args'])
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
    model_list.append(model)

In [None]:
# Loop through images
# baseline_preds = orig_baseline_preds.clone()
transform = ToTensorV2()

out_vars = pd.Series(baseline_cfg['out_vars'])
sagt2_idx = out_vars[out_vars.str.startswith('spinal_canal_stenosis')].index.values

for i, row in tqdm(enumerate(data.itertuples()), total=len(data)):
    study_id = row.study_id
    series_id = get_random_series_id(df_series, study_id, 'Sagittal T2/STIR')
    
    if series_id is None:
        continue

    preds = []
    for level in range(5):
        heatmap = kp_preds[i][level]
        y_coord, x_coord = np.unravel_index(heatmap.argmax(), heatmap.shape)
        x_coord /= heatmap.shape[1]
        y_coord /= heatmap.shape[0]

        x = get_tile(
            study_id,
            series_id,
            x_coord,
            y_coord,
            img_dir=os.path.join(data_dir, 'train_images'),
            img_num=tiles_sagt2_cfg['dataset']['args']['img_num'],
            prop=tiles_sagt2_cfg['dataset']['args']['proportion'],
            resolution=tiles_sagt2_cfg['dataset']['args']['resolution'],
            norm_coords=True,
        )
        x = transform(image=x)['image']
        x = x.unsqueeze(0)

        model = model_list[row.fold - 1]
        with torch.no_grad():
            pred = model(x.to(device))
            preds.append(pred)
    preds = torch.concat(preds).swapaxes(0, 1)
    
    # Overwrite baseline prediction
    baseline_preds[i][:, sagt2_idx] = preds

In [None]:
orig_baseline_loss = CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 4.0]).to(device))(orig_baseline_preds, baseline_ys).item()
baseline_loss = CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 4.0]).to(device))(baseline_preds, baseline_ys).item()
print(f'Baseline loss: {orig_baseline_loss:.4f} -> {baseline_loss:.4f}')
print(f'Improvement: {orig_baseline_loss - baseline_loss:.4f}, {100 * (orig_baseline_loss - baseline_loss) / orig_baseline_loss:.2f}%')

In [None]:
sagt2_orig_baseline_loss = CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 4.0]).to(device))(orig_baseline_preds[..., sagt2_idx], baseline_ys[..., sagt2_idx])
sagt2_baseline_loss = CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 4.0]).to(device))(baseline_preds[..., sagt2_idx], baseline_ys[..., sagt2_idx])
print(f'Sagittal T2/STIR loss: {sagt2_orig_baseline_loss:.4f} -> {sagt2_baseline_loss:.4f}')
print(f'Improvement: {sagt2_orig_baseline_loss - sagt2_baseline_loss:.4f}, {100 * (sagt2_orig_baseline_loss - sagt2_baseline_loss) / sagt2_orig_baseline_loss:.2f}%')