In [59]:
import sys
import os
from pathlib import Path
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Polygon

sys.path.insert(0, '../data/SpineNet')
import spinenet
from spinenet import SpineNet, download_example_scan
from spinenet.io import load_dicoms_from_folder

spnt = SpineNet(device='cuda:0', verbose=True)



Loading Detection Model...
==> Loading model trained for 436 epochs...
Loading Appearance Model...
==> Loading model trained for 188 epochs...
Loading Context Model...
==> Loading model trained for 17 epochs...
Loading Grading Model...
==> Loading model trained for 2 epochs...


In [60]:
LEVELS = ["L1", "L2", "L3", "L4", "L5", "S1"]
COLORS = {
    "L1": "red",
    "L2": "blue",
    "L3": "green",
    "L4": "yellow",
    "L5": "white",
    "S1": "purple"
}

In [61]:
test_descs_path = "../data/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv"
test_images_path = "../data/rsna-2024-lumbar-spine-degenerative-classification/train_images"

In [62]:
test_descs = pd.read_csv(test_descs_path)
test_descs_filtered = test_descs[test_descs["series_description"] == "Sagittal T1"][:10]
test_descs

Unnamed: 0,study_id,series_id,series_description
0,4003253,702807833,Sagittal T2/STIR
1,4003253,1054713880,Sagittal T1
2,4003253,2448190387,Axial T2
3,4646740,3201256954,Axial T2
4,4646740,3486248476,Sagittal T1
...,...,...,...
6289,4287160193,1507070277,Sagittal T2/STIR
6290,4287160193,1820446240,Axial T2
6291,4290709089,3274612423,Sagittal T2/STIR
6292,4290709089,3390218084,Axial T2


