## RSNA 2024 Lumbar Spine Degenerative Classification
- In this task, we are going to try to classify Lumbar Spines.
- To do this instead of building models, we will use pre-trained models built with ViT.
- Here are the steps
1.  imports
2.  read data
3.  convert data to dataset
4.  import model
5.  prepare data, transform data and model args
6.  train


In [None]:
!pip install evaluate

In [None]:
import os, gc, sys, copy, pickle
from pathlib import Path
import glob
from tqdm.auto import tqdm
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import math
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pydicom
from PIL import Image as PILImage
from datasets import Dataset, Image,Features
import warnings
import torch
warnings.filterwarnings("ignore")
tqdm.pandas()
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
#     os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)
#     tf.random.set_seed(SEED)
#     keras.utils.set_random_seed(seed=SEED)
    print('seeding done!!!')

def flush():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

In [None]:
df_train = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train.csv")
df_train_desc = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv")
df_train_desc['image_path'] = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/" + df_train_desc['study_id'].astype(str) +"/"+ df_train_desc['series_id'].astype(str) + "/"+ df_train_desc['instance_number'].astype(str) + ".dcm"
df_train_desc.sample(5)

In [None]:
df_train_melted = df_train.melt(id_vars=['study_id'], var_name='condition_level', value_name='value')

# Split the 'condition_level' column to extract 'condition' and 'level'
df_train_melted[['conditions', 'level']] = df_train_melted['condition_level'].str.rsplit('_', n=2, expand=True).iloc[:, 1:]
df_train_melted['condition'] = df_train_melted['condition_level'].apply(lambda x: '_'.join(x.split('_')[:-2])).str.replace("_"," ").str.title()
df_train_melted['level'] = df_train_melted['conditions'].str.upper() +"/"+ df_train_melted['level'].str.upper()
# Remove the original 'condition_level' column
df_train_melted = df_train_melted.drop(columns=['condition_level', 'conditions'])

# Rename columns for clarity
df_train_melted = df_train_melted.rename(columns={'value': 'severity'})
df_train_melted.sample(5)

In [None]:
df_final = df_train_desc.merge(df_train_melted, on = ["study_id","level","condition"],how = "left")
df_final

In [None]:
df_final.info()

In [None]:
def read_dicom_image(file_path, target_shape=(224, 224)):
    dicom = pydicom.dcmread(file_path)
    # Convert the DICOM pixel data to a NumPy array
    image = dicom.pixel_array
    # Normalize pixel values (if necessary)
    image = (image / np.max(image) * 255).astype(np.uint8)
    # Convert NumPy array to PIL Image
    pil_image = PILImage.fromarray(image)
    # Resize image to the target shape
    resized_image = pil_image.resize(target_shape)
    return resized_image

In [None]:
df_final.dropna(subset=['image_path', 'severity'], inplace=True)
image_paths = df_final['image_path'].values
labels = df_final['severity'].values
dataset = Dataset.from_dict({"image_path": image_paths, "label": labels})

In [None]:
labels_list = list(set(labels))
label2id, id2label = dict(), dict()
for i, label in enumerate(labels_list):
    label2id[label] = i
    id2label[i] = label
print(id2label, '\n\n', label2id)

In [None]:
dataset

In [None]:
def converts(example):
    example['image'] = read_dicom_image(example['image_path'], target_shape=(224, 224))
    return example
dataset = dataset.map(converts)

In [None]:
def display_images(dataset, num_rows=2, num_columns=5, figsize=(12, 10), max_title_length=30):
    total_images = num_rows * num_columns

    # Shuffle the dataset to get a random selection of images
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    
    fig, axes = plt.subplots(num_rows, num_columns, figsize=figsize)

    for i, idx in enumerate(indices):
        if i >= total_images:
            break
        example = dataset[idx]
        
        row = i // num_columns
        col = i % num_columns

        image = example["image"]
        label = example["label"]  

        # Display image
        axes[row, col].imshow(image)
        axes[row, col].axis('off')

        axes[row, col].set_title(label, wrap=True, fontsize='small')
    # Adjust spacing and layout
    plt.tight_layout()
    plt.show()

# Example usage
display_images(dataset)

### Training with Beit

In [None]:
from sklearn.metrics import (accuracy_score,
                             roc_auc_score,
                             precision_score,
                             recall_score,
                             confusion_matrix,
                             classification_report,
                             f1_score)

from transformers import (TrainingArguments,
                          Trainer,
                          DefaultDataCollator)
from transformers import (BeitImageProcessor,
                          BeitForImageClassification,
                          ViTImageProcessor, 
                          ViTForImageClassification,
                          AutoImageProcessor, 
                          AutoModel)
