# Preprocessing of MRNet and Data Augmentation

In [1]:
import os
import platform
from glob import glob
from scipy import ndimage
import SimpleITK as sitk
import numpy as np
import pandas as pd


In [None]:
mrnet_dataset_dir = 'Data/MRNet-v1.0'
mrnet_train_path = os.path.join(mrnet_dataset_dir, 'train')
mrnet_valid_path = os.path.join(mrnet_dataset_dir, 'valid')
mrnet_planes = ['axial', 'coronal', 'sagittal']

In [None]:
# For running code on Windows
if platform.system() == "Windows":
    mrnet_dataset_dir = mrnet_dataset_dir.replace('/', '\\')
    mrnet_train_path = mrnet_train_path.replace('/', '\\')
    mrnet_valid_path = mrnet_valid_path.replace('/', '\\')

In [None]:
mrnet_datasets = {'train': mrnet_train_path, 'valid': mrnet_valid_path}

In [None]:
mrnet_labels = ['abnormal', 'acl', 'meniscus']

In [None]:
# TRAIN DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

train_df = pd.merge(train_abnormal_df, train_acl_df, on='Case').merge(train_meniscus_df, on='Case')

In [None]:
# VALID DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

valid_df = pd.merge(valid_abnormal_df, valid_acl_df, on='Case').merge(valid_meniscus_df, on='Case')

In [None]:

def resize_3D_volume(vol, target_size=(30, 256, 256)):
    """
    Given a 3D volumteric array with shape (Z,X,Y). This function will resize
    the image across z-axis.
    The purpose of this function to standardise the depth of MRI image.

    Args:
        vol: 3D array with shape (Z,X,Y) that represents the volume of a MRI image
        target_size: target size to shape into the volumetric data

    Returns:
        np.ndarray: Returns the resized MRI volume
    """
    # Set the desired depth
    desired_depth, desired_width, desired_height = target_size
    # Get current depth
    current_depth = vol.shape[0]
    current_width = vol.shape[1]
    current_height = vol.shape[2]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Resize across z-axis
    resized_vol = ndimage.zoom(vol, (depth_factor, width_factor, height_factor), order=1)
    return resized_vol


In [None]:

def denoise_3D_volume(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume to denoise

    Returns:
        np.ndarray: Returns denoised MRI volume
    """
    vol_sitk = sitk.GetImageFromArray(vol)
    denoised_vol_sitk = sitk.CurvatureFlow(vol_sitk, timeStep=0.01, numberOfIterations=7)
    denoised_vol = sitk.GetArrayFromImage(denoised_vol_sitk)
    return denoised_vol

In [None]:


def efficient_bias_field_correction_volume(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume to perform efficient bias field correction

    Returns:
        np.ndarray: Returns bias field corrected MRI volume
    """
    # Ref: https://medium.com/@alexandro.ramr777/how-to-do-bias-field-correction-with-python-156b9d51dd79
    # Ref: https://simpleitk.readthedocs.io/en/master/link_N4BiasFieldCorrection_docs.html
    # Convert the NumPy array to SimpleITK image
    vol_sitk = sitk.GetImageFromArray(vol)

    vol_sitk = sitk.Cast(vol_sitk, sitk.sitkFloat64)

    vol_sitk_transformed = sitk.RescaleIntensity(vol_sitk, 0, 255)

    vol_sitk_transformed = sitk.LiThreshold(vol_sitk_transformed, 0, 1)

    head_mask = vol_sitk_transformed

    shrink_factor = 4

    input_img = vol_sitk

    input_img = sitk.Shrink(vol_sitk, [shrink_factor] * input_img.GetDimension())
    mask_img = sitk.Shrink(head_mask, [shrink_factor] * input_img.GetDimension())

    # Perform bias field correction using N4BiasFieldCorrection
    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()
    corrected = bias_corrector.Execute(input_img, mask_img)

    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(vol_sitk)

    log_bias_field = sitk.Cast(log_bias_field, sitk.sitkFloat64)

    corrected_image_full_resolution = vol_sitk / sitk.Exp(log_bias_field)

    # Get the NumPy array representation of the bias-corrected volume
    bias_corrected_vol = sitk.GetArrayFromImage(corrected_image_full_resolution)

    return bias_corrected_vol


In [None]:
def normalise_volume_pixels(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume

    Returns:
        np.ndarray: Normalised MRI volume
    """
    # Normalise the volume pixels to the range [0, 1]
    min_value = np.min(vol)
    max_value = np.max(vol)
    normalised_vol = (vol - min_value) / (max_value - min_value)

    return normalised_vol

In [None]:
def center_volume_pixels(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume

    Returns:
        np.ndarray: Zero centered MRI volume
    """
    # Calculate the mean value
    mean_value = np.mean(vol)

    # Center the data
    centered_vol = vol - mean_value

    return centered_vol

In [None]:

def standardise_volume_pixels(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume

    Returns:
        np.ndarray: Standardised MRI volume
    """
    # Calculate the mean and standard deviation
    mean_value = np.mean(vol)
    std_value = np.std(vol)

    # Standardise the data
    standardised_vol = (vol - mean_value) / std_value

    return standardised_vol


In [None]:
def preprocess_mri(mri_vol):
    """Summary

    Args:
        mri_vol (np.ndarray): MRI volume

    Returns:
        np.ndarray: Returns preprocessed MRI volume
    """
    mri_vol = resize_3D_volume(mri_vol)
    mri_vol = denoise_3D_volume(mri_vol)
    mri_vol = efficient_bias_field_correction_volume(mri_vol)
    mri_vol = normalise_volume_pixels(mri_vol)
    mri_vol = center_volume_pixels(mri_vol)
    mri_vol = standardise_volume_pixels(mri_vol)
    return mri_vol

In [None]:
def preprocess_mri_vols(cases, overwrite=False):
    """
    This function preprocesses all the MRI volumes in MRNet
    and stores them under 'Preprocessed_Data' directory.

    Args:
        cases (list): List of files in MRNet dataset
        overwrite (bool, optional): Option to overwrite already preprocessed MRI
    """
    cases.sort()
    for case in cases:
        mri_vol = np.load(case)
        mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

        case_path = os.path.normpath(case).split(os.sep)
        case_path[0] = 'Preprocessed_Data'
        preprocessed_case_path = os.path.join(*case_path)

        if overwrite or not os.path.exists(preprocessed_case_path):
            preprocessed_mri_vol = preprocess_mri(mri_vol)
            os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
            np.save(preprocessed_case_path, preprocessed_mri_vol)

In [None]:

def random_horizontal_flip(vol):
    """Summary

    Args:
        vol (np.ndarray): MRI volume

    Returns:
        np.ndarray: Returns horizontally flipped MRI volume
    """
    flipped_vol = np.flip(vol, axis=2)
    return flipped_vol


In [None]:

def random_rotation(vol, rotation_angles=[-2.0, -1.5, -1.0, 1.0, 1.5, 2.0, ]):
    """Summary

    Args:
        vol (np.ndarray): MRI volume
        rotation_angles (list, optional): List angles for random rotations

    Returns:
        np.ndarray: Returns randomly rotated MRI volume
    """
    rotation_angle = np.random.choice(rotation_angles)
    # print(f"Rotation by {rotation_angle} degrees.")
    rotated_vol = ndimage.rotate(vol, rotation_angle, reshape=False, mode='nearest')
    # print(rotated_vol.shape)
    return rotated_vol

In [None]:
def augment_mri_vols(dataset, labels, aug_flip_prob=0.95, overwrite=False):
    """
    This function augments MRI volumes in MRNet dataset to create more samples
    for labels that have lower number of cases.

    Args:
        dataset (str): Path to either train or valid MRNet dataset
        labels (Pandas dataframe): Labels dataframe for the exams
        aug_flip_prob (float, optional): Augmentation flip probability
        overwrite (bool, optional): Option to overwrite already preprocessed MRI
    """
    aug_labels_list = []
    plane = 'sagittal'
    if platform.system() == "Windows":
        cases = glob(f"{dataset}\\{plane}\\*.npy")
    else:
        cases = glob(f"{dataset}/{plane}/*.npy")
    cases.sort()
    for case in cases:
        # We will create a new path file for augmented images by adding '_aug' in file names
        # and we store them under the folder <plane>/aug

        case_path = os.path.normpath(case).split(os.sep)
        file_name = case_path[-1]

        orig_sagittal = os.path.join(*case_path)

        case_path[0] = 'Preprocessed_Data'
        case_path.insert(-1, 'aug')

        # SAGITTAL
        sa_temp = file_name
        dot_index = sa_temp.index('.')

        # Do this only once as the label of augmented MRIs will be the same for all three planes and tasks
        temp_aug_labels = labels.loc[labels['Case'] == sa_temp[:dot_index]][['Abnormal', 'ACL', 'Meniscus']].values.tolist()[0]

        # If acl_diagnosis is 1, only 5% chance of augmentation as majority samples are without tear
        # Increase probability of augmentation in case of ACL tears
        if np.random.rand() >= aug_flip_prob or temp_aug_labels[1] == 1:

            case_path[-1] = f"{sa_temp[:dot_index]}-aug-0{sa_temp[dot_index:]}"
            aug_sagittal = os.path.join(*case_path)

            if temp_aug_labels[1] == 0:
                if overwrite or not os.path.exists(aug_sagittal):
                    mri_vol = np.load(orig_sagittal)
                    mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

                    aug_mri_vol = random_horizontal_flip(mri_vol)
                    aug_mri_vol = random_rotation(aug_mri_vol)

                    preprocessed_aug_mri_vol = preprocess_mri(aug_mri_vol)
                    os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
                    np.save(aug_sagittal, preprocessed_aug_mri_vol)
                    aug_labels_list.append([f"{sa_temp[:dot_index]}-aug-0"] + temp_aug_labels)

            elif temp_aug_labels[1] == 1:
                for aug_ind in range(3):  # We will augment sample three times
                    if aug_ind >= 1:
                        case_path[-1] = f"{sa_temp[:dot_index]}-aug-{aug_ind}{sa_temp[dot_index:]}"
                        aug_sagittal = os.path.join(*case_path)

                    if overwrite or not os.path.exists(aug_sagittal):
                        mri_vol = np.load(orig_sagittal)
                        mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

                        if aug_ind == 0:
                            aug_mri_vol = random_horizontal_flip(mri_vol)
                        elif aug_ind == 1:
                            aug_mri_vol = random_rotation(mri_vol)
                        elif aug_ind == 2:
                            aug_mri_vol = random_horizontal_flip(mri_vol)
                            aug_mri_vol = random_rotation(aug_mri_vol)
                        preprocessed_aug_mri_vol = preprocess_mri(aug_mri_vol)
                        os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
                        np.save(aug_sagittal, preprocessed_aug_mri_vol)
                        aug_labels_list.append([f"{sa_temp[:dot_index]}-aug-{aug_ind}"] + temp_aug_labels)

    aug_train_df = pd.DataFrame(aug_labels_list, columns=labels.columns)
    # print(aug_train_df)
    csv_file_path = os.path.normpath(dataset).split(os.sep)
    if csv_file_path[-1] == 'train':
        if platform.system() == "Windows":
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "\\train-aug.csv")
        else:
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "/train-aug.csv")
    elif csv_file_path[-1] == 'valid':
        if platform.system() == "Windows":
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "\\valid-aug.csv")
        else:
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "/valid-aug.csv")
    print(f"For {dataset.upper()} datset we have {len(aug_labels_list)} augmented samples.")

In [None]:
def preprocess_mri_vols_for_plane(dataset, plane):
    """
    This function calls preprocessing on given dataset of MRNet
    and plane.

    Args:
        dataset (str): Path to either train or valid MRNet dataset
        plane (str): MRNet dataset plane axial, coronal or sagittal
    """
    if platform.system() == "Windows":
        cases = glob(f"{dataset}\\{plane}\\*.npy")
    else:
        cases = glob(f"{dataset}/{plane}/*.npy")
    preprocess_mri_vols(cases)
    print(f"For {dataset.upper()} {plane} plane we have {len(cases)} samples.")

In [None]:
# Preprocess only sagittal plane
preprocess_mri_vols_for_plane(mrnet_datasets['train'], 'sagittal')

In [None]:
# Preprocess only sagittal plane
preprocess_mri_vols_for_plane(mrnet_datasets['valid'], 'sagittal')

In [None]:
augment_mri_vols(mrnet_datasets['train'], train_df)

In [None]:
augment_mri_vols(mrnet_datasets['valid'], valid_df)

## MRNet Exploratory Data Analysis

In [None]:
import os
import platform
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
data_dir = 'Data/MRNet-v1.0'
train_data_path = os.path.join(data_dir, 'train')
valid_data_path = os.path.join(data_dir, 'valid')

In [None]:
for dataset, path in datasets.items():
    print(f"\nTotal exams in {dataset.upper()}")
    for plane in planes:
        print(f"{plane:8} plane : {len(glob(f'{os.path.join(path, plane)}/*.npy'))}")

In [None]:
def get_slices_per_exam(exams):
    """
    This function gets number of slices found in MRIs
    of the MRNet dataset.

    Args:
        exams (list): List of files in MRNet dataset

    Returns:
        NumPy array: Array of slices per exam
    """
    num_slices_per_exam = []
    for exam in exams:
        mri_vol = np.load(exam)
        num_slices_per_exam.append(mri_vol.shape[0])
    return np.asarray(num_slices_per_exam)

In [None]:
def plot_slices_per_exam(dataset):
    """
    This function plots the distibution of slices found in MRIs
    of the MRNet dataset.

    Args:
        dataset (str): Path to either train or valid MRNet dataset
    """
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    for i, plane in enumerate(planes):
        num_slices = get_slices_per_exam(glob(f"{dataset}/{plane}/*.npy"))
        print(f"For {dataset.upper()} {plane} plane min : {num_slices.min()}, max : {num_slices.max()}, avg : {num_slices.mean()}")
        sns.histplot(num_slices, stat='density', ax=axes[i], kde=True)
        axes[i].set_title(f"{plane.title()} Plane")

In [None]:
# TRAIN DATASET
plot_slices_per_exam(datasets['train'])

In [None]:
# VALID DATASET
plot_slices_per_exam(datasets['valid'])

In [None]:
# TRAIN DATASET
label_categories = ['abnormal', 'acl', 'meniscus']

In [None]:
for label in label_categories:
    if label == 'abnormal':
        train_abnormal_df = pd.read_csv(f"{data_dir}/train-{label}.csv",
                                        header=None,
                                        names=['Case', 'Abnormal'],
                                        dtype={'Case': str, 'Abnormal': np.int64})
    elif label == 'acl':
        train_acl_df = pd.read_csv(f"{data_dir}/train-{label}.csv",
                                   header=None,
                                   names=['Case', 'ACL'],
                                   dtype={'Case': str, 'ACL': np.int64})
    if label == 'meniscus':
        train_meniscus_df = pd.read_csv(f"{data_dir}/train-{label}.csv",
                                        header=None,
                                        names=['Case', 'Meniscus'],
                                        dtype={'Case': str, 'Meniscus': np.int64})

In [None]:
train_abnormal_df['Abnormal'].value_counts()

In [None]:
train_acl_df['ACL'].value_counts()

In [None]:
train_meniscus_df['Meniscus'].value_counts()

In [None]:
train_df = pd.merge(train_abnormal_df, train_acl_df, on='Case').merge(train_meniscus_df, on='Case')

In [None]:
train_df

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5), dpi=80)
fig.suptitle('MRNet Train : Total Samples in each Class')

# First graph
sns.countplot(data = train_df, x='Abnormal', ax=ax[0])
ax[0].bar_label(ax[0].containers[0])
ax[0].set_xlabel('Abnormal Class')
ax[0].set_ylabel('Count of Samples')

# Second graph
sns.countplot(data = train_df, x='ACL', ax=ax[1])
ax[1].bar_label(ax[1].containers[0])
ax[1].set_xlabel('ACL Class')
ax[1].set_ylabel('Count of Samples')

# Third graph
sns.countplot(data = train_df, x='Meniscus', ax=ax[2])
ax[2].bar_label(ax[2].containers[0])
ax[2].set_xlabel('Meniscus Class')
ax[2].set_ylabel('Count of Samples')

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=80)
ax.pie(x=train_df['ACL'].value_counts(), 
       labels=train_df['ACL'].value_counts().index,
       autopct='%.1f%%')
ax.set_title('MRNet Train : Pie Chart of ACL Class Imbalance')
plt.show()

In [None]:
# VALID DATASET
for label in label_categories:
    if label == 'abnormal':
        valid_abnormal_df = pd.read_csv(f"{data_dir}/valid-{label}.csv",
                                        header=None,
                                        names=['Case', 'Abnormal'],
                                        dtype={'Case': str, 'Abnormal': np.int64})
    elif label == 'acl':
        valid_acl_df = pd.read_csv(f"{data_dir}/valid-{label}.csv",
                                   header=None,
                                   names=['Case', 'ACL'],
                                   dtype={'Case': str, 'ACL': np.int64})
    if label == 'meniscus':
        valid_meniscus_df = pd.read_csv(f"{data_dir}/valid-{label}.csv",
                                        header=None,
                                        names=['Case', 'Meniscus'],
                                        dtype={'Case': str, 'Meniscus': np.int64})

In [None]:
valid_abnormal_df['Abnormal'].value_counts()

In [None]:
valid_acl_df['ACL'].value_counts()

In [None]:
valid_meniscus_df['Meniscus'].value_counts()

In [None]:
valid_df = pd.merge(valid_abnormal_df, valid_acl_df, on='Case').merge(valid_meniscus_df, on='Case')

In [None]:
valid_df

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5), dpi=80)
fig.suptitle('MRNet Valid : Total Samples in each Class')

# First graph
sns.countplot(data = valid_df, x='Abnormal', ax=ax[0])
ax[0].bar_label(ax[0].containers[0])
ax[0].set_xlabel('Abnormal Class')
ax[0].set_ylabel('Count of Samples')

# Second graph
sns.countplot(data = valid_df, x='ACL', ax=ax[1])
ax[1].bar_label(ax[1].containers[0])
ax[1].set_xlabel('ACL Class')
ax[1].set_ylabel('Count of Samples')

# Third graph
sns.countplot(data = valid_df, x='Meniscus', ax=ax[2])
ax[2].bar_label(ax[2].containers[0])
ax[2].set_xlabel('Meniscus Class')
ax[2].set_ylabel('Count of Samples')

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=80)
ax.pie(x=valid_df['ACL'].value_counts(), 
       labels=valid_df['ACL'].value_counts().index,
       autopct='%.1f%%')
ax.set_title('MRNet Valid : Pie Chart of ACL Class Imbalance')
plt.show()

In [None]:
full_df

### Combining both the datasets Train and Valid

In [None]:
full_df = pd.concat([train_df, valid_df], ignore_index=True)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5), dpi=80)
fig.suptitle('MRNet \n\n Total Samples in each Class', fontsize=14)

# First graph
sns.countplot(data = full_df, x='Abnormal', ax=ax[0])
ax[0].bar_label(ax[0].containers[0])
ax[0].set_xlabel('Abnormal Class')
ax[0].set_ylabel('Count of Samples')

# Second graph
sns.countplot(data = full_df, x='ACL', ax=ax[1])
ax[1].bar_label(ax[1].containers[0])
ax[1].set_xlabel('ACL Class')
ax[1].set_ylabel('Count of Samples')

# Third graph
sns.countplot(data = full_df, x='Meniscus', ax=ax[2])
ax[2].bar_label(ax[2].containers[0])
ax[2].set_xlabel('Meniscus Class')
ax[2].set_ylabel('Count of Samples')

fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=80)
fig.suptitle('MRNet', fontsize=14)
ax.pie(x=full_df['ACL'].value_counts(), 
       labels=full_df['ACL'].value_counts().index,
       autopct='%.1f%%')
ax.set_title('Pie Chart of ACL Class Imbalance', pad=10)
plt.show()

In [None]:
# Assuming you have a Pandas DataFrame called train_df with the columns Abnormal, ACL, Meniscus, etc.

# Group by the three columns and calculate the total cases for each combination
grouped = train_df.groupby(['Abnormal', 'ACL', 'Meniscus']).size().reset_index(name='Total Cases')

# Calculate the total number of cases in the dataset
total_cases = grouped['Total Cases'].sum()

# Calculate the percentage of total cases for each combination
grouped['Percentage'] = np.round((grouped['Total Cases'] / total_cases) * 100, 2)

# Display the grouped DataFrame
print(grouped)

In [None]:
train_occurence_df = train_df.groupby(['Abnormal', 'ACL', 'Meniscus']).count()

In [None]:
train_occurence_df['Percent'] = np.round((train_occurence_df['Case']/train_occurence_df['Case'].sum())*100, 2)

In [None]:
train_occurence_df

In [None]:
# Assuming you have a Pandas DataFrame called train_df with the columns Abnormal, ACL, Meniscus, etc.

# Group by the three columns and calculate the total cases for each combination
grouped = valid_df.groupby(['Abnormal', 'ACL', 'Meniscus']).size().reset_index(name='Total Cases')

# Calculate the total number of cases in the dataset
total_cases = grouped['Total Cases'].sum()

# Calculate the percentage of total cases for each combination
grouped['Percentage'] = np.round((grouped['Total Cases'] / total_cases) * 100, 2)

# Display the grouped DataFrame
print(grouped)

In [None]:
valid_occurence_df = valid_df.groupby(['Abnormal', 'ACL', 'Meniscus']).count()

In [None]:
valid_occurence_df['Percent'] = np.round((valid_occurence_df['Case'] / valid_occurence_df['Case'].sum()) * 100, 2)

In [None]:
valid_occurence_df

### Combining both the datasets Train and Valid

In [None]:
# Assuming you have a Pandas DataFrame called train_df with the columns Abnormal, ACL, Meniscus, etc.

# Group by the three columns and calculate the total cases for each combination
grouped = full_df.groupby(['Abnormal', 'ACL', 'Meniscus']).size().reset_index(name='Total Cases')

# Calculate the total number of cases in the dataset
total_cases = grouped['Total Cases'].sum()

# Calculate the percentage of total cases for each combination
grouped['Percentage'] = np.round((grouped['Total Cases'] / total_cases) * 100, 2)

# Display the grouped DataFrame
print(grouped)

## After Pre-processing and Data Augmentation

In [None]:
mrnet_dataset_dir = 'Data/MRNet-v1.0'
mrnet_train_path = os.path.join(mrnet_dataset_dir, 'train')
mrnet_valid_path = os.path.join(mrnet_dataset_dir, 'valid')

preprocessed_mrnet_dataset_dir = 'Preprocessed_Data/MRNet-v1.0'
preprocessed_mrnet_train_path = os.path.join(preprocessed_mrnet_dataset_dir, 'train')
preprocessed_mrnet_valid_path = os.path.join(preprocessed_mrnet_dataset_dir, 'valid')

mrnet_planes = ['axial', 'coronal', 'sagittal']

# For running code on Windows
if platform.system() == "Windows":
    mrnet_dataset_dir = mrnet_dataset_dir.replace('/', '\\')
    mrnet_train_path = mrnet_train_path.replace('/', '\\')
    mrnet_valid_path = mrnet_valid_path.replace('/', '\\')
    
    preprocessed_mrnet_dataset_dir = preprocessed_mrnet_dataset_dir.replace('/', '\\')
    preprocessed_mrnet_train_path = preprocessed_mrnet_train_path.replace('/', '\\')
    preprocessed_mrnet_valid_path = preprocessed_mrnet_valid_path.replace('/', '\\')

In [None]:
mrnet_datasets = {'train': mrnet_train_path, 'valid': mrnet_valid_path}

In [None]:
mrnet_labels = ['abnormal', 'acl', 'meniscus']

In [None]:
# TRAIN DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

train_df = pd.merge(train_abnormal_df, train_acl_df, on='Case').merge(train_meniscus_df, on='Case')

In [None]:
# VALID DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

valid_df = pd.merge(valid_abnormal_df, valid_acl_df, on='Case').merge(valid_meniscus_df, on='Case')

In [None]:
# AUGMENTED TRAIN LABELS
if platform.system() == "Windows":
    train_aug_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-aug.csv",
                               index_col=0,
                               dtype={'Case': str, 'Abnormal': np.int64, 'ACL': np.int64, 'Meniscus': np.int64})
else:
    train_aug_df = pd.read_csv(f"{mrnet_dataset_dir}/train-aug.csv",
                               index_col=0,
                               dtype={'Case': str, 'Abnormal': np.int64, 'ACL': np.int64, 'Meniscus': np.int64})

In [None]:
# AUGMENTED VALID LABELS
if platform.system() == "Windows":
    valid_aug_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-aug.csv",
                               index_col=0,
                               dtype={'Case': str, 'Abnormal': np.int64, 'ACL': np.int64, 'Meniscus': np.int64})
else:
    valid_aug_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-aug.csv",
                               index_col=0,
                               dtype={'Case': str, 'Abnormal': np.int64, 'ACL': np.int64, 'Meniscus': np.int64})

In [None]:
full_mrnet_df = pd.concat([train_df, valid_df, train_aug_df, valid_aug_df], ignore_index=True)

In [None]:
len(full_mrnet_df)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5), dpi=80)
fig.suptitle('MRNet \n\n Total Samples in each Class', fontsize=14)

# First graph
sns.countplot(data = full_mrnet_df, x='Abnormal', ax=ax[0])
ax[0].bar_label(ax[0].containers[0])
ax[0].set_xlabel('Abnormal Class')
ax[0].set_ylabel('Count of Samples')

# Second graph
sns.countplot(data = full_mrnet_df, x='ACL', ax=ax[1])
ax[1].bar_label(ax[1].containers[0])
ax[1].set_xlabel('ACL Class')
ax[1].set_ylabel('Count of Samples')

# Third graph
sns.countplot(data = full_mrnet_df, x='Meniscus', ax=ax[2])
ax[2].bar_label(ax[2].containers[0])
ax[2].set_xlabel('Meniscus Class')
ax[2].set_ylabel('Count of Samples')

fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=80)
fig.suptitle('MRNet', fontsize=14)
ax.pie(x=full_mrnet_df['ACL'].value_counts(), 
       labels=full_mrnet_df['ACL'].value_counts().index,
       autopct='%.1f%%')
ax.set_title('Pie Chart of reduced ACL Class Imbalance', pad=10)
plt.show()

# **Training Model**

In [None]:
import os
import platform
import pandas as pd
import numpy as np
from glob import glob

In [None]:
from tensorflow import keras

In [None]:
from sklearn.model_selection import train_test_split

## MRNet Sagittal Plane

In [None]:
mrnet_dataset_dir = 'Data/MRNet-v1.0'
mrnet_train_path = os.path.join(mrnet_dataset_dir, 'train')
mrnet_valid_path = os.path.join(mrnet_dataset_dir, 'valid')

mrnet_preprocessed_dataset_dir = 'Preprocessed_Data/MRNet-v1.0'
mrnet_preprocessed_train_path = os.path.join(mrnet_preprocessed_dataset_dir, 'train')
mrnet_preprocessed_valid_path = os.path.join(mrnet_preprocessed_dataset_dir, 'valid')

mrnet_planes = ['axial', 'coronal', 'sagittal']

In [None]:
# For running code on Windows
if platform.system() == "Windows":
    mrnet_dataset_dir = mrnet_dataset_dir.replace('/', '\\')
    mrnet_train_path = mrnet_train_path.replace('/', '\\')
    mrnet_valid_path = mrnet_valid_path.replace('/', '\\')
    
    mrnet_preprocessed_dataset_dir = mrnet_preprocessed_dataset_dir.replace('/', '\\')
    mrnet_preprocessed_train_path = mrnet_preprocessed_train_path.replace('/', '\\')
    mrnet_preprocessed_valid_path = mrnet_preprocessed_valid_path.replace('/', '\\')

In [None]:
mrnet_datasets = { 'train' : mrnet_train_path, 'valid' : mrnet_valid_path}
mrnet_classes = ['abnormal', 'acl', 'meniscus']

In [None]:
# TRAIN DATASET
for label in mrnet_classes:
    if platform.system() == "Windows":
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case','Abnormal'],
                                            dtype={'Case':str, 'Abnormal':np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case','ACL'],
                                            dtype={'Case':str, 'ACL':np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case','Meniscus'],
                                            dtype={'Case':str, 'Meniscus':np.int64})
    else:
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case','Abnormal'],
                                            dtype={'Case':str, 'Abnormal':np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case','ACL'],
                                            dtype={'Case':str, 'ACL':np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case','Meniscus'],
                                            dtype={'Case':str, 'Meniscus':np.int64})
            
mrnet_train_df = pd.merge(train_abnormal_df, train_acl_df, on='Case').merge(train_meniscus_df, on='Case')

In [None]:
mrnet_train_df

In [None]:
# VALID DATASET
for label in mrnet_classes:
    if platform.system() == "Windows":
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case','Abnormal'],
                                            dtype={'Case':str, 'Abnormal':np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case','ACL'],
                                            dtype={'Case':str, 'ACL':np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case','Meniscus'],
                                            dtype={'Case':str, 'Meniscus':np.int64})
    else:
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case','Abnormal'],
                                            dtype={'Case':str, 'Abnormal':np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case','ACL'],
                                            dtype={'Case':str, 'ACL':np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case','Meniscus'],
                                            dtype={'Case':str, 'Meniscus':np.int64})

mrnet_valid_df = pd.merge(valid_abnormal_df, valid_acl_df, on='Case').merge(valid_meniscus_df, on='Case')

In [None]:
mrnet_valid_df

In [None]:
# AUGMENTED TRAIN LABELS
if platform.system() == "Windows":
    mrnet_train_aug_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-aug.csv",
                                     index_col=0,
                                     dtype={'Case':str, 'Abnormal':np.int64, 'ACL':np.int64, 'Meniscus':np.int64})
else:
    mrnet_train_aug_df = pd.read_csv(f"{mrnet_dataset_dir}/train-aug.csv",
                                     index_col=0,
                                     dtype={'Case':str, 'Abnormal':np.int64, 'ACL':np.int64, 'Meniscus':np.int64})

In [None]:
mrnet_train_aug_df

In [None]:
# AUGMENTED VALID LABELS
if platform.system() == "Windows":
    mrnet_valid_aug_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-aug.csv",
                                     index_col=0,
                                     dtype={'Case':str, 'Abnormal':np.int64, 'ACL':np.int64, 'Meniscus':np.int64})
else:
    mrnet_valid_aug_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-aug.csv",
                                     index_col=0,
                                     dtype={'Case':str, 'Abnormal':np.int64, 'ACL':np.int64, 'Meniscus':np.int64})

In [None]:
mrnet_valid_aug_df

In [None]:
# We are working only with Sagittal plane

# TRAIN
if platform.system() == "Windows":
    mrnet_sagittal_train_files = glob(mrnet_preprocessed_train_path+"\\sagittal\\*.npy")
else:
    mrnet_sagittal_train_files = glob(mrnet_preprocessed_train_path+"/sagittal/*.npy")
mrnet_sagittal_train_files.sort()

# VALID
if platform.system() == "Windows":
    mrnet_sagittal_valid_files = glob(mrnet_preprocessed_valid_path+"\\sagittal\\*.npy")
else:
    mrnet_sagittal_valid_files = glob(mrnet_preprocessed_valid_path+"/sagittal/*.npy")
mrnet_sagittal_valid_files.sort()

# AUGMENTED TRAIN
if platform.system() == "Windows":
    mrnet_sagittal_train_aug_files = glob(mrnet_preprocessed_train_path+"\\sagittal\\aug\\*.npy")
else:
    mrnet_sagittal_train_aug_files = glob(mrnet_preprocessed_train_path+"/sagittal/aug/*.npy")
mrnet_sagittal_train_aug_files.sort()

# AUGMENTED VALID
if platform.system() == "Windows":
    mrnet_sagittal_valid_aug_files = glob(mrnet_preprocessed_valid_path+"\\sagittal\\aug\\*.npy")
else:
    mrnet_sagittal_valid_aug_files = glob(mrnet_preprocessed_valid_path+"/sagittal/aug/*.npy")
mrnet_sagittal_valid_aug_files.sort()

In [None]:
print(len(mrnet_sagittal_train_files))
print(len(mrnet_sagittal_valid_files))
print(len(mrnet_sagittal_train_aug_files))
print(len(mrnet_sagittal_valid_aug_files))

In [None]:
mrnet_filenames = []
mrnet_filenames.extend(mrnet_sagittal_train_files)
mrnet_filenames.extend(mrnet_sagittal_valid_files)
mrnet_filenames.extend(mrnet_sagittal_train_aug_files)
mrnet_filenames.extend(mrnet_sagittal_valid_aug_files)
mrnet_filenames.sort()

In [None]:
len(mrnet_filenames)

In [None]:
len(mrnet_train_df)+len(mrnet_valid_df)+len(mrnet_train_aug_df)+len(mrnet_valid_aug_df)

In [None]:
mrnet_full_df = pd.concat([mrnet_train_df, mrnet_valid_df, mrnet_train_aug_df, mrnet_valid_aug_df], ignore_index=True)

In [None]:
mrnet_full_df

In [None]:


def get_correct_labels_mrnet(filenames, labels_dataframe):
    """Summary

    Args:
        filenames (list): List of filenames of the MRI scans
        labels_dataframe (pd.Dataframe): Dataframe with all MRNet cases and labels

    Returns:
        list: List of corresponding labels for given MRNet MRI filenames
    """
    labels = []
    for file in filenames:
        name = os.path.normpath(file).split(os.sep)[-1]
        case_name = name.split('.')[0]
        label = labels_dataframe.loc[labels_dataframe['Case'] == case_name, 'ACL'].tolist()[0]
        labels.append(label)
    return labels

In [None]:
mrnet_labels = get_correct_labels_mrnet(mrnet_filenames, mrnet_full_df)

In [None]:
mrnet_filenames[:5]

In [None]:
mrnet_labels[:5]

In [None]:
# Quick check of counts of samples for each case
[[x, mrnet_labels.count(x)] for x in set(mrnet_labels)]

## Prior to training

In [None]:
BATCH_SIZE = 8
EPOCHS = 100

In [None]:
# Splitting into train, test and validation

X, X_test, y, y_test = train_test_split(mrnet_filenames, 
                                        mrnet_labels, 
                                        test_size=0.1, 
                                        random_state=610, 
                                        shuffle=True, 
                                        stratify=mrnet_labels)

X_train, X_valid, y_train, y_valid = train_test_split(X,
                                                      y,
                                                      train_size=0.7, 
                                                      random_state=610, 
                                                      shuffle=True, 
                                                      stratify=y)

In [None]:
[[x, y_train.count(x)] for x in set(y_train)]

In [None]:
[[x, y_valid.count(x)] for x in set(y_valid)]

In [None]:
[[x, y_test.count(x)] for x in set(y_test)]

In [None]:

def compute_class_weights(y_train):
    """Summary

    Args:
        y_train (list): List of labels

    Returns:
        dict: A dictionary of labels and their corresponding class weights
    """
    class_weights = dict(zip(np.unique(y_train),
                             class_weight.compute_class_weight(class_weight='balanced',
                                                               classes=np.unique(y_train),
                                                               y=y_train)))
    return class_weights

In [None]:
mrnet_class_weights = compute_class_weights(y_train)

In [None]:
mrnet_class_weights

## MRNet Model 

In [None]:
model_name = 'MRNet_Model'
MRNet_Model3 = models.mri_model_3(model_name, 2)
MRNet_Model3.compile(optimizer=keras.optimizers.Adam(learning_rate=utils.model_lr_schedule()),
                     loss='binary_crossentropy', 
                     metrics=['accuracy'])
MRNet_Model3.summary()

In [None]:


def batch_generator(filenames, labels, batch_size):
    '''
    This function loads the respective filenames and labels in the memory 
    based on the parameter batch size. It helps to control the amount of
    RAM being consumed as the datasets are large.

    Args:
        filenames (list): List of file paths to the MRI
        labels (list): List of corresponding labels of the MRI
        batch_size (int): Batch size

    Yields:
        tuple: Tuple of list of loaded MRI files and corresponding labels
    '''
    N = len(filenames)
    i = 0
    random_state_counter = 610
    filenames, labels = shuffle(filenames, labels, random_state=random_state_counter + 69)  # Shuffle at the start
    while True:
        batch_images = []
        batch_filenames = filenames[i:i + batch_size]
        for file in batch_filenames:
            mri_vol = np.load(file)
            mri_vol = np.expand_dims(mri_vol, axis=3)  # Adding extra axis for making it compatible for 3D Convolutions
            batch_images.append(mri_vol)
        batch_labels = labels[i:i + batch_size]
        batch_images = np.array(batch_images)
        batch_labels = np.array(batch_labels)
        yield (batch_images, batch_labels)
        i = i + batch_size
        if i + batch_size > N:
            i = 0
            random_state_counter += 1
            filenames, labels = shuffle(filenames, labels, random_state=random_state_counter + 69)  # Shuffle at the end of each epoch



In [None]:

def model_callback_checkpoint(model_name, model_store_path='Models'):
    """Summary

    Args:
        model_name (str): Name of the model
        model_store_path (str, optional): Path to store the models

    Returns:
        TYPE: Keras checkpoint callback to store the best model
    """
    file_name = f"{model_store_path}/{model_name}/{model_name}.h5"

    # For running code on Windows
    if platform.system() == "Windows":
        file_name = file_name.replace('/', '\\')

    checkpoint_callback = keras.callbacks.ModelCheckpoint(file_name,
                                                          save_best_only=True)
    return checkpoint_callback


def model_callback_earlystopping():
    """Summary

    Returns:
        TYPE: Keras earlystopping callback for monitoring Validation Loss
    """
    earlystopping_callback = keras.callbacks.EarlyStopping(monitor="val_loss",
                                                           patience=10,
                                                           verbose=1,
                                                           restore_best_weights=True)
    return earlystopping_callback


In [None]:
%%time
with tf.device('/device:GPU:0'):
    history = MRNet_Model3.fit(batch_generator(X_train, y_train, BATCH_SIZE),
                               steps_per_epoch=len(X_train)//BATCH_SIZE,
                               epochs=EPOCHS,
                               validation_data=batch_generator(X_valid, y_valid, BATCH_SIZE),
                               validation_steps=len(X_valid)//BATCH_SIZE, 
                               shuffle=True,
                               class_weight=mrnet_class_weights,
                               verbose=1,
                               callbacks=[utils.model_callback_checkpoint(model_name), utils.model_callback_earlystopping()])