In [1]:
import seaborn as sns

import matplotlib.pyplot as plt
import os
import time
import numpy as np
import glob
import json
import collections
import torch
import torch.nn as nn

import pydicom as dicom
import matplotlib.patches as patches

from matplotlib import animation, rc
import pandas as pd

import pydicom as dicom # dicom
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

# read data
train_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/'

train  = pd.read_csv(train_path + 'train.csv')
label = pd.read_csv(train_path + 'train_label_coordinates.csv')
train_desc  = pd.read_csv(train_path + 'train_series_descriptions.csv')
test_desc   = pd.read_csv(train_path + 'test_series_descriptions.csv')
sub         = pd.read_csv(train_path + 'sample_submission.csv')




Image Viewer

In [2]:
import pydicom
import matplotlib.pyplot as plt

# Function to generate image paths based on directory structure
def generate_image_paths(df, data_dir):
    image_paths = []
    for study_id, series_id in zip(df['study_id'], df['series_id']):
        study_dir = os.path.join(data_dir, str(study_id))
        series_dir = os.path.join(study_dir, str(series_id))
        images = os.listdir(series_dir)
        image_paths.extend([os.path.join(series_dir, img) for img in images])
    return image_paths


# Function to open and display DICOM images
def display_dicom_images(image_paths):
    plt.figure(figsize=(15, 5))  # Adjust figure size if needed
    for i, path in enumerate(image_paths[:3]):
        ds = pydicom.dcmread(path)
        plt.subplot(1, 3, i+1)
        plt.imshow(ds.pixel_array, cmap=plt.cm.bone)
        plt.title(f"Image {i+1}")
        plt.axis('off')
    plt.show()

# Function to open and display DICOM images along with coordinates
def display_dicom_with_coordinates(image_paths, label_df):
    fig, axs = plt.subplots(1, len(image_paths), figsize=(18, 6))
    
    for idx, path in enumerate(image_paths):  # Display images
        study_id = int(path.split('/')[-3])
        series_id = int(path.split('/')[-2])
        
        # Filter label coordinates for the current study and series
        filtered_labels = label_df[(label_df['study_id'] == study_id) & (label_df['series_id'] == series_id)]
        
        # Read DICOM image
        ds = pydicom.dcmread(path)
        
        # Plot DICOM image
        axs[idx].imshow(ds.pixel_array, cmap='gray')
        axs[idx].set_title(f"Study ID: {study_id}, Series ID: {series_id}")
        axs[idx].axis('off')
        
        # Plot coordinates
        for _, row in filtered_labels.iterrows():
            axs[idx].plot(row['x'], row['y'], 'ro', markersize=5)
        
    plt.tight_layout()
    plt.show()

# Load DICOM files from a folder
def load_dicom_files(path_to_folder):
    files = [os.path.join(path_to_folder, f) for f in os.listdir(path_to_folder) if f.endswith('.dcm')]
    files.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('-')[-1]))
    return files


# Generate image paths for train and test data
train_image_paths = generate_image_paths(train_desc, f'{train_path}/train_images')
test_image_paths = generate_image_paths(test_desc, f'{train_path}/test_images')


# Display the first three DICOM images
#display_dicom_images(train_image_paths)

# Display DICOM images with coordinates
study_id = "100206310"
study_folder = f'{train_path}/train_images/{study_id}'

image_paths = []
for series_folder in os.listdir(study_folder):
    series_folder_path = os.path.join(study_folder, series_folder)
    dicom_files = load_dicom_files(series_folder_path)
    if dicom_files:
        image_paths.append(dicom_files[0])  # Add the first image from each series

Deata processing

In [3]:
#################################################### Deata processing ###########################################################

# Define function to reshape a single row of the DataFrame
def reshape_row(row):
    data = {'study_id': [], 'condition': [], 'level': [], 'severity': []}
    
    for column, value in row.items():
        if column not in ['study_id', 'series_id', 'instance_number', 'x', 'y', 'series_description']:
            parts = column.split('_')
            condition = ' '.join([word.capitalize() for word in parts[:-2]])
            level = parts[-2].capitalize() + '/' + parts[-1].capitalize()
            data['study_id'].append(row['study_id'])
            data['condition'].append(condition)
            data['level'].append(level)
            data['severity'].append(value)
    
    return pd.DataFrame(data)


# Reshape the DataFrame for all rows
new_train_df = pd.concat([reshape_row(row) for _, row in train.iterrows()], ignore_index=True)

# Merge the dataframes on the common columns
merged_df = pd.merge(new_train_df, label, on=['study_id', 'condition', 'level'], how='inner')

