In [1]:
import time

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import pandas as pd
from tqdm import tqdm
import logging
import seaborn as sn

from rsna_dataloader import *


In [2]:
class CustomResNet(nn.Module):
    def __init__(self, out_features=512, pretrained_weights=None):
        super(CustomResNet, self).__init__()
        self.model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        if pretrained_weights:
            self.model.load_state_dict(torch.load(pretrained_weights))
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features=num_ftrs, out_features=out_features)
        torch.nn.init.xavier_uniform(self.model.fc.weight)

    def forward(self, x):
        return self.model(x)


class FCHead(nn.Module):
    def __init__(self, drop_rate=0.1, num_classes=3):
        super(FCHead, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(256, 128),
            # nn.BatchNorm1d(256),
            nn.Dropout(drop_rate),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 128),
            nn.Dropout(drop_rate),
            nn.LeakyReLU(0.1),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        return self.model(x)


class CustomLSTM(nn.Module):
    hidden_size = 256
    num_layers = 3

    def __init__(self, num_classes=3, num_levels=5, drop_rate=0.2, resnet_weights=None):
        super(CustomLSTM, self).__init__()
        self.cnn = CustomResNet(pretrained_weights=resnet_weights)
        self.lstm = nn.LSTM(input_size=512, hidden_size=self.hidden_size, dropout=drop_rate, num_layers=self.num_layers,
                            batch_first=True,
                            bidirectional=True)
        self.heads = [FCHead().to(device) for i in range(num_levels)]

    def forward(self, x_3d):
        hidden = None

        # Iterate over each frame of a video in a video of batch * frames * channels * height * width
        for t in range(x_3d.size(1)):
            x = self.cnn(x_3d[:, t])
            # Pass latent representation of frame through lstm and update hidden state
            out, hidden = self.lstm(x.unsqueeze(0), hidden)

            # Get the last hidden state (hidden is a tuple with both hidden and cell state in it)

        return [head(hidden[0][-1]) for head in self.heads]


In [3]:
model = torch.load("../models/resnet18_lstm_t2stir.pt")

In [4]:
transform_val = transforms.Compose([
    transforms.Lambda(lambda x: (x * 255).astype(np.uint8)),  # Convert back to uint8 for PIL
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

In [5]:
data_basepath = "../data/rsna-2024-lumbar-spine-degenerative-classification/"
training_data = retrieve_training_data(data_basepath)

In [6]:
test_loader = create_series_level_test_datasets_and_loaders(training_data, "Sagittal T2/STIR", transform_val, data_basepath + "train_images")

In [7]:
label_map = {'normal_mild': 0, 'moderate': 1, 'severe': 2}
def get_output_class(val):
    if val <= 0.33:
        return 0
    elif val <= 0.66:
        return 1
    else:
        return 2

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

CustomLSTM(
  (cnn): CustomResNet(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNo

In [11]:
preds = []
for image, label in tqdm(test_loader):
    image = image.to(device)
    preds.append((model(image), label))
    if len(preds) == 100:
        break

  5%|▌         | 99/1973 [00:38<12:14,  2.55it/s]


In [12]:
preds

[([tensor([[ 0.3289, -0.3943, -0.3013]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.4153, -0.6202, -0.8320]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.4056, -0.3372, -0.2990]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.2775, -0.1511, -0.1332]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.2998, -0.1235, -0.2097]], device='cuda:0',
          grad_fn=<AddmmBackward0>)],
  tensor([[[1, 0, 0],
           [1, 0, 0],
           [1, 0, 0],
           [1, 0, 0],
           [1, 0, 0]]], dtype=torch.int32)),
 ([tensor([[ 0.3998, -0.3008, -0.2594]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.5347, -0.5981, -0.9388]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.4067, -0.1408, -0.1969]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([[ 0.2047, -0.0862, -0.1990]], device='cuda:0',
          grad_fn=<AddmmBackward0>),
   tensor([

In [None]:
def scores_to_classes(arr):
    ret = []
    for i in range(0, len(arr)-1, 2):
        if arr[i] == 0:
            ret.append(0)
        elif arr[i] == 1:
            ret.append(2)
        else:
            ret.append(1)
            
    return ret

In [None]:
pred_classes = [(scores_to_classes(e[0][0]), scores_to_classes(e[1][0])) for e in preds]

In [None]:
pred_classes_0 = [e[0][4] for e in pred_classes]
pred_classes_0_ = [e[1][4] for e in pred_classes]

In [None]:
cf_matrix = confusion_matrix(pred_classes_0, pred_classes_0_)
df_cm = pd.DataFrame(cf_matrix)
plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True)
plt.show()


In [None]:
training_data

In [None]:
training_data[training_data["condition"] == "Spinal Canal Stenosis"].groupby(["severity"]).count()