In [1]:
import sys
sys.path.insert(0, "..")


In [2]:
import torch
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup

import timm
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import albumentations as A
from torch.utils.data import DataLoader


import cv2

from lib.dataloader import RSNA24DF

from lib.classification.dataloader import ClassificationDataLoader, DiscLevelLocs
from lib.patientInfo import DiscLevel, PatientInfo, Scan, Condition, Img, ImgLabel

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.12 (you have 1.4.11). Upgrade using: pip install --upgrade albumentations


In [3]:
AXIAL_HEIGHT=256
AXIAL_WIDTH=256

SAGITTAL_HEIGHT = 256
SAGITTAL_WIDHT = 256

MODEL_NAME = "tf_efficientnet_b5.ns_jft_in1k"
DEVICE = "cuda:0"

DEPTH=8

OUTPUT_DIM = 15
N_CLASSES = 5

WINDOW_HEIGHT_MM = 30

DATA_DIR = "../data"

N_FOLDS = 5
SEED = 41
N_WORKERS = 8



EPOCHS = 20
GRAD_ACC = 2
TGT_BATCH_SIZE = 32
BATCH_SIZE = TGT_BATCH_SIZE // GRAD_ACC
EARLY_STOPPING_EPOCH = 6
OUTPUT_DIR = "../models/classification/"

TEST_BATCH_SIZE = 1

N_FOLDS = 5
AUG_PROB = 0.75

LR = 2e-4 * TGT_BATCH_SIZE / 32
AUG = True

In [4]:

train_df = pd.read_csv(f"{DATA_DIR}/train.csv")
train_df = train_df[train_df["study_id"] != 3008676218]
train_label_coordinates_df = pd.read_csv(f"{DATA_DIR}/train_label_coordinates.csv")
train_series_descriptions_df = pd.read_csv(f"{DATA_DIR}/train_series_descriptions.csv")


rsna24DF = RSNA24DF(train_df, train_label_coordinates_df, train_series_descriptions_df, f"{DATA_DIR}/train_images")


# test_series_descriptions_df = pd.read_csv(f"{DATA_DIR}/test_series_descriptions.csv")

# rsna24DF = RSNA24DF(None, None, test_series_descriptions_df, f"{DATA_DIR}/test_images")