# Merge the dataframes on the common column 'series_id'
final_merged_df = pd.merge(merged_df, train_desc, on='series_id', how='inner')

# Merge the dataframes on the common column 'series_id'
final_merged_df = pd.merge(merged_df, train_desc, on=['series_id','study_id'], how='inner')
# Display the first few rows of the final merged dataframe

# Create the row_id column
final_merged_df['row_id'] = (
    final_merged_df['study_id'].astype(str) + '_' +
    final_merged_df['condition'].str.lower().str.replace(' ', '_') + '_' +
    final_merged_df['level'].str.lower().str.replace('/', '_')
)

# Create the image_path column
final_merged_df['image_path'] = (
    f'{train_path}/train_images/' + 
    final_merged_df['study_id'].astype(str) + '/' +
    final_merged_df['series_id'].astype(str) + '/' +
    final_merged_df['instance_number'].astype(str) + '.dcm'
)

# Define the base path for test images
base_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/test_images/'

# Function to get image paths for a series
def get_image_paths(row):
    series_path = os.path.join(base_path, str(row['study_id']), str(row['series_id']))
    if os.path.exists(series_path):
        return [os.path.join(series_path, f) for f in os.listdir(series_path) if os.path.isfile(os.path.join(series_path, f))]
    return []

# Mapping of series_description to conditions
condition_mapping = {
    'Sagittal T1': {'left': 'left_neural_foraminal_narrowing', 'right': 'right_neural_foraminal_narrowing'},
    'Axial T2': {'left': 'left_subarticular_stenosis', 'right': 'right_subarticular_stenosis'},
    'Sagittal T2/STIR': 'spinal_canal_stenosis'
}

# Create a list to store the expanded rows
expanded_rows = []

# Expand the dataframe by adding new rows for each file path
for index, row in test_desc.iterrows():
    image_paths = get_image_paths(row)
    conditions = condition_mapping.get(row['series_description'], {})
    if isinstance(conditions, str):  # Single condition
        conditions = {'left': conditions, 'right': conditions}
    for side, condition in conditions.items():
        for image_path in image_paths:
            expanded_rows.append({
                'study_id': row['study_id'],
                'series_id': row['series_id'],
                'series_description': row['series_description'],
                'image_path': image_path,
                'condition': condition,
                'row_id': f"{row['study_id']}_{condition}"
            })

# Create a new dataframe from the expanded rows
expanded_test_desc = pd.DataFrame(expanded_rows)

# Train_data and test_data
train_data = final_merged_df


import os

# Define a function to check if a path exists
def check_exists(path):
    return os.path.exists(path)

# Define a function to check if a study ID directory exists
def check_study_id(row):
    study_id = row['study_id']
    path = f'{train_path}/train_images/{study_id}'
    return check_exists(path)

# Define a function to check if a series ID directory exists
def check_series_id(row):
    study_id = row['study_id']
    series_id = row['series_id']
    path = f'{train_path}/train_images/{study_id}/{series_id}'
    return check_exists(path)

# Define a function to check if an image file exists
def check_image_exists(row):
    image_path = row['image_path']
    return check_exists(image_path)

# Apply the functions to the train_data dataframe
train_data['study_id_exists'] = train_data.apply(check_study_id, axis=1)
train_data['series_id_exists'] = train_data.apply(check_series_id, axis=1)
train_data['image_exists'] = train_data.apply(check_image_exists, axis=1)

# Filter train_data
train_data = train_data[(train_data['study_id_exists']) & (train_data['series_id_exists']) & (train_data['image_exists'])]
train_data = train_data.dropna()


# resampling

from sklearn.utils import resample

class_counts = train_data['severity'].value_counts()
print("Original class counts:\n", class_counts)

# 최대 클래스 수
max_count = class_counts.max()

# 각 클래스의 데이터를 균형 맞추기 위해 리샘플링
balanced_data = pd.DataFrame()

for severity in class_counts.index:
    class_data = train_data[train_data['severity'] == severity]
    if len(class_data) < max_count:

        class_data = resample(class_data, replace=True, n_samples=max_count, random_state=42)
    balanced_data = pd.concat([balanced_data, class_data])

train_data = balanced_data.sample(frac=1, random_state=42).reset_index(drop=True)


Original class counts:
 severity
Normal/Mild    37626
Moderate        7950
Severe          3081
Name: count, dtype: int64


Loading Data

In [4]:
####################################################### Loading Data ###########################################################

import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
# import torchvision.transforms as transforms
import torch
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from PIL import Image
import cv2

