In [1]:
from glob import glob
from natsort import natsorted
import os
import pandas as pd
import re
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, f1_score, classification_report

from tqdm.notebook import tqdm
random_state=42

pd.set_option('display.max_rows', 5)

In [2]:
# set path to training set features
train_set_path = os.path.join("features", "extracted_from_GIF_output")

# set paths to test sets 
# for each of the test sets, training and inference will be performed
# if there is an overlap between a test set and the training set, 
# the corresponding cases will be removed from the training set
test_set_paths = [os.path.join("features", "extracted_from_nnUNet_prediction", "T1"),
                 os.path.join("features", "extracted_from_nnUNet_prediction", "T2"),
                 os.path.join("features", "extracted_from_nnUNet_prediction", "T1T2")]

In [3]:
# function to process the training and test sets
def process_set(csv_dfs, train_feature_df_mean=None):
    # get pairs of right/left labels (check if "Right" or "Left" in label name, and find the corresponding label)
    LRpairs = []
    for caseID in list(csv_dfs.keys())[0:]:
        curr_csv_df = csv_dfs[caseID]
        label_names = curr_csv_df["name"].values
        label_numbers = curr_csv_df["labelnumber"].values

        for name, n in zip(label_names, label_numbers):
            if "Right" in name:
                corresponding_left_name = name.replace("Right", "Left")
                try:
                    corresponding_left_labelnumber = curr_csv_df[curr_csv_df["name"]==corresponding_left_name]["labelnumber"].values[0]
                    LRpairs.append([n, corresponding_left_labelnumber])
                except:
                    print(f"Error: No corresponding left parcel found for caseID {caseID}, labelnumber {n}, name {name}")

            elif "Left" in name:
                corresponding_right_name = name.replace("Left", "Right")
                try:
                    corresponding_right_labelnumber = curr_csv_df[curr_csv_df["name"]==corresponding_right_name]["labelnumber"].values[0]
                    LRpairs.append([corresponding_right_labelnumber, n])
                except:
                    print(f"Error: No corresponding right parcel found for caseID {caseID}, labelnumber {n}, name {name}")

    LRpairs = sorted([y for y in set([tuple(x) for x in LRpairs])])
    RtoLdict = {r:l for (r,l) in LRpairs}
    LtoRdict = {l:r for (r,l) in LRpairs}

    # relabel left/right structures with ipsi/contra lateral labels
    # ipsilateral labels will be labelled with the up-to-here "right" label, contralateral with the up-to-here "left" label
    
    # load left-right list (required to convert right/left to ipsi/contra-lateral labels)
    LR_df = pd.read_csv(os.path.join("..", "TumorLR.csv"))
    csv_dfs_ipscon = {}

    for caseID in tqdm(list(csv_dfs.keys())[0:]):
        curr_csv_df = csv_dfs[caseID]
        curr_csv_df = curr_csv_df.set_index("labelnumber")
        label_names = curr_csv_df["name"].values
        label_numbers = curr_csv_df.index.values

        tumorside = LR_df[LR_df["patient_ID"] == caseID]["TumorLR"].values[0]
        ipscon_dict = {}
        for i, [ln, n] in enumerate(zip(label_names, label_numbers)):
            if "Left" in ln and tumorside=="r":
                new_name = ln.replace("Left", "Contralateral")
                new_labelnumber = n
                pass # don't do anything, since left labels are contralateral already
            elif "Right" in ln and tumorside=="r":
                new_name = ln.replace("Right", "Ipsilateral")
                new_labelnumber = n
                pass # don't do anything, since right labels are ipsilateral already
            elif "Left" in ln and tumorside=="l":
                # label should switch to the old right label, since it is ipsilateral
                new_name = ln.replace("Left", "Ipsilateral")
                new_labelnumber = LtoRdict[n]
            elif "Right" in ln and tumorside=="l":
                # label should switch to the old left label since it is contralateral
                new_name = ln.replace("Right", "Contralateral")
                new_labelnumber = RtoLdict[n]
            else:
                new_name = ln
                new_labelnumber = n

            # add new row to ipscon_dict
            ipscon_dict_row = {}

            ipscon_dict_row["labelnumber"] = new_labelnumber
            ipscon_dict_row["name"] = new_name
            ipscon_dict_row["vol"] = curr_csv_df.loc[n]["vol"]
            ipscon_dict_row["surfVS"] = curr_csv_df.loc[n]["surfVS"]
            ipscon_dict_row["distVS"] = curr_csv_df.loc[n]["distVS"]
            ipscon_dict_row["set"] = curr_csv_df.loc[n]["set"]

            ipscon_dict[i] = ipscon_dict_row

        # this dataframe is the same as csv_dfs, but with right/left labels replaced by ipsi/contra-lateral labels
        csv_dfs_ipscon[caseID] = pd.DataFrame.from_dict(ipscon_dict, orient='index')
        
        
    # read the ground truth Koos grades into a new dataframe
    df = pd.read_csv(os.path.join("..", "Koos_grades_groundtruth.csv"))
    df = df.rename(columns = {'final groundtruth':'Koos'})
    df_with_koos = df[df["Koos"].isin([1, 2, 3, 4, "1", "2", "3", "4"])].copy()

    df_with_koos['patient_ID'] = df_with_koos['patient_ID'].apply(lambda x: int(re.findall(r"vs_gk_([\d]+)", x)[0]))
    df_with_koos["Koos"] = df_with_koos["Koos"].astype(int)
    df_with_koos["patient_ID"] = df_with_koos["patient_ID"].astype(int)

    caseIDs = df_with_koos["patient_ID"].values

    # make sure no post-operative cases and no repeated cases (ID>=400) are included
    post_op_caseIDs = [11, 17, 26, 27, 41, 45, 49, 74, 75, 92, 96, 98, 100, 108, 114, 122, 128, 129, 135, 147, 150, 164, 171, 193, 204, 214, 218, 223, 225, 238, 246, 248, 249, 250, 259, 269, 275, 280, 283, 290, 297, 300, 340, 353, 354, 356, 375, 377, 380, 391, 2, 7, 47, 62, 80, 85, 89, 109, 127, 131, 134, 152, 161, 178, 196, 203, 209, 213, 231, 247, 253, 359, 360, 379]
    caseIDs = [c for c in caseIDs if c not in post_op_caseIDs and c < 400 and c in csv_dfs]

    koos = {caseID: df_with_koos[df_with_koos["patient_ID"] == caseID]["Koos"].values[0] for caseID in caseIDs}
    print(f"Found {len(koos)} cases with Koos grades.")

    # select features for random forest training
    featureIDs = \
    [(4, 'distVS'), # cerebellar vermal lobules 
     (7, 'surfVS'), # ipsilateral cerebellum
     (1, 'vol'), # tumour
     (6, 'distVS'), # cerebellar vermal lobules 
     (5, 'distVS'), # cerebellar vermal lobules 
     (2, 'distVS'), # Pons
     (8, 'distVS'), # contralateral cerebellum
     (0, 'surfVS'), # background
     (3, 'distVS'), # Brainstem
     (7, 'distVS'), # ipsilateral cerebellum
    ]
    featureIDs

    # create feature dataframe of selected features only
    feature_df = pd.DataFrame()
    
    for (n,t) in featureIDs:
        feature_vals = []
        for caseID in caseIDs:
            try:
                feature_vals.append(csv_dfs_ipscon[caseID][csv_dfs_ipscon[caseID]["labelnumber"] == n][t].values[0])
            except:
                # decide how to deal with missing values (can happen when a structure was not in the FOV)
                print("no value available:", f"{caseID=}", f"{n=}", f"{t=}")
                if t == "distVS":
                    feature_vals.append("fill mean")  # if distance has to be guessed, assume mean of all other corresponding values of other patients in training set
                elif t == "vol":
                    feature_vals.append("fill mean")  # if volume has to be guessed, assume mean of all other corresponding values of other patients in training set
                elif t == "surfVS":
                    feature_vals.append(0)  # contact surface is most likely 0 if the structure is not it the FOV
                else:
                    raise Exception("No rule defined for this missing feature.")
        
        if any(feature_vals):  # makes sure this feature is ignored if it's zero for all cases
            col_name = str(n)+","+t
            feature_df[col_name] = feature_vals

    # replace "fill mean" either with mean of current set (training) or passed mean (test set)
    assert(not feature_df.isnull().values.any()), "dataframe has nans"
    feature_df = feature_df.replace('fill mean',np.NaN)  # replace "fill mean" with nan
    if train_feature_df_mean is None:  # if no means have been passed (training set) 
        feature_df_mean = feature_df.mean()
        feature_df=feature_df.fillna(feature_df_mean)  # replace nan with mean of column
    else: # if means have been passed 
        feature_df_mean = feature_df.mean()
        feature_df=feature_df.fillna(train_feature_df_mean)  # replace nan with passed means
    
    # replace linear dataframe index with caseID
    feature_df.set_index(np.array(caseIDs), inplace=True)
    feature_df.index.name = "caseID"
    
    data = feature_df.to_numpy()
    target = np.array([koos[caseID] for caseID in koos])
    
    return data, target, feature_df_mean