In [5]:
class FinalDataLoader(ClassificationDataLoader):
    def __init__(self
                , patient_ids
                , phase
                , transformations=None
                , axial_height = AXIAL_HEIGHT
                , axial_width = AXIAL_WIDTH
                , sagittal_height = SAGITTAL_HEIGHT
                , sagittal_width = SAGITTAL_WIDHT
                , rsna24DF=rsna24DF
                , disc_level_locs_dir="../processed-data"
                , depth=DEPTH
                ) -> None:
        super().__init__(patient_ids=patient_ids 
                         , rsna24DF=rsna24DF
                         , transformations=transformations
                         , depth=depth 
                         , disc_level_locs_dir=disc_level_locs_dir
                         , phase="train"
                        )
        self.axial_height = axial_height
        self.axial_width = axial_width
        self.sagittal_height = sagittal_height
        self.sagittal_width = sagittal_width

    def _axial_t2_imgs(self, disc_level: DiscLevel, patient_info: PatientInfo, disc_level_locs: DiscLevelLocs) -> np.ndarray:

        # print("Sampling sagittal Axial IMGS")
        xs = np.zeros((self.axial_height, self.axial_width, self.depth))

        level_mm = disc_level_locs.disc_loc_mm[disc_level.to_int()]
        # print(f"{level_mm=}")

        near_imgs = []

        for img in patient_info.get_scans(Scan.AxialT2):
            if abs(img.dicom.ImagePositionPatient[2] - level_mm) <= WINDOW_HEIGHT_MM / 2:
                near_imgs.append(img)

        # select top 4 near imgs
        # print(f"{len(near_imgs)=}")
        near_imgs = sorted(near_imgs, key = lambda img: abs(img.dicom.ImagePositionPatient[2] - level_mm))[:self.depth]

        # sort these images from head to toe order
        near_imgs = sorted(near_imgs, key = lambda img: img.dicom.ImagePositionPatient[2])
        near_imgs.reverse()

        for idx, img in enumerate(near_imgs):
            # print(f"{img.dicom.ImagePositionPatient[2]=}")
            pixels = img.dicom.pixel_array
            xs[...,idx] = cv2.resize(pixels, (self.axial_height, self.axial_width),interpolation=cv2.INTER_CUBIC)

        return xs

    def _sagittal_t2_stir_imgs(self, disc_level: DiscLevel, patient_info: PatientInfo, disc_level_locs: DiscLevelLocs) -> np.ndarray:
        # print("\n\n\n\n")
        # print("Sampling sagittal T2/STIR IMGS")
        return self._sagittal_imgs(Scan.SagittalT2_STIR, disc_level, patient_info, disc_level_locs)

    def _sagittal_t1_imgs(self, disc_level: DiscLevel, patient_info: PatientInfo, disc_level_locs: DiscLevelLocs) -> np.ndarray:
        # print("\n\n\n\n")
        # print("Sampling sagittal T1 IMGS")
        return self._sagittal_imgs(Scan.SagittalT1, disc_level, patient_info, disc_level_locs)

    def _sagittal_imgs(self, scan_type: Scan, disc_level: DiscLevel, patient_info: PatientInfo, disc_level_locs: DiscLevelLocs) -> np.ndarray:
        
        xs = np.zeros((self.sagittal_height, self.sagittal_width, self.depth))
        level_mm = disc_level_locs.disc_loc_mm[disc_level.to_int()]
        # print(f"{level_mm=}")

        imgs = sorted([(img.dicom.ImagePositionPatient[0], img) for img in patient_info.get_scans(scan_type)], key = lambda x: x[0])

        if len(imgs) == 0:
            return xs
        
        img_idxs = np.unique(np.round(np.linspace(0, len(imgs) - 1, self.depth)).astype(int))


        for idx,img_idx in enumerate(img_idxs):

            # try:
            _,img = imgs[img_idx]
            # except Exception as e:
            #     print(f"{patient_info.patient_id=}")
            #     print(f"{img_idxs=}, {len(imgs)=}")
            

            # print(f"{img.dicom.ImagePositionPatient[0]=}")

            # crop image (width)
            pixels = img.dicom.pixel_array
            pixels = cv2.normalize(pixels, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            _, binary_image = cv2.threshold(pixels, 50, 255, cv2.THRESH_BINARY)  
            contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            contour_areas = [cv2.contourArea(contour) for contour in contours]
            largest_contour_index = np.argmax(contour_areas)
            x, y, w, h = cv2.boundingRect(contours[largest_contour_index])
            cropped_img = img.dicom.pixel_array[:, x:x+w]
            # print(f"{cropped_img.shape=} (before center cropping)")
            
            window_width_px = WINDOW_HEIGHT_MM / img.dicom.PixelSpacing[0]
            level_y_coord = (img.dicom.ImagePositionPatient[2] - level_mm) / img.dicom.PixelSpacing[0]
            # assert (int(level_y_coord - window_width_px / 2) > 0)

            # print(f"center cropping range: {int(level_y_coord - window_width_px / 2)} to {int(level_y_coord + window_width_px / 2)} ")
            cropped_img = cropped_img[max(int(level_y_coord - window_width_px / 2), 0) : int(level_y_coord + window_width_px / 2) , :]
            
            # print(f"{cropped_img.shape=} (after center cropping)")
            
            xs[..., idx] =  cv2.resize(cropped_img, (self.sagittal_height, self.sagittal_width),interpolation=cv2.INTER_CUBIC)

        return xs


In [6]:
class ClassificationModel(nn.Module):
    # def __init__(self, output_dim=OUTPUT_DIM, n_classes=N_CLASSES, model_name=MODEL_NAME):
    #     super().__init__()
        
    #     self.initial_conv = nn.Conv3d(in_channels=1, out_channels=3, kernel_size=3, padding=1, stride=1)
    #     self.output_dim = output_dim
    #     self.n_classes = n_classes
    #     model = torch.hub.load('facebookresearch/pytorchvideo', model_name, pretrained=True)
    #     self.feature_extractors = nn.ModuleList(
    #         [model.blocks[i] for i in range(len(model.blocks) - 1) ]
    #     )
    #     self.final_pool = nn.Sequential(
    #        nn.AvgPool3d(kernel_size=(8,7,7), stride=(1,1,1), padding=(0,0,0)), 
    #     nn.AdaptiveAvgPool3d(output_size=1)
    #     )
    #     self.classification_head = nn.Sequential(
    #        nn.Dropout(p=0.5, inplace=False),
    #        nn.Linear(2048, self.output_dim),
    #     )

    #     del model


    # def forward(self, x):
    #     x = self.initial_conv(x)
    #     for feature_extractor in self.feature_extractors:
    #         x = feature_extractor(x)
    #     x = self.final_pool(x)
    #     x = self.classification_head(x.squeeze())
    #     return x
        
    def __init__(self, model_name, in_c=DEPTH * 3, n_classes=15, pretrained=True, features_only=False):
        super().__init__()
        self.model = timm.create_model(
                                    model_name,
                                    pretrained=pretrained, 
                                    features_only=features_only,
                                    in_chans=in_c,
                                    num_classes=n_classes,
                                    global_pool='avg'
                                    )
    
    def forward(self, x):
        y = self.model(x)
        return y

In [7]:
skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

In [8]:
transforms_train = A.Compose([
    # A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUG_PROB),
    # A.OneOf([
    #     A.OpticalDistortion(distort_limit=1.0),
    #     A.GridDistortion(num_steps=5, distort_limit=1.),
    #     A.ElasticTransform(alpha=3),
    # ], p=AUG_PROB),

    # A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUG_PROB),
    # A.Resize(AXIAL_HEIGHT, AXIAL_WIDTH),
    # A.CoarseDropout(max_holes=16, max_height=64, max_width=64, min_holes=1, min_height=8, min_width=8, p=AUG_PROB),    
    A.Normalize(mean=0.5, std=0.5)
])

transforms_val = A.Compose([
    # A.Resize(AXIAL_HEIGHT, AXIAL_WIDTH),
    A.Normalize(mean=0.5, std=0.5)
])



In [9]:

def train(model,dataloader, criterion, optimizer, scheduler=None):
    
    model.train()
    optimizer.zero_grad()
    training_loss = []
    
    for idx, (x, y) in enumerate(pbar := tqdm(dataloader)):
        
        x, y = x.to(DEVICE), y.long().to(DEVICE)
        
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            pred_y = model(x)
            pred_y = pred_y.reshape(-1, 3)
            loss = criterion(pred_y, y.reshape(-1))
            loss = loss / GRAD_ACC

            
        loss.backward()
        norm = nn.utils.clip_grad_norm(model.parameters(), 1.0)

        training_loss.append(loss.item() * GRAD_ACC)

        if (idx + 1) % GRAD_ACC == 0:
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()


        pbar.set_description(f"loss: {loss.item() * GRAD_ACC:.6f}, | norm: {norm:.4f}| training_loss: {np.mean(training_loss):.6f}")
        
    
    return np.mean(training_loss)

def validation(model, data_loader, criterion):
    
        model.eval()
    
        val_loss = []

        with torch.no_grad():
    
            for idx, (x, y) in enumerate(pbar := tqdm(data_loader)):
                
                x, y = x.to(DEVICE), y.long().to(DEVICE)

                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    
                    pred_y = model(x)
                    pred_y = pred_y.reshape(-1, 3)
                    loss = criterion(pred_y, y.reshape(-1))

                val_loss.append(loss.item())
                pbar.set_description(f"current loss: {loss.item():.6f}, validation_loss: {np.mean(val_loss):.6f}")

        return np.mean(val_loss)

In [10]:

for fold, (trn_idx, val_idx) in enumerate(skf.split(range(len(train_df)))):
    print('#'*30)
    print(f'start fold{fold}')
    print('#'*30)
    print(len(trn_idx), len(val_idx))

    df_train = train_df.iloc[trn_idx]
    df_valid = train_df.iloc[val_idx]


    train_ds = FinalDataLoader(df_train["study_id"].unique(), transformations=transforms_train, phase="train")
    train_dl = DataLoader(
                train_ds,
                batch_size=BATCH_SIZE,
                shuffle=True,
                pin_memory=True,
                drop_last=True,
                num_workers=N_WORKERS
                )
    
    valid_ds = FinalDataLoader(df_valid["study_id"].unique(), transformations=transforms_val, phase="train")
    valid_dl = DataLoader(
                valid_ds,
                batch_size=BATCH_SIZE * 2,
                shuffle=False,
                pin_memory=True,
                drop_last=False,
                num_workers=N_WORKERS
                )
    
    model = ClassificationModel(model_name=MODEL_NAME).to(DEVICE)
    optimizer = AdamW(model.parameters(), lr=LR)
    
    weights = torch.tensor([1.0, 2.0, 4.0])
    criterion = nn.CrossEntropyLoss(weight=weights.to(DEVICE))
    
    warmup_steps = EPOCHS/10 * len(train_dl) // GRAD_ACC
    num_total_steps = EPOCHS * len(train_dl) // GRAD_ACC
    num_cycles = 0.475
    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=num_total_steps,
                                                num_cycles=num_cycles)

    
    best_val_loss   = 20.902649
    early_stop_counter = 0
    
    for epoch in range(1, EPOCHS + 1):
        print(f"EPOCH: {epoch}")
        _ = train(model, train_dl, criterion, optimizer, scheduler)
        val_loss = validation(model, valid_dl, criterion)

        print(f"val_loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            early_stop_counter = 0
            print(f"updating best_val_loss from {best_val_loss} to {val_loss}")
            
            best_val_loss = val_loss

            print(f"updated losses: {best_val_loss=}")

            print("Saving model....")
            fname = f'../models/classification//best_loc_model_fold-{fold}.pt'
            torch.save(model.state_dict(), fname)
            
        else:
            early_stop_counter += 1
            print(f"{EARLY_STOPPING_EPOCH - early_stop_counter} more epochs to train until early stopping")

        if early_stop_counter == EARLY_STOPPING_EPOCH:
            break
    break

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b5.ns_jft_in1k)


