In [1]:
import os
import json
import numpy as np
import pandas as pd
import torch
import random
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
import logging
import cv2

from scipy.optimize import minimize
from tqdm import tqdm
from multiprocessing import Pool

from rsna2024.runner import Runner
from rsna2024.utils import rsna_lumbar_metric

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

logging.getLogger('albumentations').setLevel(logging.WARNING)

coord_filename = 'train_label_coordinates_predicted_v2_{}_{}_{}.csv'.format(593, 654, 603)

root_dir = '/media/latlab/MR/projects/kaggle-rsna-2024'
data_dir = os.path.join(root_dir, 'data', 'raw')
img_dir = os.path.join(data_dir, 'train_images')

levels = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
sides = ['left', 'right']
coord_df = pd.read_csv(
    os.path.join(root_dir, 'data', 'processed', coord_filename),
    dtype={'study_id': 'str', 'series_id': 'str'},
)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def load_config(config_path):
    with open(config_path) as f:
        return json.load(f)


def get_metric(y_true, y_pred):
    y_true = y_true.clone()
    y_true[y_true == -100] = 0
    solution, submission = rsna_lumbar_metric.prepare_data(y_true, y_pred)
    metric = rsna_lumbar_metric.score(
        solution=solution,
        submission=submission,
        row_id_column_name='row_id',
        any_severe_scalar=1.0,
    )
    return metric


model_names = {
    'spinal': 'rsna-2024-giddy-monkey-1266',
    'foraminal': 'rsna-2024-hardy-voice-1244',
    'subarticular': 'rsna-2024-fiery-meadow-1254',
    'global': 'rsna-2024-dashing-spaceship-1252',
    
    'spinal_2': 'rsna-2024-leafy-river-1268',
    'foraminal_2': 'rsna-2024-snowy-oath-1251',
    'subarticular_2': 'rsna-2024-hearty-spaceship-1256',
    'global_2': 'rsna-2024-cool-frost-1378',
    
    'spinal_3': 'rsna-2024-splendid-glade-1421',
    'foraminal_3': 'rsna-2024-blooming-gorge-1250',
    'subarticular_3': 'rsna-2024-smooth-resonance-1422',
    'global_3': 'rsna-2024-radiant-tree-1423',
}

## ROI Models
### Spinal

In [None]:
spinal_preds_list = []
for model_type in ['spinal', 'spinal_2', 'spinal_3']:
    model_name = model_names[model_type]
    cfg = load_config(os.path.join(root_dir, 'models', model_name, 'config.json'))
    spinal_preds, spinal_ys, spinal_data = Runner(cfg, model_name=model_name).predict(
        df_coordinates=coord_df
    )
    spinal_preds = np.moveaxis(spinal_preds.reshape(5, 5, -1, 3), 1, -1).reshape(-1, 3, 5)
    spinal_ys = np.moveaxis(spinal_ys.reshape(5, 5, -1), 1, -1).reshape(-1, 5)
    spinal_preds_list.append(torch.tensor(spinal_preds).to(device))
spinal_ys = torch.tensor(spinal_ys).to(device)

### Foraminal

In [None]:
foraminal_preds_list = []
for model_type in ['foraminal', 'foraminal_2', 'foraminal_3']:
    model_name = model_names[model_type]
    cfg = load_config(os.path.join(root_dir, 'models', model_name, 'config.json'))
    foraminal_preds, foraminal_ys, foraminal_data = Runner(cfg, model_name=model_name).predict(
        df_coordinates=coord_df
    )
    foraminal_preds = np.moveaxis(
        np.moveaxis(foraminal_preds.reshape(5, 5, -1, 2, 3), 3, -1), 1, -1
    ).reshape(-1, 3, 10)
    foraminal_ys = np.moveaxis(
        np.moveaxis(foraminal_ys.reshape(5, 5, -1, 2), 3, -1), 1, -1
    ).reshape(-1, 10)
    foraminal_preds_list.append(torch.tensor(foraminal_preds).to(device))
foraminal_ys = torch.tensor(foraminal_ys).to(device)

### Subarticular

In [None]:
subarticular_preds_list = []
for model_type in ['subarticular', 'subarticular_2', 'subarticular_3']:
    model_name = model_names[model_type]
    cfg = load_config(os.path.join(root_dir, 'models', model_name, 'config.json'))
    subarticular_preds, subarticular_ys, subarticular_data = Runner(
        cfg, model_name=model_name
    ).predict(df_coordinates=coord_df)
    subarticular_preds = np.moveaxis(
        np.moveaxis(subarticular_preds.reshape(5, 5, -1, 2, 3), 3, -1), 1, -1
    ).reshape(-1, 3, 10)
    subarticular_ys = np.moveaxis(
        np.moveaxis(subarticular_ys.reshape(5, 5, -1, 2), 3, -1), 1, -1
    ).reshape(-1, 10)
    subarticular_preds_list.append(torch.tensor(subarticular_preds).to(device))
