# Overview

This notebook provides a pipeline for preprocessing, visualizing, and saving MRI slices from the MRNet dataset, focusing on the coronal plane. 

Key steps include:

- **Data Preprocessing**: MRI slices are normalized to the range [-1, 1] and resized to 256x256 pixels.
- **Dataset Construction**: A custom PyTorch Dataset class loads, processes, and extracts slices from the center of MRI volumes.
- **Slice Counting**: The total number of slices across tasks (ACL, abnormal, meniscus) is counted.
- **Saving Preprocessed Slices**: Processed slices are saved as PNG images for both training and validation sets, facilitating their use in model training.

This workflow standardizes MRI data, making it ready for machine learning tasks.


##### Importing Libraries

In [None]:
import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import torch
import torchio as tio
import matplotlib.pyplot as plt

##### Taking Count of Data

In [2]:

# Define preprocessing transformations
preprocessing_transforms = tio.Compose([
    tio.RescaleIntensity(out_min_max=(-1, 1)),  # Normalize intensity to [-1, 1]
    tio.Resize((256, 256, 1))  # Resize to 256x256 and add dimension for z axis
])

class MRNetSinglePlaneDataset(Dataset):
    def __init__(self, root_dir, task, plane='coronal', split='train', preprocessing_transforms=None):
        super().__init__()
        self.task = task
        self.plane = plane
        self.root_dir = root_dir
        self.split = split
        self.preprocessing_transforms = preprocessing_transforms

        # Load labels
        self.records = self._get_annotations()
        self.records['id'] = self._remap_id_to_match_folder_name()
        self.labels = self.records['label'].tolist()

        # Filter to include only positive cases for the task
        self.records = self.records[self.records['label'] == 1]
        self.paths = self._get_file_paths()

    def _get_file_paths(self):
        file_paths = []
        for filename in self.records['id'].tolist():
            plane_path = os.path.join(self.root_dir, self.split, self.plane, f'{filename}.npy')
            file_paths.append(plane_path)
        return file_paths

    def _remap_id_to_match_folder_name(self):
        return self.records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))

    def _get_annotations(self):
        csv_file = os.path.join(self.root_dir, f'{self.split}-{self.task}.csv')
        records = pd.read_csv(csv_file, header=None, names=['id', 'label'])
        return records

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        # Load image data and label
        plane_path = self.paths[index]
        volume = np.load(plane_path).astype('float32')

        # Calculate slice indices to get 5 slices from the center of the volume
        num_slices = volume.shape[0]
        center_slice = num_slices // 2
        slice_indices = np.linspace(center_slice - 2, center_slice + 2, 5, dtype=int)

        # Extract and preprocess slices
        slices = []
        for i in slice_indices:
            slice_data = volume[i]  # Get the ith slice
            slice_data = slice_data[None, :, :]  # Add channel dimension (1, H, W)
            slice_data = slice_data[..., None]  # Add z dimension (1, H, W, 1)
            slice_data = self.preprocessing_transforms(slice_data)  # Apply preprocessing
            slice_data = slice_data.squeeze(-1)  # Remove the z dimension after preprocessing
            slice_data = torch.tensor(slice_data)  # Convert to PyTorch tensor
            slices.append(slice_data)

        # Stack slices into a single tensor
        data = torch.stack(slices)

        label = self.labels[index]
        label = torch.FloatTensor([label])

        # Sample identifier
        id = plane_path.split(os.sep)[-1].split('.')[0]

        return {'data': data, 'label': label, 'id': id}

# Paths and parameters
root_dir = "MRnet-v1.0"
tasks = ['acl', 'abnormal', 'meniscus']  # List of tasks
plane = 'coronal'  # Focus on coronal plane

# Function to count total slices in a dataset
def count_total_slices(dataset):
    total_slices = 0
    for i in range(len(dataset)):
        sample = dataset[i]
        total_slices += sample['data'].shape[0]
    return total_slices

# Create and count slices for each task
total_train_slices = 0
total_valid_slices = 0

for task in tasks:
    # Create the dataset objects for train and valid splits
    train_dataset = MRNetSinglePlaneDataset(root_dir=root_dir, task=task, plane=plane, split='train', preprocessing_transforms=preprocessing_transforms)
    valid_dataset = MRNetSinglePlaneDataset(root_dir=root_dir, task=task, plane=plane, split='valid', preprocessing_transforms=preprocessing_transforms)

    # Count the total number of slices in the training and validation sets
    train_slices = count_total_slices(train_dataset)
    valid_slices = count_total_slices(valid_dataset)

    total_train_slices += train_slices
    total_valid_slices += valid_slices

    print(f"Task: {task}")
    print(f"Total number of slices in training set: {train_slices}")
    print(f"Total number of slices in validation set: {valid_slices}")