##############################
start fold0
##############################
1579 395


INFO:timm.models._hub:[timm/tf_efficientnet_b5.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Converted input conv conv_stem pretrained weights from 3 to 24 channel(s)
INFO:timm.models._builder:Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


EPOCH: 1


  norm = nn.utils.clip_grad_norm(model.parameters(), 1.0)
loss: 1.070396, | norm: 6.0117| training_loss: 1.361337: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.83it/s]
current loss: 1.237676, validation_loss: 1.266427: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.59it/s]


val_loss: 1.2664
updating best_val_loss from 20.902649 to 1.2664265999427209
updated losses: best_val_loss=1.2664265999427209
Saving model....
EPOCH: 2


loss: 0.842423, | norm: 3.1346| training_loss: 0.980758: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:26<00:00,  3.71it/s]
current loss: 0.789191, validation_loss: 1.010502: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.62it/s]


val_loss: 1.0105
updating best_val_loss from 1.2664265999427209 to 1.010501884497129
updated losses: best_val_loss=1.010501884497129
Saving model....
EPOCH: 3


loss: 0.654059, | norm: 1.8931| training_loss: 0.829734: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.87it/s]
current loss: 0.659258, validation_loss: 0.834524: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.69it/s]


val_loss: 0.8345
updating best_val_loss from 1.010501884497129 to 0.8345239001971024
updated losses: best_val_loss=0.8345239001971024
Saving model....
EPOCH: 4