subarticular_ys = torch.tensor(subarticular_ys).to(device)

### Global improvement

In [5]:
preds_list = [
    torch.concatenate(
        [spinal_preds_list[i], foraminal_preds_list[i], subarticular_preds_list[i]], axis=-1
    )
    for i in range(len(spinal_preds_list))
]
ys = torch.concatenate([spinal_ys, foraminal_ys, subarticular_ys], axis=-1)

## Global ROI model

In [None]:
global_preds_list = []
for model_type in ['global', 'global_2', 'global_3']:
    model_name = model_names[model_type]
    cfg = load_config(os.path.join(root_dir, 'models', model_name, 'config.json'))
    preds_global, ys_global, data_global = Runner(cfg, model_name=model_name).predict(
        df_coordinates=coord_df
    )
    preds_global = np.moveaxis(preds_global.reshape(5, 5, -1, 3, 5), 1, -1).reshape(-1, 3, 25)
    ys_global = np.moveaxis(ys_global.reshape(5, 5, -1, 5), 1, -1).reshape(-1, 25)
    global_preds_list.append(torch.tensor(preds_global).to(device))
ys_global = torch.tensor(ys_global).to(device)

### Ensemble

In [None]:
assert torch.equal(ys, ys_global)


def get_loss(ys, preds):
    return CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 4.0]).to(device))(preds, ys).item()


def select_list_indices(lst, indices):
    return [lst[i] for i in indices]


print('split models: resnet, swin, convnext')
print('\nspinal')
for i in range(len(spinal_preds_list)):
    print(get_loss(spinal_ys, spinal_preds_list[i]))
print('\nforaminal')
for i in range(len(foraminal_preds_list)):
    print(get_loss(foraminal_ys, foraminal_preds_list[i]))
print('\nsubarticular')
for i in range(len(subarticular_preds_list)):
    print(get_loss(subarticular_ys, subarticular_preds_list[i]))

print('\nall (combined) models: split resnet, swin, convnext and global resnet, swin, convnext')
for i in range(len(preds_list)):
    print(get_metric(ys, preds_list[i]))
for i in range(len(global_preds_list)):
    print(get_metric(ys, global_preds_list[i]))

print('\nresnet + swin + convnext')
print(get_metric(ys, sum(preds_list) / 3))
print(get_metric(ys, sum(global_preds_list) / 3))
print(get_metric(ys, (sum(preds_list) + sum(global_preds_list)) / 6))

print('\nresnet + swin')
print(get_metric(ys, sum(preds_list[:2]) / 2))
print(get_metric(ys, sum(global_preds_list[:2]) / 2))
print(get_metric(ys, (sum(preds_list[:2]) + sum(global_preds_list[:2])) / 4))

print('\nswin + convnext')
print(get_metric(ys, sum(preds_list[1:]) / 2))
print(get_metric(ys, sum(global_preds_list[1:]) / 2))
print(get_metric(ys, (sum(preds_list[1:]) + sum(global_preds_list[:1])) / 4))

print('\nresnet + convnext')
print(get_metric(ys, sum(select_list_indices(preds_list, [0, 2])) / 2))
print(get_metric(ys, sum(select_list_indices(global_preds_list, [0, 2])) / 2))
print(
    get_metric(
        ys,
        (
            sum(select_list_indices(preds_list, [0, 2]))
            + sum(select_list_indices(global_preds_list, [0, 2]))
        )
        / 4,
    )
)

### Optimize ensemble weighting

In [None]:
np.random.seed(42)
preds_list_cpu = [pred.cpu() for pred in preds_list]
global_preds_list_cpu = [pred.cpu() for pred in global_preds_list]
ys_cpu = ys.cpu()

n_jobs = 36
w_n = 6
N = 100
w0 = [np.random.rand(w_n) for _ in range(N)]
w0 = [w / w.sum() for w in w0]


def objective(weights):
    weights = torch.tensor(weights, dtype=torch.float).cpu()
    pred = torch.stack(preds_list_cpu + global_preds_list_cpu, dim=-1) @ weights
    return get_metric(ys_cpu, pred)


def minimize_loss(w0):
    res = minimize(
        objective,
        w0,
        bounds=[(0, 1)] * w_n,
        options={'maxiter': 10000},
        constraints=[{'type': 'eq', 'fun': lambda w: w.sum() - 1}],
    )
    return res.fun, res.x


res = Pool(n_jobs).map(minimize_loss, w0)
losses, weights = list(zip(*res))