print(f"Total number of slices across all tasks in training set: {total_train_slices}")
print(f"Total number of slices across all tasks in validation set: {total_valid_slices}")


Task: acl
Total number of slices in training set: 1040
Total number of slices in validation set: 270
Task: abnormal
Total number of slices in training set: 4565
Total number of slices in validation set: 475
Task: meniscus
Total number of slices in training set: 1985
Total number of slices in validation set: 260
Total number of slices across all tasks in training set: 7590
Total number of slices across all tasks in validation set: 1005


##### Saving Slices

In [5]:

# Function to display images
def show_slices(slices):
    """ Function to display a row of image slices """
    fig, axes = plt.subplots(1, len(slices), figsize=(15, 15))
    for i, slice in enumerate(slices):
        axes[i].imshow(slice.numpy().squeeze(), cmap="gray")
        axes[i].axis('off')
    plt.show()

# Function to visualize images from the dataset
def visualize_sample(dataset, index=2):
    sample = dataset[index]
    data = sample['data']  # Get the 2D slices
    label = sample['label'].item()
    id = sample['id']

    # Print the label and ID
    print("Sample ID:", id)
    print("Sample label:", label)

    # Select slices to display
    slices = [data[i] for i in range(data.shape[0])]

    # Display the slices
    show_slices(slices)

# Paths and parameters
root_dir = "MRnet-v1.0"
tasks = ['acl', 'abnormal', 'meniscus']  # List of tasks
plane = 'coronal'  # Focus on coronal plane

# Save images to a directory
def save_slices(dataset, save_dir, task):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    for i in range(len(dataset)):
        sample = dataset[i]
        data = sample['data']
        id = sample['id']
        for j in range(data.shape[0]):
            slice_image = data[j].numpy().squeeze()
            plt.imsave(os.path.join(save_dir, f'{task}_{id}_slice_{j}.png'), slice_image, cmap='gray')

# Create and count slices for each task
total_train_slices = 0
total_valid_slices = 0

train_save_dir = os.path.join(root_dir, 'train_slices_raw')
valid_save_dir = os.path.join(root_dir, 'valid_slices_raw')

for task in tasks:
    # Create the dataset objects for train and valid splits
    train_dataset = MRNetSinglePlaneDataset(root_dir=root_dir, task=task, plane=plane, split='train', preprocessing_transforms=preprocessing_transforms)
    valid_dataset = MRNetSinglePlaneDataset(root_dir=root_dir, task=task, plane=plane, split='valid', preprocessing_transforms=preprocessing_transforms)

    # Count the total number of slices in the training and validation sets
    train_slices = 0
    valid_slices = 0

    for i in range(len(train_dataset)):
        train_slices += train_dataset[i]['data'].shape[0]

    for i in range(len(valid_dataset)):
        valid_slices += valid_dataset[i]['data'].shape[0]

    total_train_slices += train_slices
    total_valid_slices += valid_slices

    print(f"Task: {task}")
    print(f"Total number of slices in training set: {train_slices}")
    print(f"Total number of slices in validation set: {valid_slices}")

    # Save slices to their respective directories
    print(f"Saving {train_slices} slices for training set of task {task}...")
    save_slices(train_dataset, train_save_dir, task)
    print(f"Saving {valid_slices} slices for validation set of task {task}...")
    save_slices(valid_dataset, valid_save_dir, task)

print(f"Total number of slices across all tasks in training set: {total_train_slices}")
print(f"Total number of slices across all tasks in validation set: {total_valid_slices}")
print(f"Train slices saved to {train_save_dir}")
print(f"Valid slices saved to {valid_save_dir}")


Task: acl
Total number of slices in training set: 1040
Total number of slices in validation set: 270
Saving 1040 slices for training set of task acl...
Saving 270 slices for validation set of task acl...
Task: abnormal
Total number of slices in training set: 4565
Total number of slices in validation set: 475
Saving 4565 slices for training set of task abnormal...
Saving 475 slices for validation set of task abnormal...
Task: meniscus
Total number of slices in training set: 1985
Total number of slices in validation set: 260
Saving 1985 slices for training set of task meniscus...
Saving 260 slices for validation set of task meniscus...
Total number of slices across all tasks in training set: 7590
Total number of slices across all tasks in validation set: 1005
Train slices saved to C:\Users\ASUS\Documents\Uobd\project\Datasets\MRNet-v1.0\train_slices_raw
Valid slices saved to C:\Users\ASUS\Documents\Uobd\project\Datasets\MRNet-v1.0\valid_slices_raw