In [63]:
def calculate_centers(data):
    centers = {}
    for item in data:
        level = item["predicted_label"]
        if level in LEVELS:
            average_polygon = item["average_polygon"]
            centroid_x = np.mean(average_polygon[:, 0])
            centroid_y = np.mean(average_polygon[:, 1])
            centroid_z = item["slice_nos"][len(item["slice_nos"])//2]
            centers[level] = (centroid_x, centroid_y, centroid_z)
    return centers

In [64]:
centers_per_study = {
    "study_id": [],
    "series_id": [],
    "x": [],
    "y": [],
    "instance_number": [],
    "level": []
}

In [65]:
for index, row in test_descs_filtered.iterrows():
    scan = load_dicoms_from_folder(f"{test_images_path}/{row['study_id']}/{row['series_id']}", require_extensions=False)
    num_slices = scan.volume.shape[-1]

    vert_dicts = spnt.detect_vb(scan.volume, scan.pixel_spacing)
    centers = calculate_centers(vert_dicts)
    
    for level in centers:
        centers_per_study["study_id"].append(row['study_id'])
        centers_per_study["series_id"].append(row['series_id'])
        centers_per_study["level"].append(level)
        
        centers_per_study["x"].append(centers[level][0])
        centers_per_study["y"].append(centers[level][1])
        centers_per_study["instance_number"].append(centers[level][2])

In [66]:
centers_per_study = pd.DataFrame.from_dict(centers_per_study)
centers_per_study

Unnamed: 0,study_id,series_id,x,y,instance_number,level
0,4003253,1054713880,171.570871,116.577009,8,L1
1,4003253,1054713880,169.672123,157.646825,8,L2
2,4003253,1054713880,166.796317,202.517857,8,L3
3,4003253,1054713880,168.193304,246.345536,8,L4
4,4003253,1054713880,180.032143,285.492411,7,L5
5,4003253,1054713880,205.034226,320.785714,7,S1
6,4646740,3486248476,202.648065,121.860398,11,L1
7,4646740,3486248476,191.529539,183.674256,9,L2
8,4646740,3486248476,186.141807,246.120076,8,L3
9,4646740,3486248476,187.07115,305.788156,9,L4


# Stage 1.5: Dump out bounding boxes csv

In [67]:
def convert_coords_to_patient(x, y, dicom_slice):            
    dX, dY = dicom_slice.PixelSpacing
    
    X = np.array(list(dicom_slice.ImageOrientationPatient[:3]) + [0]) * dY
    Y = np.array(list(dicom_slice.ImageOrientationPatient[3:]) + [0]) * dX

    S = np.array(list(dicom_slice.ImagePositionPatient) + [1])

    transform_matrix = np.array([Y, X, np.zeros(len(X)), S]).T
    # transform_matrix = transform_matrix @ transform_matrix_factor

    return (transform_matrix @ np.array([y, x, 0, 1]).T)

In [68]:
from pydicom import dcmread

test_images_basepath = "../data/rsna-2024-lumbar-spine-degenerative-classification/train_images"

patient_coords_dict = {
    "study_id": [],
    "level": [],
    "x": [],
    "y": [],
    "z": []
}

for index, group in centers_per_study.groupby("study_id"):
    for row_index, row in group.iterrows():
        dicom_slice_path = f"{test_images_basepath}/{row['study_id']}/{row['series_id']}/{row['instance_number']}.dcm"
        dicom_slice = dcmread(dicom_slice_path)
        coords = convert_coords_to_patient(row['x'], row['y'], dicom_slice)
        
        patient_coords_dict["study_id"].append(row['study_id'])
        patient_coords_dict["level"].append(row['level'])
        patient_coords_dict["x"].append(coords[0])
        patient_coords_dict["y"].append(coords[1])
        patient_coords_dict["z"].append(coords[2])
    
patient_coords = pd.DataFrame.from_dict(patient_coords_dict)
patient_coords

Unnamed: 0,study_id,level,x,y,z
0,4003253,L1,2.192236,57.39329,-373.897314
1,4003253,L2,2.417129,55.910159,-405.982332
2,4003253,L3,2.663579,53.663727,-441.036977
3,4003253,L4,2.900401,54.755404,-475.276522
4,4003253,L5,-1.697765,63.999043,-505.892668
5,4003253,S1,-1.529187,83.532135,-533.464894
6,4646740,L1,2.94656,46.128942,95.62545
7,4646740,L2,9.54656,40.627584,65.040448
8,4646740,L3,12.8466,37.961777,34.142755
9,4646740,L4,9.54656,38.421608,4.619467


In [69]:
#patient_coords.to_csv('/kaggle/working/coords_3d.csv', index=False)

In [70]:
patient_bounding_boxes_dict = {
    "study_id": [],
    "level": [],
    "x_min": [],
    "y_min": [],
    "z_min": [],
    "x_max": [],
    "y_max": [],
    "z_max": [],
}

for index, group in patient_coords.groupby("study_id"):
    ordered_group = group.sort_values(by="level", ascending=True)
    if len(ordered_group) != 6:
        continue
    for level_index in range(5):
        patient_bounding_boxes_dict["study_id"].append(ordered_group['study_id'].iloc[0])
        level_label = ordered_group['level'].iloc[level_index].lower() + "_" + ordered_group['level'].iloc[level_index + 1] 
        patient_bounding_boxes_dict["level"].append(level_label)
        
        # Middle vertebra points
        pt_0 = np.array(ordered_group.iloc[level_index][["x", "y", "z"]])
        pt_1 = np.array(ordered_group.iloc[level_index + 1][["x", "y", "z"]])
        
        # Distance vector to the next vertebra
        d_vec = np.array(pt_0 - pt_1)
        d_size = np.linalg.norm(d_vec)
        d_unit = d_vec / d_size
        
        
        # Get a pair of orthogonal vectors to find x and y boundary candidates
        orth_1 = np.random.randn(3).astype(np.float64)
        orth_1 = orth_1 - orth_1.dot(d_unit) * d_unit
        orth_1 = orth_1 / np.linalg.norm(orth_1)
        
        orth_1 = orth_1.astype(np.float64)
        d_unit = d_unit.astype(np.float64)
        
        orth_2 = np.cross(orth_1, d_unit)
        orth_2 = orth_2.astype(np.float64)
        
        orth_1 *= d_size
        orth_2 *= d_size
        
        # Get candidate points (10 of them, 2 per orthogonal per each vertebra center, and the centers themselves)
        c_pts = np.array([pt - vec for pt in (pt_0, pt_1) for vec in (orth_1, orth_2)] + 
                         [pt + vec for pt in (pt_0, pt_1) for vec in (orth_1, orth_2)] +
                         [pt_0, pt_1])
        
        # x_min and x_max are just the min and max from all this
        x_min = np.min(c_pts[:, 0])
        x_max = np.max(c_pts[:, 0])
                
        # y_max is going to be over the center ys
        # And we're going to get y_min by getting y_min over c_pts and then extending the y_min over center ys
        c_pts_y_min = np.min(c_pts[:, 1])
        c_pts_y_max = np.max(c_pts[:, 1])

        y_max = max(pt_0[1], pt_1[1])
        y_min = min(pt_0[1], pt_1[1])
        
        y_max += abs(c_pts_y_max - y_max) * 2
        y_min -= abs(c_pts_y_min - y_min) / 2
        
        # z_max and z_min will be the same as x_min and x_max
        z_min = np.min(c_pts[:, 2])
        z_max = np.max(c_pts[:, 2])    
        
        patient_bounding_boxes_dict["x_min"].append(x_min)
        patient_bounding_boxes_dict["y_min"].append(y_min)
        patient_bounding_boxes_dict["z_min"].append(z_min)
        patient_bounding_boxes_dict["x_max"].append(x_max)
        patient_bounding_boxes_dict["y_max"].append(y_max)
        patient_bounding_boxes_dict["z_max"].append(z_max)

patient_bounding_boxes = pd.DataFrame.from_dict(patient_bounding_boxes_dict)
patient_bounding_boxes

Unnamed: 0,study_id,level,x_min,y_min,z_min,x_max,y_max,z_max
0,4003253,l1_L2,-24.815985,42.422933,-407.351076,29.425349,111.342195,-372.52857
1,4003253,l2_L3,-32.479435,36.251634,-443.296855,37.560144,125.55853,-403.722454
2,4003253,l3_L4,-21.679099,41.495624,-476.212849,27.24308,103.427817,-440.10065
3,4003253,l4_L5,-32.732321,39.587189,-513.891909,33.934958,124.671902,-467.277282
4,4003253,l5_S1,-29.785708,52.566304,-549.778413,26.558756,129.26309,-489.57915
5,4646740,l1_L2,-22.164476,27.650596,56.478346,34.657596,98.03689,104.187552
6,4646740,l2_L3,-21.454963,22.427431,30.784835,43.848123,102.764965,68.398367
7,4646740,l3_L4,-19.332318,23.439551,1.488643,41.725478,96.510513,37.273579
8,4646740,l4_L5,-17.861117,25.132462,-29.11614,36.954237,98.386911,11.310892
9,4646740,l5_S1,-19.959952,34.196574,-61.727329,35.753072,103.906222,-7.909329


In [71]:
#patient_bounding_boxes.to_csv('/kaggle/working/bounding_boxes_3d.csv', index=False)

# Stage 2: Run inference on individual vertebrae

## Data Loading

In [72]:
import os

def retrieve_image_paths(base_path, study_id, series_id):
    series_dir = os.path.join(base_path, str(study_id), str(series_id))
    images = os.listdir(series_dir)
    image_paths = [os.path.join(series_dir, img) for img in images]
    return image_paths

In [73]:
import open3d as o3d
import pgzip
import os

def read_vertebral_levels_as_voxel_grids(dir_path,
                                         vertebral_levels: list[str],
                                         max_bounds: list[np.array],
                                         min_bounds: list[np.array],
                                         pcd_overall: o3d.geometry.PointCloud = None,
                                         cache_basepath="/kaggle/working/cached_3d",
                                         series_type_dict=None,
                                         downsampling_factor=1,
                                         voxel_size=(128, 128, 42),
                                        caching=False):
    ret = {}
    
    os.makedirs(cache_basepath, exist_ok=True)
    
    if pcd_overall is None:
        pcd_overall = read_study_as_pcd(dir_path,
                                        series_types_dict=series_type_dict,
                                        downsampling_factor=downsampling_factor,
                                        img_size=(voxel_size[0], voxel_size[2]),
                                        stack_slices_thickness=True,
                                        resize_slices=False)

    
    for index, vertebral_level in enumerate(vertebral_levels):
        cache_path = os.path.join(cache_basepath, f"cached_grid_{vertebral_level}_{voxel_size[0]}_{voxel_size[1]}_{voxel_size[2]}.npy.gz")
        f = None
        if caching and os.path.exists(cache_path):
            try:
                f = pgzip.PgzipFile(cache_path, "r")
                ret[vertebral_level] = np.load(f, allow_pickle=True)
                f.close()
            except Exception as e:
                print(dir_path, "\n", e)
                if f:
                    f.close()
                os.remove(cache_path)

        else:
            bbox = o3d.geometry.AxisAlignedBoundingBox(min_bound=min_bounds[index], max_bound=max_bounds[index])
            pcd_level = pcd_overall.crop(bbox)

            box = pcd_level.get_axis_aligned_bounding_box()

            max_b = np.array(box.get_max_bound())
            min_b = np.array(box.get_min_bound())

            pts = (np.array(pcd_level.points) - (min_b)) * (
                    (voxel_size[0] - 1, voxel_size[1] - 1, voxel_size[2] - 1) / (max_b - min_b))
            coords = np.round(pts).astype(np.int32)
            vals = np.array(pcd_level.colors, dtype=np.float16)

            grid = np.zeros((3, voxel_size[0], voxel_size[1], voxel_size[2]), dtype=np.float16)
            indices = coords[:, 0], coords[:, 1], coords[:, 2]

            np.maximum.at(grid[0], indices, vals[:, 0])
            np.maximum.at(grid[1], indices, vals[:, 1])
            np.maximum.at(grid[2], indices, vals[:, 2])
            
            if caching:
                f = pgzip.PgzipFile(cache_path, "w")
                np.save(f, grid)
                f.close()

            ret[vertebral_level] = grid

    return ret


In [74]:
def read_study_as_pcd(dir_path,
                      series_types_dict=None,
                      downsampling_factor=1,
                      resize_slices=True,
                      resize_method="nearest",
                      stack_slices_thickness=True,
                      img_size=(256, 256)):
    pcd_overall = o3d.geometry.PointCloud()

    for path in glob.glob(os.path.join(dir_path, "**/*.dcm"), recursive=True):
        dicom_slice = dcmread(path)

        series_id = os.path.basename(os.path.dirname(path))
        study_id = os.path.basename(os.path.dirname(os.path.dirname(path)))
        if series_types_dict is None or int(series_id) not in series_types_dict:
            series_desc = dicom_slice.SeriesDescription
        else:
            series_desc = series_types_dict[int(series_id)]
            series_desc = series_desc.split(" ")[-1]

        x_orig, y_orig = dicom_slice.pixel_array.shape
        if resize_slices:
            if resize_method == "nearest":
                img = np.expand_dims(cv2.resize(dicom_slice.pixel_array, img_size, interpolation=cv2.INTER_AREA), -1)
            elif resize_method == "maxpool":
                img_tensor = torch.tensor(dicom_slice.pixel_array).float()
                img = F.adaptive_max_pool2d(img_tensor.unsqueeze(0), img_size).numpy()
            else:
                raise ValueError(f"Invalid resize_method {resize_method}")
        else:
            img = np.expand_dims(np.array(dicom_slice.pixel_array), -1)
        x, y, z = np.where(img)

        downsampling_factor_iter = max(downsampling_factor, int(math.ceil(len(x) / 6e6)))

        index_voxel = np.vstack((x, y, z))[:, ::downsampling_factor_iter]
        grid_index_array = index_voxel.T
        pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(grid_index_array.astype(np.float64)))

        vals = np.expand_dims(img[x, y, z][::downsampling_factor_iter], -1)
        if series_desc == "T1":
            vals = np.pad(vals, ((0, 0), (0, 2)))
        elif series_desc == "T2":
            vals = np.pad(vals, ((0, 0), (1, 1)))
        elif series_desc == "T2/STIR":
            vals = np.pad(vals, ((0, 0), (2, 0)))
        else:
            raise ValueError(f"Unknown series desc: {series_desc}")

        pcd.colors = o3d.utility.Vector3dVector(vals.astype(np.float64))

        if resize_slices:
            transform_matrix_factor = np.matrix(
                [[0, y_orig / img_size[1], 0, 0],
                 [x_orig / img_size[0], 0, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]]
            )
        else:
            transform_matrix_factor = np.matrix(
                [[0, 1, 0, 0],
                 [1, 0, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]]
            )

        dX, dY = dicom_slice.PixelSpacing
        dZ = dicom_slice.SliceThickness

        X = np.array(list(dicom_slice.ImageOrientationPatient[:3]) + [0]) * dX
        Y = np.array(list(dicom_slice.ImageOrientationPatient[3:]) + [0]) * dY

        S = np.array(list(dicom_slice.ImagePositionPatient) + [1])

        transform_matrix = np.array([X, Y, np.zeros(len(X)), S]).T
        transform_matrix = transform_matrix @ transform_matrix_factor

        if stack_slices_thickness:
            for z in range(int(dZ)):
                pos = list(dicom_slice.ImagePositionPatient)
                if series_desc == "T2":
                    pos[-1] += z
                else:
                    pos[0] += z
                S = np.array(pos + [1])

                transform_matrix = np.array([X, Y, np.zeros(len(X)), S]).T
                transform_matrix = transform_matrix @ transform_matrix_factor

                pcd_overall += copy.deepcopy(pcd).transform(transform_matrix)

        else:
            pcd_overall += copy.deepcopy(pcd).transform(transform_matrix)

    return pcd_overall


