In [1]:
import os
from glob import glob
import pandas as pd
import pydicom
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
rsna_dir = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/"
all_image_dirs = glob(f"{rsna_dir}/train_images/**/*")

# **Part 1.1, EDA:**

Note: study_id: 3637444890 series_id: 3892989905 (shows neck scan) and 3951475160 (Spinal Canal Stenosis Diagnosis) has an error

In [3]:
train_series_descriptions = pd.read_csv(f'{rsna_dir}train_series_descriptions.csv')
train_label_coordinates = pd.read_csv(f'{rsna_dir}train_label_coordinates.csv')
train_y = pd.read_csv(f'{rsna_dir}train.csv')

# Removing nan study ids
non_nans = list(train_y.dropna()['study_id'].unique())
train_series_descriptions = train_series_descriptions[train_series_descriptions['study_id'].isin(non_nans)]
train_label_coordinates = train_label_coordinates[train_label_coordinates['study_id'].isin(non_nans)]
train_y = train_y[train_y['study_id'].isin(non_nans)]

# Removing anomaly
train_series_descriptions = train_series_descriptions[train_series_descriptions['study_id']!=3637444890]
train_label_coordinates = train_label_coordinates[train_label_coordinates['study_id']!=3637444890]
train_y = train_y[train_y['study_id']!=3637444890]

train_series_descriptions_val_counts = train_series_descriptions['study_id'].value_counts()

# print("Brief explanation for each csv:")
# print(f"- train_series_descriptions.csv contains study/patient ids, the series numbers,\n  and the MRI image description (denoting the direction of scanning).")
# print("  Each row is for one MRI image")
# print("- train_label_coordinates.csv contains study/patient ids, the series numbers,\n  the specific instance number (denoting the nth slice in the 3D MRI image),")
# print("  spine levels (l1/l2, etc), and xy coordinates for the condition.")
# print("  Each row is for each condition+level diagnosis.")
# print("- train.csv contains a column with the study id's and 25 columns for each condition+spine level\n  whose entries are condition severities for predicting.")
# print("  Each row is for one patient\n")
# print(f"1. Unique MRI images names:\n {train_series_descriptions['series_description'].unique()} \n")
# print(f"2. Value counts for each image name:\n {train_series_descriptions['series_description'].value_counts()}\n")
# print(f"3. Value counts for each patient id in train_series_descriptions:\n {train_series_descriptions_val_counts}")
# print(f"Number of patients with more or less than 3 MRIs: {len(train_series_descriptions_val_counts[train_series_descriptions_val_counts!=3])}\n")
# print(f"4. Num patients: {len(train_series_descriptions['study_id'].unique())}\n")
# print(f"5. Value counts for each condition:\n {train_label_coordinates['condition'].value_counts()}\n")
# print(f"6. Value counts for each condition by level:\n {train_label_coordinates[['condition','level']].value_counts().sort_index()}\n")
# temp_df = train_series_descriptions[['study_id','series_description']].value_counts().sort_index()
# sagittal_df = temp_df[temp_df.index.get_level_values(1).isin(['Sagittal T1'])]
# axial_df = temp_df[temp_df.index.get_level_values(1).isin(['Axial T2'])]
# stir_df = temp_df[temp_df.index.get_level_values(1).isin(['Sagittal T2/STIR'])]
# print(f"7. Study IDs with more than 1 Sagittal T1 Scan:\n {sagittal_df[sagittal_df>1]}\n")
# print(f"8. Study IDs with more than 1 Axial T2 Scan:\n {axial_df[axial_df>1]}\n" )
# temp_df = train_label_coordinates
# temp_df['series_description'] = temp_df['series_id'].map(train_series_descriptions.drop(columns=['study_id']).set_index('series_id').to_dict()['series_description'])
# print(f"9. Distribution of scan type and condition diagnosis:\n {temp_df[['series_description','condition']].value_counts().sort_index()}\n")
# file_num_df = train_series_descriptions
# file_num_df = file_num_df.reset_index(drop=True)
# file_nums = [len(glob(f"{rsna_dir}train_images/{row['study_id']}/{row['series_id']}/**")) for idx,row in file_num_df.iterrows()]
# file_num_df = pd.concat([file_num_df,pd.Series(file_nums)],axis=1)
# file_num_df.columns = ['study_id','series_id','series_description','file_len']
# print(f"10. Number of image slices by MRI type (max,min,mean):")
# print(f"{file_num_df.groupby('series_description').max()['file_len']}\n")
# print(f"{file_num_df.groupby('series_description').min()['file_len']}\n")
# print(f"{file_num_df.groupby('series_description').mean()['file_len']}\n")


