# Lumbar Spine Degenerative Classification using Deep Learning 💻⚕️

### Authors: Daniel Popov (315784173), Yotam Gershi (315784173)
*Computer Science Students*  
*Passionate about leveraging AI to solve challenging problems in the medical field* 👨‍⚕️

## Introduction

We are **Daniel Popov** and **Yotam Gershi**, computer science students passionate about leveraging artificial intelligence to tackle challenging problems in the medical field. In this project, we explore the **RSNA 2024 Lumbar Spine Degenerative Classification competition dataset**, which focuses on classifying degenerative changes in the lumbar spine using medical imaging.

The primary goal of this project is to explore the dataset thoroughly and develop deep learning models capable of accurately classifying various types of degenerative changes in the lumbar spine. This work aims to enhance the accuracy and efficiency of diagnostic workflows in clinical settings.


#### Motivation

Spinal degeneration is becoming increasingly prevalent worldwide, particularly within aging populations. Degenerative changes in the lumbar spine are often associated with chronic pain and diminished quality of life. There is a pressing need for automated tools to assist radiologists in evaluating spine health, improving diagnostic accuracy, and optimizing treatment planning. Our project aligns with this need, aiming to build AI-based tools that support healthcare professionals in managing these conditions more effectively.


#### Approach

This project will guide you through the following steps:

1. **Understanding the problem**: We start by analyzing the dataset's objectives and structure.
2. **Exploratory Data Analysis (EDA)**: We’ll conduct a comprehensive analysis of the CSV and image files to gain insights into the dataset, uncover patterns, detect anomalies, and understand class distributions. This step ensures we understand the data's nuances and are prepared to build effective models.
3. **Data Preprocessing**: We'll prepare the data for modeling, addressing issues such as missing values, class imbalance, and image normalization.
4. **Model Development**: Using state-of-the-art deep learning techniques, we will train and evaluate models to classify degenerative changes in lumbar spine images.
5. **Performance Evaluation**: We will assess our models using appropriate metrics and iterate on improvements.


#### Impact

Our work contributes to the growing field of AI in healthcare, aiming to make spinal health assessment more accurate, efficient, and accessible. By automating aspects of lumbar spine degeneration classification, this project seeks to reduce the workload of radiologists and improve patient outcomes.


Let’s get started! 🚀


### The Problem

Degenerative spine conditions significantly impact quality of life, often causing pain, reduced mobility, and diminished overall well-being. Accurate identification and assessment of these conditions using medical imaging are critical for developing effective treatment plans and improving patient outcomes.

The **RSNA 2024 Lumbar Spine Degenerative Classification Challenge** focuses on automating the detection and grading of three key types of degenerative conditions in the lumbar spine from medical images:

1. **Foraminal Narrowing**  
   This condition occurs when the foramina—the passageways through which spinal nerves exit the spinal canal—become compressed. This narrowing can occur on either the left or right side and often leads to significant nerve-related symptoms.

2. **Subarticular Stenosis**  
   Subarticular stenosis refers to the narrowing of the space beneath the articular processes of the spine, where nerve roots pass. This condition can result in nerve compression, causing pain and neurological symptoms depending on whether it occurs on the left or right side.

3. **Canal Stenosis**  
   Canal stenosis involves narrowing of the central spinal canal, which houses the spinal cord and nerve roots. The severity of symptoms can range from mild discomfort to significant neurological deficits, depending on the degree of compression.

Each of these conditions can appear at various levels within the lumbar spine, specifically around each vertebral disc. For example, the **L4/L5** level corresponds to the disc between the fourth (L4) and fifth (L5) lumbar vertebrae. The challenge requires predicting the degree of compression at these levels and classifying them as **normal**, **mild**, **moderate**, or **severe**.

The dataset includes MRI images of the lumbar spine, annotated by expert radiologists to indicate the presence and severity of these conditions. Participants are tasked with building machine learning models that can classify these conditions accurately, assisting radiologists in diagnosing degenerative spine conditions more efficiently and consistently.

By leveraging advanced computer vision and deep learning techniques, this challenge seeks to improve the diagnostic accuracy and reliability of spinal degeneration assessments, ultimately enhancing patient care.

