In [1]:

import sys
sys.path.insert(0, "..")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


In [2]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import timm
from tqdm import tqdm

import pickle

import cv2
import matplotlib.pyplot as plt

from dataclasses import dataclass

from lib.model import DetectionModel
from lib.dataloader import DetectionDataLoader, RSNA24DF
from lib.patientInfo import Img, Scan, PatientInfo

In [3]:


DATA_DIR = "../data"
MODEL_NAME = "tf_efficientnet_b5.ns_jft_in1k"
DEVICE = "cuda:0"
# MODEL_DIR = "/home/paradox/Desktop/ai/rsna-2024-lumbar-spine-degenerative-classification/models/sagittial_t2_stir"

EPOCHS = 20

HEIGHT = 512
WIDTH = 512

# Network params
N_CLASSES = 5
HIDDEN_DIM = 768
NORMALISE=255.0


WINDOW_WIDTH = None
WINDOW_CENTER = None
CHANGE_WINDOW=False

OUTPUT_DIR = "../processed-data/test-data/"



In [4]:
# train_df = pd.read_csv(f"{DATA_DIR}/train.csv")
# train_df = train_df[train_df["study_id"] != 3008676218]
# train_label_coordinates_df = pd.read_csv(f"{DATA_DIR}/train_label_coordinates.csv")

test_series_descriptions_df = pd.read_csv(f"{DATA_DIR}/test_series_descriptions.csv")

rsna24df = RSNA24DF(None, None, test_series_descriptions_df, f"{DATA_DIR}/test_images")

In [5]:
class CanalSteneosisDataLoader(DetectionDataLoader):
    def __init__(self
             , patient_ids
             , positive_negative_ratio=0.5
             , positive_augment_prob=0.25
             , negative_augment_prob=0.15
             , rsna24DF=rsna24df
             , transformations=[]
             , height=HEIGHT
             , width=WIDTH
             , phase = "train"
             , window_width=WINDOW_WIDTH
             , window_center=WINDOW_CENTER
             , change_window=CHANGE_WINDOW
             , normalise=NORMALISE
             ) -> None:
        super().__init__(patient_ids=patient_ids 
                         , rsna24DF=rsna24DF
                         , transformations=transformations
                         , height=height 
                         , width=width 
                         , positive_negative_ratio=positive_negative_ratio 
                         , positive_augment_prob=positive_augment_prob 
                         , negative_augment_prob=negative_augment_prob
                         , phase=phase
                         , window_center=window_center
                         , window_width=window_width
                         , change_window=change_window
                         , normalise=normalise
                        )

    def _get_patient_scans(self, patient_info) -> list[Img]:
        return patient_info.get_scans(Scan.SagittalT2_STIR)
        
    def _mk_target_array(self, x, labels) -> tuple[np.ndarray, np.ndarray]:
        y_class, y_loc = np.zeros(N_CLASSES), np.zeros((N_CLASSES, 2))

        for label in labels:
            level = label.location.disc_level.to_int()
            y_class[level] = 1.0
            y_loc[level] = np.array([label.x, label.y])
        
        return y_class, y_loc

In [6]:
model = DetectionModel(MODEL_NAME, n_classes=N_CLASSES, coord_dim=2, hidden_dim=HIDDEN_DIM).to(DEVICE)
model.load_state_dict(torch.load("../models/sagittial_t2_stir/best_loc_model_fold-0.pt"))
model.eval()
print("Done loading weights")

Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


Done loading weights


In [10]:
patient_ids

array([44036939])

In [11]:
test_series_descriptions_df.shape

(3, 3)

In [7]:
patient_ids = test_series_descriptions_df["study_id"].unique()
ds = CanalSteneosisDataLoader(patient_ids, phase='pred', transformations=[])

In [8]:
@dataclass
class DiscLevelLocs:
    disc_pixel_loc: np.ndarray
    img_idxs: np.ndarray
    img_type: Scan
    disc_loc_mm: np.ndarray
    patient_id: int
    

def get_disc_level_loc(patient_idx, threshold=0.4):

    xs, imgs = ds[patient_idx]
    xs = torch.from_numpy(xs).to(DEVICE)
    num_imgs = xs.shape[0]

    with torch.no_grad():
        pred_class, pred_loc = model(xs)
        pred_class = pred_class.sigmoid().cpu()
        pred_loc = pred_loc.reshape(num_imgs, N_CLASSES, 2).cpu()

        max_class_prob = pred_class.max(0)
        disc_locs = pred_loc[max_class_prob.indices, torch.arange(5), :]

    disc_loc_mm = []

    for level, img_idx in enumerate(max_class_prob.indices):
        img = imgs[img_idx].dicom
        h,w = img.pixel_array.shape
        r_height, r_width = h/HEIGHT, w/WIDTH

        # X coordinate
        disc_locs[level,0] *= r_width
        
        # Y coordinate
        disc_locs[level,1] *= r_height

        # disc loc in mm = z coord - pixel_spacing_y * pixel_loc
        disc_loc_mm.append( img.ImagePositionPatient[2] -  img.PixelSpacing[0] * disc_locs[level,1])


    disc_level_locs = DiscLevelLocs( disc_pixel_loc=disc_locs.numpy()
                                   , disc_loc_mm=np.array(disc_loc_mm)
                                   , img_idxs = max_class_prob.indices.numpy()
                                   , img_type = Scan.SagittalT2_STIR
                                   , patient_id=patient_ids[patient_idx]
                                   )

    return disc_level_locs


In [9]:
for idx in tqdm(range(len(ds))):
    disc_level_locs = get_disc_level_loc(idx)

    with open(f"{OUTPUT_DIR}/{disc_level_locs.patient_id}.pkl", "wb") as f:
        pickle.dump(disc_level_locs, f)

  return F.conv2d(input, weight, bias, self.stride,
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.13it/s]


In [None]:
disc_level_locs = get_disc_level_loc(0)

# locs, disc_loc_mm, img_idxs
disc_level_locs.disc_pixel_loc, disc_level_locs.disc_loc_mm, disc_level_locs.img_idxs

In [None]:
import pandas as pd

In [None]:
test_df = pd.read_csv("../data/test_series_descriptions.csv")

In [None]:
test_df.head()

In [None]:
patient_info = PatientInfo.from_df( 44036939, None, test_series_descriptions_df, None, "../data/test_images/", patient_type="test")

In [None]:
for idx,img in enumerate(patient_info.get_scans(disc_level_locs.img_type)):
    if idx in disc_level_locs.img_idxs:
        x = img.dicom.pixel_array
        x = cv2.normalize(x, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)

        print(x.shape)
        print(idx, disc_level_locs.disc_pixel_loc[disc_level_locs.img_idxs == idx, :])
        for loc in disc_level_locs.disc_pixel_loc[disc_level_locs.img_idxs == idx, :]:
            c = (int(loc[0]), int(loc[1]))
            x = cv2.circle(x.copy(), c, 10, (255,192,203), 2)
        plt.imshow(x)
        plt.show()

In [None]:
for img in patient_info.get_scans(Scan.AxialT2):
    print(img.dicom.ImagePositionPatient[2])