# Location generator for all lesions
The algorithm uses the MS FreeSurfer segmentations and the custom MS generated atlas (*prepare_atlas_MNI.ipynb*) to add a label to the metadata of each lesion with the location where they belong.
## Protocol to set a location
1. The order is (for all of them, 10: 
- Ventricles
- GM (overlap of 50% of the lesion needed)
- Deep GM (overlap of 50% of the lesion needed)
- Cerebellum
- Brainstem
2. If there is no match to any of the previous, it is considered WM lesion.

In [8]:
import sys
sys.path.append("..")
#from location import compute_locations
from config import *

In [9]:
from tqdm import tqdm
import pandas as pd
import nibabel as nib
import json
import multiprocessing as mp
from scipy.ndimage import binary_dilation
from scipy.ndimage.morphology import distance_transform_edt

from config import *
from utils import read_patients_metadata


gm = [3, 42]
brainstem = [16, ]
ventricles = [4, 14, 15, 43]
cerebellum = [7, 8, 46, 47]
deepGM = [10, 11, 12, 13, 49, 50, 51, 52]

names = ["WM", "Periventricular", "GM", "Deep GM", "Cerebellum", "BrainStem"]
final_labels = [ventricles, gm, deepGM, cerebellum, brainstem]
dilation = [5, 3, 0, 0, 0]
min_overlaps_PERC = [0, 0.5, 0.5, 0, 0]

def compute_location_matches_patient(db, patient, replace = False):
    dataset = AVAILABLE_DATASETS[db]
    
    # we load the FS segmentation
    paths = dataset.get(return_type="filename", subject=f"{patient:03d}", **CONTRASTS["MASK_FREESURFER"])
    if len(paths) == 0:
        print(f"[ERROR] Patient {patient} skipped: freesurfer mask missing.")
        return None
    atlas_path = paths[0]
    atlas_im = nib.load(atlas_path)
    atlas = atlas_im.get_fdata().astype(int)
    atlas_MS = np.zeros_like(atlas)
    
    for i in range(len(final_labels)):
        current = np.zeros_like(atlas)
        for sub_lab in final_labels[i]:
            current[(atlas == sub_lab) & (atlas_MS == 0)] = True
        
        if dilation[i] > 0:
            current = binary_dilation(current, iterations = dilation[i])
            
        atlas_MS[current == True] = i + 1
    
    atlas_path2 = atlas_path.replace(".nii.gz", "_MS.nii.gz")
    new_atlas = nib.Nifti1Image(atlas_MS.astype(float), atlas_im.affine)
    nib.save(new_atlas, atlas_path2)

    # computations of edt for each label
    #print("[LOG] Computing EDT matrices...")
    edts = {}
    for i in range(len(final_labels)):
        edts[i] = distance_transform_edt(atlas_MS != i + 1)
        #atlas_path2 = atlas_path.replace(".nii.gz", f"_{i+1}.nii.gz")
        #new_atlas = nib.Nifti1Image(edts[i], atlas_im.affine)
        #nib.save(new_atlas, atlas_path2)
    #print("[LOG] EDT matrices computed.")
    
    #print(f"[START] Db {db}, pat {patient}.")
    pipeline = DERIVATIVES["LESIONS"]["pipeline"].format(PATCH_SIZE[0], PATCH_SIZE[1], PATCH_SIZE[2])
    
    try:
        json_lesions_path = dataset.get(return_type="filename", subject=f"{patient:03d}", scope=pipeline, extension='json')[0]
        mask_path = dataset.get(return_type="filename", subject=f"{patient:03d}", **CONTRASTS["MASK"])[0]
    except Exception:
        print(f"[ERROR] Patient {patient} skipped: file missing.")
        return

    mask = nib.load(mask_path).get_fdata().astype(int)

    # READ
    with open(json_lesions_path) as inp:
        pat_metadata = json.load(inp)

    #print("[RUN] Metadata loaded")
    counters = [0, ] + [0 for i in final_labels]
    for les in pat_metadata:
        volume = int(pat_metadata[les]["volume"])
        if not replace and "location" in pat_metadata[les]:
            #print(f"[RUN] Lesion {les} skipped.")
            continue
        final_label = None
        les = int(les)
        for label in range(len(final_labels)):
            current = edts[label].copy()
            overlap = current[mask == les]
            if overlap.shape[0] == 0:
                print(f"[ERROR] No overlap with {df_labels.loc[df_labels['label'] == label, 'name'].array[0]} in lesion: {db} - {patient} - {les}")
                continue
                
            overlap_needed_perc = min_overlaps_PERC[label]
            num_overlap = len(np.where(overlap == 0)[0])
            # overlap needed
            #print(f"{les} - {label} - {num_overlap / volume >= overlap_needed_perc}")
            if num_overlap >= overlap_needed_abs and num_overlap / volume >= overlap_needed_perc:
                final_label = label + 1
                break
        # no match => WM
        if final_label == None:
            final_label = 0 # WM
            
        #print(f"{patient} - {les} - {df_labels.loc[df_labels['label'] == min_label, 'name'].array[0]} - {min_distance}")
        pat_metadata[str(les)]["location"] = names[final_label]
        pat_metadata[str(les)]["location_distance"] = 0
    os.remove(json_lesions_path)
    with open(json_lesions_path, "w") as outfile:
        json.dump(pat_metadata, outfile)
    return db, patient


def compute_locations(datasets = None, cpus = 6, replace = False):
    if datasets is None:
        datasets = list(range(len(AVAILABLE_DATASETS)))
        
    pool = mp.Pool(cpus)
    processes = []
    for db in datasets:
        dataset = AVAILABLE_DATASETS[db]
        for patient in dataset.get_subjects():
            def callback(result):
                if result is not None:
                    print(f"[END] Patient {result[0]} - {result[1]} finished...")
            processes.append(pool.apply_async(compute_location_matches_patient, args=(db, int(patient), replace), callback=callback))
    for p in processes:
        p.get()
    pool.close()
    pool.join()
    

In [6]:
#compute_locations(datasets = [DATASET_BASEL_ID, ], cpus = 2, replace = True)
#compute_locations(datasets = [DATASET_CHUV_ID, ], cpus = 1)

In [7]:
compute_location_matches_patient(0, 57, replace = True)

1000 - 0 - True
1000 - 1 - False
1000 - 2 - False
1000 - 3 - True
1000 - 4 - True
1000
1001 - 0 - True
1001
1002 - 0 - True
1002
2000 - 0 - True
2000
2001 - 0 - True
2001


(0, 57)