For more details on the dataset and challenge, visit the [RSNA 2024 Lumbar Spine Degenerative Classification competition page](https://www.kaggle.com/competitions/rsna-2024-lumbar-spine-degenerative-classification).


### Anatomical Overview 🦴

The spine is divided into four regions:

- **Cervical region**: Contains 7 vertebral bodies.
- **Thoracic region**: Contains 12 vertebral bodies.
- **Lumbar region**: Contains 5 vertebral bodies.
- **Sacral region**: Contains 3-5 fused vertebral bodies.


<div style="display: flex; justify-content: center; align-items: center;">
    <img src="https://prod-images-static.radiopaedia.org/images/53655832/Gray-square.001_big_gallery.jpeg" alt="Description" width="300"/>
</div>

Between each vertebral body in all regions (except the sacrum) lies a **vertebral disc**. These discs act as cushions, providing flexibility and absorbing shock.

Along the posterior aspect of each vertebral body lies the **spinal cord**, a vital structure that transmits signals between the brain and the rest of the body. 

At each vertebral body, **spinal nerves** exit the spinal cord through openings between the vertebral bodies called **foramina**. These nerves are responsible for transmitting sensory and motor information to and from different parts of the body.


<div style="display: flex; justify-content: center; align-items: center;">
    <img src="https://files.miamineurosciencecenter.com/media/filer_public_thumbnails/filer_public/78/1e/781e78be-8980-466f-8a82-83a5c8350770/herniated_disc_larger.jpg__720.0x600.0_q85_subject_location-360%2C300_subsampling-2.jpg" alt="Description" width="300"/>
</div>

Compression of the spinal cord or any of the spinal nerves can cause significant pain and discomfort to patients. Several factors can lead to such compression, including:

- **Bulging vertebral disc**: When the disc protrudes beyond its normal boundary, it can press on nearby nerves or the spinal cord.
- **Degenerative changes in the bones**: These changes can lead to the growth of bony protrusions (osteophytes) or compression of the vertebrae themselves.
- **Trauma**: Injuries to the spine can result in displacement or fractures that compress the spinal cord or nerves.
- **Thickening of ligaments**: Ligaments surrounding the spinal cord may thicken over time, contributing to reduced space and nerve compression.

### **Foraminal Narrowing**

The spinal cord has spinal nerves that exit the spinal canal through openings called **foramina**. The foramina are best viewed in the **sagittal plane**. Occasionally, these openings can become compressed, resulting in **foraminal narrowing**. This compression causes pain along the nerve distribution downstream of the affected area.

- **Left image**: A sagittal MRI slice where the foramina are visible. Crosshairs indicate where the foramina exit the spinal canal.  
- **Right image**: Grading criteria for the degree of compression (note: for this challenge, **Normal/Mild** is grouped into one label).


<div style="display: flex; justify-content: center; align-items: center;">
    <div style="margin-right: 50px; flex: 1; display: flex; justify-content: flex-end;">
        <img src="https://i.imgur.com/6c7erNM.png" alt="Image 1 Description" width="300"/>
    </div>
    <div style="margin-left: 50px; flex: 1; display: flex; justify-content: flex-start;">
        <img src="https://i.imgur.com/b1VGiN5.png" alt="Image 2 Description" width="300"/>
    </div>
</div>

### **Subarticular Stenosis**

**Subarticular stenosis** occurs due to compression of the spinal cord in the **subarticular zone**, which can be best visualized in the **axial plane**.

- **Left image**: A schematic illustrating the relevant anatomical zone.  
- **Right image**: Grading criteria for determining the degree of subarticular stenosis (note: for this challenge, **Normal/Mild** is grouped into one label).

<div style="display: flex; justify-content: center; align-items: center;">
    <div style="margin-right: 50px; flex: 1; display: flex; justify-content: flex-end;">
        <img src="https://files.miamineurosciencecenter.com/media/filer_public_thumbnails/filer_public/d5/08/d508ae6a-a4f2-4796-be9f-455f8df45fe1/herniation_zones.jpg__1700.0x1308.0_q85_subject_location-850%2C656_subsampling-2.jpg" alt="Image 1 Description" width="300"/>
    </div>
    <div style="margin-left: 50px; flex: 1; display: flex; justify-content: flex-start;">
        <img src="https://i.imgur.com/Usuxgge.png" alt="Image 2 Description" width="300"/>
    </div>
</div>

**Canal Stenosis**

### **Canal Stenosis**

**Canal stenosis** refers to impingement on the **spinal canal**, where the spinal cord travels. Impingement can result from:

- **Bulging vertebral disc**: Protrusion of the disc into the canal.
- **Trauma**: Injury that alters the alignment or structure of the spinal canal.
- **Bony osteophytes**: Outgrowths of the vertebral bodies caused by degenerative changes.
- **Ligamental thickening**: Thickening of the ligaments that run along the spinal canal.

The degree of compression is best assessed in the **axial plane**.

- **Left image**: Canal stenosis visible in the sagittal plane, providing an overview of its appearance.  
- **Right image**: Grading criteria for canal stenosis (note: for this challenge, **Normal/Mild** is grouped into one label).


<div style="display: flex; justify-content: center; align-items: center;">
    <div style="margin-right: 50px; flex: 1; display: flex; justify-content: flex-end;">
        <img src="https://prod-images-static.radiopaedia.org/images/940993/f7a8adca63efae788f621869cc21e8_big_gallery.jpg" alt="Image 1 Description" width="300"/>
    </div>
    <div style="margin-left: 50px; flex: 1; display: flex; justify-content: flex-start;">
        <img src="https://i.imgur.com/opjnAwl.png" alt="Image 2 Description" width="300"/>
    </div>
</div>

### Imaging Overview 🩻

MRI imaging of the spine can be performed in three planes: the **axial plane**, the **sagittal plane**, and the **coronal plane**. In our dataset, the two primary image types are from the axial and sagittal planes:

- **Axial Plane**: Captures horizontal slices (perpendicular to the spine) across the body from top to bottom. These images are useful for assessing the spine and surrounding structures in cross-section.
- **Sagittal Plane**: Captures vertical slices (parallel to the spine) from left to right. Sagittal images provide a side view of the spine, which is essential for evaluating the alignment and curvature of the spinal column.

MRI images are typically classified as either **T1-weighted** or **T2-weighted**:

- **T1-Weighted Images**: Highlight fat as bright. For example, the inner parts of bones, which contain fatty marrow, appear brighter on T1 images. These images are often used to evaluate the anatomy of the spine and the surrounding soft tissues.
- **T2-Weighted Images**: Highlight water as bright, making fluids like the cerebrospinal fluid (CSF) in the spinal canal appear brighter. These images are particularly useful for detecting abnormalities related to water content, such as inflammation, edema, and other pathological changes.

> Unlike CT images, MRI images are not standardized regarding pixel values. The intensity values in MRI images do not have a fixed scale and can vary between different scanners and settings. As a result, standardizing these images before analysis might be necessary depending on the approach.

The images in this dataset are stored in the **DICOM format** (Digital Imaging and Communications in Medicine), which is the standard format for storing and transmitting medical images. DICOM is widely used in hospitals and clinics for its ability to encapsulate both image data and metadata.

**Key Features of DICOM Images**:
- **Embedded Metadata**: Each DICOM file contains image data along with metadata, such as patient demographics, study descriptions, imaging modality, acquisition parameters, and more. This metadata is critical for accurate interpretation and analysis.
- **Multi-frame Support**: DICOM files can store a series of images (e.g., multiple slices from an MRI scan) within a single file, allowing for efficient management of large datasets.
- **Lossless Compression**: DICOM images often use lossless compression to preserve full fidelity, ensuring no critical information is lost during storage or transmission.
- **Interoperability**: The DICOM standard ensures compatibility across different imaging equipment and software, making it easier to manage and analyze medical images from various sources.

The dataset is organized into the following directory structure, which includes folders for training and test images, as well as several CSV files containing metadata and labels:

```
├── /kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/
    ├── test_images/
    │   ├── 1005139/
    │   │   └── 609308237/
    │   │       ├── 1.dcm
    │   │       └── ...
    │   └── ...
    ├── test_series_descriptions.csv
    ├── train_images/
    │   ├── 4003253/
    │   │   └── 702807833/
    │   │       ├── 1.dcm
    │   │       └── ...
    │   └── ...
    ├── train_label_coordinates.csv
    ├── train_series_descriptions.csv
    └── train.csv
```

### **Explanation of Key Components**

- **`train_images`**:  
  This directory contains the MRI images for the training set. The images are organized into subdirectories, where each subdirectory corresponds to a unique study ID. Within each patient-specific folder, there are additional subfolders representing different imaging series or studies. The actual MRI slices are stored as DICOM (`.dcm`) files, the standard format for medical imaging.

- **`test_images`**:  
  This directory contains the MRI images for the test set. The images are organized into subdirectories, where each subdirectory corresponds to a unique study ID. Within each patient-specific folder, there are additional subfolders representing different imaging series or studies. The actual MRI slices are stored as DICOM (`.dcm`) files, the standard format for medical imaging.

- **`train.csv`**:  
  This file contains patient study identifiers (`study_id`) and labels for various degenerative spine conditions at different lumbar levels. Each condition's severity is classified as **Normal/Mild**, **Moderate**, or **Severe** across multiple spinal regions.

- **`train_label_coordinates.csv`**:  
  This file contains detailed information about specific regions of interest (ROIs) within the training images. The coordinates provided in this file localize areas affected by degenerative conditions, helping to understand the spatial distribution and severity of these changes.

- **`train_series_descriptions.csv`**:  
  This CSV file includes metadata for each series in the training set. It provides additional information about the MRI sequences, such as the imaging plane (axial or sagittal), sequence type (T1-weighted or T2-weighted), and other acquisition parameters.

- **`test_series_descriptions.csv`**:  
  This CSV file includes metadata for each series in the test set. It provides additional information about the MRI sequences, such as the imaging plane (axial or sagittal), sequence type (T1-weighted or T2-weighted), and other acquisition parameters.


## Imports 📤

In [197]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np 
import pandas as pd 

import cv2
import pydicom
from PIL import Image
from IPython.display import Image as IPyImage, display

import os
import re
import glob
import random
import shutil
from tqdm import tqdm
from IPython.display import clear_output

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import animation, rc
import matplotlib.image as mpimg
from pathlib import Path
sns.set(style="whitegrid")

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR

import timm

import yaml

import albumentations as A

from sklearn.model_selection import KFold
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split

from ultralytics import YOLO

In [2]:
# Applying a custom color scheme to the plots
import urllib.request
url = "https://raw.githubusercontent.com/h4pZ/rose-pine-matplotlib/main/themes/rose-pine-dawn.mplstyle"
save_path = "/Users/danipopov/Projects/RSNA2024/data/rose-pine-dawn.mplstyle"  # Include the file name
urllib.request.urlretrieve(url, save_path)
plt.style.use(save_path)

##  Exploratory Data Analysis (EDA) 🔍

### Load Data 🔄

Loading the dataset for the Lumbar Spine Degenerative Classification problem. We load the train, label coordinates, and series descriptions CSV files, which provide details about the spinal conditions and image metadata.

In [3]:
train_df = pd.read_csv("/Users/danipopov/Projects/RSNA2024/data/train.csv") 
label_coords_df = pd.read_csv("/Users/danipopov/Projects/RSNA2024/data/train_label_coordinates.csv")  
series_desc_df = pd.read_csv("/Users/danipopov/Projects/RSNA2024/data/train_series_descriptions.csv")  

### EDA for CSV Files 🕵📁

#### train.csv 📁

The `train.csv` file provides the primary labels for the training dataset. It includes a unique identifier for each patient's study (`study_id`) and columns representing various degenerative conditions at different spinal levels, such as spinal canal stenosis, neural foraminal narrowing, and subarticular stenosis, each labeled by specific lumbar spine levels (e.g., L1/L2, L2/L3, etc.). The values in each of these columns indicate the severity of the degeneration, categorized as `Normal/Mild`, `Moderate`, or `Severe`.

In [4]:
# Show five first row of the train_df
train_df.head(5)

In [5]:
train_df.shape

First we discover that we have 1975 rows and 10 columns, which means that we have 1975 studies with 25 labels for different conditions.
Let's check if there are any null values in the dataset.

In [6]:
train_df.info()

In [7]:
train_df.isnull().sum()

We can observe that we have some missing data, which means that for some `study_id`, we have missing labels for some conditions.

A heatmap is a great tool to effectively visualize missing data, showing where values are present or missing. Dark purple areas indicate data is available, while yellow lines highlight missing values, helping identify patterns to address before analysis.

In [8]:
plt.figure(figsize=(8,6))
sns.heatmap(train_df.isnull(), cbar=False, cmap='viridis', yticklabels=False)
plt.title('Heatmap of Missing Values')
plt.show()

The first thing we can observe is that, on average, the **missing rows have 2-3 missing values**, with some rows having more missing values. Most of the missing values occur in the `subarticular stenosis`. We will remember to address this later.

Next, we will visualize the distribution of severity labels (Normal/Mild, Moderate, and Severe) across foraminal narrowing, subarticular stenosis, and spinal canal stenosis.

In [9]:
# Exclude 'study_id' from the columns list
columns = [col for col in train_df.columns if col != 'study_id'] 

# Set up subplots for the three conditions
figure, axis = plt.subplots(1, 3, figsize=(12, 8))
# Loop through the conditions and plot each one
for idx, d in enumerate(['foraminal', 'subarticular', 'canal']):
    # Select diagnosis columns related to the current condition
    diagnosis = [col for col in train_df.columns if d in col]
    # Melt the DataFrame to convert columns into rows for easier plotting
    melted_df = pd.melt(train_df[diagnosis], value_vars=diagnosis, var_name='diagnosis', value_name='severity')
    # Countplot for the current diagnosis
    sns.countplot(data=melted_df, x='diagnosis', hue='severity', ax=axis[idx], palette='muted', hue_order=['Normal/Mild', 'Moderate', 'Severe'])
    axis[idx].set_title(f'{d.capitalize()} Distribution')
    axis[idx].tick_params(axis='x', rotation=90)
plt.tight_layout()
plt.show()

The majority of labels fall into the `Normal/Mild` category, highlighting a potential **class imbalance** that may impact the model's ability to detect the less frequent but clinically significant "Moderate" and "Severe" cases. To address this imbalance, we may apply techniques like class weighting, oversampling/undersampling, or stratified splitting during model training.

Another method to visualize the distribution of the "Normal/Mild," "Moderate," and "Severe" labels across various diagnoses is by utilizing pie charts. This approach may provide additional insights compared to our previous visualizations.

In [10]:
def plot_pie_charts(diagnosis_type, num_rows, num_cols, fig_size):

    figure, axis = plt.subplots(num_rows, num_cols, figsize=fig_size)
    
    # Filter diagnosis columns based on the given diagnosis type
    diagnosis_cols = [col for col in train_df.columns if diagnosis_type in col]
    
    # Loop through each diagnosis column and create a pie chart
    for idx, d in enumerate(diagnosis_cols):
        dff = train_df[d]
        value_counts = dff.value_counts()
        
        # Plot the pie chart
        axis[idx//num_cols, idx%num_cols].pie(value_counts, labels=value_counts.index, autopct='%1.1f%%')
        axis[idx//num_cols, idx%num_cols].set_title(d)
    
    # Adjust layout for better visibility
    plt.tight_layout()
    plt.show()

In [11]:
# Plot for foraminal narrowing
plot_pie_charts('foraminal', 5, 2, (15, 10))

In [12]:
# Plot for subarticular stenosis
plot_pie_charts('subarticular', 5, 2, (15, 10))

In [13]:
# Plot for spinal canal stenosis
plot_pie_charts('canal', 3, 2, (10, 10))

Our analysis indicates that across all diagnoses, including subarticular stenosis, foraminal narrowing, and canal stenosis, certain spine levels exhibit a small percentage of Severe cases. However, a significant imbalance is observed within the Moderate category, suggesting that both Severe and Moderate cases warrant further consideration during the data preprocessing phase.

Following our exploration of the label distribution across various conditions and the examination of the correlation matrix, we will now proceed to analyze the distribution of labels by spinal level. This analysis will enable us to assess how the severity of conditions (Normal/Mild, Moderate, Severe) is distributed across different spinal levels: L1/L2, L2/L3, L3/L4, L4/L5, and L5/S1. 

Understanding this distribution is crucial for identifying which spinal levels are most significantly impacted by degeneration and determining whether certain levels exhibit a higher prevalence of severe cases.

In [14]:
# Getting the columns of each level
l1_l2_cols = [col for col in columns if 'l1_l2' in col]
l2_l3_cols = [col for col in columns if 'l2_l3' in col]
l3_l4_cols = [col for col in columns if 'l3_l4' in col]
l4_l5_cols = [col for col in columns if 'l4_l5' in col]
l5_s1_cols = [col for col in columns if 'l5_s1' in col]

def plot_label_distribution(df, level, cols, title):
    # Melt the DataFrame to convert columns into rows for easier plotting
    melted_df = pd.melt(df[cols], value_vars=cols, var_name=f'{level}_diagnosis', value_name='severity')
    
    # Countplot
    plt.figure(figsize=(8, 6))
    sns.countplot(data=melted_df, x=f'{level}_diagnosis', hue='severity', palette='muted', hue_order=['Normal/Mild', 'Moderate', 'Severe'])
    plt.title(f'{title} Label Distribution')
    plt.xticks(rotation=45)
    plt.show()

In [15]:
# Plot for L1/L2
plot_label_distribution(train_df, 'L1/L2', l1_l2_cols, 'L1/L2')

In [16]:
# Plot for L2/L3
plot_label_distribution(train_df, 'L2/L3', l2_l3_cols, 'L2/L3')

In [17]:
# Plot for L3/L4
plot_label_distribution(train_df, 'L3/L4', l3_l4_cols, 'L3/L4')

In [18]:
# Plot for L4/L5
plot_label_distribution(train_df, 'L4/L5', l4_l5_cols, 'L4/L5')

In [19]:
# Plot for L5/S1
plot_label_distribution(train_df, 'L5/S1', l5_s1_cols, 'L5/S1')

The distribution of labels across all spinal levels (L1/L2 to L5/S1) reveals that the majority of cases are categorized as "Normal/Mild," highlighting a significant class imbalance. Moderate and Severe cases are considerably less frequent, with Severe cases being particularly rare. Notably, as we progress to lower spinal levels (e.g., L3/L4, L4/L5, and L5/S1), there is a **slight increase** in the proportion of Moderate and Severe cases, especially for conditions such as neural foraminal narrowing and subarticular stenosis. This imbalance, particularly in the lower spine, underscores the necessity of considering a model structure that can accommodate different spinal levels to effectively learn from the severe cases.


We will now proceed to the next tool for exploratory data analysis (EDA), the `correlation matrix`. This tool provides valuable insights into the relationships between various spinal conditions. A correlation value close to 1 indicates a strong positive relationship, while a value close to -1 signifies a strong negative relationship. These insights are instrumental in determining whether certain conditions are interrelated, thereby aiding in our understanding of the progression of spinal degeneration.

In [20]:
# Copy the df to convert categorical values to numeric
train_df_for_corr = train_df.copy()
train_df_for_corr.replace({"Normal/Mild": 0, "Moderate": 1, "Severe": 2}, inplace=True)

# Find all columns 
canal_columns = [col for col in columns if 'canal' in col]
foraminal_columns = [col for col in columns if 'foraminal' in col]
subarticular_columns = [col for col in columns if 'subarticular' in col]

# Define a function to plot heatmap
def plot_corr_heatmap(data, title, figsize=(10, 8)):
    corr_matrix = data.corr()
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
    plt.figure(figsize=figsize)
    sns.heatmap(corr_matrix, mask=mask, annot=True, fmt=".2f", cmap="coolwarm", square=True)
    plt.title(title)
    plt.show()

In [21]:
# Overall correlation matrix
plot_corr_heatmap(train_df_for_corr[columns], 'Correlation Matrix', figsize=(20, 15))

In [22]:
# Canal columns correlation matrix
plot_corr_heatmap(train_df_for_corr[canal_columns], 'Correlation Matrix For Canal Columns', figsize=(10, 5))

In [23]:
# Foraminal columns correlation matrix
plot_corr_heatmap(train_df_for_corr[foraminal_columns], 'Correlation Matrix For Foraminal Columns', figsize=(10, 5))

In [24]:
# Subarticular columns correlation matrix
plot_corr_heatmap(train_df_for_corr[subarticular_columns], 'Correlation Matrix For Subarticular Columns', figsize=(10, 5))

From the correlation matrices, we observe that certain spinal conditions demonstrate moderate to strong positive correlations, particularly within the same category (e.g., foraminal narrowing or subarticular stenosis across different levels). This indicates that degeneration at one spinal level is often associated with degeneration at adjacent levels. Furthermore, we note that within the subarticular and foraminal columns, there exists a correlation between the right and left sides. 

This observation will be taken into account when addressing missing labels and during subsequent preprocessing steps. Understanding these relationships is crucial for predicting the potential progression of degeneration throughout the spine.

Next, we will address the imbalance between the 'Severe' and 'Moderate' labels. First, we will count the number of rows that contain at least one occurrence of these labels and then sum the frequency of each label across the dataset. This analysis will enable us to quantify the extent of the imbalance between the two labels.

Additionally, we will check for rows containing missing values (NaNs) and save the indices of those rows. This will facilitate the handling of any potential data issues related to incomplete information during model training.

In [25]:
train_df_label_counts = train_df.copy()

# Find each rows contain severe or moderate cases
train_df_label_counts['contains_Severe'] = train_df.apply(lambda row: 'Severe' in row.values, axis=1)
train_df_label_counts['contains_Moderate'] = train_df.apply(lambda row: 'Moderate' in row.values, axis=1)

# How much severe and moderate cases in each row
train_df_label_counts['Severe_count'] = train_df.apply(lambda x: (x == 'Severe').sum(), axis=1)
train_df_label_counts['Moderate_count'] = train_df.apply(lambda x: (x == 'Moderate').sum(), axis=1)

# Sum how many rows contains severe and moderate cases
severe_count = train_df_label_counts['contains_Severe'].sum()
moderate_count = train_df_label_counts['contains_Moderate'].sum()

# Count how many rows contain the Severe and Moderate labels
print(f'How many rows contain the Severe label: {severe_count}, and the Moderate label: {moderate_count}')

From the results, we observe that 980 rows contain the severe label, while 1963 rows contain the moderate label. As noted at the beginning of the exploratory data analysis (EDA), the `train_df` consists of 1975 rows, indicating that nearly all rows have at least one moderate label. This suggests that almost all patients have at least one moderate condition and half of the patients have at least one secere label.

In [26]:
# Checking for rows with NaN values
rows_with_null_labels = train_df_label_counts.isna().any(axis=1)
num_rows_with_null = rows_with_null_labels.sum()

print(f'How many rows contain at least one Null value: {num_rows_with_null}')

We discovered that there are **185** rows containing at least one NaN value. This means we can choose to either remove them, which would reduce the amount of training and validation data and potentially affect model performance, or find a way to fill the NaN values.

In [27]:
# Saving the study_ids of rows with NaN values
indices_of_NaN = train_df_label_counts[train_df_label_counts.isna().any(axis=1)]['study_id'].tolist()

# Saving the study_ids of rows containing 'Severe'
indices_of_Severe = train_df_label_counts[train_df_label_counts.apply(lambda row: 'Severe' in row.values, axis=1)]['study_id'].tolist()

# Saving the study_ids of rows containing 'Moderate'
indices_of_Moderate = train_df_label_counts[train_df_label_counts.apply(lambda row: 'Moderate' in row.values, axis=1)]['study_id'].tolist()

# Find the study_ids that contain both 'Severe' and 'Moderate'
indices_of_Severe_and_Moderate = train_df_label_counts[train_df_label_counts.apply(lambda row: 'Severe' in row.values and 'Moderate' in row.values, axis=1)]['study_id'].tolist()

In [28]:
# Count how many rows have both Severe and Moderate
print(f'Number of rows with both Severe and Moderate: {len(indices_of_Severe_and_Moderate)}')

Next, we will examine the distribution of the 'Severe' and 'Moderate' labels within the dataset. Our objective is to understand the frequency of these labels in each row, which represents spinal conditions across various levels. We will visualize the counts of severe and moderate cases to further illustrate the extent of the label imbalance.

This analysis will provide valuable insights into the distribution of 'Severe' and 'Moderate' cases among patients.

In [29]:
severe_df = pd.DataFrame(train_df_label_counts['Severe_count'])

# Plot the Severe count distribution for each row
plt.figure(figsize=(8,4))
sns.histplot(severe_df, x='Severe_count', bins=50)
plt.title('Severe Count')
plt.xlabel('Number of severe count')
plt.ylabel('Count')
plt.show()

In [30]:
severe_df = pd.DataFrame(train_df_label_counts['Moderate_count'])

# Plot the Moderate count distribution for each row
plt.figure(figsize=(8,4))
sns.histplot(severe_df, x='Moderate_count', bins=30)
plt.title('Moderate Count')
plt.xlabel('Number of moderate count')
plt.ylabel('Count')
plt.show()

From the histograms, we observe that most rows have very few or no `Severe` labels (the majority having zero or one severe label), indicating a significant imbalance, with very few instances containing multiple 'Severe' conditions. On the other hand, the distribution of `Moderate` labels is more evenly spread, but there is still a noticeable imbalance, with many rows containing fewer moderate conditions. This further reinforces the need to address this imbalance during model training, possibly through resampling techniques or by adjusting class weights.

**Conclusion for train.csv analysis:** 💡📁

1. **Data Quality:**
   - 185 rows (out of 1974 total) contain at least one null value, indicating some missing data.

2. **Class Imbalance:**
   - Significant imbalance exists, with "Normal/Mild" being the most common category.
   - Out of 1974 studies:
     - 980 contain at least one Severe label
     - 1693 contain at least one Moderate label

3. **Spinal Level Distribution:**
   - "Normal/Mild" cases dominate across all spinal levels (L1/L2 to L5/S1).
   - Lower spinal levels show a slight increase in Moderate and Severe cases.

4. **Correlation Patterns:**
   - Moderate to strong positive correlations observed between spinal conditions, especially within the same type.
   - Right and left sides show notable correlation for subarticular and foraminal columns.

5. **Multi-level Involvement:**
   - 952 studies have both Severe and Moderate labels, indicating cases with varying severity across multiple spinal levels.

This analysis highlights the need to address class imbalance and missing data in subsequent preprocessing and modeling steps.

#### train_label_coordinates.csv 📁

The `train_label_coordinates.csv` file provides essential metadata for the localization of degenerative conditions in the images. It includes the patient's study ID (`study_id`), the series of images (`series_id`), and the specific image number within the series (`instance_number`). For each image, it provides coordinates (`x`, `y`) that can be useful for localization tasks such as identifying the specific areas of degeneration. 

In [31]:
label_coords_df.head(5)

In [32]:
label_coords_df.shape

We observe that we have 48,692 rows and 7 columns.

In [33]:
# Check for missing values
label_coords_df.isnull().sum()

Great! There are no missing values in the label_coords_df, so we don’t need to worry about any gaps in the data here.

In `train.csv`, we only had the `study_id` column, but now with the `train_label_coordinates.csv` file, we also have information about `series_id`. This allows us to explore how many unique series are associated with each study. 

We will group the data by `study_id` and count the number of unique `series_id` for each study. This step helps us understand the variability in MRI series across studies and may reveal patterns related to patient scans, such as studies with additional imaging series.


In [34]:
# Get the number of unique series_id for each study_id
unique_series_per_study = label_coords_df.groupby('study_id')['series_id'].nunique().reset_index(name='unique_series_count')

# Plot the distribution of unique series per study id
plt.figure(figsize=(8, 4))
sns.histplot(data=unique_series_per_study, x='unique_series_count', bins=20, kde=True)
plt.title('Distribution of Unique Series per Study ID')
plt.xlabel('Number of Unique Series per Study')
plt.ylabel('Frequency')
plt.show()

The distribution shows that most studies have exactly three unique series, with a small number of studies having four or five series. This suggests that the majority of studies follow a consistent data collection protocol, with three series per study. However, some studies include additional series, possibly indicating more complex cases or different imaging needs for certain patients.

Before moving forward with the exploration, we will check if we have the coordinates for each of the conditions and spinal levels by using a value count on the `study_id` in `train_label_coordinates.csv`. This will help us determine whether every study in the dataset has the expected number of condition labels. Studies with fewer labels may indicate missing information, which could affect model training. By identifying these cases, we can handle them appropriately in future steps.

In [35]:
label_coords_df['study_id'].value_counts()

In [36]:
# Get the sum of condition per study id 
sum_of_condition_per_study = label_coords_df['study_id'].value_counts()
value_counts_df = pd.DataFrame(sum_of_condition_per_study, columns=['count'])

# Plot the distribution of condition per study id 
plt.figure(figsize=(8,4))
sns.histplot(data=value_counts_df, x='count', bins=10)
plt.title('Distribution of Value Count for Study IDs')
plt.xlabel('Number of conditions per Study ID')
plt.ylabel('Frequency')
plt.show()

The histogram reveals that the majority of studies have exactly 25 conditions, which aligns with the expected number of spinal levels and conditions. However, a small number of studies have fewer than 25 conditions. This may indicate missing or incomplete data for certain patients, meaning that not all spinal levels have corresponding coordinates.

To address these cases later on, we will save the indices of all studies that do not have exactly 25 conditions. These studies might require special handling during the modeling process to ensure we do not introduce bias due to incomplete data.

If some of these cases involve the right or left side, we can complete them based on the opposite side’s coordinates by simply adjusting the value for the x-coordinate.

In [37]:
# Save the indices of study IDs that do not have 25 conditions
incomplete_study_ids = value_counts_df[value_counts_df['count'] != 25].index.to_list()

# Print the number of incomplete studies
print(f'Number of studies with fewer than 25 conditions: {len(incomplete_study_ids)}')

**We observed something interesting** there are 185 studies in the `train.csv` that have NaN values, and we also found 185 studies in the `train_label_coordinates.csv` with fewer than 25 conditions. Now, based on the study IDs we saved earlier, we can check if the missing data in `train.csv` aligns with the incomplete data in `train_label_coordinates.csv`. This will help us determine if the missing data is consistent across both datasets.

In [38]:
# Checking if the study_ids with missing values from train_df match with incomplete_study_ids
matching_study_ids = set(indices_of_NaN).intersection(set(incomplete_study_ids))

print(f'Number of matching study IDs between train_df and label_coords_df: {len(matching_study_ids)}')

We discovered that 177 study IDs match between `train.csv` and `train_label_coordinates.csv`, meaning most of the studies with missing data align across both datasets, but not all of them. We will now save the `matching_study_ids` for further analysis and also track the study IDs that do not match in each dataset.

In [39]:
# Save the non-matching study IDs for each dataset
non_matching_train_df_ids = set(indices_of_NaN) - set(matching_study_ids)
non_matching_label_coords_df_ids = set(incomplete_study_ids) - set(matching_study_ids)

# Print the results
print(f'Number of non-matching study IDs in train_df: {len(non_matching_train_df_ids)}')
print(f'Number of non-matching study IDs in label_coords_df: {len(non_matching_label_coords_df_ids)}')

# Save or use these lists for further analysis
matching_study_ids = list(matching_study_ids)
non_matching_train_df_ids = list(non_matching_train_df_ids)
non_matching_label_coords_df_ids = list(non_matching_label_coords_df_ids)

Next, we will proceed to examine the number of unique images (`instance numbers`) utilized per study ID. This analysis will assist us in understanding the variation in the quantity of images available for each study and may uncover interesting patterns, such as whether certain studies possess a greater volume of imaging data than others.

In [40]:
# Get the number of unique series_id for each study_id
unique_instance_per_study = label_coords_df.groupby(['study_id'])['instance_number'].nunique().reset_index(name='unique_instance_number_count')

# Plot distribution of instance number of images 
plt.figure(figsize=(8, 4))
sns.histplot(data=unique_instance_per_study, x='unique_instance_number_count',bins=50, kde=True)
plt.title('Distribution of Instance Numbers')
plt.xlabel('Number of Unique Instances')
plt.ylabel('Frequency')
plt.show()


The plot reveals that the number of unique instance numbers per study generally ranges between 8 and 14, with a peak around 10. This suggests that most studies have a similar number of images available, with only a few studies having a higher or lower number of unique instance numbers. The consistency in the number of images per study may indicate a standardized imaging protocol for most studies, while the studies with a higher or lower number of instances might reflect special cases or differing imaging requirements.



In [41]:
# Visualize x and y coordinates to detect outliers
plt.figure(figsize=(12, 6))
sns.scatterplot(x='x', y='y', data=label_coords_df)
plt.title('Scatter plot of X and Y coordinates')
plt.xlabel('X coordinate')
plt.ylabel('Y coordinate')
plt.show()

The scatter plot reveals a dense cluster of coordinates, indicating that most points fall within a certain range of `x` and `y` values. However, we can spot a few outliers, particularly at the lower end of the `x` and `y` scales. This visualization helps us confirm that the majority of coordinates are tightly grouped, with only a few deviations.

Now that we have completed exploring the `study_id`, `series_id`, and `instance_number` columns, we will move on to analyze the `x` and `y` coordinate distributions. These coordinates are crucial for localization tasks, and visualizing them will allow us to detect any potential outliers or patterns in the dataset.

In [42]:
plt.figure(figsize=(12, 6))
sns.scatterplot(x='x', y='y', hue='level', data=label_coords_df)
plt.title('Scatter plot of X and Y coordinates by Level')
plt.show()

In [43]:
plt.figure(figsize=(12, 8))
sns.scatterplot(x='x', y='y', hue='condition', data=label_coords_df)
plt.title('Scatter plot of X and Y coordinates by condition')
plt.show()

From the scatter plots grouped by spine level and condition, we observe that the coordinates are fairly well-distributed across the different levels and conditions. However, there are noticeable clusters in specific areas for certain spine levels and conditions, which may be related to the anatomical structure of the spine and the regions typically affected by each condition. 

There is also a possibility of outliers in the `x` and `y` coordinates. If this is the case, we will need to consider how much we can trust these coordinates for use in our model or how we might address these issues.

We will explore the distribution of the `x` and `y` coordinates to gain insights into the spatial characteristics of the dataset. Visualizing the distribution of these values will help us understand the range and common locations of the coordinates across different studies. This can also assist in detecting any outliers or anomalies in the spatial data.


In [44]:
# x coordinaten distribution
plt.figure(figsize=(8, 6))
sns.histplot(data=label_coords_df, x='x', kde=True)
plt.title('Distribution of x')
plt.xlabel('x values')
plt.ylabel('Frequency')
plt.show()

In [45]:
# y coordinaten distribution
plt.figure(figsize=(8, 6))
sns.histplot(data=label_coords_df, x='y', kde=True)
plt.title('Distribution of y')
plt.xlabel('y values')
plt.ylabel('Frequency')
plt.show()

The distribution of the `x` coordinates shows a clear concentration of values between 150 and 350, indicating that the majority of the spinal coordinates fall within this range. There is a noticeable peak around 200, suggesting that most of the degenerative areas are localized in this region. However, there are some **outliers** toward both lower and higher `x` values, which could represent edge cases or anomalies in the dataset.

Similarly, the distribution of `y` coordinates exhibits a comparable pattern, with most of the values concentrated between 150 and 400. The peak at around 200-250 for both `x` and `y` coordinates suggests that the dataset is capturing similar regions of interest across multiple studies. A few **outliers** can also be seen in the higher range, indicating rare cases where the coordinates deviate from the general trend. This visualization helps confirm the consistency in spatial localization across studies, with a few exceptional cases.


In [46]:
def add_severity(df):
    df['severity'] = None

    # Iterate over rows in the dataframe
    for idx, coor_row in df.iterrows():
        try:
            # Construct the patient_severity key
            patient_severity = f"{coor_row['condition'].lower().replace(' ', '_')}_{coor_row['level'].lower().replace('/','_')}"
            
            # Filter for the corresponding study_id
            study_row = train_df[train_df['study_id'] == int(coor_row['study_id'])]
            
            # Check if the key exists in train_df for that study
            if patient_severity in study_row.columns:
                severity = study_row[patient_severity].values[0]
                df.at[idx, 'severity'] = severity  # Update the severity
            else:
                print(f"Warning: {patient_severity} not found in train_df for study_id {coor_row['study_id']}.")
                df.at[idx, 'severity'] = 'Unknown'  # Set as 'Unknown' if no match found
            
        except Exception as e:
            print(f"Error processing study_id: {coor_row['study_id']} - {e}")
            df.at[idx, 'severity'] = 'Unknown'  # Set as 'Unknown' in case of an error
    
    return df

# Apply the function
label_coords_df_counts = label_coords_df.copy()
label_coords_df_counts = add_severity(label_coords_df_counts)
label_coords_df_counts.head()

In [47]:
plt.figure(figsize=(8, 8))
sns.kdeplot(x='x', y='y', data=label_coords_df_counts[label_coords_df_counts['severity'].isin(['Severe', 'Moderate'])], 
            hue='condition', fill=True, cmap="Reds", thresh=0.05)
plt.title('Density Plot of Severe and Moderate Condition Locations')
plt.show()

From the density plot, we can observe the most frequent areas of severe and moderate conditions.
These visualizations indicate that certain conditions tend to cluster around similar `x` and `y` 
coordinate locations, showing potential patterns in the dataset. 

This information might be 
valuable when training models that need to understand the spatial relationships of these 
conditions in medical images.

**Conclusion for train_label_coordinates.csv analysis:** 💡📁

1. **Data Quality:**
   - No missing values (NaN) were found in the `train_label_coordinates.csv` file..
   - Most study_ids are associated with 3 series_ids.
   - A smaller portion (approximately 300) have 4 series_ids.
   - Not all labels from `train.csv` have corresponding coordinates in this file.
   - **177 study_ids** with missing data in `train.csv` match those without complete coordinate sets (25 coordinates per condition) in this file.

2. **Study and Series Distribution:**
   - Most study_ids are associated with 3 series_ids.
   - A smaller portion (approximately 300) have 4 series_ids.
   - A very small number of studies have 5 series_ids.

3. **Instance Number Distribution:**
   - The majority of study_ids use 10-12 instance images for the conditions.
   - The overall range of instance numbers is between 6 and 17.

4. **Coordinate Distribution:**
   - X coordinates:
      - Concentrated between 150 and 350.
      - Notable peak around 200.
   - Y-coordinates:
      - Similar pattern, concentrated between 150 and 400.
      - Peak observed around 200-250.

5. **Implications:**
   - The consistent range of coordinates suggests a standardized imaging area across most studies.
   - Variations in series and instance numbers may indicate differences in imaging protocols or patient-specific factors.
   - The matching missing data between `train.csv` and this file suggests systematic issues in data collection or processing for certain studies.

This analysis provides insights into the structure and quality of the coordinate data, highlighting areas that may require attention in preprocessing and modeling stages, particularly regarding missing data and potential standardization of imaging protocols.

#### train_series_descriptions.csv 📁

The `train_series_descriptions.csv` file provides essential metadata for the description of each series of images (`series_id`). We will start by checking for missing values in the file and then compare the distribution of the series count per study ID with the one from the `label_coords_df` to ensure consistency across the datasets.

In [48]:
series_desc_df.head(5)

In [49]:
series_desc_df.shape

We observe that we have 6294 rows and 3 columns.

In [50]:
# Check for missing values
series_desc_df.isna().sum()

There are no missing values in the `train_series_descriptions.csv` file, which is great. Since this file contains metadata about the image series, we can now move on to checking the distribution of the series count per study ID and compare it with the distribution from `label_coords_df`.

In [51]:
# Group by study_id to get the count of series per study
series_per_study_df = series_desc_df.groupby(['study_id'])['series_id'].count().reset_index(name='series_count')

# Plot the distribution of series count per study
plt.figure(figsize=(8, 6))
sns.histplot(data=series_per_study_df, x='series_count', kde=True)
plt.title('Distribution of Series Count per Study')
plt.xlabel('Number of Series per Study')
plt.ylabel('Frequency')
plt.show()


The distribution of the series count per study in `series_desc_df` appears to be consistent with the distribution in `label_coords_df`. This confirms that both datasets follow a similar pattern regarding the number of series associated with each study.


Next, we will explore the `series_description` column in `series_desc_df`. This column provides categorical information about each series, and understanding its distribution will give us insights into the types of imaging series present in the dataset. We will visualize this using both a pie chart and a bar plot to see the relative proportions and absolute counts of each type.

In [52]:
# Count the frequency of each series_description
series_description_counts = series_desc_df['series_description'].value_counts()

# Plot a pie chart to visualize the distribution
plt.figure(figsize=(8, 6))
series_description_counts.plot.pie(autopct='%1.1f%%', startangle=90)
plt.title('Distribution of Series Description')
plt.ylabel('')  # Hide the y-label for aesthetics
plt.show()

In [53]:
# Bar plot to visualize the series description counts
plt.figure(figsize=(8, 6))
sns.barplot(x=series_description_counts.index, y=series_description_counts.values)
plt.title('Frequency of Series Descriptions')
plt.xlabel('Series Description')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')  
plt.show()

The distribution of `series_description` across the dataset shows a fairly even split between three main categories: **Axial T2**, **Sagittal T1**, and **Sagittal T2/STIR**. From the pie chart, we observe that **Axial T2** constitutes the largest portion, representing around 37.2% of the total series. Meanwhile, **Sagittal T1** and **Sagittal T2/STIR** are almost equally represented, with both accounting for roughly 31.4% and 31.5% respectively.

The bar plot further confirms this distribution, showing that all three types of imaging series are well-represented in the dataset, with **Axial T2** slightly more prevalent. This balanced distribution suggests that the dataset contains a diverse range of imaging modalities, which could provide robust information for training models.

**Conclusion for train_series_descriptions.csv analysis:** 💡📁

1.  **Data Quality:**
    - No missing values in the train_series_descriptions.csv file, which is excellent for data integrity.

2.  **Series Distribution:**
    - The distribution of series count per study is consistent with the distribution observed in label_coords_df.
    - Most studies have 3 series, with a small portion (close to 300) having 4 series, and a very small portion having 5 series.

3.  **Series Description Distribution:**
    - The dataset shows a fairly even split between three main categories:
        - Axial T2: ~37.2% of the total series.
        - Sagittal T1: ~31.4% of the total series.
        - Sagittal T2/STIR: ~31.5% of the total series.

4.  **Imaging Modalities:**
    - The balanced distribution of imaging modalities (Axial T2, Sagittal T1, and Sagittal T2/STIR) suggests a diverse range of imaging data, which could provide robust information for training models.

5.  **Data Structure:**
    - The file contains essential metadata for each series of images, including study_id, series_id, and series_description.

This analysis reveals a well-structured dataset with a balanced distribution of imaging modalities. The absence of missing values and the consistency with other parts of the dataset (like label_coords_df) indicate good data quality. The even distribution of different imaging types (Axial T2, Sagittal T1, and Sagittal T2/STIR) suggests that the dataset contains a comprehensive range of spinal imaging data, which could be beneficial for developing robust machine learning models for spinal condition analysis.

With this, we can proceed to the EDA of the images 🩻🩻.

### EDA for Images 🕵🩻

In this section, we'll explore the DICOM (Digital Imaging and Communications in Medicine) images in our dataset. DICOM is the standard format for medical imaging.

Understanding these images is essential for our analysis, as they contain valuable information about patient health and potential abnormalities. By visualizing and examining the DICOM files, we'll gain insights into the structure of the spinal cord and assess the quality of our imaging data. This step is critical for building accurate models and deriving meaningful conclusions in our medical imaging project.

We'll use specialized tools to load and display these DICOM images, as well as explore the rich metadata they contain. This metadata provides crucial context about the imaging process, patient positioning, and other technical details that can influence our analysis.


In [54]:
# Load and display an image
image_path  = r'/Users/danipopov/Projects/RSNA2024/data/train_images/4003253/702807833/1.dcm'
ds = pydicom.dcmread(image_path)
pxl_arr = ds.pixel_array

def plot_dicom_img(dicom_array):

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Visualize the image with enhanced contrast using 'bone' colormap
    axes[0].imshow(dicom_array, cmap=plt.cm.bone)
    axes[0].axis('off')
    axes[0].set_title('Bone')

    # Visualize the image with enhanced contrast using 'gray' colormap
    axes[1].imshow(dicom_array, cmap='gray')
    axes[1].axis('off')
    axes[1].set_title('Gray')

    # Visualize the image with default colormap
    axes[2].imshow(dicom_array)
    axes[2].axis('off')
    axes[2].set_title('Default')

    plt.tight_layout()
    plt.show()

plot_dicom_img(pxl_arr)

In [55]:
# Display DICOM metadata
print(ds)

The DICOM file contains a wealth of metadata, We can see details such as:

- **Patient ID**: Identifies the patient associated with the scan.
- **Series Description**: Provides the imaging modality used (e.g., `T2`).
- **Slice Thickness**: Specifies the thickness of each MRI slice (`4.0 mm`).
- **Spacing Between Slices**: Indicates the distance between consecutive slices (`4.8 mm`).
- **Patient Position**: Describes the patient's position during the scan (`Head First-Supine`).
- **Image Dimensions**: The image is `640x640` pixels with a pixel spacing of `0.46875 mm`.
- **Photometric Interpretation**: The image is in grayscale (`MONOCHROME2`).

Understanding this metadata helps us assess the imaging quality and provides valuable context for future image analysis, especially when working with models like U-Net or ResNet, which may require certain image preprocessing steps such as resizing, normalization, and more.

Creating Metadata for Each Study ID will help us to streamline the process of displaying and organizing images for each study id, we will create a structured metadata dictionary. This metadata will help us easily access images based on their `study_id` and `series_id`. For each scan, the metadata object will contain:

- The path to the folder containing the images (`folder_path`).
- An array of Series Instance IDs that represent different series within the study(`series_ids`).
- An array of Series Descriptions that describe the imaging modality or scan type (e.g., T2, T1).

```
meta_df = {
    study_id: {
        'folder_path': ...,        # Path to the folder containing the images
        'series_ids': [ ... ],     # List of SeriesInstanceUIDs for the study
        'series_desc': [ ... ]     # List of descriptions for each series (T2, T1, etc.)
    },
    ...
}
```

In [56]:
# Getting a list of all the study IDs and paths to their images
images_dir_path = r'/Users/danipopov/Projects/RSNA2024/data/train_images'
study_id_list = os.listdir(images_dir_path)
study_id_paths = [(x, f"{images_dir_path}/{x}") for x in study_id_list]

# Initialize the metadata dictionary
meta_df = {}

# Process each study and its series
for study_id, study_folder_path in study_id_paths:
    series_ids = []
    series_descriptions = []
    
    # Get all the series IDs (folders) within the study folder
    try:
        series_folders = os.listdir(study_folder_path)
    except FileNotFoundError as e:
        print(f"Error: Folder not found for study {study_id}. Skipping this study.")
        continue  # Skip this study if the folder doesn't exist

    # Process each series in the study folder
    for series_id in series_folders:
        try:
            # Fetch the series description from the dataframe
            series_description = series_desc_df[series_desc_df['series_id'] == int(series_id)]['series_description'].iloc[0]
        except (IndexError, ValueError):
            # Handle cases where series_id is not found in the dataframe or can't be converted to int
            series_description = 'Unknown'

        # Append series ID and description to the lists
        series_ids.append(series_id)
        series_descriptions.append(series_description)
    
    # Add metadata for the current study_id
    meta_df[int(study_id)] = {
        'folder_path': study_folder_path,
        'series_ids': series_ids,
        'series_descriptions': series_descriptions
    }

In [57]:
# Example of meta_df object
keys = list(meta_df.keys())[:4]  # Get the first 4 keys
values = [meta_df[key] for key in keys]  # Get the values corresponding to the keys

# Display the key-value pairs
for key, value in zip(keys, values):
    print(f'Key: {key} \n')
    print(f'Value: {value} \n')

In the above output, we displayed examples from our `meta_df` structure. Each entry in this dictionary represents a study, identified by the `study_id`, along with its associated folder path, series IDs, and corresponding series descriptions. This metadata structure will facilitate efficient access to and exploration of the DICOM images for each study, allowing us to load and visualize the images based on their `study_id` and `series_id`. This will also streamline any future analysis involving individual series types, such as 'Axial T2' or 'Sagittal T1'. 

In [58]:
def display_images(patient_meta_data, max_images_per_row=4):
    # Extract folder path, series IDs, and descriptions from metadata
    file_path = patient_meta_data['folder_path']
    series_ids = patient_meta_data['series_ids']
    series_descriptions = patient_meta_data['series_descriptions']
    
    # Loop through each series to display images
    for idx, series_id in enumerate(series_ids):
        # Construct the path to the series folder
        series_id_path = os.path.join(file_path, series_id)
        # Find all DICOM files in the series directory
        images = [img for img in glob.glob(f'{series_id_path}/*.dcm')]
        num_images  = len(images)

        if num_images == 0:
            print(f"No images found in series: {series_descriptions[idx]} ({series_id})")
            continue  # Skip series if no images are found

        num_rows = (num_images + max_images_per_row - 1) // max_images_per_row
        # Create a grid of subplots
        fig, axes = plt.subplots(num_rows, max_images_per_row, figsize=(5, 1.5 * num_rows))
        # Flatten axes if multiple rows, or convert to list for single row
        if num_rows > 1:
            axes = axes.flatten()
        else:
            axes = [axes]

        # Loop through images and plot each one
        for i, image in enumerate(images):
            try:
                # Read DICOM file and extract the pixel array
                ds = pydicom.dcmread(str(image))
                pixel_array_numpy = ds.pixel_array
                ax = axes[i]
                ax.imshow(pixel_array_numpy, cmap=plt.cm.bone)
                ax.axis('off')
            except Exception as e:
                print(f"Error loading image: {image}, Error: {e}")     

        # Turn off unused subplots
        for i in range(num_images, len(axes)):
            axes[i].axis('off')

        plt.suptitle(series_descriptions[idx], fontsize=16)
        plt.tight_layout()
        plt.show()
        
# Usage Example:
key = keys[0]
patient_meta_data = meta_df[key]
display_images(patient_meta_data)

In the above function, we displayed all the images for a selected `study_id` and its associated series. Each series is visualized in a grid layout. This function will be highly useful for visualizing the medical images before further analysis.

We can notice that in each series, some images provide more insights into the anatomical structure and offer better visualization of the spine.

Additionally, using the DICOM data, we will check the distribution of pixel sizes for each image across different series descriptions.

To further explore the MRI data, it's useful to view the images in motion, particularly when analyzing slice-by-slice changes across a single series. By converting the series into an animated, we can simulate a 3D continuous scan, which helps in detecting patterns and abnormalities that may not be immediately obvious in static images.

In [59]:
def mri_3d(study_id, patient_meta_data, series_index=1):

    # Extract folder path, series IDs, and descriptions from metadata
    file_path = patient_meta_data['folder_path']
    series_id = patient_meta_data['series_ids'][series_index]
    series_description = patient_meta_data['series_descriptions'][series_index]

    def load_dicom(path):
        ds = pydicom.dcmread(path)
        img = ds.pixel_array
        img = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        return img

    # Construct the path to the series folder
    series_id_path = os.path.join(file_path, series_id)

    # Get sorted list of DICOM image paths
    img_paths = glob.glob(f'{series_id_path}/*.dcm')
    img_paths = sorted(img_paths, key=lambda x: int(x.split('/')[-1].split('.')[0]))

    # Load images from paths
    imges = [load_dicom(img_path) for img_path in img_paths]

    # Set up the animation
    rc('animation', html='jshtml')

    fig = plt.figure(figsize=(6, 6))
    plt.axis('off')
    im = plt.imshow(imges[0])
    text = plt.text(0.05, 0.05, f'Slide {1}', transform=fig.transFigure, fontsize=16, color='darkblue')
    title = plt.title(f'Study ID: {study_id}\nSeries ID: {series_id}\nDescription: {series_description}',
                     fontsize=14, color='black')

    def animate(i):
        im.set_array(imges[i])
        text.set_text(f'Slide {i + 1}')
        return im, text

    anim = animation.FuncAnimation(fig, animate, frames=len(imges), blit=True)

    plt.close(fig)

    return anim

# Usage
key = keys[0]
patient_meta_data = meta_df[key]
display(mri_3d(key, patient_meta_data, series_index=0))

In [60]:
display(mri_3d(key, patient_meta_data, series_index=1))

In [61]:
display(mri_3d(key, patient_meta_data, series_index=2))

In [62]:
display(mri_3d(key, patient_meta_data, series_index=3))

After successfully creating 3D repersntion of the MRI sequences, we can now visualize the anatomical changes across the slices.

This dynamic representation helps us better understand the spatial progression of degeneration or abnormality. Moving forward, we will explore advanced image analyses by incorporating metadata such as slice coordinates and condition severity to detect correlations between slice location and the severity of degeneration.

We can observe in the dynamic representation that our previous idea about some images in the layout showing better representation is confirmed. These images indeed provide more detailed information about spinal degenerative problems.

As we recall from the *Introduction*, each of the spinal degenerative conditions can be best viewed from different types of images. Let's remind ourselves:

- **Foraminal Narrowing** is best viewed in the sagittal plane

- **Subarticular Stenosis** is best visualized in the axial plane

- **Canal Stenosis** is also best viewed from the axial plane


Next, we will move on to add the seires description and visualize all the coordinates given for a single image using the information in `train_label_coordinates.csv` and `meta_df`. This will help us understand if there's consistency in the values of the coordinates on a single image and potentially provide more insights into the data structure and annotation process.


In [63]:
def add_desc(df, meta_df):
    df['series_desc'] = None

    # Iterate over rows in the dataframe
    for idx, coor_row in df.iterrows():
        try:
            # Find the meta_df for the study_id
            meta_info = meta_df[int(coor_row['study_id'])]

            # Find the index of the series_id in the meta_info
            series_index = meta_info['series_ids'].index(str(coor_row['series_id']))

            # Get the corresponding series description
            series_desc = meta_info['series_descriptions'][series_index]

            # Update the series_desc column
            df.at[idx, 'series_desc'] = series_desc

        except KeyError:
            print(f"Error processing study_id: {coor_row['study_id']} - Study ID not found in meta_df")
            df.at[idx, 'series_desc'] = 'Unknown'
        except ValueError:
            print(f"Error processing study_id: {coor_row['study_id']} - Series ID not found in meta_df")
            df.at[idx, 'series_desc'] = 'Unknown'
        except Exception as e:
            print(f"Error processing study_id: {coor_row['study_id']} - {e}")
            df.at[idx, 'series_desc'] = 'Unknown'

    return df

# Apply the function
coords_with_desc = label_coords_df.copy()
coords_with_desc = add_desc(coords_with_desc, meta_df)
coords_with_desc.head(20)

In [64]:
def show_coor(images_dir_path, filtered_instances, study_id, series_id, instance_number):
    from matplotlib.colors import LogNorm
    import matplotlib.patches as patches

    lag = 20
    img_path = os.path.join(images_dir_path,
                                str(study_id),
                                str(series_id),
                                str(instance_number) + '.dcm')

    ds = pydicom.dcmread(img_path)
    fig, ax  = plt.subplots(figsize=(8, 6))

    ax.imshow(ds.pixel_array) # cmap ='CMRmap' or 'bone'
    ax.axis('off')
    # Create a legend
    legend_elements = []

    a = 25 * max(ds.pixel_array.shape)/640

    for _, row in filtered_instances.iterrows():
        x, y = row['x'], row['y']

        rect2 = patches.Rectangle((x - a, y - a), 2*a, 2*a, linewidth=2, edgecolor='white', facecolor='none')
        rect1 = patches.Rectangle((x - a, y - a), 2*a, 2*a, linewidth=2, facecolor='white', alpha = 0.25)

        ax.add_patch(rect2)
        ax.add_patch(rect1)

        # Add the condition to the legend
        legend_elements.append(patches.Patch(facecolor='none', edgecolor='r', ))

    # Add title
    title = f"Study: {study_id}, Series: {series_id}, Instance: {instance_number}"
    ax.set_title(title, fontsize=20)

    # Display additional columns:
    for _, row in filtered_instances.iterrows():
        text = f"level {row['level']}, {row['condition']}"
        ax.text(row['x'] + lag, row['y']+np.random.randint(-15, 15), text, fontsize=10, color='white', verticalalignment='center_baseline')

    plt.show()



filtered_instances = coords_with_desc[
    (coords_with_desc['study_id'] == 4003253) &
    (coords_with_desc['series_id'] == 702807833) &
    (coords_with_desc['instance_number'] == 8)
]

show_coor(images_dir_path, filtered_instances, 4003253, 702807833, 8)

In [65]:
filtered_instances = coords_with_desc[
    (coords_with_desc['study_id'] == 4003253) &
    (coords_with_desc['series_id'] == 1054713880) &
    (coords_with_desc['instance_number'] == 11)
]

show_coor(images_dir_path, filtered_instances, 4003253, 1054713880, 11)

In [66]:
filtered_instances = coords_with_desc[
    (coords_with_desc['study_id'] == 4003253) &
    (coords_with_desc['series_id'] == 2448190387) &
    (coords_with_desc['instance_number'] == 11)
]
show_coor(images_dir_path, filtered_instances, 4003253, 2448190387, 11)

**Big conclusion** I gained from this exploration of the coordinates on single images and all conditions is:

- **Images for Spinal Canal Stenosis** are maybe all in the sagittal plane.

- **Images for Subarticular Stenosis** can be best viewed in the axial plane, but without the **correct** image in the sagittal plane, we can't understand to which **disk-level** the axial plane image refers.

- **Images for Neural Foraminal Narrowing** appear to be in the sagittal plane, which is consistent with best practices for visualizing this condition.

- There seems to be consistency in coordinate placement across different spine levels for the same condition, which could be useful for developing automated detection algorithms.

- The coordinates for different conditions (e.g., Spinal Canal Stenosis vs. Neural Foraminal Narrowing) are placed at different locations within the image, reflecting the distinct anatomical areas affected by each condition.

- Some images show multiple coordinates for the same condition at different spine levels, indicating that a single image can be used to assess multiple levels simultaneously.

- The variation in y-coordinates across different spine levels (e.g., L1/L2 vs. L5/S1) suggests that the model or annotation process takes into account the natural curvature of the spine.

- The above image shows similarity in the left and right coordinates for the same condition and spine level. We can use this information if there are missing coordinates on one of the sides.

Let's explore the outlier values (minimum and maximum) for the `x` and `y` coordinates to see how they are visualized as annotations on the images.

In [67]:
outlier_instance = coords_with_desc[coords_with_desc['y'] == coords_with_desc['y'].min()]
show_coor(images_dir_path, outlier_instance, outlier_instance['study_id'].iloc[0], outlier_instance['series_id'].iloc[0], outlier_instance['instance_number'].iloc[0])

In [68]:
outlier_instance = coords_with_desc[coords_with_desc['y'] == coords_with_desc['y'].max()]
show_coor(images_dir_path, outlier_instance, outlier_instance['study_id'].iloc[0], outlier_instance['series_id'].iloc[0], outlier_instance['instance_number'].iloc[0])

In [69]:
outlier_instance = coords_with_desc[coords_with_desc['x'] == coords_with_desc['x'].min()]
show_coor(images_dir_path, outlier_instance, outlier_instance['study_id'].iloc[0], outlier_instance['series_id'].iloc[0], outlier_instance['instance_number'].iloc[0])

In [70]:
outlier_instance = coords_with_desc[coords_with_desc['x'] == coords_with_desc['x'].max()]
show_coor(images_dir_path, outlier_instance, outlier_instance['study_id'].iloc[0], outlier_instance['series_id'].iloc[0], outlier_instance['instance_number'].iloc[0])

Based on the visualizations of the minimum and maximum y-coordinate instances, we discoverd that in each `study_id` each `series_id` can have diffrent size for the pixels (like we see above). We will keep it in mind when working with diffrent `series_id` isnide the `study_id`.

Now that we've explored the coordinates on the MRI images, we will check our hypothesis
that **all images for Spinal Canal Stenosis are in the sagittal plane**. We will do this
using `coords_with_desc`

In [71]:
def count_and_plot_conditions(df):
    conditions = ['Spinal Canal Stenosis', 'Subarticular Stenosis', 'Neural Foraminal Narrowing']
    view_types = ['Sagittal', 'Axial']

    # Initialize a dictionary to store the counts
    counts = {cond: {view: 0 for view in view_types} for cond in conditions}

    # Count occurrences
    for _, row in df.iterrows():
        condition = row['condition']
        series_desc = row['series_desc']

        # Normalize condition names
        if 'Neural Foraminal Narrowing' in condition:
            condition = 'Neural Foraminal Narrowing'
        elif 'Subarticular Stenosis' in condition:
            condition = 'Subarticular Stenosis'
        else:
            condition = 'Spinal Canal Stenosis'

        if condition in conditions:
            view = 'Sagittal' if 'Sagittal' in series_desc else 'Axial'
            counts[condition][view] += 1

    # Convert counts to DataFrame for easier plotting
    plot_data = pd.DataFrame(counts).T

    # Create the bar plot
    ax = plot_data.plot(kind='bar', figsize=(12, 6), width=0.8)

    # Customize the plot
    plt.title('Condition Counts by View Type', fontsize=16)
    plt.xlabel('Condition', fontsize=12)
    plt.ylabel('Count', fontsize=12)
    plt.legend(title='View Type')
    plt.xticks(rotation=45, ha='right')

    # Add value labels on the bars
    for container in ax.containers:
        ax.bar_label(container, padding=3)

    plt.tight_layout()
    plt.show()

    return counts

# Usage
counts = count_and_plot_conditions(coords_with_desc)
print(counts)

Now we can confirm that our hypothesis is correct. The conclusions drawn from our analysis are:

-  **Spinal Canal Stenosis**: All images for this condition are indeed in the sagittal plane, which aligns with our initial hypothesis. However, this is not ideal as axial images are typically preferred for this condition. We will need to address this when building our model.

- **Subarticular Stenosis**: This condition is exclusively imaged in the axial plane, which is consistent with best practices.

- **Neural Foraminal Narrowing**: This condition is exclusively imaged in the sagittal plane, which is consistent with best practices.

These findings provide valuable insights into the imaging practices for different spinal conditions. They largely confirm our initial hypotheses and align with established best practices in radiological imaging for spine pathologies. However, the exclusive use of sagittal images for Spinal Canal Stenosis warrants further investigation to understand if this represents a limitation in the dataset or if there are specific cases where sagittal imaging is preferred for this condition.


Next, we remeber from the EDA of the `train_label_coordinates.csv` that no a lot of `study_id` has missing coordiantion so will try to find wich are missing and see if we can fill the coordiation if for example it's the left and right side in Axial and we have one sice coordiations

In [72]:
labels = [col for col in train_df.columns if col != 'study_id']
labels = set(labels)

def find_missing_coords(df, label_cols, meta_df):
    visited = set()
    missing_coords = {study_id: set() for study_id in meta_df.keys()}

    for study_id, study_df in tqdm(df.groupby('study_id'), desc="Processing studies"):
        if study_id in visited:
            continue

        visited.add(study_id)

        labels = set()
        for _, row in study_df.iterrows():
            condition = row['condition'].replace(' ', '_').lower()
            level = row['level'].replace('/', '_').lower()
            label = f"{condition}_{level}"
            labels.add(label)

        missing = label_cols - labels
        missing_coords[study_id] = missing

    return missing_coords

# Usaage
missing_coords = find_missing_coords(label_coords_df, labels, meta_df)

In [73]:
def plot_missing_labels(missing_coords):
    # Create a DataFrame with all possible labels
    all_labels = sorted(set().union(*missing_coords.values()))
    df = pd.DataFrame(index=missing_coords.keys(), columns=all_labels)

    # Fill the DataFrame
    for study_id, labels in missing_coords.items():
        df.loc[study_id, list(labels)] = True  # Convert set to list here

    # Fill NaN with False (present) and True with True (missing)
    df = df.fillna(False)

    # Create a heatmap
    plt.figure(figsize=(20, 10))
    sns.heatmap(df, cbar=False, cmap='viridis')
    plt.title('Missing Labels Across Study IDs')
    plt.xlabel('Labels')
    plt.ylabel('Study IDs')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    # Print summary statistics
    print(f"Total number of study IDs with missing labels: {len(missing_coords)}")
    print("\nMost common missing labels:")
    label_counts = df.sum().sort_values(ascending=False)
    for label, count in label_counts.items():
        if count > 0:
            print(f"{label}: {count}")

# Usage
plot_missing_labels(missing_coords)

**Key Findings from the Heatmap and Missing Labels Analysis**

1. **Missing Labels Distribution**: Most study IDs that has missing labels have approximately 3 missing labels.

2. **Common Missing Labels**: Predominantly related to Subarticular Stenosis and Spinal Canal Stenosis, particularly at the L1/L2 and L2/L3 levels.

3. **Subarticular Stenosis Pattern**:
   - Both right and left sides of L1/L2 and L2/L3 are frequently missing.
   - All Subarticular Stenosis images are in the axial plane.
   - Potential to infer or estimate missing coordinates based on patterns from other disk levels (e.g., L3/L4, L4/L5, L5/S1).

4. **Spinal Canal Stenosis Pattern**:
   - Significant missing coordinates, especially at L1/L2 (70 missing) and L2/L3 (40 missing) levels.
   - All Spinal Canal Stenosis images are in the Sagittal T2/STIR plane.
   - Consistency in imaging plane may allow for interpolation or estimation of missing coordinates.

5. **Systematic Issue**: The pattern suggests a potential systematic issue in data collection or annotation, particularly for upper lumbar levels (L1/L2 and L2/L3).

**Conclusion for Images analysis:**  💡🩻

1. **Image Types and Consistency:**
    - Each `study_id` contains three types of MRI sequences: Sagittal T2/STIR, Sagittal T1, and Axial T2.
    - The number of images per `series_id` can vary, indicating potential inconsistencies in image acquisition or selection.

2. **Image Dimensions:**
    - Images within the same `study_id` can have different sizes, which may require standardization during preprocessing.
    - Some images have unusually large dimensions, potentially due to high-resolution scans or different acquisition protocols.

3. **Coordinate Annotations:**
    - Coordinates are available for different severity labels (Normal, Moderate, Severe) across various spinal conditions.
    - Not all images in `train_images` have coordinates.
    - Some images lack coordinate annotations, but it's possible to fill them using available coordinates for the same `instance_number` or `series_id`.

4. **Visualization Insights:**
    - 3D-like visualizations (Animations) of image slices provided a comprehensive view of spinal anatomy across different MRI sequences.
    - Displaying images with overlaid coordinates helped in understanding the annotation process and potential challenges in identifying specific spinal conditions.

5. **Condition-Specific Imaging:**
    - Spinal Canal Stenosis: Exclusively imaged in the sagittal plane (T2/STIR), which may not be ideal.
    - Subarticular Stenosis: Consistently imaged in the axial plane, aligning with best practices.
    - Neural Foraminal Narrowing: Imaged in the sagittal plane, consistent with standard practices.

6. **Data Quality and Consistency:**
    - The presence of outliers in coordinate data suggests the need for careful data cleaning.
    - Systematic missing data, particularly for upper lumbar levels (L1/L2 and L2/L3), indicates a potential bias in the dataset.


### 💡 Conclusions from the Exploratory Data Analysis 💡

1. **Data Structure and Composition**:
   - Dataset consists of MRI images in DICOM format and corresponding CSV files with labels and coordinates.
   - Each `study_id` contains three types of MRI sequences: Sagittal T2/STIR, Sagittal T1, and Axial T2.
   - Number of images per `series_id` varies, indicating potential inconsistencies in image acquisition or selection.

2. **Image Characteristics**:
   - Images within the same `study_id` can have different sizes, requiring standardization during preprocessing.
   - Some images have unusually large dimensions, potentially due to high-resolution scans or different acquisition protocols.

3. **Coordinate Annotations**:
   - Coordinates available for different severity labels (Normal, Moderate, Severe) across various spinal conditions.
   - Not all images in `train_images` have coordinates.
   - Some images lack coordinate annotations, but can potentially be filled using available coordinates for the same `instance_number` or `series_id`.

4. **Condition-Specific Imaging**:
   - Spinal Canal Stenosis: Exclusively imaged in sagittal plane (T2/STIR), which may not be ideal.
   - Subarticular Stenosis: Consistently imaged in axial plane, aligning with best practices.
   - Neural Foraminal Narrowing: Imaged in sagittal plane, consistent with standard practices.

5. **Missing Data and Outliers**:
   - Presence of outliers in coordinate data suggests need for careful data cleaning.
   - Systematic missing data, particularly for upper lumbar levels (L1/L2 and L2/L3), indicates potential bias in dataset.
   - Approximately 185 rows in `train_df` have missing values (NaN).

6. **Data Quality and Consistency**:
   - Dataset shows inconsistencies in image acquisition and annotation, needing address during preprocessing.
   - Need for standardization of image sizes and handling of missing coordinates.

7. **Class Imbalance**:
   - Imbalance between Normal/Mild, Moderate, and Severe cases, needs consideration in model development.

These findings provide crucial insights for our next steps. We will explore different approaches to fill the missing coordinates by preprocessing the data and choosing a model to train. This model will fill in the coordinates for all images inside the `train_label_coordinates` dataset. Afterward, we will be able to train efficiently by passing all available images and detect the ROI (Region of Interest).

Let's explore! 🔍🕵️

## Finding coordinates using YOLO


In [74]:
assert False, "Stop here"

To handle find all coordinates for each spine level and right, left side in axial plane, I will use the YOLOv8 model. 

For each plane (Sagittal T1, Sagittal T2/STIR, Axial T2) we will build a dataset for each plane and train a model to recognize the levels of the spine, and for the axial plane we will use the yolo model to detect the left and right side of the of the disk.

The dataset will have the following structure:

    ```
    dataset/
    ├── images/
    │   ├── train/
    │   │   ├── patient1_scan1.png    # Original MRI image, NO boxes drawn
    │   │   ├── patient1_scan2.png    # Original MRI image, NO boxes drawn
    │   └── val/
    │       ├── patient2_scan1.png    # Original MRI image, NO boxes drawn
    │       ├── patient2_scan2.png    # Original MRI image, NO boxes drawn
    ├── labels/
    │   ├── train/
    │   │   ├── patient1_scan1.txt    # Contains: "0 0.5 0.5 0.1 0.1"
    │   │   ├── patient1_scan2.txt    # Contains coordinates
    │   └── val/
    │       ├── patient2_scan1.txt    # Contains coordinates
    │       ├── patient2_scan2.txt    # Contains coordinates
    ```

The content of each label file (e.g., patient1_scan1.txt) will follow this format:
0 0.5 0.5 0.1 0.1   # Format: <class> <x_center> <y_center> <width> <height>

Before starting to build, we need all images to be the same size. We will do this by converting all the images to PNG format. We need to consider normalization and resizing all the images so they have the same size and pixel value range. My images are usually 640x640 or 320x320. One more thing to consider is that the coordinates for each image are based on the original image size, so if we change it, we will need to address the issue of adjusting the coordinate values.

In [75]:
sagt1_df = coords_with_desc[coords_with_desc['series_desc'] == 'Sagittal T1'].copy()
sagt2_df = coords_with_desc[coords_with_desc['series_desc'] == 'Sagittal T2/STIR'].copy()
axialt2_df = coords_with_desc[coords_with_desc['series_desc'] == 'Axial T2'].copy()

### Sagittal T1

In [75]:
class SagT1_YOLO:
    def __init__(self, df, images_dir, output_dir, img_size=384):
        """
        Initialize Sagittal T1 YOLO detector
        Args:
            df: DataFrame with annotations
            images_dir: Path to DICOM images
            output_dir: Path to save processed dataset
            img_size: Target image size for YOLO
        """
        self.df = df
        self.images_dir = images_dir
        self.output_dir = output_dir
        self.img_size = img_size
        
        # Create initial directory structure
        self.create_dataset_structure()
        
        # Process dataset
        self.instance_coords_df = self.process_instance_coordinates()
        self.processed_df = self.process_spine_dataset()
        
        # Create YAML and train model
        self.yaml_path = self.create_dataset_yaml()
        self.model, self.results = self.train_yolo()

    def create_dataset_structure(self):
        """Create YOLO dataset directory structure"""
        for split in ['train', 'val']:
            for subdir in ['images', 'labels']:
                path = os.path.join(self.output_dir, split, subdir)
                os.makedirs(path, exist_ok=True)

    def process_instance_coordinates(self):
        """
        Process coordinates for Sagittal T1 images with coordinate sharing across instances
        Returns DataFrame with coordinates for all instances, sharing information across the series
        """
        result_records = []

        # Group by series to process related instances
        for (study_id, series_id), series_data in self.df.groupby(['study_id', 'series_id']):
            # First, collect all coordinates for each level in the series
            series_level_coords = {}
            for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
                level_key = level.lower().replace('/', '_')
                level_data = series_data[series_data['level'] == level]

                if not level_data.empty:
                    right_coords = []
                    left_coords = []
                    instances_with_data = set()  # Track which instances have data for this level

                    for _, row in level_data.iterrows():
                        instances_with_data.add(row['instance_number'])
                        if 'Right' in row['condition']:
                            right_coords.append((row['x'], row['y']))
                        elif 'Left' in row['condition']:
                            left_coords.append((row['x'], row['y']))

                    # Calculate coordinates for this level
                    level_coords = {
                        'instances': instances_with_data,
                        'coords': {
                            'right': (np.mean([x for x, _ in right_coords]) if right_coords else None,
                                    np.mean([y for _, y in right_coords]) if right_coords else None),
                            'left': (np.mean([x for x, _ in left_coords]) if left_coords else None,
                                   np.mean([y for _, y in left_coords]) if left_coords else None)
                        }
                    }

                    # Calculate center coordinates if possible
                    if right_coords and left_coords:
                        level_coords['coords']['center'] = (
                            (level_coords['coords']['right'][0] + level_coords['coords']['left'][0]) / 2,
                            (level_coords['coords']['right'][1] + level_coords['coords']['left'][1]) / 2
                        )
                    elif right_coords:
                        level_coords['coords']['center'] = level_coords['coords']['right']
                    elif left_coords:
                        level_coords['coords']['center'] = level_coords['coords']['left']
                    else:
                        level_coords['coords']['center'] = (None, None)

                    series_level_coords[level_key] = level_coords

            # Get all unique instance numbers in the series
            all_instances = series_data['instance_number'].unique()

            # Create records for each instance, sharing coordinates across the series
            for instance_number in all_instances:
                record = {
                    'study_id': study_id,
                    'series_id': series_id,
                    'instance_number': instance_number
                }

                # Add source tracking
                record['coordinate_sources'] = {}

                # Add coordinates for all levels to this instance
                for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
                    level_key = level.lower().replace('/', '_')
                    if level_key in series_level_coords:
                        level_data = series_level_coords[level_key]

                        # Record which instances provided data for this level
                        record['coordinate_sources'][level_key] = list(level_data['instances'])

                        # Add right coordinates
                        if level_data['coords']['right'][0] is not None:
                            record[f'{level_key}_right_x'] = level_data['coords']['right'][0]
                            record[f'{level_key}_right_y'] = level_data['coords']['right'][1]

                        # Add left coordinates
                        if level_data['coords']['left'][0] is not None:
                            record[f'{level_key}_left_x'] = level_data['coords']['left'][0]
                            record[f'{level_key}_left_y'] = level_data['coords']['left'][1]

                        # Add center coordinates
                        if level_data['coords']['center'][0] is not None:
                            record[f'{level_key}_center_x'] = level_data['coords']['center'][0]
                            record[f'{level_key}_center_y'] = level_data['coords']['center'][1]

                result_records.append(record)

        # Convert to DataFrame
        result_df = pd.DataFrame(result_records)

        # Add metadata about coordinate availability
        result_df['available_levels'] = result_df.apply(
            lambda row: [
                level for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
                if not pd.isna(row.get(f"{level.lower().replace('/', '_')}_center_x"))
            ],
            axis=1
        )

        result_df['total_levels'] = result_df['available_levels'].apply(len)

        # Print statistics
        print("\nDataset Statistics:")
        print(f"Total series processed: {len(result_df['series_id'].unique())}")
        print(f"Total instances processed: {len(result_df)}")
        print("\nLevel availability:")
        for level in ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']:
            count = result_df[f'{level}_center_x'].notna().sum()
            print(f"{level}: {count} instances ({count/len(result_df)*100:.1f}%)")

        return result_df

    def create_yolo_annotation(self, row, image_width, image_height):
        """Create YOLO format annotations for Sagittal T1 images"""
        annotations = []
        box_width = 0.05
        box_height = 0.05
        
        level_map = {'l1_l2': 0, 'l2_l3': 1, 'l3_l4': 2, 'l4_l5': 3, 'l5_s1': 4}
        
        for level, idx in level_map.items():
            x_coord = row.get(f'{level}_center_x')
            y_coord = row.get(f'{level}_center_y')
            
            if pd.notna(x_coord) and pd.notna(y_coord):
                x_norm = x_coord / image_width
                y_norm = y_coord / image_height
                annotations.append(f"{idx} {x_norm:.6f} {y_norm:.6f} {box_width:.6f} {box_height:.6f}")
        
        return annotations

    def process_spine_dataset(self):
        """Process and save dataset in YOLO format"""
        # Split studies
        studies = self.instance_coords_df['study_id'].unique()
        train_studies, val_studies = train_test_split(studies, train_size=0.8, random_state=42)
        
        processed_counts = {'train': 0, 'val': 0}
        failed_cases = []
        
        for _, row in tqdm(self.instance_coords_df.iterrows(), desc="Processing Sagittal T1 images"):
            try:
                # Convert IDs to integers for path construction
                study_id = str(int(row['study_id']))
                series_id = str(int(row['series_id']))
                instance_number = str(int(row['instance_number']))
                
                # Construct image path
                img_path = os.path.join(self.images_dir, study_id, series_id, instance_number)
                
                # Try with and without .dcm extension
                if os.path.exists(img_path + '.dcm'):
                    img_path = img_path + '.dcm'
                elif not os.path.exists(img_path):
                    raise FileNotFoundError(f"Image not found: {img_path}")
                
                # Read and process image
                ds = pydicom.dcmread(img_path)
                image = ds.pixel_array
                h, w = image.shape
                
                # Create annotations
                annotations = self.create_yolo_annotation(row, w, h)
                if not annotations:
                    continue
                
                # Prepare image
                image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                image_resized = cv2.resize(image_normalized, (self.img_size, self.img_size))
                
                # Save files
                is_train = row['study_id'] in train_studies
                split = 'train' if is_train else 'val'
                
                img_filename = f"{study_id}_{series_id}_{instance_number}.png"
                label_filename = f"{study_id}_{series_id}_{instance_number}.txt"
                
                cv2.imwrite(os.path.join(self.output_dir, split, 'images', img_filename), 
                           image_resized)
                with open(os.path.join(self.output_dir, split, 'labels', label_filename), 'w') as f:
                    f.write('\n'.join(annotations))
                
                # Add split information to row
                row['split'] = split
                processed_counts[split] += 1
                
            except Exception as e:
                failed_cases.append((study_id, series_id, instance_number, str(e)))
        
        print(f"\nProcessing Summary:")
        print(f"Training images: {processed_counts['train']}")
        print(f"Validation images: {processed_counts['val']}")
        
        if failed_cases:
            print("\nFailed cases:")
            for case in failed_cases:
                print(f"Study {case[0]}, Series {case[1]}, Instance {case[2]}: {case[3]}")
        
        return self.instance_coords_df

    def create_dataset_yaml(self):
        """Create YOLO dataset configuration file"""
        yaml_content = {
            'path': os.path.abspath(self.output_dir),
            'train': 'train/images',
            'val': 'val/images',
            'nc': 5,  # number of classes
            'names': {
                0: 'L1/L2',
                1: 'L2/L3',
                2: 'L3/L4',
                3: 'L4/L5',
                4: 'L5/S1'
            }
        }

        yaml_path = os.path.join(self.output_dir, 'dataset.yaml')
        with open(yaml_path, 'w') as f:
            yaml.dump(yaml_content, f, sort_keys=False)

        return yaml_path

    def train_yolo(self):
        """Train YOLO model for Sagittal T1 images"""
        try:
            model = YOLO('yolov8x.pt')
            
            config = {
                'data': self.yaml_path,
                'imgsz': self.img_size,
                'batch': 8,
                'epochs': 2, 
                'patience': 5,
                'device': 'mps',
                'workers': 4,
                'project': 'spine_detection',
                'name': 'sagittal_t1_yolo',
                'exist_ok': True,
                'pretrained': True,
                'optimizer': 'AdamW',
                'verbose': True,
                'seed': 42,
                'deterministic': True,
                'dropout': 0.1,
                'lr0': 0.001,
                'lrf': 0.01,
                'momentum': 0.937,
                'weight_decay': 0.0005,
                'warmup_epochs': 5,
                'warmup_momentum': 0.8,
                'box': 7.5,
                'cls': 0.5,
                'dfl': 1.5,
                'close_mosaic': 10,
                'amp': True,  
                'rect': True,  
                'multi_scale': True,  
                'val': True, 
            }
            
            results = model.train(**config)
            return model, results
            
        except Exception as e:
            print(f"Error training model: {str(e)}")
            return None, None

    def get_training_stats(self):
        """Get statistics about the processed dataset"""
        stats = {
            'total_series': len(self.instance_coords_df['series_id'].unique()),
            'total_instances': len(self.instance_coords_df),
            'train_images': len(self.processed_df[self.processed_df['split'] == 'train']),
            'val_images': len(self.processed_df[self.processed_df['split'] == 'val']),
            'level_coverage': {}
        }
        
        for level in ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']:
            count = self.instance_coords_df[f'{level}_center_x'].notna().sum()
            stats['level_coverage'][level] = {
                'count': count,
                'percentage': count/len(self.instance_coords_df)*100
            }
            
        return stats

In [76]:
# Sagittal T1
images_dir  = '/Users/danipopov/Projects/RSNA2024/data/train_images'
output_dir = '/Users/danipopov/Projects/RSNA2024/data/spine_dataset_sagt1'

sag_t1_model = SagT1_YOLO(
    df=sagt1_df,
    images_dir=images_dir,
    output_dir=output_dir,
    img_size=384
)

**Note:** In the notebook we train the model for 1 or 2 epochs, beacuse it takes a long time to train on my personal computer.
The acctual model was trained for 100 epochs in the Kaggle notebook and all the other yolo models as well, the function to save the model is below was used in the kaggle notebook to save the weights of the best model.

In [None]:
def save_trained_model(model, best_model_path, model_type, save_path='/kaggle/working/'):
    """
    Save the trained YOLO model with simple fixed naming
    
    Args:
        model: YOLO model object
        best_model_path: Path to the best model weights
        model_type: String indicating model type ('sag_t1', 'sag_t2', or 'axial_t2')
        save_path: Base path to save the model
    """
    # Define simple model names
    model_names = {
        'sag_t1': 'sagittal_t1_spine_detector.pt',
        'sag_t2': 'sagittal_t2_spine_detector.pt',
        'axial_t2': 'axial_t2_spine_detector.pt'
    }
    
    if model_type not in model_names:
        raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(model_names.keys())}")
    
    try:
        if not os.path.exists(best_model_path):
            print(f"Best model weights not found at {best_model_path}")
            return
            
        final_save_path = os.path.join(save_path, model_names[model_type])
        
        # Copy the model file
        shutil.copy(best_model_path, final_save_path)
        print(f"Model saved to {final_save_path}")
        
    except Exception as e:
        print(f"Error saving model: {str(e)}")

# Save models with appropriate type
save_trained_model(
    sag_t1_model.model, 
    '/kaggle/working/spine_detection/sagittal_t1_yolo/weights/best.pt',
    model_type='sag_t1'
)

#save_trained_model(
#    sag_t2_model.model, 
#    '/kaggle/working/spine_detection/sagittal_t2_yolo/weights/best.pt',
#    model_type='sag_t2'
#)

#save_trained_model(
#    axial_t2_model.model, 
#    '/kaggle/working/spine_detection/axial_t2_yolo/weights/best.pt',
#    model_type='axial_t2'
#)


Below we remove the files that we don't need anymore.


In [78]:
# Remove all the folders execpt the weights
shutil.rmtree("/Users/danipopov/Projects/RSNA2024/spine_detection/sagittal_t1_yolo")

Next we will test the model on a few images and see how it preforms and can be used to detect the spine levels.

In [188]:
def plot_spine_predictions(image_path, model_path, 
                         conf_threshold=0.25, iou_threshold=0.45, img_size=384):
    """
    Plot YOLO predictions for spine levels on a DICOM image
    
    Args:
        image_path: Path to DICOM image
        model_path: Path to saved YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
        img_size: Image size for model input
    """
    # Load model
    model = YOLO(model_path)
    
    # Read DICOM
    ds = pydicom.dcmread(image_path)
    image = ds.pixel_array
    
    # Normalize and resize
    image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    image_resized = cv2.resize(image_normalized, (img_size, img_size))
    
    # Convert grayscale to RGB
    image_rgb = np.stack([image_resized] * 3, axis=-1)
    
    # Create figure
    plt.figure(figsize=(8, 6))
    
    # Plot original image
    plt.subplot(1, 2, 1)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot image with predictions
    plt.subplot(1, 2, 2)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Predictions')
    
    # Get predictions
    results = model.predict(
        source=image_rgb,
        conf=conf_threshold,
        iou=iou_threshold
    )
    
    # Define colors for each level
    colors = ['red', 'green', 'blue', 'yellow', 'purple']
    level_names = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
    
    if results[0].boxes is not None:
        boxes = results[0].boxes.cpu().numpy()
        
        # Sort boxes by y-coordinate to display levels in order
        box_data = []
        for box in boxes:
            cls_id = int(box.cls[0])
            conf = box.conf[0]
            x1, y1, x2, y2 = box.xyxy[0]
            box_data.append((y1, cls_id, conf, x1, y1, x2, y2))
        
        box_data.sort()  # Sort by y1 coordinate
        
        # Plot each detection
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            color = colors[cls_id]
            level_name = level_names[cls_id]
            
            # Draw bounding box
            plt.gca().add_patch(plt.Rectangle(
                (x1, y1), x2-x1, y2-y1,
                fill=False, color=color, linewidth=2
            ))
            
            # Add label
            plt.text(
                x2 + 5, (y1 + y2) / 2, 
                f'{level_name}: {conf:.2f}',
                color=color, fontsize=8, verticalalignment='center',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')
            )
            
            # Print detection info
            print(f"Found {level_name} with confidence {conf:.2f}")
    else:
        print("No detections found")
    
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [251]:
def process_multiple_images(image_paths, model_path, 
                          conf_threshold=0.25, iou_threshold=0.45):
    """
    Process multiple images and display their predictions, skipping images without predictions
    
    Args:
        image_paths: List of paths to DICOM images
        model_path: Path to saved YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
    """
    # First, filter images that have predictions
    valid_images = []
    model = YOLO(model_path)
    
    for img_path in image_paths:
        ds = pydicom.dcmread(img_path)
        image = ds.pixel_array
        image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        image_resized = cv2.resize(image_normalized, (384, 384))
        image_rgb = np.stack([image_resized] * 3, axis=-1)
        
        results = model.predict(
            source=image_rgb,
            conf=conf_threshold,
            iou=iou_threshold
        )
        
        if results[0].boxes is not None and len(results[0].boxes) > 0:
            valid_images.append(img_path)
    
    if not valid_images:
        print("No images with valid predictions found")
        return
    
    # Process only images with predictions
    for idx, img_path in enumerate(valid_images):
        if idx == 5:
            break
        plot_spine_predictions(img_path, model_path, conf_threshold, iou_threshold)

In [252]:
# Usage example for Sagittal T1:
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3486248476/*.dcm')
process_multiple_images(image_paths, '/Users/danipopov/Projects/RSNA2024/models/sagittal_t1_spine_detector.pt')

We can see that the model is able to detect the spine levels, but it is not perfect some images it detects more levels than it should and some it detects less and some it detects none.

So we will need to use some postprocessing to make sure that we get the correct number of levels for each image or to fill in the missing levels or remove images that have some levels that are not visible.

### Sagittal T2/STIR

In [83]:
class SagT2_YOLO:
    def __init__(self, df, images_dir, output_dir, img_size=384):
        """
        Initialize Sagittal T2/STIR YOLO detector
        Args:
            df: DataFrame with annotations
            images_dir: Path to DICOM images
            output_dir: Path to save processed dataset
            img_size: Target image size for YOLO
        """
        self.df = df
        self.images_dir = images_dir
        self.output_dir = output_dir
        self.img_size = img_size
        
        # Create initial directory structure
        self.create_dataset_structure()
        
        # Process dataset
        self.instance_coords_df = self.process_instance_coordinates()
        self.processed_df = self.process_spine_dataset()
        
        # Create YAML and train model
        self.yaml_path = self.create_dataset_yaml()
        self.model, self.results = self.train_yolo()

    def create_dataset_structure(self):
        """Create YOLO dataset directory structure"""
        for split in ['train', 'val']:
            for subdir in ['images', 'labels']:
                path = os.path.join(self.output_dir, split, subdir)
                os.makedirs(path, exist_ok=True)

    def process_instance_coordinates(self):
        """
        Process coordinates for Sagittal T2 images with coordinate sharing across instances
        Returns DataFrame with coordinates for all instances, sharing information across the series
        """
        result_records = []

        # Group by series to process related instances
        for (study_id, series_id), series_data in self.df.groupby(['study_id', 'series_id']):
            # First, collect all coordinates for each level in the series
            series_level_coords = {}
            for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
                level_data = series_data[series_data['level'] == level]

                if not level_data.empty:
                    coords = []
                    instances_with_data = set()  # Track which instances have data for this level

                    for _, row in level_data.iterrows():
                        instances_with_data.add(row['instance_number'])
                        coords.append((row['x'], row['y']))

                    # Calculate average coordinates for this level
                    if coords:
                        avg_x = np.mean([x for x, _ in coords])
                        avg_y = np.mean([y for _, y in coords])
                        
                        series_level_coords[level] = {
                            'instances': instances_with_data,
                            'coords': (avg_x, avg_y),
                            'original_coords': coords  # Keep original coordinates for reference
                        }

            # Get all unique instance numbers in the series
            all_instances = series_data['instance_number'].unique()

            # Create records for each instance
            for instance_number in all_instances:
                instance_data = series_data[series_data['instance_number'] == instance_number]
                
                record = {
                    'study_id': study_id,
                    'series_id': series_id,
                    'instance_number': instance_number,
                    'coordinate_sources': {}  # Track where coordinates came from
                }

                # Process each spinal level
                for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
                    level_key = level.replace('/', '_').lower()
                    
                    # Check if this instance has original coordinates for this level
                    instance_level_data = instance_data[instance_data['level'] == level]
                    
                    if not instance_level_data.empty:
                        # Use original coordinates for this instance
                        x = instance_level_data.iloc[0]['x']
                        y = instance_level_data.iloc[0]['y']
                        record['coordinate_sources'][level_key] = 'original'
                    elif level in series_level_coords:
                        # Use shared coordinates from series
                        x, y = series_level_coords[level]['coords']
                        record['coordinate_sources'][level_key] = 'shared'
                    else:
                        # No coordinates available
                        x = y = None
                        record['coordinate_sources'][level_key] = 'missing'
                    
                    record[f'{level_key}_x'] = x
                    record[f'{level_key}_y'] = y

                result_records.append(record)

        # Convert to DataFrame
        result_df = pd.DataFrame(result_records)

        # Add metadata about coordinate availability
        result_df['available_levels'] = result_df.apply(
            lambda row: [
                level for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
                if not pd.isna(row[f"{level.replace('/', '_').lower()}_x"])
            ],
            axis=1
        )
        
        result_df['total_levels'] = result_df['available_levels'].apply(len)

        # Print statistics
        print("\nDataset Statistics:")
        print(f"Total series processed: {len(result_df['series_id'].unique())}")
        print(f"Total instances processed: {len(result_df)}")
        print("\nLevel availability:")
        for level in ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']:
            count = result_df[f'{level}_x'].notna().sum()
            original_count = sum(result_df['coordinate_sources'].apply(
                lambda x: x.get(level, '') == 'original'
            ))
            shared_count = sum(result_df['coordinate_sources'].apply(
                lambda x: x.get(level, '') == 'shared'
            ))
            print(f"{level}: {count} instances ({count/len(result_df)*100:.1f}%)")
            print(f"  - Original: {original_count}")
            print(f"  - Shared: {shared_count}")

        return result_df

    def create_yolo_annotation(self, row, image_width, image_height):
        """Create YOLO format annotations for Sagittal T2 images"""
        annotations = []
        box_width = 0.05  # Relative box width
        box_height = 0.05  # Relative box height
        
        level_map = {
            'l1_l2': 0,
            'l2_l3': 1,
            'l3_l4': 2,
            'l4_l5': 3,
            'l5_s1': 4
        }
        
        for level, idx in level_map.items():
            x_coord = row.get(f'{level}_x')
            y_coord = row.get(f'{level}_y')
            
            if pd.notna(x_coord) and pd.notna(y_coord):
                x_norm = x_coord / image_width
                y_norm = y_coord / image_height
                annotations.append(f"{idx} {x_norm:.6f} {y_norm:.6f} {box_width:.6f} {box_height:.6f}")
        
        return annotations

    def process_spine_dataset(self):
        """Process and save dataset in YOLO format"""
        # Split studies
        studies = self.instance_coords_df['study_id'].unique()
        train_studies, val_studies = train_test_split(studies, train_size=0.8, random_state=42)
        
        processed_counts = {'train': 0, 'val': 0}
        failed_cases = []
        
        for _, row in tqdm(self.instance_coords_df.iterrows(), desc="Processing Sagittal T1 images"):
            try:
                # Convert IDs to integers for path construction
                study_id = str(int(row['study_id']))
                series_id = str(int(row['series_id']))
                instance_number = str(int(row['instance_number']))
                
                # Construct image path
                img_path = os.path.join(self.images_dir, study_id, series_id, instance_number)
                
                # Try with and without .dcm extension
                if os.path.exists(img_path + '.dcm'):
                    img_path = img_path + '.dcm'
                elif not os.path.exists(img_path):
                    raise FileNotFoundError(f"Image not found: {img_path}")
                
                # Read and process image
                ds = pydicom.dcmread(img_path)
                image = ds.pixel_array
                h, w = image.shape
                
                # Create annotations
                annotations = self.create_yolo_annotation(row, w, h)
                if not annotations:
                    continue
                
                # Prepare image
                image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                image_resized = cv2.resize(image_normalized, (self.img_size, self.img_size))
                
                # Save files
                is_train = row['study_id'] in train_studies
                split = 'train' if is_train else 'val'
                
                img_filename = f"{study_id}_{series_id}_{instance_number}.png"
                label_filename = f"{study_id}_{series_id}_{instance_number}.txt"
                
                cv2.imwrite(os.path.join(self.output_dir, split, 'images', img_filename), 
                           image_resized)
                with open(os.path.join(self.output_dir, split, 'labels', label_filename), 'w') as f:
                    f.write('\n'.join(annotations))
                
                # Add split information to row
                row['split'] = split
                processed_counts[split] += 1
                
            except Exception as e:
                failed_cases.append((study_id, series_id, instance_number, str(e)))
        
        print(f"\nProcessing Summary:")
        print(f"Training images: {processed_counts['train']}")
        print(f"Validation images: {processed_counts['val']}")
        
        if failed_cases:
            print("\nFailed cases:")
            for case in failed_cases:
                print(f"Study {case[0]}, Series {case[1]}, Instance {case[2]}: {case[3]}")
        
        return self.instance_coords_df

    def create_dataset_yaml(self):
        """Create YOLO dataset configuration file"""
        yaml_content = {
            'path': os.path.abspath(self.output_dir),
            'train': 'train/images',
            'val': 'val/images',
            'nc': 5,  # number of classes
            'names': {
                0: 'L1/L2',
                1: 'L2/L3',
                2: 'L3/L4',
                3: 'L4/L5',
                4: 'L5/S1'
            }
        }

        yaml_path = os.path.join(self.output_dir, 'dataset.yaml')
        with open(yaml_path, 'w') as f:
            yaml.dump(yaml_content, f, sort_keys=False)

        return yaml_path

    def train_yolo(self):
        """Train YOLO model for Sagittal T1 images"""
        try:
            model = YOLO('yolov8x.pt')
            
            config = {
                'data': self.yaml_path,
                'imgsz': self.img_size,
                'batch': 8,
                'epochs': 1, 
                'patience': 5,
                'device': 'mps',
                'workers': 4,
                'project': 'spine_detection',
                'name': 'sagittal_t2_yolo',
                'exist_ok': True,
                'pretrained': True,
                'optimizer': 'AdamW',
                'verbose': True,
                'seed': 42,
                'deterministic': True,
                'dropout': 0.1,
                'lr0': 0.001,
                'lrf': 0.01,
                'momentum': 0.937,
                'weight_decay': 0.0005,
                'warmup_epochs': 5,
                'warmup_momentum': 0.8,
                'box': 7.5,
                'cls': 0.5,
                'dfl': 1.5,
                'close_mosaic': 10,
                'amp': True,  
                'rect': True,  
                'multi_scale': True,  
                'val': True, 
            }
                        
            results = model.train(**config)
            return model, results
            
        except Exception as e:
            print(f"Error training model: {str(e)}")
            return None, None

    def get_training_stats(self):
        """Get statistics about the processed dataset"""
        stats = {
            'total_series': len(self.instance_coords_df['series_id'].unique()),
            'total_instances': len(self.instance_coords_df),
            'train_images': len(self.processed_df[self.processed_df['split'] == 'train']),
            'val_images': len(self.processed_df[self.processed_df['split'] == 'val']),
            'level_coverage': {}
        }
        
        for level in ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']:
            count = self.instance_coords_df[f'{level}_center_x'].notna().sum()
            stats['level_coverage'][level] = {
                'count': count,
                'percentage': count/len(self.instance_coords_df)*100
            }
            
        return stats

In [84]:
# Sagittal T2/STIR
images_dir  = '/Users/danipopov/Projects/RSNA2024/data/train_images'
output_dir = '/Users/danipopov/Projects/RSNA2024/data/spine_dataset_segt2'

sag_t2_model = SagT2_YOLO(
    df=sagt2_df,
    images_dir=images_dir,
    output_dir=output_dir,
    img_size=384
)

In [85]:
# Remove all the folders execpt the weights
shutil.rmtree("/Users/danipopov/Projects/RSNA2024/spine_detection/sagittal_t2_yolo")

In [253]:
# Usage example for Sagittal T2:
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3666319702/*.dcm')
process_multiple_images(image_paths, '/Users/danipopov/Projects/RSNA2024/models/sagittal_t2_spine_detector.pt')

We can notice the yolo model for the Sagittal T2 is not very good, as it detects only 10% of the images.

### Axial T2

For the Axial T2, we want our yolo model to detect both left and right sides of the disk in the axial plane.

In [87]:
class AxialT2_YOLO:
    def __init__(self, df, images_dir, output_dir, img_size=384):
        """
        Initialize Axial T2 YOLO detector
        Args:
            df: DataFrame with annotations
            images_dir: Path to DICOM images
            output_dir: Path to save processed dataset
            img_size: Target image size for YOLO
        """
        self.df = df
        self.images_dir = images_dir
        self.output_dir = output_dir
        self.img_size = img_size
        
        # Create initial directory structure
        self.create_dataset_structure()
        
        # Process dataset
        self.processed_records = self.process_instance_coordinates()
        self.processed_df = self.process_spine_dataset()
        
        # Create YAML and train model
        self.yaml_path = self.create_dataset_yaml()
        self.model, self.results = self.train_yolo()

    def create_dataset_structure(self):
        """Create YOLO dataset directory structure"""
        for split in ['train', 'val']:
            for subdir in ['images', 'labels']:
                path = os.path.join(self.output_dir, split, subdir)
                os.makedirs(path, exist_ok=True)

    def process_instance_coordinates(self):
        """
        Process coordinates for Axial images, creating separate records for each level
        that has both left and right coordinates
        Returns list of records with coordinates for each complete level
        """
        result_records = []

        # Group by instance to process each image separately
        for (study_id, series_id, instance_number), instance_data in self.df.groupby(['study_id', 'series_id', 'instance_number']):
            # Process each level separately
            levels_data = {}
            
            for _, row in instance_data.iterrows():
                level = row['level']
                condition = row['condition']
                x, y = row['x'], row['y']
                
                if level not in levels_data:
                    levels_data[level] = {'left': None, 'right': None}
                
                if 'Left' in condition:
                    levels_data[level]['left'] = (x, y)
                elif 'Right' in condition:
                    levels_data[level]['right'] = (x, y)
            
            # Only create records for levels with both coordinates
            for level, coords in levels_data.items():
                if coords['left'] is not None and coords['right'] is not None:
                    record = {
                        'study_id': study_id,
                        'series_id': series_id,
                        'instance_number': instance_number,
                        'level': level,
                        'left_coord': coords['left'],
                        'right_coord': coords['right']
                    }
                    result_records.append(record)
        
        # Print statistics
        print("\nDataset Statistics:")
        total_complete = len(result_records)
        
        print(f"Total complete level pairs: {total_complete}")
        print("\nBy level statistics:")
        
        for level in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']:
            level_count = sum(1 for r in result_records if r['level'] == level)
            if total_complete > 0:
                percentage = (level_count / total_complete) * 100
            else:
                percentage = 0
            print(f"{level}: {level_count} complete pairs ({percentage:.1f}%)")
        
        return result_records

    def create_yolo_annotation(self, record, image_width, image_height):
        """
        Create YOLO format annotations for a specific level
        Returns list of annotations for left and right sides
        """
        annotations = []
        box_width = 0.05
        box_height = 0.05
        
        # Both coordinates must be present
        x, y = record['left_coord']
        x_norm = x / image_width
        y_norm = y / image_height
        annotations.append(f"0 {x_norm:.6f} {y_norm:.6f} {box_width:.6f} {box_height:.6f}")
        
        x, y = record['right_coord']
        x_norm = x / image_width
        y_norm = y / image_height
        annotations.append(f"1 {x_norm:.6f} {y_norm:.6f} {box_width:.6f} {box_height:.6f}")
        
        return annotations

    def process_spine_dataset(self):
        """Process and save dataset in YOLO format"""
        # Split studies
        studies = set(record['study_id'] for record in self.processed_records)
        train_studies, val_studies = train_test_split(list(studies), train_size=0.8, random_state=42)
        
        processed_counts = {'train': 0, 'val': 0}
        failed_cases = []
        
        for record in tqdm(self.processed_records, desc="Processing Axial images"):
            try:
                study_id = str(int(record['study_id']))
                series_id = str(int(record['series_id']))
                instance_number = str(int(record['instance_number']))
                level = record['level'].lower().replace('/', '_')
                
                # Construct image path
                img_path = os.path.join(self.images_dir, study_id, series_id, f"{instance_number}.dcm")
                if not os.path.exists(img_path):
                    img_path = os.path.join(self.images_dir, study_id, series_id, instance_number)
                    if os.path.exists(img_path + '.dcm'):
                        img_path = img_path + '.dcm'
                    else:
                        raise FileNotFoundError(f"Image not found: {img_path}")
                
                # Read and process image
                ds = pydicom.dcmread(img_path)
                image = ds.pixel_array
                h, w = image.shape
                
                # Create annotations
                annotations = self.create_yolo_annotation(record, w, h)
                
                # Prepare image
                image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                image_resized = cv2.resize(image_normalized, (self.img_size, self.img_size))
                
                # Determine split
                is_train = record['study_id'] in train_studies
                split = 'train' if is_train else 'val'
                
                # Create unique filenames including level information
                base_filename = f"{study_id}_{series_id}_{instance_number}_{level}"
                img_filename = f"{base_filename}.png"
                label_filename = f"{base_filename}.txt"
                
                # Save files
                cv2.imwrite(os.path.join(self.output_dir, split, 'images', img_filename), 
                           image_resized)
                with open(os.path.join(self.output_dir, split, 'labels', label_filename), 'w') as f:
                    f.write('\n'.join(annotations))
                
                processed_counts[split] += 1
                
            except Exception as e:
                failed_cases.append((study_id, series_id, instance_number, str(e)))
        
        print(f"\nProcessing Summary:")
        print(f"Training images: {processed_counts['train']}")
        print(f"Validation images: {processed_counts['val']}")
        
        if failed_cases:
            print("\nFailed cases:")
            for case in failed_cases:
                print(f"Study {case[0]}, Series {case[1]}, Instance {case[2]}: {case[3]}")
        
        return pd.DataFrame(self.processed_records)

    def create_dataset_yaml(self):
        """Create YOLO dataset configuration file"""
        yaml_content = {
            'path': os.path.abspath(self.output_dir),
            'train': 'train/images',
            'val': 'val/images',
            'nc': 2,  # number of classes (left and right)
            'names': {
                0: 'left',
                1: 'right'
            }
        }

        yaml_path = os.path.join(self.output_dir, 'dataset.yaml')
        with open(yaml_path, 'w') as f:
            yaml.dump(yaml_content, f, sort_keys=False)

        return yaml_path

    def train_yolo(self):
        """Train YOLO model"""
        try:
            model = YOLO('yolov8x.pt')
            
            config = {
                'data': self.yaml_path,
                'imgsz': self.img_size,
                'batch': 8,
                'epochs': 1,
                'patience': 5,
                'device': 'mps',
                'workers': 8,
                'project': 'spine_detection',
                'name': 'axial_t2_yolo',
                'exist_ok': True,
                'pretrained': True,
                'optimizer': 'AdamW',
                'verbose': True,
                'seed': 42,
                'deterministic': True,
                'dropout': 0.2,
                'lr0': 0.001,
                'lrf': 0.01,
                'momentum': 0.937,
                'weight_decay': 0.0005,
                'warmup_epochs': 10,
                'warmup_momentum': 0.8,
                'box': 7.5,
                'cls': 0.5,
                'dfl': 1.5,
                'close_mosaic': 10,
                'amp': True
            }
            
            results = model.train(**config)
            return model, results
            
        except Exception as e:
            print(f"Error training model: {str(e)}")
            return None, None


In [88]:
# Axial T2
images_dir  = '/Users/danipopov/Projects/RSNA2024/data/train_images'
output_dir = '/Users/danipopov/Projects/RSNA2024/data/spine_dataset_axialt2'

axial_t2_model = AxialT2_YOLO(
    df=axialt2_df,
    images_dir=images_dir,
    output_dir=output_dir,
    img_size=384
)

In [89]:
# Remove all the folders execpt the weights
shutil.rmtree("/Users/danipopov/Projects/RSNA2024/spine_detection/axial_t2_yolo")
shutil.rmtree("/Users/danipopov/Projects/RSNA2024/data/spine_dataset_axialt2")

In [201]:
def plot_axial_predictions(image_path, model_path='axial_t2_spine_detector.pt', 
                         conf_threshold=0.25, iou_threshold=0.45, img_size=384):
    """
    Plot YOLO predictions for axial images showing left and right sides
    
    Args:
        image_path: Path to DICOM image
        model_path: Path to saved YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
        img_size: Image size for model input
    """
    # Load model
    model = YOLO(model_path)
    
    # Read DICOM
    ds = pydicom.dcmread(image_path)
    image = ds.pixel_array
    
    # Normalize and resize
    image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    image_resized = cv2.resize(image_normalized, (img_size, img_size))
    
    # Convert grayscale to RGB
    image_rgb = np.stack([image_resized] * 3, axis=-1)
    
    # Create figure
    plt.figure(figsize=(10, 7))
    
    # Plot original image
    plt.subplot(1, 2, 1)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot image with predictions
    plt.subplot(1, 2, 2)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Predictions')
    
    # Get predictions
    results = model.predict(
        source=image_rgb,
        conf=conf_threshold,
        iou=iou_threshold
    )
    
    # Define colors and names for left/right sides
    side_colors = {'left': 'red', 'right': 'blue'}
    side_names = {0: 'Left', 1: 'Right'}
    
    if results[0].boxes is not None:
        boxes = results[0].boxes.cpu().numpy()
        
        # Sort boxes by x-coordinate (left to right)
        box_data = []
        for box in boxes:
            cls_id = int(box.cls[0])
            conf = box.conf[0]
            x1, y1, x2, y2 = box.xyxy[0]
            box_data.append((x1, cls_id, conf, x1, y1, x2, y2))
        
        box_data.sort()  # Sort by x1 coordinate
        
        # Plot each detection
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            color = side_colors['left'] if cls_id == 0 else side_colors['right']
            side_name = side_names[cls_id]
            
            # Draw bounding box
            plt.gca().add_patch(plt.Rectangle(
                (x1, y1), x2-x1, y2-y1,
                fill=False, color=color, linewidth=2
            ))
            
            # Add label
            plt.text(
                x2 + 5, (y1 + y2) / 2, 
                f'{side_name}: {conf:.2f}',
                color=color, fontsize=8, verticalalignment='center',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none')
            )
            
            # Print detection info
            print(f"Found {side_name} side with confidence {conf:.2f}")
    else:
        print("No detections found")
    
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [249]:
def process_multiple_axial_images(image_paths, model_path='axial_t2_spine_detector.pt', 
                                conf_threshold=0.15, iou_threshold=0.3):
    """
    Process multiple axial images and display their predictions
    
    Args:
        image_paths: List of paths to DICOM images
        model_path: Path to saved YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
    """
    # First, filter images that have predictions
    valid_images = []
    model = YOLO(model_path)
    
    print("Analyzing images...")
    for img_path in tqdm(image_paths):
        ds = pydicom.dcmread(img_path)
        image = ds.pixel_array
        image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        image_resized = cv2.resize(image_normalized, (384, 384))
        image_rgb = np.stack([image_resized] * 3, axis=-1)
        
        results = model.predict(
            source=image_rgb,
            conf=conf_threshold,
            iou=iou_threshold
        )
        
        if results[0].boxes is not None and len(results[0].boxes) > 0:
            valid_images.append(img_path)
    
    if not valid_images:
        print("No images with valid predictions found")
        return
    
    print(f"\nFound {len(valid_images)} images with valid predictions")
    
    # Process only images with predictions
    for idx, img_path in enumerate(valid_images):
        if idx == 5:
            break
        print(f"\nProcessing image: {os.path.basename(img_path)}")
        plot_axial_predictions(img_path, model_path, conf_threshold, iou_threshold)

In [250]:
# Usage example for Axial T2:
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3201256954/*.dcm')
process_multiple_axial_images(image_paths, '/Users/danipopov/Projects/RSNA2024/models/axial_t2_spine_detector.pt')

Next we will want to see the roi's of spine levels and right and left sides of disk in axial plane. 

In [204]:
def plot_rois_sigattal(image_path, model_path, 
                            conf_threshold=0.25, iou_threshold=0.45, 
                            img_size=384, roi_size=64):
    """
    Plot Original Image, Image with predictions, and ROIs of the detections
    
    Args:
        image_path: Path to DICOM image
        model_path: Path to YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
        img_size: Size for model input
        roi_size: Size of ROI crops
    """
    # Load model
    model = YOLO(model_path)
    
    # Read DICOM
    ds = pydicom.dcmread(image_path)
    image = ds.pixel_array
    
    # Normalize and resize
    image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    image_resized = cv2.resize(image_normalized, (img_size, img_size))
    
    # Convert grayscale to RGB
    image_rgb = np.stack([image_resized] * 3, axis=-1)
    
    # Get predictions
    results = model.predict(
        source=image_rgb,
        conf=conf_threshold,
        iou=iou_threshold
    )
    
    # Define colors and names
    colors = ['red', 'green', 'blue', 'yellow', 'purple']
    level_names = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
    
    # Get number of detections for subplot layout
    num_detections = len(results[0].boxes) if results[0].boxes is not None else 0
    
    if num_detections == 0:
        print("No detections found")
        return
    
    # Create figure with larger size
    plt.figure(figsize=(15, 7))
    
    # Plot original image
    plt.subplot(1, 2, 1)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot image with predictions
    plt.subplot(1, 2, 2)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Predictions')
    plt.axis('off')
    
    if results[0].boxes is not None:
        boxes = results[0].boxes.cpu().numpy()
        
        # Sort boxes by y-coordinate
        box_data = []
        for box in boxes:
            cls_id = int(box.cls[0])
            conf = box.conf[0]
            x1, y1, x2, y2 = box.xyxy[0]
            box_data.append((y1, cls_id, conf, x1, y1, x2, y2))
        
        box_data.sort()  # Sort by y1 coordinate
    
        # Plot each detection with larger text and boxes
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            color = colors[cls_id]
            level_name = level_names[cls_id]
            
            # Draw bounding box with thicker line
            plt.gca().add_patch(plt.Rectangle(
                (x1, y1), x2-x1, y2-y1,
                fill=False, color=color, linewidth=3
            ))
            
            # Add label with larger font and better positioning
            plt.text(
                x1, y1 - 10, 
                f'{level_name}: {conf:.2f}',
                color=color, fontsize=10, 
                bbox=dict(facecolor='white', alpha=0.8, edgecolor=color, pad=2)
            )
            
            print(f"Found {level_name} with confidence {conf:.2f}")
    
    plt.tight_layout()
    plt.show()
    
    # Create a separate figure for ROIs with larger size
    if num_detections > 0:
        # Calculate grid dimensions
        n_cols = min(3, num_detections)
        n_rows = (num_detections + n_cols - 1) // n_cols
        
        plt.figure(figsize=(15, 5 * n_rows))
        
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            plt.subplot(n_rows, n_cols, i + 1)
            
            # Calculate ROI coordinates
            center_x = int((x1 + x2) / 2)
            center_y = int((y1 + y2) / 2)
            half_size = roi_size // 2
            
            # Extract ROI
            roi = image_resized[
                max(0, center_y - half_size):min(img_size, center_y + half_size),
                max(0, center_x - half_size):min(img_size, center_x + half_size)
            ]
            
            # Resize ROI if needed
            if roi.shape[0] != roi_size or roi.shape[1] != roi_size:
                roi = cv2.resize(roi, (roi_size, roi_size))
            
            plt.imshow(roi, cmap='gray')
            plt.title(f'ROI: {level_names[cls_id]}\nConfidence: {conf:.2f}')
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()

In [205]:
# Usage example Sagittal T1
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3486248476/*.dcm')
plot_rois_sigattal(image_paths[0] ,'/Users/danipopov/Projects/RSNA2024/models/sagittal_t1_spine_detector.pt', conf_threshold=0.25, iou_threshold=0.45, img_size=384, roi_size=64)


In [206]:
# Usage example Sagittal T2/STIR
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3666319702/*.dcm')
plot_rois_sigattal(image_paths[4] ,'/Users/danipopov/Projects/RSNA2024/models/sagittal_t2_spine_detector.pt', conf_threshold=0.25, iou_threshold=0.45, img_size=384, roi_size=64)


In [207]:
def plot_rois_axial(image_path, model_path, 
                            conf_threshold=0.25, iou_threshold=0.45, 
                            img_size=384, roi_size=64):
    """
    Plot Original Image, Image with predictions, and ROIs of the detections
    
    Args:
        image_path: Path to DICOM image
        model_path: Path to YOLO model
        conf_threshold: Confidence threshold for predictions
        iou_threshold: IOU threshold for NMS
        img_size: Size for model input
        roi_size: Size of ROI crops
    """
    # Load model
    model = YOLO(model_path)
    
    # Read DICOM
    ds = pydicom.dcmread(image_path)
    image = ds.pixel_array
    
    # Normalize and resize
    image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    image_resized = cv2.resize(image_normalized, (img_size, img_size))
    
    # Convert grayscale to RGB
    image_rgb = np.stack([image_resized] * 3, axis=-1)
    
    # Get predictions
    results = model.predict(
        source=image_rgb,
        conf=conf_threshold,
        iou=iou_threshold
    )
    
    # Define colors and names
    colors = ['red', 'blue']
    level_names = ['Left', 'Right']
    
    # Get number of detections for subplot layout
    num_detections = len(results[0].boxes) if results[0].boxes is not None else 0
    
    if num_detections == 0:
        print("No detections found")
        return
    
    # Create figure with larger size
    plt.figure(figsize=(15, 7))
    
    # Plot original image
    plt.subplot(1, 2, 1)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot image with predictions
    plt.subplot(1, 2, 2)
    plt.imshow(image_resized, cmap='gray')
    plt.title('Predictions')
    plt.axis('off')
    
    if results[0].boxes is not None:
        boxes = results[0].boxes.cpu().numpy()
        
        # Sort boxes by y-coordinate
        box_data = []
        for box in boxes:
            cls_id = int(box.cls[0])
            conf = box.conf[0]
            x1, y1, x2, y2 = box.xyxy[0]
            box_data.append((y1, cls_id, conf, x1, y1, x2, y2))
        
        box_data.sort()  # Sort by y1 coordinate
    
        # Plot each detection with larger text and boxes
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            color = colors[cls_id]
            level_name = level_names[cls_id]
            
            # Draw bounding box with thicker line
            plt.gca().add_patch(plt.Rectangle(
                (x1, y1), x2-x1, y2-y1,
                fill=False, color=color, linewidth=3
            ))
            
            # Add label with larger font and better positioning
            plt.text(
                x1, y1 - 10, 
                f'{level_name}: {conf:.2f}',
                color=color, fontsize=10, 
                bbox=dict(facecolor='white', alpha=0.8, edgecolor=color, pad=2)
            )
            
            print(f"Found {level_name} with confidence {conf:.2f}")
    
    plt.tight_layout()
    plt.show()
    
    # Create a separate figure for ROIs with larger size
    if num_detections > 0:
        # Calculate grid dimensions
        n_cols = 2
        n_rows = 1
        
        plt.figure(figsize=(15, 5 * n_rows))
        
        for i, (_, cls_id, conf, x1, y1, x2, y2) in enumerate(box_data):
            plt.subplot(n_rows, n_cols, i + 1)
            
            # Calculate ROI coordinates
            center_x = int((x1 + x2) / 2)
            center_y = int((y1 + y2) / 2)
            half_size = roi_size // 2
            
            # Extract ROI
            roi = image_resized[
                max(0, center_y - half_size):min(img_size, center_y + half_size),
                max(0, center_x - half_size):min(img_size, center_x + half_size)
            ]
            
            # Resize ROI if needed
            if roi.shape[0] != roi_size or roi.shape[1] != roi_size:
                roi = cv2.resize(roi, (roi_size, roi_size))
            
            plt.imshow(roi, cmap='gray')
            plt.title(f'ROI: {level_names[cls_id]}\nConfidence: {conf:.2f}')
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()

In [208]:
# Usage example Axial T2
image_paths = glob.glob('/Users/danipopov/Projects/RSNA2024/data/train_images/4646740/3201256954/*.dcm')
plot_rois_axial(image_paths[1] ,'/Users/danipopov/Projects/RSNA2024/models/axial_t2_spine_detector.pt', conf_threshold=0.25, iou_threshold=0.45, img_size=384, roi_size=96)


### Dataset Organization for Model Training

Before training our models, we need to build a structured dataset. We will use the `train_df` DataFrame and our YOLO models to process each `study_id`. For each study, we will:

1. Select 5 representative images from each plane (Axial T2, Sagittal T1, and Sagittal T2/STIR)
2. Extract ROIs (Regions of Interest) for:
   - Spine levels in sagittal planes
   - Left and right sides of the disk in axial plane
3. Only include images where we can successfully extract all ROIs (5 spine levels plus left and right sides)

The dataset will be organized in the following directory structure:

```
images_dataset/
  └── study_id/
      ├── Axial_T2/
      │     ├── img_1/
      │     │     ├── 1.png
      │     │     └── 2.png
      │     │     └── 3.png
      │     │     └── 4.png
      │     │     └── 5.png
      │     │
      │     └── img_2/
      │           ├── 1.png
      │           └── 2.png
      │           └── 3.png
      │           └── 4.png
      │           └── 5.png
      └── Sagittal_T1/
            ├── img_1/
            │     ├── 1.png
            │     └── 2.png
            │     └── 3.png
            │     └── 4.png
            │     └── 5.png
            └── img_2/
                  ├── 1.png
                  └── 2.png
                  └── 3.png
                  └── 4.png
                  └── 5.png
        └── Sagittal_T2/
            ├── img_1/
            │     ├── 1.png
            │     └── 2.png
            │     └── 3.png
            │     └── 4.png
            │     └── 5.png
            └── img_2/
                  ├── 1.png
                  └── 2.png
                  └── 3.png
                  └── 4.png
                  └── 5.png

``'


In [79]:
def extract_rois(image, boxes, roi_size=96):
    rois = [] 
    for box in boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
        roi = image[
            max(0, center_y - roi_size // 2):min(image.shape[0], center_y + roi_size // 2),
            max(0, center_x - roi_size // 2):min(image.shape[1], center_x + roi_size // 2)
        ]
        if roi.shape[0] != roi_size or roi.shape[1] != roi_size:
            roi = cv2.resize(roi, (roi_size, roi_size))
        rois.append(roi)
    return rois

In [80]:
def process_study(study_id, meta_df, src_path, dst_path, axial_model, sagittal_model_1, sagittal_model_2):
    study_meta = meta_df[int(study_id)]
    series_dict = {desc.replace('/', '_'): [] for desc in ['Axial T2', 'Sagittal T1', 'Sagittal T2/STIR']}
    
    # First, validate that we have enough valid images for all series
    valid_study = True
    
    for series_id, desc in zip(study_meta['series_ids'], study_meta['series_descriptions']):
        series_dict[desc.replace('/', '_')].append(series_id)
    
    for desc, series_list in series_dict.items():
        # Collect all images for this description
        all_images = []
        for series_id in series_list:
            series_path = os.path.join(src_path, study_id, series_id)
            all_images.extend(glob.glob(os.path.join(series_path, '*.dcm')))
        
        if len(all_images) == 0:
            valid_study = False
            print(f"No images found for {study_id} - {desc}")
            break
            
        # Sort images and get model for this description
        all_images.sort()
        if 'Axial' in desc:
            model = axial_model
            required_detections = 2  # Left and right
        elif 'T1' in desc:
            model = sagittal_model_1
            required_detections = 5  # All spine levels
        else:  # 'Sagittal T2/STIR'
            model = sagittal_model_2
            required_detections = 5  # All spine levels
            
        # Find valid images (those with all required detections)
        valid_images = []
        for img_path in all_images:
            ds = pydicom.dcmread(img_path)
            image = ds.pixel_array
            image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
            image_resized = cv2.resize(image_normalized, (384, 384))
            image_rgb = np.stack([image_resized] * 3, axis=-1)
            
            results = model.predict(source=image_rgb, conf=0.25, iou=0.45)
            
            if results[0].boxes is not None and len(results[0].boxes) == required_detections:
                valid_images.append(img_path)
        
        # Check if we have enough valid images
        if len(valid_images) < 5:
            valid_study = False
            print(f"Not enough valid images for {study_id} - {desc}. Found {len(valid_images)}, need 5")
            break
    
    # If study is not valid, return without creating any directories or saving images
    if not valid_study:
        return False
        
    # If we get here, we have enough valid images for all series
    # Now process and save the images
    for desc, series_list in series_dict.items():
        desc_path = os.path.join(dst_path, study_id, desc)
        os.makedirs(desc_path, exist_ok=True)
        
        all_images = []
        for series_id in series_list:
            series_path = os.path.join(src_path, study_id, series_id)
            all_images.extend(glob.glob(os.path.join(series_path, '*.dcm')))
            
        all_images.sort()
        
        if 'Axial' in desc:
            model = axial_model
            required_detections = 2
        elif 'T1' in desc:
            model = sagittal_model_1
            required_detections = 5
        else:
            model = sagittal_model_2
            required_detections = 5
            
        valid_images = []
        for img_path in all_images:
            ds = pydicom.dcmread(img_path)
            image = ds.pixel_array
            image_normalized = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
            image_resized = cv2.resize(image_normalized, (384, 384))
            image_rgb = np.stack([image_resized] * 3, axis=-1)
            
            results = model.predict(source=image_rgb, conf=0.25, iou=0.45)
            
            if results[0].boxes is not None and len(results[0].boxes) == required_detections:
                valid_images.append((img_path, image_resized, results[0].boxes))
                
        # Select 5 evenly spaced images from valid images
        indices = np.linspace(0, len(valid_images)-1, 5, dtype=int)
        selected_images = [valid_images[i] for i in indices]
        
        # Process selected images
        for idx, (img_path, image_resized, boxes) in enumerate(selected_images):
            rois = extract_rois(image_resized, boxes)
            
            img_folder = os.path.join(desc_path, f'img_{idx+1}')
            os.makedirs(img_folder, exist_ok=True)
            
            for roi_idx, roi in enumerate(rois):
                cv2.imwrite(os.path.join(img_folder, f'{roi_idx+1}.png'), roi)
    
    return True

In [81]:
def create_dataset(train_df, meta_df, src_path, dst_path, axial_model_path, sagittal_model_path_1, sagittal_model_path_2):
    axial_model = YOLO(axial_model_path)
    sagittal_model_1 = YOLO(sagittal_model_path_1)
    sagittal_model_2 = YOLO(sagittal_model_path_2)
    
    valid_studies = []
    for i, study_id in enumerate(tqdm(train_df['study_id'].unique())):
        if process_study(str(study_id), meta_df, src_path, dst_path, 
                        axial_model, sagittal_model_1, sagittal_model_2):
            valid_studies.append(study_id)
        
        if i % 1 == 0:  # Clear output every iteration
            clear_output(wait=True)
    
    print(f"Processed {len(valid_studies)} valid studies out of {len(train_df['study_id'].unique())}")
    return valid_studies
            
# Usage
src_path = '/Users/danipopov/Projects/RSNA2024/data/train_images'
dst_path = '/Users/danipopov/Projects/RSNA2024/data/images_dataset'
axial_model_path = '/Users/danipopov/Projects/RSNA2024/models/axial_t2_spine_detector.pt'
sagittal_model_path_1 = '/Users/danipopov/Projects/RSNA2024/models/sagittal_t1_spine_detector.pt'
sagittal_model_path_2 = '/Users/danipopov/Projects/RSNA2024/models/sagittal_t2_spine_detector.pt'

valid_studies = create_dataset(train_df, meta_df, src_path, dst_path, 
                             axial_model_path, sagittal_model_path_1, sagittal_model_path_2)

In [81]:
valid_studies = glob.glob(f'/Users/danipopov/Projects/RSNA2024/data/images_dataset/*')
print(f"Processed {len(valid_studies)} valid studies out of {len(train_df['study_id'].unique())}")

Now let's examine our newly created dataset, which includes 1,825 study IDs. For each study, we have extracted 5 images per plane (Sagittal T1, Sagittal T2/STIR, and Axial T2), along with their corresponding Regions of Interest (ROIs). This examination will help us validate the quality and consistency of our preprocessed data.

In [82]:
# Checking the image_dataset 
image_dir = '/Users/danipopov/Projects/RSNA2024/data/images_dataset'
study_dir = os.path.join(image_dir, '293713262')
# Sagittal T1
sagt1_dir = os.path.join(study_dir, 'Sagittal T1')
image_dir_sag1 = sorted(glob.glob(os.path.join(sagt1_dir, 'img_*')))
roi_dir_sag1 = sorted(glob.glob(os.path.join(image_dir_sag1[0], '*.png')))
# Sagittal T2/STIR
sagt2_dir = os.path.join(study_dir, 'Sagittal T2_STIR')
image_dir_sag2 = sorted(glob.glob(os.path.join(sagt2_dir, 'img_*')))
roi_dir_sag2 = sorted(glob.glob(os.path.join(image_dir_sag2[0], '*.png')))
# Axial T2
axial_dir = os.path.join(study_dir, 'Axial T2')
image_dir_axial = sorted(glob.glob(os.path.join(axial_dir, 'img_*')))
roi_dir_axial = sorted(glob.glob(os.path.join(image_dir_axial[0], '*.png')))

# YOLO model weights
axial_model_path = '/Users/danipopov/Projects/RSNA2024/models/axial_t2_spine_detector.pt'
sagittal_model_path_1 = '/Users/danipopov/Projects/RSNA2024/models/sagittal_t1_spine_detector.pt'
sagittal_model_path_2 = '/Users/danipopov/Projects/RSNA2024/models/sagittal_t2_spine_detector.pt'

In [84]:
def show_rois(rois_sagt1, rois_sagt2, rois_axialt2):
    # Create figure with 3 rows (one for each plane)
    plt.figure(figsize=(15, 12))
    
    # Plot Sagittal T1 ROIs
    plt.subplot(3, 1, 1)
    plt.suptitle('ROIs for Different Planes', fontsize=16, y=0.95)
    for i, roi_path in enumerate(rois_sagt1):
        plt.subplot(3, 5, i+1)
        roi = cv2.imread(roi_path, cv2.IMREAD_GRAYSCALE)
        plt.imshow(roi, cmap='gray')
        plt.title(f'L{i+1}/L{i+2}' if i < 4 else 'L5/S1', fontsize=10)
        plt.axis('off')
    plt.subplot(3, 5, 3).set_title('Sagittal T1', fontsize=12, pad=20)
    
    # Plot Sagittal T2/STIR ROIs
    for i, roi_path in enumerate(rois_sagt2):
        plt.subplot(3, 5, i+6)
        roi = cv2.imread(roi_path, cv2.IMREAD_GRAYSCALE)
        plt.imshow(roi, cmap='gray')
        plt.title(f'L{i+1}/L{i+2}' if i < 4 else 'L5/S1', fontsize=10)
        plt.axis('off')
    plt.subplot(3, 5, 8).set_title('Sagittal T2/STIR', fontsize=12, pad=20)
    
    # Plot Axial T2 ROIs
    for i, roi_path in enumerate(rois_axialt2):
        plt.subplot(3, 5, i+11)
        roi = cv2.imread(roi_path, cv2.IMREAD_GRAYSCALE)
        plt.imshow(roi, cmap='gray')
        plt.title('Left' if i == 0 else 'Right', fontsize=10)
        plt.axis('off')
    plt.subplot(3, 5, 13).set_title('Axial T2', fontsize=12, pad=20)
    
    plt.tight_layout()
    plt.show()

# Usage
show_rois(roi_dir_sag1, roi_dir_sag2, roi_dir_axial)

### Dataset Class and DataLoader

Now that we have created the dataset, we can proceed to the next step: building our dataset class and DataLoader with PyTorch for subsequent model training. This will provide an efficient way to load and batch our preprocessed images during the training process.

In [121]:
class SpineDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        """
        Args:
            df: DataFrame containing labels
            image_dir: Root directory containing the ROI images
            transform: Optional transforms to apply to images
        """
        self.df = df
        self.image_dir = image_dir
        self.transform = transform
        self.series_descriptions = ['Axial T2', 'Sagittal T1', 'Sagittal T2_STIR']
        # Get list of valid study IDs (directories in image_dir)
        self.valid_studies = [d for d in os.listdir(image_dir) 
                            if os.path.isdir(os.path.join(image_dir, d))]
                
    def __len__(self):
        return len(self.valid_studies)
        
    def load_series_rois(self, study_path, desc):
        """
        Load ROIs for a specific series description
        Returns: numpy array of ROIs [num_images * num_rois_per_image, H, W]
        """
        series_path = os.path.join(study_path, desc)
        all_rois = []
        
        # Get all image folders (img_1, img_2, etc.)
        img_folders = sorted(glob.glob(os.path.join(series_path, 'img_*')))
        
        for img_folder in img_folders:
            # Get all ROIs for this image
            roi_paths = sorted(glob.glob(os.path.join(img_folder, '*.png')))
            for roi_path in roi_paths:
                roi = cv2.imread(roi_path, cv2.IMREAD_GRAYSCALE)
                if self.transform:
                    roi = self.transform(image=roi)['image']
                all_rois.append(roi)
                
        return np.stack(all_rois)

    def __getitem__(self, idx):
        study_id = self.valid_studies[idx]
        study_path = os.path.join(self.image_dir, study_id)
        
        # Get labels for this study
        labels = self.df[self.df['study_id'] == int(study_id)].iloc[0, 1:].values.astype(np.int64)
        
        # Load ROIs for each series
        all_rois = []
        for desc in self.series_descriptions:
            series_rois = self.load_series_rois(study_path, desc)
            all_rois.append(series_rois)
            
        # Stack all ROIs together
        rois = np.concatenate(all_rois, axis=0)
        
        # Convert to torch tensors
        rois = torch.FloatTensor(rois)
        labels = torch.LongTensor(labels)
        
        return rois, labels 

In [88]:
train_df.isna().sum()

We didn't forget that our train_df contains missing values and categorical data. Since we'll be using cross-entropy loss for training our models, we need to convert these categorical values to integers to represent each label for each study_id. Additionally, we'll fill the missing values with -100, which is the default padding value for cross-entropy loss.

In [95]:
# Create a copy of the training DataFrame
train_df_copy = train_df.copy()

# Print initial missing values statistics
print("\nMissing values before filling:")
print(train_df_copy.isna().sum())

# Fill missing values with -100
train_df_copy.fillna(-100, inplace=True)

# Map categorical values to integers
label_map = {'Normal/Mild': 1, 'Moderate': 2, 'Severe': 3}
train_df_copy = train_df_copy.replace(label_map)

# Verify no missing values remain
print("\nMissing values after filling:")
print(train_df_copy.isna().sum())

# Display first few rows of processed DataFrame
print("\nFirst few rows of processed DataFrame:")
display(train_df_copy.head())

In [134]:
# Example usage:
train_transform = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.75),
    
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
    ], p=0.75),
    
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.75),
    A.Normalize(mean=0.5, std=0.5),
])

test_transform = A.Compose([
    A.Normalize(mean=0.5, std=0.5),
])

dataset = SpineDataset(
    df=train_df_copy,
    image_dir='/Users/danipopov/Projects/RSNA2024/data/images_dataset',
    transform=train_transform
)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

# Test the dataloader
images, labels = next(iter(dataloader))
print(f"Batch images shape: {images.shape}")
print(f"Batch labels shape: {labels.shape}")

Let's examine the first batch of our dataloader.

In [136]:
def show_batch(images, labels):
    """
    Display all ROIs from a batch
    Args:
        images: Tensor of shape [B, 60, 64, 64]
        labels: Tensor of shape [B, 25]
    """
    # Get first batch item [60, 64, 64]
    images = images[0]  

    # Create figure with subplots
    fig = plt.figure(figsize=(20, 12))
    plt.suptitle('ROIs from Different Planes', fontsize=16)
    
    # Plot Axial T2 ROIs (first 10 images: 5 pairs of left/right)
    for i in range(10):
        plt.subplot(6, 10, i + 1)
        plt.imshow(images[i].numpy(), cmap='gray')
        plt.title(f'Axial: {"Left" if i%2==0 else "Right"}\nImg {i//2+1}', fontsize=8)
        plt.axis('off')
    
    # Plot Sagittal T1 ROIs (next 25 images: 5 images × 5 levels)
    for i in range(25):
        plt.subplot(6, 10, i + 11)
        plt.imshow(images[i+10].numpy(), cmap='gray')
        level = f'L{i%5+1}/L{i%5+2}' if i%5 < 4 else 'L5/S1'
        plt.title(f'Sag T1: {level}\nImg {i//5+1}', fontsize=8)
        plt.axis('off')
    
    # Plot Sagittal T2 ROIs (last 25 images: 5 images × 5 levels)
    for i in range(25):
        plt.subplot(6, 10, i + 36)
        plt.imshow(images[i+35].numpy(), cmap='gray')
        level = f'L{i%5+1}/L{i%5+2}' if i%5 < 4 else 'L5/S1'
        plt.title(f'Sag T2: {level}\nImg {i//5+1}', fontsize=8)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print labels for the batch
    print("Labels:", labels[0])

# Test the visualization
images, labels = next(iter(dataloader))
show_batch(images, labels)

Great! 😁 We can see that our DataLoader is working as expected, providing us with properly organized ROIs for each plane and their corresponding labels for each study_id.

Now we can proceed to the next step: designing our model architecture. Given the nature of our data (60 ROIs of size 64x64 per study) and our classification task, I will use the ResNet50 architecture from timm library.

timm (torch image models) library is a collection of state-of-the-art computer vision models for PyTorch from Hugging Face. It provides:
- Pre-trained models optimized for various tasks
- Easy-to-use interfaces for model customization
- Consistent API across different architectures
- Regular updates with new models and improvements

ResNet50 is an excellent choice for our project because:
1. **Powerful Architecture**:
   - Deeper network with 50 layers
   - Enhanced feature extraction capability
   - Bottleneck blocks for efficient computation
   - Strong skip connections to prevent vanishing gradients

2. **Feature Extraction**:
   - More sophisticated feature hierarchies
   - Better at capturing complex patterns
   - Proven success in medical imaging tasks
   - Can handle our 64x64 ROI size efficiently

3. **Practical Benefits**:
   - Better performance than ResNet18
   - Good balance of depth and computational cost
   - Still manageable training time
   - Extensive pretrained weights available
   - Well-documented with strong community support

### Model Architecture

In [137]:
# Create pretrained ResNet50 as starting point
model = timm.create_model('resnet50.a1_in1k', pretrained=True)

# Modify first conv layer to accept 60 channels
model.conv1 = nn.Conv2d(60, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Replace the final classification layer
# 75 because we have 25 labels and for each label we will want to predict 
# the probability of the label being present or absent (25 * 3 = 75)
model.fc = nn.Linear(model.fc.in_features, 75)

# Initialize the new layers
model.conv1.weight.data.normal_(0, 0.01)
model.fc.bias.data.fill_(0)
model.fc.weight.data.normal_(0, 0.01)

# All layers will be trainable
for name, param in model.named_parameters():
    param.requires_grad = True

### Evaluateion metrics

Before we proceed to the training process, we will need to define what metrics we will use to evaluate our model's performance on this dataset.

Metrics:
1. **Accuracy**: Overall correctness of the model's predictions
2. **Precision**: Proportion of true positives among all predicted positives
3. **Recall**: Proportion of true positives that were correctly identified
4. **F1-score**: Harmonic mean of precision and recall
5. **AUC-ROC**: Area under the Receiver Operating Characteristic curve
6. **Confusion Matrix**: Visualization of true vs. false positives and negatives for each class

These metrics will help us comprehensively evaluate our model's performance across different aspects of classification accuracy and reliability.

In [179]:
def calculate_metrics(preds, targets):
    """
    Calculate accuracy, precision, recall, and F1 score for a single column
    Args:
        preds: tensor of shape (batch_size, 3) - logits for one column
        targets: tensor of shape (batch_size,) - ground truth for one column
    Returns:
        accuracy, precision, recall, f1
    """
    device = preds.device
    preds = torch.argmax(preds, dim=1)  # Convert logits to predictions
    correct = (preds == targets).float()
    accuracy = correct.mean()
    
    precision = []
    recall = []
    f1 = []
    
    for class_id in range(3):  # For each class (Normal, Mild, Severe)
        true_positives = ((preds == class_id) & (targets == class_id)).float().sum()
        predicted_positives = (preds == class_id).float().sum()
        actual_positives = (targets == class_id).float().sum()
        
        # Calculate metrics
        class_precision = true_positives / predicted_positives if predicted_positives > 0 else torch.tensor(0.0, device=device)
        class_recall = true_positives / actual_positives if actual_positives > 0 else torch.tensor(0.0, device=device)
        class_f1 = 2 * (class_precision * class_recall) / (class_precision + class_recall) if (class_precision + class_recall) > 0 else torch.tensor(0.0, device=device)
        
        precision.append(class_precision)
        recall.append(class_recall)
        f1.append(class_f1)
    
    # Average metrics across classes
    precision = torch.stack(precision).mean()
    recall = torch.stack(recall).mean()
    f1 = torch.stack(f1).mean()
    
    return accuracy, precision, recall, f1

In [180]:
def generate_heatmap(preds, targets, conditions, levels, title):
    preds = preds.cpu().numpy()
    targets = targets.cpu().numpy()
    
    preds = preds.reshape(-1, 25, 3)
    preds = np.argmax(preds, axis=2)
    
    num_conditions = len(conditions)
    num_levels = len(levels)
    
    heatmap_data = np.zeros((num_conditions, num_levels))
    
    for i in range(num_conditions):
        for j in range(num_levels):
            idx = i * num_levels + j
            correct_predictions = (preds[:, idx] == targets[:, idx]).mean()
            heatmap_data[i, j] = correct_predictions
    
    plt.figure(figsize=(12, 8))
    sns.heatmap(heatmap_data, annot=True, cmap='YlGnBu', xticklabels=levels, yticklabels=conditions, fmt='.3f')
    plt.title(title)
    plt.xlabel('Disk Levels')
    plt.ylabel('Conditions')
    plt.tight_layout()
    return plt.gcf()

In [181]:
# Update MetricTracker to include all metrics
class MetricTracker:
    def __init__(self):
        self.metrics = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': [],
            'train_precision': [],
            'val_precision': [],
            'train_recall': [],
            'val_recall': [],
            'train_f1': [],
            'val_f1': [],
            'train_auc': [],
            'val_auc': [],
            'learning_rates': []
        }
    
    def update(self, metric_name, value):
        self.metrics[metric_name].append(value)
    
    def save(self, fold, save_path):
        np.save(f'{save_path}/metrics_fold_{fold}.npy', self.metrics)

In [182]:
def plot_training_metrics(fold_num, plot_path=None):
    """Plot and save all training metrics for a given fold"""
    if plot_path is None:
        plot_path = '/Users/danipopov/Projects/RSNA2024/plots/training_metrics'
    os.makedirs(plot_path, exist_ok=True)
    
    metrics = np.load(f'/Users/danipopov/Projects/RSNA2024/metrics/metrics_fold_{fold_num}.npy', 
                     allow_pickle=True).item()
    
    fig, axes = plt.subplots(3, 2, figsize=(15, 18))
    
    # Loss plot
    axes[0,0].plot(metrics['train_loss'], label='Train')
    axes[0,0].plot(metrics['val_loss'], label='Validation')
    axes[0,0].set_title('Loss')
    axes[0,0].set_xlabel('Epoch')
    axes[0,0].legend()
    axes[0,0].grid(True)
    
    # Accuracy plot
    axes[0,1].plot(metrics['train_acc'], label='Train')
    axes[0,1].plot(metrics['val_acc'], label='Validation')
    axes[0,1].set_title('Accuracy')
    axes[0,1].set_xlabel('Epoch')
    axes[0,1].legend()
    axes[0,1].grid(True)
    
    # Precision plot
    axes[1,0].plot(metrics['train_precision'], label='Train')
    axes[1,0].plot(metrics['val_precision'], label='Validation')
    axes[1,0].set_title('Precision')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].legend()
    axes[1,0].grid(True)
    
    # Recall plot
    axes[1,1].plot(metrics['train_recall'], label='Train')
    axes[1,1].plot(metrics['val_recall'], label='Validation')
    axes[1,1].set_title('Recall')
    axes[1,1].set_xlabel('Epoch')
    axes[1,1].legend()
    axes[1,1].grid(True)
    
    # F1 Score plot
    axes[2,0].plot(metrics['train_f1'], label='Train')
    axes[2,0].plot(metrics['val_f1'], label='Validation')
    axes[2,0].set_title('F1 Score')
    axes[2,0].set_xlabel('Epoch')
    axes[2,0].legend()
    axes[2,0].grid(True)
    
    # Learning Rate plot
    axes[2,1].plot(metrics['learning_rates'], label='Learning Rate')
    axes[2,1].set_title('Learning Rate')
    axes[2,1].set_xlabel('Epoch')
    axes[2,1].legend()
    axes[2,1].grid(True)
    
    plt.suptitle(f'Training Metrics - Fold {fold_num}', y=1.02, fontsize=16)
    plt.tight_layout()
    
    # Save the plot
    plt.savefig(f'{plot_path}/training_metrics_fold_{fold_num}.png', 
                bbox_inches='tight', dpi=300)
    plt.close()

In [189]:
def generate_fold_visualizations(fold_num, base_path=None):
    """Generate and save all visualizations for a specific fold"""
    if base_path is None:
        base_path = '/Users/danipopov/Projects/RSNA2024/plots'
    
    # Load predictions
    pred_path = f'/Users/danipopov/Projects/RSNA2024/models/predictions/fold_{fold_num}_predictions.pt'
    predictions = torch.load(pred_path)
    outputs = predictions['outputs']
    labels = predictions['labels']
    
    # Create directories
    heatmap_path = f'{base_path}/heatmaps'
    roc_path = f'{base_path}/roc_curves'
    os.makedirs(heatmap_path, exist_ok=True)
    os.makedirs(roc_path, exist_ok=True)
    
    # Generate and save heatmap
    conditions = [
        'Spinal Canal Stenosis',
        'Left Neural Foraminal Narrowing',
        'Right Neural Foraminal Narrowing',
        'Left Subarticular Stenosis',
        'Right Subarticular Stenosis'
    ]
    levels = ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
    
    heatmap_fig = generate_heatmap(
        preds=outputs,
        targets=labels,
        conditions=conditions,
        levels=levels,
        title=f'Fold {fold_num} Performance Heatmap'
    )
    heatmap_fig.savefig(f'{heatmap_path}/heatmap_fold_{fold_num}.png')
    plt.close(heatmap_fig)
    
    # Generate ROC curves for each condition
    class_names = ['Normal', 'Mild', 'Severe']
    for i, condition in enumerate(conditions):
        plt.figure(figsize=(10, 8))
        for severity in range(3):
            start_idx = i * 3
            probs = torch.softmax(outputs[:, start_idx:start_idx+3], dim=1)
            true_labels = labels[:, i]
            
            # Calculate ROC curve
            from sklearn.metrics import roc_curve, auc
            fpr, tpr, _ = roc_curve(
                (true_labels == severity).cpu().numpy(),
                probs[:, severity].cpu().numpy()
            )
            roc_auc = auc(fpr, tpr)
            
            plt.plot(
                fpr, 
                tpr, 
                label=f'{class_names[severity]} (AUC = {roc_auc:.2f})'
            )
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curves for {condition} - Fold {fold_num}')
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.savefig(f'{roc_path}/roc_curves_{condition.replace(" ", "_")}_fold_{fold_num}.png')
        plt.close()

In [190]:
def visualize_fold_results(fold_num, base_path=None):
    """Generate and save all visualizations for a specific fold"""
    if base_path is None:
        base_path = '/Users/danipopov/Projects/RSNA2024/plots'
    
    # Create base directory
    os.makedirs(base_path, exist_ok=True)
    
    # Generate training metrics plots
    metrics_path = f'{base_path}/training_metrics'
    os.makedirs(metrics_path, exist_ok=True)
    plot_training_metrics(fold_num, metrics_path)
    
    # Generate heatmap and ROC curves
    generate_fold_visualizations(fold_num, base_path)
    
    print(f"All visualizations for fold {fold_num} have been saved to {base_path}")

### Training 

In [185]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=15, min_delta=0, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [186]:
def train(model, train_df, image_dir, train_transform, test_transform, N_SPLITS, EPOCHS, device):
    """
    Train the model using N-fold cross validation with improved training loop
    Args:
        model: PyTorch model
        train_df: DataFrame with labels
        image_dir: Root directory containing all study folders
        transform: Albumentations transforms
        N_SPLITS: Number of folds for cross validation
        EPOCHS: Number of training epochs
        device: torch device
    """
    fold_results = []
    
    # Get all valid study IDs
    valid_studies = [d for d in os.listdir(image_dir) 
                    if os.path.isdir(os.path.join(image_dir, d))]
    
    weights = torch.tensor([1.0, 2.0, 4.0]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weights).to(device)
    
    # Setup KFold cross validation
    skf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(valid_studies)):
        print('=' * 50)
        print(f"Training Fold {fold+1}/{N_SPLITS}")
        print('=' * 50)

        # Initialize metric tracker for this fold
        tracker = MetricTracker()

        # Create datasets
        train_dataset = SpineDataset(
            df=train_df,
            image_dir=image_dir,  
            transform= train_transform
        )
        val_dataset = SpineDataset(
            df=train_df,
            image_dir=image_dir,
            transform=test_transform  
        )
        
        # Set train/val splits
        train_dataset.valid_studies = [valid_studies[i] for i in train_idx]
        val_dataset.valid_studies = [valid_studies[i] for i in val_idx]
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=16,
            shuffle=True, 
            num_workers=0, 
            pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=16, 
            shuffle=False,
            num_workers=0, 
            pin_memory=True
        )
        
        # Initialize model for this fold
        model = model.to(device)
        
        # Initialize new layers properly
        def init_weights(m):
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
        
        model.conv1.apply(init_weights)
        model.fc.apply(init_weights)
        
        # Optimizer with lower learning rate
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=0.0001,
            weight_decay=0.01
        )
        current_lr = optimizer.param_groups[0]['lr']

        # Cosine annealing scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,
            T_mult=2
        )
        
        # Early stopping
        early_stopping = EarlyStopping(
            patience=15,
            min_delta=1e-4,
            verbose=True
        )
        
        best_val_loss = float('inf')
        scaler = torch.cuda.amp.GradScaler()  # For mixed precision training
        
        # Inside the training function, modify the progress bar sections:

        # Training phase progress bar
        for epoch in range(EPOCHS):
            print(f"\nEpoch {epoch+1}/{EPOCHS}")
            
            # Training phase
            model.train()
            total_train_loss = 0
            train_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0}
            train_steps = len(train_loader)

            # Modified progress bar
            with tqdm(total=train_steps, desc=f'Training Epoch {epoch+1}', 
                    bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') as train_pbar:
                
                for idx, (images, labels) in enumerate(train_loader):
                    images = images.to(device)
                    labels = labels.to(device)
                    
                    optimizer.zero_grad()
                    
                    with torch.cuda.amp.autocast():
                        outputs = model(images)
                        loss = 0
                        batch_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0}

                        for col in range(labels.shape[1]):
                            pred = outputs[:, col*3: col*3+3]
                            ground_truth = labels[:, col]
                            loss += criterion(pred, ground_truth)
                        
                            # Calculate metrics for this column
                            acc, prec, rec, f1 = calculate_metrics(pred, ground_truth)
                            batch_metrics['accuracy'] += acc.item()
                            batch_metrics['precision'] += prec.item()
                            batch_metrics['recall'] += rec.item()
                            batch_metrics['f1'] += f1.item()
                        
                        # Average metrics across columns
                        for k in batch_metrics:
                            batch_metrics[k] /= labels.shape[1]
                            train_metrics[k] += batch_metrics[k]

                        loss /= labels.shape[1]
                    
                    scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    
                    total_train_loss += loss.item()
                    train_pbar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'acc': f'{batch_metrics["accuracy"]:.4f}'
                    })
                    train_pbar.update(1)
            
            # Average training metrics
            avg_train_loss = total_train_loss / train_steps
            for k in train_metrics:
                train_metrics[k] /= train_steps
            
            # Validation phase
            model.eval()
            total_val_loss = 0
            val_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0}
            val_steps = len(val_loader)
            
            all_val_outputs = []
            all_val_labels = []

            # Modified validation progress bar
            with tqdm(total=val_steps, desc=f'Validation Epoch {epoch+1}', 
                    bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') as val_pbar:
                
                with torch.no_grad():
                    for idx, (images, labels) in enumerate(val_loader):
                        images = images.to(device)
                        labels = labels.to(device)
                        
                        with torch.cuda.amp.autocast():
                            outputs = model(images)
                            loss = 0
                            batch_metrics = {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0}
                            
                            for col in range(labels.shape[1]):
                                pred = outputs[:, col*3: col*3+3]
                                ground_truth = labels[:, col]
                                loss += criterion(pred, ground_truth)
                            
                                # Calculate metrics for this column
                                acc, prec, rec, f1 = calculate_metrics(pred, ground_truth)
                                batch_metrics['accuracy'] += acc.item()
                                batch_metrics['precision'] += prec.item()
                                batch_metrics['recall'] += rec.item()
                                batch_metrics['f1'] += f1.item()
                            
                            for k in batch_metrics:
                                batch_metrics[k] /= labels.shape[1]
                                val_metrics[k] += batch_metrics[k]

                            loss /= labels.shape[1]
                            total_val_loss += loss.item()
                        
                        # Store predictions and labels for ROC-AUC and heatmap
                        all_val_outputs.append(outputs)
                        all_val_labels.append(labels)

                        val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                        val_pbar.update(1)
            
            avg_val_loss = total_val_loss / val_steps
            val_pbar.close()
            
            # Average validation metrics
            avg_val_loss = total_val_loss / val_steps
            for k in val_metrics:
                val_metrics[k] /= val_steps

            # Update metric tracker
            tracker.update('train_loss', avg_train_loss)
            tracker.update('val_loss', avg_val_loss)
            tracker.update('train_acc', train_metrics['accuracy'])
            tracker.update('val_acc', val_metrics['accuracy'])
            tracker.update('train_precision', train_metrics['precision'])
            tracker.update('val_precision', val_metrics['precision'])
            tracker.update('train_recall', train_metrics['recall'])
            tracker.update('val_recall', val_metrics['recall'])
            tracker.update('train_f1', train_metrics['f1'])
            tracker.update('val_f1', val_metrics['f1'])
            tracker.update('learning_rates', current_lr)

            # Learning rate scheduling (move this before metric tracking)
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']  # Update current_lr

            # Save metrics
            save_path = '/Users/danipopov/Projects/RSNA2024/metrics'
            os.makedirs(save_path, exist_ok=True)
            tracker.save(fold+1, save_path)
            
            # Generate and save heatmap at end of epoch
            if epoch == EPOCHS - 1 or early_stopping.early_stop:
                all_outputs = torch.cat(all_val_outputs)
                all_labels = torch.cat(all_val_labels)
                
                # Save predictions and labels for later visualization
                save_path = f'/Users/danipopov/Projects/RSNA2024/models/predictions'
                os.makedirs(save_path, exist_ok=True)
                torch.save({
                    'outputs': all_outputs,
                    'labels': all_labels
                }, f'{save_path}/fold_{fold+1}_predictions.pt')                
            
            # Print epoch results
            print(f'\nEpoch {epoch+1}:')
            print(f'Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
            print(f'Train Acc: {train_metrics["accuracy"]:.4f}, Val Acc: {val_metrics["accuracy"]:.4f}')
            print(f'Train Precision: {train_metrics["precision"]:.4f}, Val Precision: {val_metrics["precision"]:.4f}')
            print(f'Train Recall: {train_metrics["recall"]:.4f}, Val Recall: {val_metrics["recall"]:.4f}')
            print(f'Train F1: {train_metrics["f1"]:.4f}, Val F1: {val_metrics["f1"]:.4f}')
            print(f'Learning Rate: {current_lr:.6f}')
             
            # Early stopping check
            early_stopping(avg_val_loss)
            if early_stopping.early_stop:
                print("Early stopping triggered")
                break
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_val_loss,
                    'metrics': {
                        'train_metrics': train_metrics,
                        'val_metrics': val_metrics
                    },
                }, f'/Users/danipopov/Projects/RSNA2024/models/fold_models/best_model_fold_{fold+1}.pth')
                print(f'Saved new best model with validation loss: {best_val_loss:.4f}')
        
        # Save final metrics for this fold
        save_path = '/Users/danipopov/Projects/RSNA2024/metrics'
        os.makedirs(save_path, exist_ok=True)
        tracker.save(fold+1, save_path)
        
        fold_results.append(best_val_loss)
        print(f'Fold {fold+1} Best Loss: {best_val_loss:.4f}')
        
        # Print final metrics for this fold
        print(f"\nFinal Metrics for Fold {fold+1}:")
        print(f"Validation Accuracy: {val_metrics['accuracy']:.4f}")
        print(f"Validation Precision: {val_metrics['precision']:.4f}")
        print(f"Validation Recall: {val_metrics['recall']:.4f}")
        print(f"Validation F1: {val_metrics['f1']:.4f}")
    
    print('\nTraining completed!')
    print(f'Average Best Loss across folds: {sum(fold_results)/len(fold_results):.4f}')
    
    return fold_results

In [192]:
# Train the model
results = train(
    model=model,
    train_df=train_df_copy,
    image_dir='/Users/danipopov/Projects/RSNA2024/data/images_dataset',
    train_transform=train_transform,
    test_transform=test_transform,
    N_SPLITS=3,
    EPOCHS=25,
    device='mps'
)

In [193]:
# Generate visualizations for each fold
for fold in range(3):
    visualize_fold_results(fold + 1)

### Analyzing the Results 🔍 📈📊

Great! Now that we've built and trained our model, let's evaluate its performance on the test dataset.

Before we proceed with testing, I want to address how we handled the data imbalance problem. We implemented two key strategies:
1. Used a weighted loss function to give more importance to underrepresented classes
2. Applied k-fold cross validation to make the model more robust and reduce overfitting

For comprehensive analysis, we saved several metrics during training:
1. Training and validation metrics for each epoch
2. Heatmaps showing performance across different spinal conditions and levels
3. ROC curves analyzing performance for each spinal degenerative condition

Let's examine these results in detail for each fold.

In [206]:
def display_result_image(image_path, title, figsize=(15, 15)):
    plt.figure(figsize=figsize)
    img = mpimg.imread(image_path)
    plt.imshow(img)
    plt.axis('off')
    plt.title(title)
    plt.show()


In [221]:
def display_fold_roc_curves(fold_number, image_paths, figsize=(14, 10)):
    n_images = len(image_paths)
    rows = (n_images + 2) // 3  
    cols = min(3, n_images)     
    
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    fig.suptitle(f'ROC Curves for Fold {fold_number}', fontsize=16)
    
    # Flatten axes array if we have multiple rows
    if rows > 1:
        axes = axes.flatten()
    
    # Plot each ROC curve
    for idx, img_path in enumerate(image_paths):
        if Path(img_path).exists():
            img = mpimg.imread(img_path)
            if rows == 1:
                ax = axes[idx] if cols > 1 else axes
            else:
                ax = axes[idx]
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f'Condition {idx+1}')
    
    # Remove empty subplots if any
    if rows > 1:
        for idx in range(len(image_paths), len(axes)):
            fig.delaxes(axes[idx])
    
    plt.tight_layout()
    plt.show()

#### Fold 1

In [210]:
# For training metrics subplot
display_result_image('plots/training_metrics/training_metrics_fold_1.png', 'Training Metrics for Fold 1')

Training Metrics Analysis - Fold 1 📈

The training results show several important trends across 25 epochs:

Loss
- The model shows strong convergence with the loss decreasing significantly in the first 5 epochs
- Training and validation losses align well, starting from ~0.6 and stabilizing around 0.42
- The close alignment between training and validation loss suggests no significant overfitting

Accuracy
- Validation accuracy quickly reaches ~77% and remains stable
- Training accuracy gradually improves to match validation accuracy
- Final accuracy for both training and validation converges at approximately 77%

Precision and Recall
- Precision improves from 31% to 38% over the training period
- Recall shows similar improvement, reaching approximately 41%
- Both metrics show steady improvement without significant oscillation
- The small gap between training and validation suggests good generalization

F1 Score
- F1 score, which balances precision and recall, improves from 31% to 38%
- The validation F1 score closely follows the training curve
- The steady increase indicates consistent improvement in overall model performance

Learning Rate
- Implements a cyclical learning rate strategy
- Peaks at 0.0001 with controlled decreases
- The cycling pattern helps avoid local minima and promotes better convergence

Overall Assessment
The model demonstrates stable training with good convergence and no significant overfitting. While the accuracy metrics are promising at 77%, the precision and recall metrics (around 38-41%) suggest room for improvement in handling class imbalance. The close alignment between training and validation metrics across all measures indicates good generalization capabilities.

In [209]:
# For heatmap subplot
display_result_image('plots/heatmaps/heatmap_fold_1.png', 'Heatmap for Fold 1')

Performance Heatmap Analysis - Fold 1 📊

The heatmap visualizes the model's performance across different spinal conditions and disk levels, revealing several key insights:

Spinal Canal Stenosis
- Shows consistently strong performance across all disk levels
- Highest accuracy at L5/S1 (0.977) and L1/L2 (0.956)
- Even at its lowest performance (L4/L5: 0.768), maintains good reliability
- Overall best-performing condition among all pathologies

Neural Foraminal Narrowing (Left & Right)
- Both sides show similar performance patterns
- Excellent performance in upper spine levels (L1/L2: ~0.967)
- Gradual decrease in accuracy moving down the spine
- Lower performance at L5/S1 (Left: 0.612, Right: 0.627)
- Right side slightly outperforms left side in most levels

Subarticular Stenosis (Left & Right)
- Shows moderate to good performance
- Strongest at upper levels (L1/L2: ~0.83)
- Lowest performance at L4/L5 (Left: 0.550, Right: 0.522)
- Slight improvement at L5/S1 compared to L4/L5
- Generally lower accuracy compared to other conditions

General Observations
1. Performance tends to be strongest at upper spine levels (L1/L2, L2/L3)
2. Most conditions show decreased accuracy at L4/L5
3. Bilateral conditions (left/right) show similar patterns
4. L4/L5 level consistently shows the lowest performance across all conditions
5. Model performs best with Spinal Canal Stenosis and struggles more with Subarticular Stenosis

This pattern suggests that the model is more reliable for central spinal conditions compared to lateral pathologies, and performs better in upper spinal regions compared to lower ones.

In [222]:
# For ROC curves subplot
roc_paths = [
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Neural_Foraminal_Narrowing_fold_1.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Subarticular_Stenosis_fold_1.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Neural_Foraminal_Narrowing_fold_1.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Subarticular_Stenosis_fold_1.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Spinal_Canal_Stenosis_fold_1.png'
]
display_fold_roc_curves(1, roc_paths)

ROC Curve Analysis - Fold 1📊

The ROC curves demonstrate the model's ability to distinguish between different severity levels across various spinal conditions. Here's a detailed analysis:

Left Neural Foraminal Narrowing
- Excellent performance with high AUC scores
- Mild cases: AUC = 0.92 (best performing)
- Severe cases: AUC = 0.90
- Shows strong early detection capability with steep initial curve rise

Right Neural Foraminal Narrowing
- Similar but slightly lower performance compared to left side
- Mild cases: AUC = 0.89
- Severe cases: AUC = 0.84
- Good discrimination ability for both severity levels

Left Subarticular Stenosis
- Good performance with balanced detection
- Mild cases: AUC = 0.88
- Severe cases: AUC = 0.80
- Shows better detection of mild cases compared to severe

Right Subarticular Stenosis
- Lowest performing among all conditions
- Mild cases: AUC = 0.80
- Severe cases: AUC = 0.77
- More gradual curve progression indicating lower confidence in predictions

Spinal Canal Stenosis
- Strong consistent performance
- Mild cases: AUC = 0.89
- Severe cases: AUC = 0.88
- Nearly identical performance for both severity levels
- Smooth curve progression indicating stable predictions

Key Observations
1. All conditions show AUC scores above 0.77, indicating good to excellent classification performance
2. Left-sided conditions generally perform better than right-sided ones
3. Neural Foraminal Narrowing shows the best overall performance
4. Mild cases are generally detected with higher accuracy than severe cases
5. Subarticular Stenosis shows the most room for improvement, particularly on the right side

The model demonstrates robust classification ability across all conditions, with particularly strong performance in detecting Neural Foraminal Narrowing. The consistently high AUC scores suggest reliable clinical applicability, though there's some variation in performance between different conditions and severity levels.

#### Fold 2

In [224]:
# For training metrics subplot
display_result_image('plots/training_metrics/training_metrics_fold_2.png', 'Training Metrics for Fold 2')

Training Metrics Analysis - Fold 2 📈

The training results show several important trends across 25 epochs:

Loss
- Similar to Fold 1, strong initial convergence in first 5 epochs
- Training and validation losses start from ~0.6 and stabilize around 0.40-0.42
- Slight divergence between training and validation loss after epoch 15, but difference remains small
- Training loss continues to decrease gradually while validation loss plateaus

Accuracy
- Validation accuracy reaches ~77% quickly and shows more fluctuation than Fold 1
- Training accuracy improves steadily to ~77%
- Final accuracy converges around 77% for both, matching Fold 1's performance
- More variance in validation accuracy compared to Fold 1

Precision and Recall
- Precision improves from 30% to 38% over the training period
- Recall increases from 33% to 41-42%
- More fluctuation in validation metrics compared to Fold 1
- Training metrics show steadier improvement than validation
- Final gap between training and validation slightly larger than in Fold 1

F1 Score
- F1 score improves from 30% to 38%
- More variance in validation F1 score compared to Fold 1
- Training F1 score shows steady improvement
- Final performance matches Fold 1 despite more fluctuation

Learning Rate
- Identical cyclical learning rate strategy as Fold 1
- Peaks at 0.0001 with controlled decreases
- Maintains consistent learning pattern across folds

Overall Assessment
Fold 2 shows similar overall performance to Fold 1 but with more fluctuation in validation metrics. The model still demonstrates good convergence, reaching comparable final metrics (77% accuracy, 38% precision/F1, 41% recall). The increased variance in validation metrics suggests this fold might have encountered more challenging examples, but the final performance remains stable. The slight divergence between training and validation metrics after epoch 15 bears monitoring but doesn't indicate severe overfitting.

In [225]:
# For heatmap subplot
display_result_image('plots/heatmaps/heatmap_fold_2.png', 'Heatmap for Fold 2')

Performance Heatmap Analysis - Fold 2 📊

The heatmap visualizes the model's performance across different spinal conditions and disk levels, revealing several key insights:

Spinal Canal Stenosis
- Maintains excellent performance across all disk levels
- Highest accuracy at L1/L2 (0.957) and L5/S1 (0.967)
- Lowest at L4/L5 (0.720), but still acceptable performance
- Shows very similar pattern to Fold 1, confirming consistency
- Remains the most reliable condition for detection

Neural Foraminal Narrowing (Left & Right)
- Both sides demonstrate strong performance patterns
- Left side shows exceptional performance at upper levels (L1/L2: 0.977)
- Right side slightly lower but still excellent (L1/L2: 0.952)
- Performance decreases in lower spine levels
- L5/S1 performance (Left: 0.613, Right: 0.638) matches Fold 1 pattern
- More balanced performance between left and right compared to Fold 1

Subarticular Stenosis (Left & Right)
- Good performance in upper spine levels
- Left side peaks at L1/L2 (0.863)
- Right side shows similar pattern (L1/L2: 0.849)
- Significant drop at L4/L5 (Left: 0.464, Right: 0.484)
- Recovers slightly at L5/S1 (Left: 0.673, Right: 0.688)
- Performance slightly better than Fold 1 for most levels

General Observations
1. Upper spine levels (L1/L2, L2/L3) consistently show strongest performance
2. L4/L5 remains the most challenging level across all conditions
3. Bilateral conditions show more balanced performance than in Fold 1
4. Overall pattern matches Fold 1, suggesting model stability
5. Slight improvement in Subarticular Stenosis detection compared to Fold 1

This fold confirms the model's strengths and challenges seen in Fold 1, with marginally better performance in some areas. The consistency between folds suggests the model has learned robust features for diagnosis, particularly for central spinal conditions and upper spinal regions.

In [226]:
# For ROC curves subplot
roc_paths = [
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Neural_Foraminal_Narrowing_fold_2.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Subarticular_Stenosis_fold_2.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Neural_Foraminal_Narrowing_fold_2.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Subarticular_Stenosis_fold_2.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Spinal_Canal_Stenosis_fold_2.png'
]
display_fold_roc_curves(2, roc_paths)

ROC Curve Analysis - Fold 2 📊

The ROC curves demonstrate the model's performance across different conditions and severity levels, showing some variations from Fold 1:

Left Neural Foraminal Narrowing
- Outstanding performance with improved AUC scores
- Mild cases: AUC = 0.93 (slight improvement from Fold 1)
- Severe cases: AUC = 0.91 (better than Fold 1)
- Very steep initial curve rise indicating excellent early detection
- More consistent performance between mild and severe cases

Right Neural Foraminal Narrowing
- Strong performance with slight improvement
- Mild cases: AUC = 0.90 (better than Fold 1)
- Severe cases: AUC = 0.83 (similar to Fold 1)
- Clear separation between mild and severe detection capabilities
- Good early detection rate for mild cases

Left Subarticular Stenosis
- Moderate to good performance
- Mild cases: AUC = 0.85 (slightly lower than Fold 1)
- Severe cases: AUC = 0.72 (lower than Fold 1)
- Larger gap between mild and severe detection
- More gradual curve progression compared to Neural Foraminal Narrowing

Right Subarticular Stenosis
- Shows similar challenges to Fold 1
- Mild cases: AUC = 0.73 (lower than Fold 1)
- Severe cases: AUC = 0.77 (matching Fold 1)
- More stepwise progression in the curves
- Unusual pattern where severe cases perform better than mild

Spinal Canal Stenosis
- Excellent and improved performance
- Mild cases: AUC = 0.92 (better than Fold 1)
- Severe cases: AUC = 0.93 (significant improvement)
- Very consistent performance between severity levels
- Sharp initial rise indicating strong detection confidence

Key Observations
1. Overall AUC scores range from 0.72 to 0.93, showing good to excellent performance
2. Neural conditions show more consistent performance across folds
3. Spinal Canal Stenosis shows marked improvement in Fold 2
4. Subarticular Stenosis remains the most challenging condition
5. Better balance between mild and severe cases in most conditions

The model maintains its strong performance in Fold 2, with some conditions showing improvement while others remain challenging. The consistency in Neural Foraminal Narrowing and Spinal Canal Stenosis detection across folds suggests robust learning for these conditions. The variation in Subarticular Stenosis performance between folds indicates this remains an area for potential improvement.

#### Fold 3

In [227]:
# For training metrics subplot
display_result_image('plots/training_metrics/training_metrics_fold_3.png', 'Training Metrics for Fold 3')

Training Metrics Analysis - Fold 3 📈

The training results show several important trends across 25 epochs:

Loss
- Strong initial convergence in first 5 epochs, similar to previous folds
- Training loss starts at ~0.575 and stabilizes around 0.39
- Validation loss plateaus around 0.43
- Slight divergence between training and validation loss after epoch 15
- More stable validation loss compared to Fold 2

Accuracy
- Both training and validation accuracy reach ~77%
- More fluctuation in validation accuracy compared to previous folds
- Training accuracy shows steady improvement, reaching 77.8%
- Validation accuracy shows more variance but maintains good performance
- Final convergence similar to Folds 1 and 2

Precision and Recall
- Precision improves from 32% to 39.5% for training
- Validation precision stabilizes around 38%
- Recall shows improvement from 35% to 43% for training
- Validation recall maintains around 40-41%
- Larger gap between training and validation metrics in later epochs

F1 Score
- Training F1 score shows steady improvement from 32% to 40%
- Validation F1 score stabilizes around 38%
- More pronounced divergence between training and validation after epoch 15
- Final performance comparable to previous folds

Learning Rate
- Maintains same cyclical learning rate strategy as previous folds
- Peaks at 0.0001 with controlled decreases
- Consistent pattern across all folds

Overall Assessment
Fold 3 shows similar overall performance to previous folds but with some notable differences:
1. Slightly higher final training metrics
2. More pronounced gap between training and validation performance
3. More fluctuation in validation metrics
4. Strong initial convergence maintained
5. Potential signs of mild overfitting in later epochs

While the model maintains good performance, the increased divergence between training and validation metrics suggests this fold might benefit from additional regularization or earlier stopping.

In [228]:
# For heatmap subplot
display_result_image('plots/heatmaps/heatmap_fold_3.png', 'Heatmap for Fold 3')

Performance Heatmap Analysis - Fold 3 📊

The heatmap for Fold 3 reveals performance patterns across spinal conditions and disk levels:

Spinal Canal Stenosis
- Maintains excellent performance consistent with previous folds
- Highest accuracy at L1/L2 (0.961) and L5/S1 (0.965)
- Notable drop at L4/L5 (0.701), the lowest performance point
- Strong performance at upper levels (L2/L3: 0.895)
- Pattern matches previous folds with slight variations

Neural Foraminal Narrowing (Left & Right)
- Left side shows exceptional performance at upper levels (L1/L2: 0.967)
- Right side maintains strong performance (L1/L2: 0.959)
- Both sides show gradual decrease moving down the spine
- L4/L5 performance (Left: 0.630, Right: 0.653) shows typical drop
- L5/S1 remains challenging (Left: 0.632, Right: 0.627)

Subarticular Stenosis (Left & Right)
- Good performance in upper spine regions
- Left side: L1/L2 (0.863), Right side: L1/L2 (0.870)
- Consistent performance between left and right at L2/L3 (~0.80)
- Significant drop at L4/L5 (Left: 0.503, Right: 0.549)
- Moderate recovery at L5/S1 (Left: 0.648, Right: 0.679)

General Observations
1. Upper spine levels (L1/L2, L2/L3) consistently show strongest performance
2. L4/L5 remains the most challenging level across all conditions
3. Performance pattern follows similar trends to previous folds
4. Right-sided conditions show slightly better performance at lower levels
5. Vertical gradient pattern clearly visible, showing decreasing performance down the spine

Comparison with Previous Folds
- Overall performance aligns well with Folds 1 and 2
- Slightly better consistency between left and right conditions
- L4/L5 performance remains a consistent challenge
- Upper spine level accuracy maintains high standards
- Subarticular Stenosis shows similar patterns to previous folds

The heatmap confirms the model's reliable performance patterns across different folds, suggesting robust learning of spinal pathology features, particularly in upper spinal regions.

In [229]:
# For ROC curves subplot
roc_paths = [
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Neural_Foraminal_Narrowing_fold_3.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Left_Subarticular_Stenosis_fold_3.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Neural_Foraminal_Narrowing_fold_3.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Right_Subarticular_Stenosis_fold_3.png',
    '/Users/danipopov/Projects/RSNA2024/plots/roc_curves/roc_curves_Spinal_Canal_Stenosis_fold_3.png'
]
display_fold_roc_curves(3, roc_paths)

ROC Curve Analysis - Fold 3 📊

The ROC curves illustrate the model's ability to differentiate between various severity levels across different spinal conditions. Here’s a detailed analysis:

Left Neural Foraminal Narrowing
- Excellent performance with high AUC scores
- Mild cases: AUC = 0.90
- Severe cases: AUC = 0.89
- Strong early detection capability with a steep initial curve rise
- Consistent performance across severity levels

Left Subarticular Stenosis
- Moderate performance with some challenges
- Mild cases: AUC = 0.84
- Severe cases: AUC = 0.73
- The curve shows a gradual increase, indicating some difficulty in distinguishing severe cases
- Larger gap between mild and severe detection capabilities

Right Neural Foraminal Narrowing
- Strong performance similar to the left side
- Mild cases: AUC = 0.90
- Severe cases: AUC = 0.83
- Good discrimination ability for both severity levels
- The curve indicates reliable detection, especially for mild cases

Right Subarticular Stenosis
- Lower performance compared to other conditions
- Mild cases: AUC = 0.76
- Severe cases: AUC = 0.73
- The curve shows a more gradual progression, indicating lower confidence in predictions
- More challenging to distinguish between severity levels

Spinal Canal Stenosis
- Consistent and strong performance
- Mild cases: AUC = 0.89
- Severe cases: AUC = 0.88
- Smooth curve progression indicating stable predictions
- High confidence in distinguishing between severity levels

Key Observations
1. AUC scores range from 0.73 to 0.90, indicating good to excellent classification performance.
2. Neural Foraminal Narrowing shows the best overall performance across all severity levels.
3. Subarticular Stenosis remains the most challenging condition, particularly for severe cases.
4. The model demonstrates robust classification ability, especially for mild cases.
5. Consistency in performance across folds suggests reliable clinical applicability.

Overall, the model maintains strong performance in Fold 3, with some conditions showing improvement while others highlight areas for potential enhancement. The consistent high AUC scores across most conditions indicate effective learning and generalization capabilities.

### Conclusions 💡

In this section, we recap the steps taken to build and train the model:

1. We began with the instance images and coordinates provided by the radiologist in the `train_label_coordinates.csv` file to build and train our YOLOv8 models for detecting disk levels and sides (left or right).
2. We then used the YOLO models to create a new dataset that iterates over all patient images, attempting to detect disk levels and sides. From the images where the models successfully detected these features, we selected 5 images per plane. For each image, we extracted the bounding boxes (64x64) of the disk levels or sides to obtain the Regions of Interest (ROIs).
3. We built a dataset class for training our model. For each study ID, we passed 60 images of ROIs across all planes: 50 images for the sagittal plane and 10 images for the axial plane. For each sagittal image, we extracted 5 ROIs of disk levels, and for the axial plane, we extracted 2 ROIs of disk sides.
4. We chose the ResNet50 architecture as the base model and added a new head for the classification task.
5. We trained the model for 25 epochs across 3 folds and obtained the following results:
    - **Loss**: The model shows strong initial convergence in the first 5 epochs, with training and validation losses starting from ~0.6 and stabilizing around 0.40-0.42.
    - **Accuracy**: The model quickly reaches ~77% accuracy, showing more fluctuation than Fold 1.
    - **Precision and Recall**: Precision improves from 30% to 38% over the training period, while recall increases from 33% to 41-42%.
    - **F1 Score**: The F1 score improves from 30% to 38%.
    - **Learning Rate**: The model maintains the same cyclical learning rate strategy as in Fold 1.

We notice that this approach yielded good results because we focused on helping the model learn the specific disks rather than the entire spine image, we notice that acroos all folds the model shows good and similar performance. We will also evaluate our model on the Kaggle test set to see how it performs against the competition metrics.