def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data




# Define a custom dataset class
class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image_path = self.dataframe['image_path'][index]
        image = load_dicom(image_path)  # Define this function to load your DICOM images
        label = self.dataframe['severity'][index]

        # image가 None인 경우 처리
        if image is None:
            # 예를 들어, None 대신 검정색 이미지 반환
            image = np.zeros((256, 256), dtype=np.uint8)  # 또는 이미지 크기에 맞게 설정
        
        # image proccessing
        image = (image * 255).astype(np.uint8)  # Convert back to uint8 for PIL
        
        image = Image.fromarray(image)  # Convert to PIL Image
        # Convert back to numpy array
        image = np.array(image)
        # image crop
        x = round(self.dataframe['x'][index])
        y = round(self.dataframe['y'][index])
        
        gap_x = round(image.shape[0] / 10)
        gap_y = round(image.shape[1] / 10)

        # 이미지 크롭 범위가 유효한지 확인
        if y-gap_y < 0 or y+gap_y > image.shape[0] or x-gap_x < 0 or x+gap_x > image.shape[1]:
            image = np.zeros((256, 256), dtype=np.uint8)  # 유효하지 않으면 검정색 이미지 반환
        else:
            image = image[y-gap_y : y+gap_y, x-gap_x : x+gap_x]

        image = cv2.equalizeHist(image)

        color_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

        image = Image.fromarray(image).resize((224, 224), Image.BILINEAR)
        image = np.array(image)
        
        # Convert to 3 channels (RGB)
        image = np.stack([image] * 3, axis=-1) if image.ndim == 2 else image

        image = image.transpose((2, 0, 1))  # Change to (C, H, W) format
        image = torch.tensor(image, dtype=torch.float32) / 255.0  # Convert to tensor and normalize

        return image, label

# Function to create datasets and dataloaders for each series description
def create_datasets_and_loaders(df, series_description, batch_size=20):
    filtered_df = df[df['series_description'] == series_description]
    
    train_df, val_df = train_test_split(filtered_df, test_size=0.2, random_state=42)
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)

    train_dataset = CustomDataset(train_df)
    val_dataset = CustomDataset(val_df)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return trainloader, valloader, len(train_df), len(val_df)

# Create dataloaders for each series description
dataloaders = {}
lengths = {}

trainloader_t1, valloader_t1, len_train_t1, len_val_t1 = create_datasets_and_loaders(train_data, 'Sagittal T1')
trainloader_t2, valloader_t2, len_train_t2, len_val_t2 = create_datasets_and_loaders(train_data, 'Axial T2')
trainloader_t2stir, valloader_t2stir, len_train_t2stir, len_val_t2stir = create_datasets_and_loaders(train_data, 'Sagittal T2/STIR')

dataloaders['Sagittal T1'] = (trainloader_t1, valloader_t1)
dataloaders['Axial T2'] = (trainloader_t2, valloader_t2)
dataloaders['Sagittal T2/STIR'] = (trainloader_t2stir, valloader_t2stir)

lengths['Sagittal T1'] = (len_train_t1, len_val_t1)
lengths['Axial T2'] = (len_train_t2, len_val_t2)
lengths['Sagittal T2/STIR'] = (len_train_t2stir, len_val_t2stir)


import matplotlib.pyplot as plt

# Function to visualize a batch of images
def visualize_batch(dataloader):
    images, labels = next(iter(dataloader))
    fig, axes = plt.subplots(1, len(images), figsize=(20, 5))
    for i, (img, lbl) in enumerate(zip(images, labels)):
        ax = axes[i]
        img = img.permute(1, 2, 0)  # Convert to HWC for visualization
        ax.imshow(img)
        ax.set_title(f"Label: {lbl}")
        ax.axis('off')
    plt.show()

Data Visulization

In [5]:
################################ Data Visulization ##############################################

# # Visualize samples from each dataloader
# print("Visualizing Sagittal T1 samples")
# visualize_batch(valloader_t1)
# print("Visualizing Axial T2 samples")
# visualize_batch(trainloader_t2)
# print("Visualizing Sagittal T2/STIR samples")
# visualize_batch(trainloader_t2stir)


Model

In [6]:
################################ Model ###########################################################

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