In [75]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchio as tio
import torch.nn as nn
import pydicom

CONDITIONS = {
    "Sagittal T2/STIR": ["Spinal Canal Stenosis"],
    "Axial T2": ["Left Subarticular Stenosis", "Right Subarticular Stenosis"],
    "Sagittal T1": ["Left Neural Foraminal Narrowing", "Right Neural Foraminal Narrowing"],
}

LEVELS = ["l1_l2", "l2_l3", "l3_l4", "l4_l5", "l5_s1"]

class StudyPerVertebraLevelDataset(Dataset):
    def __init__(self,
                 base_path: str,
                 dataframe: pd.DataFrame,
                 bounds_dataframe: pd.DataFrame,
                 transform_3d=None,
                 vol_size=(128, 128, 128)
                ):
        self.base_path = base_path

        self.dataframe = (dataframe[['study_id', "series_id", "series_description"]]
                          .drop_duplicates())
        self.bounds_dataframe = bounds_dataframe

        self.subjects = self.dataframe[['study_id']].drop_duplicates().reset_index(drop=True)
        self.series = self.dataframe[["study_id", "series_id"]].drop_duplicates().groupby("study_id")[
            "series_id"].apply(list).to_dict()
        self.series_descs = {e[0]: e[1] for e in
                             self.dataframe[["series_id", "series_description"]].drop_duplicates().values}

        self.transform_3d = transform_3d
        self.vol_size = vol_size

    def __len__(self):
        return len(self.subjects)

    def __getitem__(self, index):
        curr = self.subjects.iloc[index % len(self.subjects)]

        study_path = os.path.join(self.base_path, str(curr["study_id"]))
        
        curr_bounds = self.bounds_dataframe[self.bounds_dataframe["study_id"] == curr["study_id"]].sort_values(by="level")
        
        study_images = read_vertebral_levels_as_voxel_grids(study_path,
                                                      vertebral_levels=LEVELS,
                                                      min_bounds=np.array(curr_bounds[['x_min', 'y_min', 'z_min']].values),
                                                      max_bounds=np.array(curr_bounds[['x_max', 'y_max', 'z_max']].values),
                                                      series_type_dict=self.series_descs,
                                                      voxel_size=self.vol_size
                                                    )
        
        ret = []
        for level in LEVELS:
            image = study_images[level]
            image = torch.FloatTensor(image)
            image = self.transform_3d(image)
            ret.append(image.to(torch.half))
            
        return torch.stack(ret), curr["study_id"]