loss: 1.106116, | norm: 3.6012| training_loss: 0.780106: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.83it/s]
current loss: 0.980719, validation_loss: 0.836258: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.62it/s]


val_loss: 0.8363
5 more epochs to train until early stopping
EPOCH: 5


loss: 0.619677, | norm: 1.5873| training_loss: 0.744581: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.79it/s]
current loss: 0.848104, validation_loss: 0.797468: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.61it/s]


val_loss: 0.7975
updating best_val_loss from 0.8345239001971024 to 0.7974684605231652
updated losses: best_val_loss=0.7974684605231652
Saving model....
EPOCH: 6


loss: 0.646602, | norm: 2.0413| training_loss: 0.707565: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.81it/s]
current loss: 0.667565, validation_loss: 0.857785: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.64it/s]


val_loss: 0.8578
5 more epochs to train until early stopping
EPOCH: 7


loss: 0.768049, | norm: 2.2335| training_loss: 0.668065: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:26<00:00,  3.68it/s]
current loss: 0.599330, validation_loss: 0.851151: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.64it/s]


val_loss: 0.8512
4 more epochs to train until early stopping
EPOCH: 8


loss: 0.800804, | norm: 2.6480| training_loss: 0.615684: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:26<00:00,  3.75it/s]
current loss: 0.600349, validation_loss: 0.897478: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.64it/s]