class CustomResNet(nn.Module):
    def __init__(self, num_classes=3, pretrained_model_path=None):
        super(CustomResNet, self).__init__()
        self.model = models.resnet18(weights=None)  # Pre-trained weights will be loaded manually
        if pretrained_model_path:
            self.model.load_state_dict(torch.load(pretrained_model_path))
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Initialize models
sagittal_t1_model = CustomResNet(num_classes=3, pretrained_model_path="/kaggle/input/resnet/resnet18-5c106cde.pth").to(device)
axial_t2_model = CustomResNet(num_classes=3, pretrained_model_path="/kaggle/input/resnet/resnet18-5c106cde.pth").to(device)
sagittal_t2stir_model = CustomResNet(num_classes=3, pretrained_model_path="/kaggle/input/resnet/resnet18-5c106cde.pth").to(device)


# Unfreeze all layers
for param in sagittal_t1_model.model.parameters():
    param.requires_grad = True
for param in axial_t2_model.model.parameters():
    param.requires_grad = True
for param in sagittal_t2stir_model.model.parameters():
    param.requires_grad = True

# Training parameters
criterion = nn.CrossEntropyLoss()

# Initialize separate optimizers for each model with L2 regularization
optimizer_sagittal_t1 = torch.optim.SGD(sagittal_t1_model.model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizer_axial_t2 = torch.optim.SGD(axial_t2_model.model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizer_sagittal_t2stir = torch.optim.SGD(sagittal_t2stir_model.model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# Store the models and optimizers in dictionaries for easy access
model_dics = {
    'Sagittal T1': sagittal_t1_model,
    'Axial T2': axial_t2_model,
    'Sagittal T2/STIR': sagittal_t2stir_model,
}

optimizers = {
    'Sagittal T1': optimizer_sagittal_t1,
    'Axial T2': optimizer_axial_t2,
    'Sagittal T2/STIR': optimizer_sagittal_t2stir,
}

trainable_params = sum(p.numel() for p in sagittal_t1_model.parameters() if p.requires_grad)
print(f"Number of parameters: {trainable_params}")

cuda


Training

In [7]:
############################# Training ###################################################3

label_map = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}


# Define the number of epochs
num_epochs = 15

scheduler_sagittal_t1 = torch.optim.lr_scheduler.StepLR(optimizer_sagittal_t1, step_size=7, gamma=0.1)
scheduler_axial_t2 = torch.optim.lr_scheduler.StepLR(optimizer_axial_t2, step_size=7, gamma=0.1)
scheduler_sagittal_t2stir = torch.optim.lr_scheduler.StepLR(optimizer_sagittal_t2stir, step_size=7, gamma=0.1)

schedulers = {
    'Sagittal T1': scheduler_sagittal_t1,
    'Axial T2': scheduler_axial_t2,
    'Sagittal T2/STIR': scheduler_sagittal_t2stir,
}


## This is pretrained model ##

from tqdm import tqdm

# Training loop for all models
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 20)
    
    for model_name, model in model_dics.items():
        
        # Set the model to training mode
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        
        trainloader = dataloaders[model_name][0]
        valloader = dataloaders[model_name][1]
        optimizer = optimizers[model_name]
        scheduler = schedulers[model_name]
        
        # Use tqdm to display progress
        train_progress_bar = tqdm(trainloader, desc=f"{model_name} Train")
        for images, labels in train_progress_bar:
            labels = torch.tensor([label_map[label] for label in labels]).to(device)
            images = images.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
            
            # Update progress bar with loss and accuracy
            train_progress_bar.set_postfix(loss=loss.item(), acc=correct_predictions.double() / len(trainloader.dataset))
        
        epoch_loss = running_loss / len(trainloader.dataset)
        epoch_acc = correct_predictions.double() / len(trainloader.dataset)
        
        print(f"{model_name} Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
        scheduler.step()
        
        # Validation step
        model.eval()
        val_running_loss = 0.0
        val_correct_predictions = 0
        
        # Use tqdm to display progress
        val_progress_bar = tqdm(valloader, desc=f"{model_name} Val")
        with torch.no_grad():
            for images, labels in val_progress_bar:
                labels = torch.tensor([label_map[label] for label in labels]).to(device)
                images = images.to(device)
                
                outputs = model(images)

                loss = criterion(outputs, labels)
                
                val_running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct_predictions += torch.sum(preds == labels.data)
                
                # Update progress bar with loss and accuracy
                val_progress_bar.set_postfix(loss=loss.item(), acc=val_correct_predictions.double() / len(valloader.dataset))
        
        val_loss = val_running_loss / len(valloader.dataset)
        val_acc = val_correct_predictions.double() / len(valloader.dataset)
        
        print(f"{model_name} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

# load_dir = '/kaggle/input/pretrained-model'

# for model_name, model in model_dics.items():
#     modified_string = model_name.replace(' ', '_').replace('/', '_')
#     load_path = os.path.join(load_dir, f"resNet18_{modified_string}.pth")
#     if os.path.exists(load_path):
#         model.load_state_dict(torch.load(load_path))
#         model.eval()
#         print("load!")
#     else:
#         print(f"Model file not found: {load_path}")


load!
load!
load!


key point model

In [8]:
###################################################################################################################
###################################################################################################################
######################################### key point model #########################################################

# 데이터 프레임 변환
df_pivot = label.pivot_table(index=['study_id', 'series_id', 'instance_number', 'condition'],
                          columns='level',
                          values=['x', 'y'],
                          aggfunc='first').reset_index()

# 컬럼 이름 재구성
df_pivot.columns = ['_'.join(col).strip() for col in df_pivot.columns.values]

# 컬럼 이름 보기 좋게 정리
df_pivot.columns = [col.replace('_', ' ') if col != '' else col for col in df_pivot.columns]
df_pivot.columns = [col.replace('  ', '_').strip() for col in df_pivot.columns]

df_pivot.columns = df_pivot.columns.str.replace(' ', '_')

merged_df = pd.merge(df_pivot, train_desc, on=['study_id', 'series_id'], how='left')

# Create the image_path column
merged_df['image_path'] = (
    f'{train_path}/train_images/' + 
    merged_df['study_id'].astype(str) + '/' +
    merged_df['series_id'].astype(str) + '/' +
    merged_df['instance_number'].astype(str) + '.dcm'
)

# Train_data and test_data
train_data = merged_df

Data processing

In [9]:
#################################################### Data processing ###########################################################


train_data.columns = train_data.columns.str.strip().str.replace(' ', '_')

x_cols = [col for col in train_data.columns if col.startswith('x_')]
y_cols = [col for col in train_data.columns if col.startswith('y_')]

grouped_df = train_data.groupby(['study_id', 'series_id', 'series_description'], as_index=False).agg(
    {**{col: 'mean' for col in x_cols + y_cols}, 'image_path': 'first'})

train_data = grouped_df.dropna()

Loading Data

In [10]:
####################################################### Loading Data ###########################################################

# Define a custom dataset class
class CustomDataset_keypoint(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        # self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image_path = self.dataframe['image_path'][index]
        image = load_dicom(image_path)  # Define this function to load your DICOM images

        L1_L2 = (self.dataframe['x_L1/L2'][index] , self.dataframe['y_L1/L2'][index])
        L2_L3 = (self.dataframe['x_L2/L3'][index] , self.dataframe['y_L2/L3'][index])
        L3_L4 = (self.dataframe['x_L3/L4'][index] , self.dataframe['y_L3/L4'][index])
        L4_L5 = (self.dataframe['x_L4/L5'][index] , self.dataframe['y_L4/L5'][index])
        L5_S1 = (self.dataframe['x_L5/S1'][index] , self.dataframe['y_L5/S1'][index])



        keypoints = [L1_L2, L2_L3, L3_L4, L4_L5, L5_S1]

        if image is None:
            # 예를 들어, None 대신 검정색 이미지 반환
            image = np.zeros((256, 256), dtype=np.uint8)  # 또는 이미지 크기에 맞게 설정
        
        # image proccessing
        image = (image * 255).astype(np.uint8)  # Convert back to uint8 for PIL
        h, w = image.shape
        image = cv2.equalizeHist(image)

        image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
        keypoints_resized = [(x * 224 / w, y * 224 / h) for x, y in keypoints]

        # Convert to 3 channels (RGB)
        image = np.stack([image] * 3, axis=-1) if image.ndim == 2 else image

        image = image.transpose((2, 0, 1))  # Change to (C, H, W) format
        image = torch.tensor(image, dtype=torch.float32) / 255.0  # Convert to tensor and normalize
        keypoints_resized = torch.tensor(keypoints_resized, dtype=torch.float32).view(-1)

        return image, keypoints_resized

# Function to create datasets and dataloaders for each series description
def create_datasets_and_loaders_keypoint(df, series_description, batch_size=20):

    filtered_df = df[df['series_description'] == series_description]
    
    train_df, val_df = train_test_split(filtered_df, test_size=0.2, random_state=42)
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)

    train_dataset = CustomDataset_keypoint(train_df)
    val_dataset = CustomDataset_keypoint(val_df)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return trainloader, valloader, len(train_df), len(val_df)

# Create dataloaders for each series description
dataloaders = {}
lengths = {}

trainloader_t1, valloader_t1, len_train_t1, len_val_t1 = create_datasets_and_loaders_keypoint(train_data, 'Sagittal T1')
trainloader_t2, valloader_t2, len_train_t2, len_val_t2 = create_datasets_and_loaders_keypoint(train_data, 'Axial T2')
trainloader_t2stir, valloader_t2stir, len_train_t2stir, len_val_t2stir = create_datasets_and_loaders_keypoint(train_data, 'Sagittal T2/STIR')

dataloaders['Sagittal T1'] = (trainloader_t1, valloader_t1)
dataloaders['Axial T2'] = (trainloader_t2, valloader_t2)
dataloaders['Sagittal T2/STIR'] = (trainloader_t2stir, valloader_t2stir)

lengths['Sagittal T1'] = (len_train_t1, len_val_t1)
lengths['Axial T2'] = (len_train_t2, len_val_t2)
lengths['Sagittal T2/STIR'] = (len_train_t2stir, len_val_t2stir)

import matplotlib.pyplot as plt

# Function to visualize a batch of images
def visualize_batch(dataloader):
    images, keypoints = next(iter(dataloader))

Model

In [11]:
################################ Model ###########################################################

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class KeypointModel(nn.Module):
    def __init__(self, num_keypoints=5, pretrained_model_path=None):
        super(KeypointModel, self).__init__()
        self.resnet = models.resnet50(weights=None)  # Pre-trained weights will be loaded manually
        if pretrained_model_path:
            self.resnet.load_state_dict(torch.load(pretrained_model_path))
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_keypoints * 2)

    def forward(self, x):
        return self.resnet(x)

# Initialize models
sagittal_t1_model_keypoint = KeypointModel(num_keypoints=5, pretrained_model_path="/kaggle/input/resnet/resnet50.pth").to(device)
axial_t2_model_keypoint = KeypointModel(num_keypoints=5, pretrained_model_path="/kaggle/input/resnet/resnet50.pth").to(device)
sagittal_t2stir_model_keypoint = KeypointModel(num_keypoints=5, pretrained_model_path="/kaggle/input/resnet/resnet50.pth").to(device)

# Unfreeze all layers
for param in sagittal_t1_model_keypoint.parameters():
    param.requires_grad = True
for param in axial_t2_model_keypoint.parameters():
    param.requires_grad = True
for param in sagittal_t2stir_model_keypoint.parameters():
    param.requires_grad = True

# Initialize separate optimizers for each model with L2 regularization
optimizer_sagittal_t1_keypoint = torch.optim.SGD(sagittal_t1_model_keypoint.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
#optimizer_axial_t2 = torch.optim.SGD(axial_t2_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizer_axial_t2_keypoint = torch.optim.Adam(axial_t2_model_keypoint.parameters(), lr=0.001, weight_decay=1e-4)
optimizer_sagittal_t2stir_keypoint = torch.optim.SGD(sagittal_t2stir_model_keypoint.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

# Store the models and optimizers in dictionaries for easy access
keypoints_models = {
    'Sagittal T1': sagittal_t1_model_keypoint,
    'Axial T2': axial_t2_model_keypoint,
    'Sagittal T2/STIR': sagittal_t2stir_model_keypoint,
}

keypoints_optimizers = {
    'Sagittal T1': optimizer_sagittal_t1_keypoint,
    'Axial T2': optimizer_axial_t2_keypoint,
    'Sagittal T2/STIR': optimizer_sagittal_t2stir_keypoint,
}

scheduler_sagittal_t1_keypoint = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_sagittal_t1_keypoint, mode='min', factor=0.1, patience=5, verbose=True)
scheduler_axial_t2_keypoint = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_axial_t2_keypoint, mode='min', factor=0.1, patience=5, verbose=True)
scheduler_sagittal_t2stir_keypoint = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_sagittal_t2stir_keypoint, mode='min', factor=0.1, patience=5, verbose=True)

keypoints_schedulers = {
    'Sagittal T1': scheduler_sagittal_t1_keypoint,
    'Axial T2': scheduler_axial_t2_keypoint,
    'Sagittal T2/STIR': scheduler_sagittal_t2stir_keypoint,
}

trainable_params = sum(p.numel() for p in sagittal_t1_model_keypoint.parameters() if p.requires_grad)
print(f"Number of parameters: {trainable_params}")

Training

In [12]:
############################# Training ###################################################3


# Training loop for all models
num_epochs = 27
criterion = nn.SmoothL1Loss() 

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 20)
    
    for model_name, model in keypoints_models.items():
        
        # Set the model to training mode
        model.train()
        running_loss = 0.0
        
        trainloader = dataloaders[model_name][0]
        valloader = dataloaders[model_name][1]
        optimizer = keypoints_optimizers[model_name]
        scheduler = keypoints_schedulers[model_name]
        
        for images, keypoints in trainloader:
            keypoints = keypoints.to(device)
            images = images.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, keypoints)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        
        epoch_loss = running_loss / len(trainloader.dataset)
        
        print(f"{model_name} Train Loss: {epoch_loss:.4f}")
        
        # Validation step
        model.eval()
        val_running_loss = 0.0
        
        with torch.no_grad():
            for images, keypoints in valloader:
                keypoints = keypoints.view(-1, 10).to(device)
                images = images.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, keypoints)
                
                val_running_loss += loss.item() * images.size(0)
        
        val_loss = val_running_loss / len(valloader.dataset)
        
        print(f"{model_name} Val Loss: {val_loss:.4f}")
        scheduler.step(val_loss)