In [76]:
transform_3d = tio.Compose([
    tio.RescaleIntensity([0, 1]),
])

In [77]:
def create_subject_level_testset_and_loader(df: pd.DataFrame,
                                             transform_3d,
                                             base_path: str,
                                             batch_size=1,
                                             num_workers=0):
    testset = StudyPerVertebraLevelDataset(base_path=test_images_basepath, dataframe=df, bounds_dataframe=patient_bounding_boxes, transform_3d=transform_3d)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return testset, test_loader

In [78]:
test_descs_valid = test_descs[test_descs["study_id"].isin(patient_bounding_boxes["study_id"])]
testset, test_loader = create_subject_level_testset_and_loader(test_descs_valid, transform_3d, test_images_basepath)

## Model Loading

In [79]:
import timm_3d
import torch

from spacecutter.losses import CumulativeLinkLoss
from spacecutter.models import LogisticCumulativeLink
from spacecutter.callbacks import AscensionCallback

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

class Classifier3dMultihead(nn.Module):
    def __init__(self,
                 backbone="efficientnet_lite0",
                 in_chans=1,
                 out_classes=5,
                 cutpoint_margin=0,
                 pretrained=False):
        super(Classifier3dMultihead, self).__init__()
        self.out_classes = out_classes

        self.backbone = timm_3d.create_model(
            backbone,
            features_only=False,
            drop_rate=0,
            drop_path_rate=0,
            pretrained=pretrained,
            in_chans=in_chans,
            global_pool="max",
        )
        if "efficientnet" in backbone:
            head_in_dim = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Sequential(
                nn.LayerNorm(head_in_dim),
                nn.Dropout(0),
            )

        elif "vit" in backbone or "coat" in backbone:
            self.backbone.head.drop = nn.Dropout(0)
            head_in_dim = self.backbone.head.fc.in_features
            self.backbone.head.fc = nn.Identity()

        self.heads = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(head_in_dim, 1),
                LogisticCumulativeLink(3)
            ) for i in range(out_classes)]
        )

        self.ascension_callback = AscensionCallback(margin=cutpoint_margin)

    def forward(self, x):
        feat = self.backbone(x)
        return torch.swapaxes(torch.stack([head(feat) for head in self.heads]), 0, 1)

    def _ascension_callback(self):
        for head in self.heads:
            self.ascension_callback.clip(head[-1])