val_loss: 0.8975
3 more epochs to train until early stopping
EPOCH: 9


loss: 0.657627, | norm: 2.6590| training_loss: 0.587942: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:25<00:00,  3.83it/s]
current loss: 0.565228, validation_loss: 0.872453: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.60it/s]


val_loss: 0.8725
2 more epochs to train until early stopping
EPOCH: 10


loss: 0.635659, | norm: 2.8735| training_loss: 0.529170: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:26<00:00,  3.76it/s]
current loss: 0.553116, validation_loss: 0.980774: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.58it/s]


val_loss: 0.9808
1 more epochs to train until early stopping
EPOCH: 11


loss: 0.302529, | norm: 1.9913| training_loss: 0.465565: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:26<00:00,  3.75it/s]
current loss: 0.815569, validation_loss: 0.880021: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.61it/s]

val_loss: 0.8800
0 more epochs to train until early stopping





In [None]:

for epoch in range(1, EPOCHS + 1):
    print(f"EPOCH: {epoch}")
    _ = train(model, train_dl, criterion, optimizer, scheduler)
    val_loss = validation(model, valid_dl, criterion)

    print(f"val_loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        early_stop_counter = 0
        print(f"updating best_val_loss from {best_val_loss} to {val_loss}")
        
        best_val_loss = val_loss

        print(f"updated losses: {best_val_loss=}")

        print("Saving model....")
        fname = f'../models/classification//best_loc_model_fold-{fold}.pt'
        torch.save(model.state_dict(), fname)
        
    else:
        early_stop_counter += 1
        print(f"{EARLY_STOPPING_EPOCH - early_stop_counter} more epochs to train until early stopping")

    if early_stop_counter == EARLY_STOPPING_EPOCH:
        break

In [11]:

model = ClassificationModel(model_name=MODEL_NAME).to(DEVICE)

model.load_state_dict(torch.load("../models/classification/best_loc_model_fold-0.pt"))
model.eval()

ds = FinalDataLoader(test_series_descriptions_df["study_id"].unique(), transformations=transforms_val, phase="pred", disc_level_locs_dir="../processed-data/test-data")

Using cache found in /home/paradox/.cache/torch/hub/facebookresearch_pytorchvideo_main


In [12]:

preds = []
for x in ds[0]:
    with torch.no_grad():
        x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0).to(DEVICE)
        y = model(x)
        print(y)
        y = y.reshape(-1, 3).softmax(dim=1).cpu().numpy()
        preds.append(y.copy())

level_mm=62.20833
level_mm=62.20833
level_mm=62.20833
tensor([ 0.4057, -0.4424, -0.0332,  0.0242,  0.2285, -0.3126, -0.0453,  0.1302,
        -0.2160, -0.1779,  0.0373,  0.1771, -0.1856,  0.0184,  0.3195],
       device='cuda:0')