# load_dir = '/kaggle/input/pretrained-model'
# os.makedirs(load_dir, exist_ok=True)

# for model_name, model in keypoints_models.items():
#     modified_string = model_name.replace(' ', '_').replace('/', '_')
#     load_path = os.path.join(load_dir, f"keypoints_resNet50_{modified_string}.pth")
#     model.load_state_dict(torch.load(load_path))
#     print("load")

load
load
load


Inference

In [13]:
###################################################################################################################
###################################################################################################################
###################################################################################################################

##################################### Inference ##################################################################

levels = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']

# Function to update row_id with levels
def update_row_id(row, levels):
    level = levels[row.name % len(levels)]
    return f"{row['study_id']}_{row['condition']}_{level}"

def update_level(row, levels):
    return levels[row.name % len(levels)]

# Update row_id in expanded_test_desc to include levels
expanded_test_desc['row_id'] = expanded_test_desc.apply(lambda row: update_row_id(row, levels), axis=1)
expanded_test_desc['level'] = expanded_test_desc.apply(lambda row: update_level(row, levels), axis=1)

# Define a custom test dataset class
class TestDataset(Dataset):
    def __init__(self, dataframe, keypoints_models):
        self.dataframe = dataframe
        self.keypoints_models = keypoints_models

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image_path = self.dataframe['image_path'][index]
        image = load_dicom(image_path)  # Define this function to load your DICOM images

        # image가 None인 경우 처리
        if image is None:
            # 예를 들어, None 대신 검정색 이미지 반환
            image = np.zeros((256, 256), dtype=np.uint8)  # 또는 이미지 크기에 맞게 설정
        
        # image processing
        image = (image * 255).astype(np.uint8)  # Convert back to uint8 for PIL
        image = cv2.equalizeHist(image)
        h, w = image.shape
        resized_input_image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)

        r_img_c = resized_input_image

        # Convert to 3 channels (RGB) for keypoints model
        resized_input_image = np.stack([resized_input_image] * 3, axis=-1)
        resized_input_image = resized_input_image.astype(np.float32) / 255.0  # 정규화
        resized_input_image = np.transpose(resized_input_image, (2, 0, 1))  # HWC to CHW
        resized_input_image = np.expand_dims(resized_input_image, axis=0)  # 배치 차원 추가
        resized_input_image = torch.from_numpy(resized_input_image).to(device)   # 텐서로 변환

        series_description = self.dataframe['series_description'][index]

        keypoints_model = self.keypoints_models.get(series_description, None)
        if keypoints_model is None:
            raise ValueError(f"Model for series_description '{series_description}' not found.")
        
        keypoints_model.eval()
        with torch.no_grad():
            outputs = keypoints_model(resized_input_image)
        
        outputs = outputs.squeeze().cpu().numpy()  # 결과를 numpy 배열로 변환
        level = self.dataframe['level'][index]

        levels = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']
        level_index = levels.index(level)

        x = outputs[2 * level_index]
        y = outputs[2 * level_index + 1]

        x = x * w / 224
        y = y * h / 224

        ########################## visulize ##########################
        # x_list = outputs[: : 2]
        # y_list = outputs[1 : : 2]
        # x_list = x_list * w / 224
        # y_list = y_list * h / 224

        # img_c = image
        # img_c = cv2.cvtColor(img_c, cv2.COLOR_GRAY2BGR)

        # cv2.circle(img_c, (round(x_list[0]), round(y_list[0])), 5, (0, 0, 255), -1)
        # cv2.circle(img_c, (round(x_list[1]), round(y_list[1])), 5, (0, 255, 0), -1)
        # cv2.circle(img_c, (round(x_list[2]), round(y_list[2])), 5, (255, 0, 0), -1)
        # cv2.circle(img_c, (round(x_list[3]), round(y_list[3])), 5, (0, 255, 255), -1)
        # cv2.circle(img_c, (round(x_list[4]), round(y_list[4])), 5, (255, 0, 255), -1)
        # cv2.imshow("img", img_c)
        # cv2.waitKey(0)

        #############################################################
        
        gap_x = w / 10
        gap_y = h / 10

        # 이미지 크롭 범위가 유효한지 확인
        if y-gap_y < 0 or y+gap_y > h or x-gap_x < 0 or x+gap_x > w:
            image = np.zeros((256, 256), dtype=np.uint8)  # 유효하지 않으면 검정색 이미지 반환
        else:
            image = image[round(y-gap_y):round(y+gap_y), round(x-gap_x):round(x+gap_x)]

        image = cv2.equalizeHist(image)

        image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)

        # cv2.imshow("image", image)

        # cv2.waitKey(0)
        
        # Convert to 3 channels (RGB)
        image = np.stack([image] * 3, axis=-1) if image.ndim == 2 else image

        image = image.transpose((2, 0, 1))  # Change to (C, H, W) format
        image = torch.tensor(image, dtype=torch.float32) / 255.0  # Convert to tensor and normalize

        return image