import evaluate
import torch
from torchvision import transforms
from torchvision.transforms import (CenterCrop,
                                    Compose,
                                    Normalize,
                                    RandomRotation,
                                    RandomResizedCrop,
                                    RandomHorizontalFlip,
                                    RandomAdjustSharpness,
                                    Resize,
                                    ToTensor)

In [None]:
#model_path = "/kaggle/working/beit-base"
processor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
#processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
#processor = AutoImageProcessor.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/')

size = processor.size["height"]
image_mean, image_std = processor.image_mean, processor.image_std
normalize = transforms.Normalize(mean=image_mean, std=image_std)

_train_transforms = transforms.Compose(
        [
            transforms.Resize((size, size)),
            transforms.RandomRotation(15),
            transforms.RandomAdjustSharpness(2),
            transforms.ToTensor(),
            normalize,
        ]
    )



_val_transforms = transforms.Compose(
        [
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [None]:
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset['train']
test_data = dataset['test']

In [None]:
# Set the transforms
train_data.set_transform(train_transforms)
test_data.set_transform(val_transforms)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([label2id[example["label"]] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
#model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", num_labels=3,ignore_mismatched_sizes=True)
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k',num_labels=3, ignore_mismatched_sizes=True)
#model = BeitForImageClassification.from_pretrained('/kaggle/input/dinov2/pytorch/base/1/')
model.config.id2label = id2label
model.config.label2id = label2id

In [None]:
accuracy = evaluate.load("accuracy")
def compute_metricss(eval_pred):
    predictions = eval_pred.predictions
    # predictions = np.exp(predictions)/np.exp(predictions).sum(axis=1, keepdims=True)
    label_ids = eval_pred.label_ids
    # Calculate accuracy using the loaded accuracy metric
    acc_score = accuracy.compute(predictions=predictions.argmax(axis=1), references=label_ids)['accuracy']
    return {
        "accuracy": acc_score
    }


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

In [None]:
metric_name = "accuracy"
model_name = "Lumbar Spine Degenerative Classification"
num_train_epochs=1
args = TrainingArguments(
    output_dir=model_name,
    report_to=None,
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,#32
    per_device_eval_batch_size=8,
    num_train_epochs=num_train_epochs,
    weight_decay=0.02,
    warmup_steps=50,
    remove_unused_columns=False,
    save_strategy='epoch',
    load_best_model_at_end=True,
    save_total_limit=1, # save fewer checkpoints to limit used space
)


In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
#model.save_pretrained("/kaggle/working/beit-base", from_pt=True)

In [None]:
trainer.evaluate()
trainer.train()
#Training Loss	Validation Loss	Accuracy
#0.430600	0.445599	0.815351

In [None]:
outputs = trainer.predict(test_data)
outputs.metrics

In [None]:
test_data

In [None]:
y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')

# Display accuracy and F1 score
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
# Get the confusion matrix
cm = confusion_matrix(y_true, y_pred)
cm

In [None]:
np.concatenate(outputs.predictions, axis=0)

### Prepare test dataset

In [None]:
# Function to load test images from the nested directory structure
def load_test_images(test_images_dir):
    image_paths = []
    for root, _, files in os.walk(test_images_dir):
        for file in files:
            if file.endswith(".dcm"):
                image_paths.append(os.path.join(root, file))
    return image_paths

# Directory containing test images
test_images_dir = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images/"
test_image_paths = load_test_images(test_images_dir)


In [None]:
df_test = pd.read_csv("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_series_descriptions.csv")
df_test['image_path'] = test_images_dir + df_test['study_id'].astype(str) +"/"+ df_test['series_id'].astype(str) 
dff =pd.DataFrame(test_image_paths)
dff.columns = ['image_path']
dff['instance_number'] = dff['image_path'].str.extract(r'/([^/]+)\.dcm$')
dff['image_path'] = dff['image_path'].str.replace(r'/([^/]+)\.dcm$', '', regex=True)
df_test_f = dff.merge(df_test, on = 'image_path', how = 'left')
df_test_f['image_path'] = df_test_f['image_path'] + "/" + df_test_f['instance_number'] +".dcm"
df_test_f['label']=random.choices(list(label2id.keys()), k=len(df_test_f))
df_test_f.drop(['instance_number', 'study_id', 'series_id','series_description'], axis=1, inplace=True)
df_test_f

In [None]:
dataset_test = Dataset.from_pandas(df_test_f)
dataset_test = dataset_test.map(converts)
dataset_test.set_transform(val_transforms)

In [None]:
dataset_test[0]

In [None]:
dataset_test

In [None]:
dataset_test[32]["image"]

In [None]:
#outputs = trainer.predict(dataset_test)

In [None]:
#y_true = outputs.label_ids
#y_pred = outputs.predictions.argmax(1)

helpers:

https://kaggle.com/code/samu2505/rsna-pytorch-train-lb-0-84-cv-0-54
https://www.kaggle.com/code/dima806/sea-animals-image-detection-vit