model = Classifier3dMultihead(backbone="coatnet_rmlp_3_rw_224", in_chans=3, out_classes=5).to(device)
model.load_state_dict(torch.load("../models/coatnet_rmlp_3_rw_224_128_vertebrae_tuned_fold_0/coatnet_rmlp_3_rw_224_128_vertebrae_tuned_fold_0_0.pt"))

<All keys matched successfully>

In [80]:
CONDITIONS = {
    "Sagittal T2/STIR": ["spinal_canal_stenosis"],
    "Axial T2": ["left_subarticular_stenosis", "right_subarticular_stenosis"],
    "Sagittal T1": ["left_neural_foraminal_narrowing", "right_neural_foraminal_narrowing"],
}

ALL_CONDITIONS = sorted(["spinal_canal_stenosis", "left_subarticular_stenosis", "right_subarticular_stenosis", "left_neural_foraminal_narrowing", "right_neural_foraminal_narrowing"])
LEVELS = ["l1_l2", "l2_l3", "l3_l4", "l4_l5", "l5_s1"]

results_df = pd.DataFrame({"row_id":[], "normal_mild": [], "moderate": [], "severe": []})

ALL_CONDITIONS


['left_neural_foraminal_narrowing',
 'left_subarticular_stenosis',
 'right_neural_foraminal_narrowing',
 'right_subarticular_stenosis',
 'spinal_canal_stenosis']