# Create a test dataset and dataloader
test_dataset = TestDataset(expanded_test_desc, keypoints_models)
testloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


# Function to get the model based on series_description
def get_model(series_description):
    return model_dics.get(series_description, None)

# Function to make predictions on the test data
def predict_test_data(testloader, expanded_test_desc):
    predictions = []
    normal_mild_probs = []
    moderate_probs = []
    severe_probs = []
    
    for model in model_dics.values():
        model.eval()
        
    with torch.no_grad():
        for idx, images in enumerate(tqdm(testloader)):
            images = images.to(device)
            series_description = expanded_test_desc.iloc[idx]['series_description']
            model = get_model(series_description)
            if model:
                model.eval()  # Set the model to eval mode
                outputs = model(images)
                probs = torch.softmax(outputs, dim=1).squeeze(0)
                normal_mild_probs.append(probs[0].item())
                moderate_probs.append(probs[1].item())
                severe_probs.append(probs[2].item())
                predictions.append(probs)
            else:
                normal_mild_probs.append(None)
                moderate_probs.append(None)
                severe_probs.append(None)
                predictions.append(None)
    return normal_mild_probs, moderate_probs, severe_probs, predictions



# Make predictions on the test data
normal_mild_probs, moderate_probs, severe_probs, test_predictions = predict_test_data(testloader, expanded_test_desc)