Notes: Duplicate AxialT2s, and duplicate Sagittal T1s need to be filtered out. This will be done by selecting the series_id with the most image slices (so highest resolution).

# Part 1.2, Filtering the duplicates:

In [4]:
file_num_df = train_series_descriptions
file_num_df = file_num_df.reset_index(drop=True)
file_nums = [len(glob(f"{rsna_dir}train_images/{row['study_id']}/{row['series_id']}/**")) for idx,row in file_num_df.iterrows()]
file_num_df = pd.concat([file_num_df,pd.Series(file_nums)],axis=1)
file_num_df.columns = ['study_id','series_id','series_description','file_len']
file_num_df = file_num_df[file_num_df.index.isin(file_num_df.groupby(['study_id','series_description'])['file_len'].idxmax().unique())]
# file_num_df['series_id'].value_counts()[file_num_df['series_id'].value_counts()>1] --> Shows that each series id is unique
study_ids = file_num_df['study_id'].unique()
series_ids = file_num_df['series_id'].unique()
train_series_descriptions = train_series_descriptions[train_series_descriptions['study_id'].isin(study_ids) & train_series_descriptions['series_id'].isin(series_ids)]
train_label_coordinates = train_label_coordinates[train_label_coordinates['study_id'].isin(study_ids) & train_label_coordinates['series_id'].isin(series_ids)]
train_y = train_y[train_y['study_id'].isin(study_ids)]

In [5]:
# print("Checking the new dfs:")
# print(f"1. Unique MRI images names:\n {train_series_descriptions['series_description'].unique()} \n")
# print(f"2. Value counts for each image name:\n {train_series_descriptions['series_description'].value_counts()}\n")
# print(f"3. Value counts for each patient id in train_series_descriptions:")
# print(train_series_descriptions["study_id"].value_counts())
# print(f"4. Num patients: {len(train_series_descriptions['study_id'].unique())}\n")
# print(f"5. Value counts for each condition:\n {train_label_coordinates['condition'].value_counts()}\n")
# print(f"6. Value counts for each condition by level:\n {train_label_coordinates[['condition','level']].value_counts().sort_index()}\n")
# temp_df = train_series_descriptions[['study_id','series_description']].value_counts().sort_index()
# sagittal_df = temp_df[temp_df.index.get_level_values(1).isin(['Sagittal T1'])]
# axial_df = temp_df[temp_df.index.get_level_values(1).isin(['Axial T2'])]
# stir_df = temp_df[temp_df.index.get_level_values(1).isin(['Sagittal T2/STIR'])]
# print(f"7. Study IDs with more than 1 Sagittal T1 Scan:\n {sagittal_df[sagittal_df>1]}\n")
# print(f"8. Study IDs with more than 1 Axial T2 Scan:\n {axial_df[axial_df>1]}\n" )
# temp_df = train_label_coordinates
# temp_df['series_description'] = temp_df['series_id'].map(train_series_descriptions.drop(columns=['study_id']).set_index('series_id').to_dict()['series_description'])
# print(f"9. Distribution of scan type and condition diagnosis:\n {temp_df[['series_description','condition']].value_counts().sort_index()}\n")
# file_num_df = train_series_descriptions
# file_num_df = file_num_df.reset_index(drop=True)
# file_nums = [len(glob(f"{rsna_dir}train_images/{row['study_id']}/{row['series_id']}/**")) for idx,row in file_num_df.iterrows()]
# file_num_df = pd.concat([file_num_df,pd.Series(file_nums)],axis=1)
# file_num_df.columns = ['study_id','series_id','series_description','file_len']
# print(f"10. Number of image slices by MRI type (max,min,mean,median):")
# print(f"{file_num_df.groupby('series_description').max()['file_len']}\n")
# print(f"{file_num_df.groupby('series_description').min()['file_len']}\n")
# print(f"{file_num_df.groupby('series_description').mean()['file_len']}\n")
# print(f"{file_num_df.groupby('series_description').median()['file_len']}\n")

# Part 1.3, Preparing Data:

In [6]:
def saving_pngs(dcm_files,save_dir):
    for j, dcm_file in enumerate(dcm_files):
        path = f"{save_dir}/"
        if not os.path.exists(path):
            os.makedirs(path)         
        dcm = pydicom.dcmread(dcm_file)
        image = dcm.pixel_array
        path = f"{save_dir}/{j:03d}.png"
        if image.shape[0]<=512:
            resized = cv2.resize(image,(512,512),interpolation = cv2.INTER_CUBIC)
            resized = (resized - resized.min())/(resized.max()-resized.min() +1e-6) * 255
            cv2.imwrite(path,resized)
        else:
            resized = cv2.resize(image,(512,512),interpolation = cv2.INTER_AREA)
            resized = (resized - resized.min())/(resized.max()-resized.min() +1e-6) * 255
            cv2.imwrite(path,resized)
def keyFunc(e):
    return int(e.split('/')[-1][:-4])

In [7]:
for idx,row in tqdm(train_series_descriptions.iterrows(), total=train_series_descriptions.shape[0]):
    dcm_files = glob(f"{rsna_dir}train_images/{row['study_id']}/{row['series_id']}/**")
    dcm_files.sort(key=keyFunc)
    num_files = len(dcm_files)
    if row['series_description']=="Axial T2":
        save_dir = f"cvt_png/{row['study_id']}/Axial T2"
        if num_files<10:
            saving_pngs(dcm_files,save_dir)
        else:
            interval_len = num_files/10
            dcm_indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dcm_files = [dcm_files[index] for index in dcm_indexes]
            saving_pngs(dcm_files,save_dir)
    elif row['series_description']=="Sagittal T1":
        save_dir = f"cvt_png/{row['study_id']}/Sagittal T1"
        if num_files<10:
            saving_pngs(dcm_files,save_dir)
        else:
            interval_len = num_files/10
            dcm_indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dcm_files = [dcm_files[index] for index in dcm_indexes]
            saving_pngs(dcm_files,save_dir)
    elif row['series_description']=="Sagittal T2/STIR":
        save_dir = f"cvt_png/{row['study_id']}/Sagittal T2"        
        if num_files<10:
            saving_pngs(dcm_files,save_dir)
        else:
            interval_len = num_files/10
            dcm_indexes = [int(np.floor(i*interval_len)) for i in range(10)]
            dcm_files = [dcm_files[index] for index in dcm_indexes]
            saving_pngs(dcm_files,save_dir)

100%|██████████| 5367/5367 [22:46<00:00,  3.93it/s]


# Continuing to training -> https://www.kaggle.com/code/conradtrey/rsna2024-lsdc-part-2-training/edit

In [8]:
# Observing images Axial T2:

# print(f"Examining the extreme case: 4096820034"+"\n")
# print(train_series_descriptions[train_series_descriptions['study_id']==4096820034],"\n")

# def display_images(images, title, max_images_per_row=4):
#     # Calculate the number of rows needed
#     num_images = len(images)
#     num_rows = (num_images + max_images_per_row - 1) // max_images_per_row  # Ceiling division

#     # Create a subplot grid
#     fig, axes = plt.subplots(num_rows, max_images_per_row, figsize=(5, 1.5 * num_rows))
    
#     # Flatten axes array for easier looping if there are multiple rows
#     if num_rows > 1:
#         axes = axes.flatten()
#     else:
#         axes = [axes]  # Make it iterable for consistency

#     # Plot each image
#     for idx, image in enumerate(images):
#         ax = axes[idx]
#         ax.imshow(image, cmap='gray')  # Assuming grayscale for simplicity, change cmap as needed
#         ax.axis('off')  # Hide axes

#     # Turn off unused subplots
#     for idx in range(num_images, len(axes)):
#         axes[idx].axis('off')
#     fig.suptitle(title, fontsize=16)

#     plt.tight_layout()

# def keyFunc(e:str):
#     return int(e.split('/')[-1][:-4])

# for idx,row in train_series_descriptions[(train_series_descriptions['study_id']==4096820034) & (train_series_descriptions['series_description']=="Axial T2")].iterrows():
#     image_dir = f"{rsna_dir}train_images/{row['study_id']}/{row['series_id']}/**"
#     dcm_dirs = glob(image_dir)
#     dcm_dirs.sort(key=keyFunc)
#     slices = [pydicom.dcmread(path).pixel_array for path in dcm_dirs]
#     display_images(slices, f"{row['series_id']},{len(slices)}")

# ***Main finding: more or less similar images. Choose the one with most samples.***