In [81]:
# Pre-populate results df
import glob
import os

study_ids = glob.glob("../data/rsna-2024-lumbar-spine-degenerative-classification/train_images/*")
study_ids = [os.path.basename(e) for e in study_ids]

results_df = pd.DataFrame({"row_id":[], "normal_mild": [], "moderate": [], "severe": []})
for study_id in study_ids:
    for condition in ALL_CONDITIONS:
        for level in LEVELS:
            row_id = f"{study_id}_{condition}_{level}"
            results_df = results_df._append({"row_id": row_id, "normal_mild": 1/3, "moderate": 1/3, "severe": 1/3}, ignore_index=True)


In [82]:
import torch
from torch.cuda.amp import autocast
import time
import copy

start_time = time.time()

visualize_mid_slices = False

with torch.no_grad():
    with autocast(dtype=torch.float16):
        model.eval()

        for images, study_id in test_loader:
            output = model(images.squeeze(0).to(device))
            for level_index, image in enumerate(images.squeeze(0)):
                if visualize_mid_slices:
                    plt.imshow(np.max(image.numpy()[1, :, :, 62:66], axis=2), cmap="gray")
                    plt.show()
                for condition_index, condition_out in enumerate(output[level_index]):
                    row_id = f"{study_id[0]}_{ALL_CONDITIONS[condition_index]}_{LEVELS[level_index]}"
                    
                    results_df.loc[results_df.row_id == row_id, 'normal_mild'] = condition_out.cpu().numpy()[0]
                    results_df.loc[results_df.row_id == row_id, 'moderate'] = condition_out.cpu().numpy()[1]
                    results_df.loc[results_df.row_id == row_id, 'severe'] = condition_out.cpu().numpy()[2]
                
print("--- %s seconds ---" % (time.time() - start_time))

--- 131.04566550254822 seconds ---


In [83]:
results_df

Unnamed: 0,row_id,normal_mild,moderate,severe
0,100206310_left_neural_foraminal_narrowing_l1_l2,0.333333,0.333333,0.333333
1,100206310_left_neural_foraminal_narrowing_l2_l3,0.333333,0.333333,0.333333
2,100206310_left_neural_foraminal_narrowing_l3_l4,0.333333,0.333333,0.333333
3,100206310_left_neural_foraminal_narrowing_l4_l5,0.333333,0.333333,0.333333
4,100206310_left_neural_foraminal_narrowing_l5_s1,0.333333,0.333333,0.333333
...,...,...,...,...
49370,998688940_spinal_canal_stenosis_l1_l2,0.333333,0.333333,0.333333
49371,998688940_spinal_canal_stenosis_l2_l3,0.333333,0.333333,0.333333
49372,998688940_spinal_canal_stenosis_l3_l4,0.333333,0.333333,0.333333
49373,998688940_spinal_canal_stenosis_l4_l5,0.333333,0.333333,0.333333


In [84]:
results_df.to_csv('submission.csv', index=False)