# Add predictions and probabilities to the test DataFrame
expanded_test_desc['normal_mild'] = normal_mild_probs
expanded_test_desc['moderate'] = moderate_probs
expanded_test_desc['severe'] = severe_probs

submission = expanded_test_desc[["row_id","normal_mild","moderate","severe"]]

# Group by 'row_id' and sum the values
grouped_submission = submission.groupby('row_id').max().reset_index()

# Normalize the columns
grouped_submission[['normal_mild', 'moderate', 'severe']] = grouped_submission[['normal_mild', 'moderate', 'severe']].div(grouped_submission[['normal_mild', 'moderate', 'severe']].sum(axis=1), axis=0)

print(grouped_submission)

100%|██████████| 194/194 [00:08<00:00, 23.81it/s]

                                             row_id  normal_mild  moderate  \
0    44036939_left_neural_foraminal_narrowing_l1_l2     0.400475  0.378718   
1    44036939_left_neural_foraminal_narrowing_l2_l3     0.374593  0.281383   
2    44036939_left_neural_foraminal_narrowing_l3_l4     0.054124  0.740683   
3    44036939_left_neural_foraminal_narrowing_l4_l5     0.735425  0.097878   
4    44036939_left_neural_foraminal_narrowing_l5_s1     0.090160  0.577792   
5         44036939_left_subarticular_stenosis_l1_l2     0.503563  0.235323   
6         44036939_left_subarticular_stenosis_l2_l3     0.464731  0.331803   
7         44036939_left_subarticular_stenosis_l3_l4     0.189192  0.290436   
8         44036939_left_subarticular_stenosis_l4_l5     0.421516  0.159078   
9         44036939_left_subarticular_stenosis_l5_s1     0.637418  0.240354   
10  44036939_right_neural_foraminal_narrowing_l1_l2     0.400475  0.378718   
11  44036939_right_neural_foraminal_narrowing_l2_l3     0.374593