In [4]:
all_results = {}
for idx, test_set_path in enumerate(test_set_paths):
    print("----------------------------------------------------------------")
    print(test_set_path)
    
    print("load training set...")
    # load extracted GIF features from origial GIF output (with manual VS segmentation)
    csv_files = natsorted(glob(os.path.join(train_set_path, "*.csv")))

    csv_dfs = {}  # save loaded dataframes in dictionary
    for f in csv_files:
        caseID = int(re.findall(r"vs_gk_(\d+)", f)[0])
        csv_df = pd.read_csv(f)
        csv_df['set'] = 'train'
        csv_dfs[caseID] = csv_df

    # remove csv entries of the current test set from the training set
    csv_files_test = natsorted(glob(os.path.join(test_set_path, "*.csv")))

    for f in csv_files_test:
        caseID = int(re.findall(r"vs_gk_(\d+)", f)[0])
        del csv_dfs[caseID]
    
    # process the training set
    X_train, y_train_groundtruth, train_feature_df_mean = process_set(csv_dfs)

    # train random forest
    clf = RandomForestClassifier(n_estimators=100000, max_depth=5, bootstrap=True, min_samples_leaf=2)
    clf.fit(X_train,y_train_groundtruth)
    
    print("load test set...")
    # load test set
    csv_test_dfs = {}
    for f in csv_files_test:
        caseID = int(re.findall(r"vs_gk_(\d+)", f)[0])
        csv_df = pd.read_csv(f)
        csv_df['set'] = 'test'
        csv_test_dfs[caseID] = csv_df
    
    # make sure no overlap between training and test set
    assert([c for c in csv_test_dfs if c in csv_dfs] == []), "Error: found training cases in test set"
    
    X_test, y_test_groundtruth, test_feature_df_mean = process_set(csv_test_dfs, train_feature_df_mean=train_feature_df_mean)

    # predict with trained random forest
    y_pred=clf.predict(X_test)

    print("calculate metrics...")
    # print confusion matrices and classification report
    all_results[idx] = (y_test_groundtruth, y_pred)
    print("confusion matrix:")
    confusion_mat = confusion_matrix(y_test_groundtruth, y_pred)
    class_report = classification_report(y_test_groundtruth, y_pred, )
    print(confusion_mat)
    print(class_report)

