In [74]:
import os
import numpy as np
import nibabel as nib
import operator
import matplotlib.pyplot as plt

# Prepare thigh data

In [16]:
root_dir = '/home/donal/t/Donal/thigh_nii/'
out_dir = './data/'

In [24]:
def clean_ids(path):
    data = {}
    for file in os.listdir(path):
        if 'masks' in path:
            id_ = file.split('_Sarc_')[-1].split('_')[0]
        elif 'scans' in path:
            id_ = file.split('SARC_')[-1].split('_')[0]
        if '.nii' in id_:
            id_ = id_.split('.')[0]
        id_num = int(id_)
        data[id_num] = path + file
    return data

In [28]:
mask_data = clean_ids(root_dir +'masks/')
slice_data = clean_ids(root_dir + 'scans/')
assert mask_data.keys() == slice_data.keys()

In [31]:
mask_data.keys(), slice_data.keys()

(dict_keys([1, 3, 5, 6, 8, 9, 10, 14, 15, 19, 20, 24, 26, 28, 29, 30, 32, 33, 34, 36]),
 dict_keys([1, 3, 5, 6, 8, 9, 10, 14, 15, 19, 20, 24, 26, 28, 29, 30, 32, 33, 34, 36]))

In [85]:
def extract_slice(img_path, mask_path, plot=False):
    # Find slice number where mask != 0
    img = nib.load(img_path).get_fdata()
    mask = nib.load(mask_path).get_fdata()
    assert img.shape ==  mask.shape, "Mask and image shape don't match"
    all_counts = {}
    # Count occurences - should handle multiple slices w/ delin.
    for i in range(mask.shape[-1]):
        unique, counts = np.unique(mask[..., i], return_counts=True)
        occ = dict(zip(unique, counts))
        if 1 in occ.keys():
            all_counts[i] = occ[1]
    print(all_counts)
    # Filter by number of occurences
    if len(all_counts.keys()) == 1:
        idx = [int(k) for k in all_counts]
    else:
        idx = int(max(all_counts, key=lambda key: all_counts[key]))
    print(idx)
    slice_, contour = np.rot90(np.squeeze(img[..., idx]), k=-1), np.rot90(np.squeeze(mask[..., idx]), k=-1)
    # ---SAVE -- 
    np.save(f"./data/slices/{img_path.split('/')[-1].split('.')[0]}.npy", slice_)
    np.save(f"./data/masks/{img_path.split('/')[-1].split('.')[0]}.npy", contour)
    # ---- PLOT --- 
    if plot:
        plot_slice = np.clip(slice_, a_min=874, a_max=1274)
        plot_contour = np.where(contour == 0, np.nan, 1)
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.axis('off')
        ax.imshow(plot_slice, cmap='gray')
        ax.imshow(plot_contour, alpha=0.5)
        fig.savefig(f"./sanity/{img_path.split('/')[-1].split('.')[0]}.png")
        plt.close()

In [86]:
for key in slice_data.keys():
    extract_slice(slice_data[key], mask_data[key], plot=True)

{242: 26402}
[242]
{230: 35185}
[230]
{252: 34565}
[252]
{144: 41905}
[144]
{229: 28036}
[229]
{189: 34790}
[189]
{241: 29357}
[241]
{229: 27171}
[229]
{218: 38143}
[218]
{144: 30742}
[144]
{250: 35879}
[250]
{250: 44252}
[250]
{239: 37848}
[239]
{242: 21011}
[242]
{131: 2138, 246: 22099}
246
{243: 30198}
[243]
{236: 35559}
[236]
{237: 11648}
[237]
{218: 37768}
[218]
{119: 15483, 230: 24956}
230