level_mm=34.433823
level_mm=34.433823
level_mm=34.433823
tensor([ 0.3217, -0.3847, -0.0195, -0.0971,  0.2153, -0.1836, -0.1216,  0.1495,
        -0.0982, -0.1911,  0.0737,  0.2167, -0.2191,  0.0416,  0.3105],
       device='cuda:0')
level_mm=20.595108
level_mm=20.595108
level_mm=20.595108
tensor([ 0.3081, -0.4020,  0.0328, -0.1800,  0.2407, -0.1573, -0.2032,  0.1702,
        -0.0549, -0.2865,  0.1053,  0.2727, -0.3081,  0.0647,  0.3911],
       device='cuda:0')
level_mm=0.5409241
level_mm=0.5409241
level_mm=0.5409241
tensor([ 0.2847, -0.4143,  0.0322, -0.2370,  0.2510, -0.1299, -0.2485,  0.1846,
        -0.0375, -0.3041,  0.1014,  0.3077, -0.3226,  0.0703,  0.4030],
       device='cuda:0')
level_mm=-20.475632
level_mm=-20.475632
level_mm=-20.475632
tensor([ 0.

In [17]:
df = {}

for i in range(5):
    disc_level = DiscLevel.from_int(i)
    disc_level_preds = preds[i]
    for condition, pred in zip(Condition.all_conditions(), disc_level_preds):
        loc = Location(disc_level=disc_level, condition=condition).to_str()
        df[f"44036939_{loc}"] = {"normal_mild": pred[0], "moderate": pred[1], "severe": pred[2]}

In [18]:
from lib.patientInfo import Location

In [24]:
pd.DataFrame(df).transpose()

Unnamed: 0,normal_mild,moderate,severe
44036939_spinal_canal_stenosis_l1_l2,0.482407,0.206579,0.311014
44036939_left_neural_foraminal_narrowing_l1_l2,0.340065,0.417134,0.242801
44036939_right_neural_foraminal_narrowing_l1_l2,0.329501,0.392703,0.277795
44036939_left_subarticular_stenosis_l1_l2,0.272749,0.338263,0.388988
44036939_right_subarticular_stenosis_l1_l2,0.257492,0.315776,0.426732
44036939_spinal_canal_stenosis_l2_l3,0.453645,0.223834,0.322521
44036939_left_neural_foraminal_narrowing_l2_l3,0.304526,0.41618,0.279294
44036939_right_neural_foraminal_narrowing_l2_l3,0.299836,0.393211,0.306953
44036939_left_subarticular_stenosis_l2_l3,0.262683,0.342335,0.394982
44036939_right_subarticular_stenosis_l2_l3,0.250241,0.324781,0.424978


In [20]:
preds

[array([[0.48240703, 0.20657907, 0.31101385],
        [0.34006482, 0.41713405, 0.2428011 ],
        [0.32950148, 0.39270347, 0.27779502],
        [0.27274927, 0.33826306, 0.3889877 ],
        [0.25749156, 0.31577626, 0.4267322 ]], dtype=float32),
 array([[0.45364544, 0.22383353, 0.322521  ],
        [0.304526  , 0.41618007, 0.2792939 ],
        [0.2998361 , 0.39321098, 0.30695292],
        [0.262683  , 0.34233522, 0.39498177],
        [0.25024068, 0.32478124, 0.42497802]], dtype=float32),
 array([[0.44425666, 0.21840213, 0.33734122],
        [0.28201714, 0.42950064, 0.2884822 ],
        [0.2768153 , 0.40212578, 0.321059  ],
        [0.23646207, 0.34988976, 0.41364822],
        [0.22401266, 0.32523805, 0.45074928]], dtype=float32),
 array([[0.43976662, 0.21859288, 0.3416404 ],
        [0.26724026, 0.43532407, 0.29743564],
        [0.26477334, 0.40827805, 0.3269486 ],
        [0.23020862, 0.34532288, 0.42446855],
        [0.21990104, 0.3257624 , 0.45433658]], dtype=float32),
 array([[0.4