----------------------------------------------------------------
features/extracted_from_nnUNet_prediction/T1
load training set...


  0%|          | 0/2 [00:00<?, ?it/s]

Found 2 cases with Koos grades.
load test set...


  0%|          | 0/1 [00:00<?, ?it/s]

Found 1 cases with Koos grades.
calculate metrics...
confusion matrix:
[[0 1]
 [0 0]]
              precision    recall  f1-score   support

           3       0.00      0.00      0.00       1.0
           4       0.00      0.00      0.00       0.0

    accuracy                           0.00       1.0
   macro avg       0.00      0.00      0.00       1.0
weighted avg       0.00      0.00      0.00       1.0

----------------------------------------------------------------
features/extracted_from_nnUNet_prediction/T2
load training set...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/2 [00:00<?, ?it/s]

Found 2 cases with Koos grades.
load test set...


  0%|          | 0/1 [00:00<?, ?it/s]

Found 1 cases with Koos grades.
calculate metrics...
confusion matrix:
[[0 1]
 [0 0]]
              precision    recall  f1-score   support

           3       0.00      0.00      0.00       1.0
           4       0.00      0.00      0.00       0.0

    accuracy                           0.00       1.0
   macro avg       0.00      0.00      0.00       1.0
weighted avg       0.00      0.00      0.00       1.0

----------------------------------------------------------------
features/extracted_from_nnUNet_prediction/T1T2
load training set...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/2 [00:00<?, ?it/s]

Found 2 cases with Koos grades.
load test set...


  0%|          | 0/1 [00:00<?, ?it/s]

Found 1 cases with Koos grades.
calculate metrics...
confusion matrix:


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[0 1]
 [0 0]]
              precision    recall  f1-score   support

           3       0.00      0.00      0.00       1.0
           4       0.00      0.00      0.00       0.0

    accuracy                           0.00       1.0
   macro avg       0.00      0.00      0.00       1.0
weighted avg       0.00      0.00      0.00       1.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
