In [1]:
import pydicom
import glob, os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import cv2
from tqdm import tqdm
import re
import shutil
import os

In [2]:
rd ='/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
rs='/kaggle/working'

In [3]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [4]:
dfc = pd.read_csv(f'{rd}/train_label_coordinates.csv')
df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
df.head()

Unnamed: 0,study_id,series_id,series_description
0,4003253,702807833,Sagittal T2/STIR
1,4003253,1054713880,Sagittal T1
2,4003253,2448190387,Axial T2
3,4646740,3201256954,Axial T2
4,4646740,3486248476,Sagittal T1


In [5]:
shutil.copy(f"{rd}/train_label_coordinates.csv", rs)
shutil.copy(f"{rd}/train_series_descriptions.csv", rs)
print("File moved to output directory.")

File moved to output directory.


In [6]:
dft = pd.read_csv(f'{rd}/train.csv')
dft = dft.iloc[:500]
dft.to_csv(f'{rs}/train_500.csv', index=False)

In [7]:
import os
import pandas as pd

def filter_study_ids(primary_csv, target_directory):

    # Ensure primary CSV file exists
    if not os.path.exists(primary_csv):
        print(f"Error: {primary_csv} not found. Skipping processing.")
        return

    # Load primary CSV and extract study_id values
    primary_df = pd.read_csv(primary_csv)

    if 'study_id' not in primary_df.columns:
        print(f"Error: 'study_id' column not found in {primary_csv}.")
        return

    # Identify missing and valid study_id values
    missing_study_ids = primary_df[primary_df.isna().any(axis=1)]['study_id'].unique()
    valid_study_ids = primary_df['study_id'].unique()

    print(f"Filtering CSV files in {target_directory}...")
    print(f"- Removing {len(missing_study_ids)} missing study_id values")
    print(f"- Keeping {len(valid_study_ids)} valid study_id values")

    # Process each CSV file in the target directory
    for filename in os.listdir(target_directory):
        if filename.endswith('.csv'):
            file_path = os.path.join(target_directory, filename)

            try:
                df = pd.read_csv(file_path)

                if 'study_id' in df.columns:
                    # Remove rows with missing study_id values & keep only valid ones
                    df_filtered = df[~df['study_id'].isin(missing_study_ids)]
                    df_filtered = df_filtered[df_filtered['study_id'].isin(valid_study_ids)]
                    
                    df_filtered.to_csv(file_path, index=False)
                    print(f"Updated {filename}: Removed {len(df) - len(df_filtered)} rows, kept valid study_id values.")
                else:
                    print(f"Warning: 'study_id' column not found in {filename}. Skipping this file.")

            except Exception as e:
                print(f"Error processing {filename}: {e}")

# Define file paths
primary_csv = f'{rs}/train_500.csv'
target_directory = '/kaggle/working/'

# Run the filtering process
filter_study_ids(primary_csv, target_directory)


Filtering CSV files in /kaggle/working/...
- Removing 42 missing study_id values
- Keeping 500 valid study_id values
Updated train_label_coordinates.csv: Removed 37246 rows, kept valid study_id values.
Updated train_500.csv: Removed 42 rows, kept valid study_id values.
Updated train_series_descriptions.csv: Removed 4826 rows, kept valid study_id values.


In [8]:
dfts=pd.read_csv(f'{rs}/train_series_descriptions.csv')
dfts

Unnamed: 0,study_id,series_id,series_description
0,4003253,702807833,Sagittal T2/STIR
1,4003253,1054713880,Sagittal T1
2,4003253,2448190387,Axial T2
3,4646740,3201256954,Axial T2
4,4646740,3486248476,Sagittal T1
...,...,...,...
1463,1099112122,919181265,Sagittal T1
1464,1099112122,1815821295,Sagittal T2/STIR
1465,1103373889,1437983577,Sagittal T2/STIR
1466,1103373889,1669196197,Sagittal T1


In [9]:
dfts['series_description'].value_counts()

series_description
Axial T2            551
Sagittal T1         459
Sagittal T2/STIR    458
Name: count, dtype: int64

In [10]:
def imread_and_imwirte(src_path, dst_path):
    dicom_data = pydicom.dcmread(src_path)
    image = dicom_data.pixel_array
    image = (image - image.min()) / (image.max() - image.min() +1e-6) * 255
    img = cv2.resize(image, (512, 512),interpolation=cv2.INTER_CUBIC)
    assert img.shape==(512,512)
    cv2.imwrite(dst_path, img)

In [11]:
st_ids= dfts['study_id'].unique()
st_ids[:3], len(st_ids)

(array([4003253, 4646740, 7143189]), 458)

In [12]:
desc = list(dfts['series_description'].unique())
desc

['Sagittal T2/STIR', 'Sagittal T1', 'Axial T2']

In [13]:
for idx, si in enumerate(tqdm(st_ids, total=len(st_ids))):
    pdf = dfts[dfts['study_id']==si]
    for ds in desc:
        ds_ = ds.replace('/', '_')
        pdf_ = pdf[pdf['series_description']==ds]
        os.makedirs(f'mri_png/{si}/{ds_}', exist_ok=True)
        allimgs = []
        for i, row in pdf_.iterrows():
            pimgs = glob.glob(f'{rd}/train_images/{row["study_id"]}/{row["series_id"]}/*.dcm')
            pimgs = sorted(pimgs, key=natural_keys)
            allimgs.extend(pimgs)
            
        if len(allimgs)==0:
            print(si, ds, 'has no images')
            continue

        if ds == 'Axial T2':
            for j, impath in enumerate(allimgs):
                dst = f'mri_png/{si}/{ds}/{j:03d}.png'
                imread_and_imwirte(impath, dst)
                
        elif ds == 'Sagittal T2/STIR':
            
            step = len(allimgs) / 10.0
            st = len(allimgs)/2.0 - 4.0*step
            end = len(allimgs)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                dst = f'mri_png/{si}/{ds_}/{j:03d}.png'
                ind2 = max(0, int((i-0.5001).round()))
                imread_and_imwirte(allimgs[ind2], dst)
                
            assert len(glob.glob(f'mri_png/{si}/{ds_}/*.png'))==10
                
        elif ds == 'Sagittal T1':
            step = len(allimgs) / 10.0
            st = len(allimgs)/2.0 - 4.0*step
            end = len(allimgs)+0.0001
            for j, i in enumerate(np.arange(st, end, step)):
                dst = f'mri_png/{si}/{ds}/{j:03d}.png'
                ind2 = max(0, int((i-0.5001).round()))
                imread_and_imwirte(allimgs[ind2], dst)
                
            assert len(glob.glob(f'mri_png/{si}/{ds}/*.png'))==10

100%|██████████| 458/458 [10:03<00:00,  1.32s/it]


In [14]:
!zip -r /kaggle/working/output.zip /kaggle/working/


  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/mri_png/ (stored 0%)
  adding: kaggle/working/mri_png/597752094/ (stored 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/ (stored 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/008.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/007.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/001.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/005.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/002.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/009.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/004.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/000.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/003.png (deflated 0%)
  adding: kaggle/working/mri_png/597752094/Sagittal T1/006.png (deflated 0%)
  adding