# I/ Loading CT scan  

In [None]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches

#Function to load the image
def get_image_label(img_number):
    load_image = nib.load(f"train-images/CT_{img_number}.nii.gz").get_fdata()
    load_label = nib.load(f"train-images/labels/final/CT_{img_number}.nii.gz").get_fdata()
    return load_image, load_label

In [None]:
#Load the image
img_number=2
load_image, load_label=get_image_label(img_number)
load_image=load_image.swapaxes(1,2)



# II/ Data visualisation

In [None]:
#VFunction to visualise a frame of the CT scan
def visualise_img(img,tranche):
    fig, ax = plt.subplots()
    img_tranche=img[:,:,tranche]
           
    ax.imshow(img_tranche, cmap='gray')  
    plt.show()

tranche_number=load_image.swapaxes(1,2)

In [None]:
#Interface
from ipywidgets import widgets, interact,fixed

tranche_number=load_image.shape[2]
interact(visualise_img, img=fixed(load_image), tranche=(1,tranche_number,1),threshold=(100,1000,1));

# III/ Bone segmentation 

In [None]:
from scipy.signal import convolve2d
import cv2 as cv2
import skimage






##Function to denoise the CT_SCAN using fastNlMeansDenoisingMulti from cv2###
## Return the CT scan denoised with a normalised intensity    
def denoise_NL(load_image,threshold,h):
    denoise_matrix=np.zeros(load_image.shape)

    for k in range(load_image.shape[2]-10):
        factor=250/np.max(load_image[:,:,k:(k+10)])

        #We use frame from k, k+10 to denoise the frame k+1##
        imgs_to_denoise=np.where(load_image[:,:,k:(k+10)]>threshold,load_image[:,:,k:(k+10)],0)
        imgs_to_denoise=imgs_to_denoise*factor
        imgs_to_denoise=imgs_to_denoise.astype(np.uint8).T

        img=cv2.fastNlMeansDenoisingMulti(imgs_to_denoise,h=h,imgToDenoiseIndex=1,temporalWindowSize=3)
        img=img.astype(float)/factor
        img=img/250
        img=img.T

        denoise_matrix[:,:,k+1]=img
    return denoise_matrix

#Function to denoise a single frame of the CT_scan##
def denoise_NL_single_frame(load_image,frame,threshold,h):
    factor=250/np.max(load_image[:,:,frame:(frame+10)])

        #We use frame from k, k+10 to denoise the frame k+1##
    imgs_to_denoise=np.where(load_image[:,:,frame:(frame+10)]>threshold,load_image[:,:,frame:(frame+10)],0)
    imgs_to_denoise=imgs_to_denoise*factor
    imgs_to_denoise=imgs_to_denoise.astype(np.uint8).T

    img=cv2.fastNlMeansDenoisingMulti(imgs_to_denoise,h=h,imgToDenoiseIndex=1,temporalWindowSize=3)
    img=img.astype(float)/factor
    img=img/250
    img=img.T

    return img

        




##Get the center coordonates of a certain label##
def get_center_label(contour_image,label):
    coordonates=np.where(contour_image==label)
    return np.mean(coordonates, axis=1)

#Get the contour of each bones from the CT_scan at a certain frame
#return the contour imaged. All contours are filled
def get_segmentation_frame(img_after_thresh,frame,volume_min_pixel):
    if frame==None:
        cv_image=250*img_after_thresh.astype(np.uint8)
    else:
        cv_image=250*img_after_thresh[:,:,frame].astype(np.uint8)
    ##apply filter 
    result = cv_image
    ##Get contour from the image
    contours, _ = cv2.findContours(result.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    contour_image = np.zeros_like(cv_image,dtype=np.uint8)
    contour_image = cv2.UMat(contour_image)
    # Draw contours on the black image
    for i in range(len(contours)):
        
        cv2.drawContours(contour_image, [contours[i]], -1, (1+i, 1+i, 1+i),thickness=cv2.FILLED)

    contour_image = contour_image.get()
    ##
    for contour_label in range(1,len(contours)+1):
        if(np.count_nonzero(contour_image==contour_label) < volume_min_pixel):
            contour_image=np.where(contour_image==contour_label,0,contour_image)
    return contour_image

In [None]:
from matplotlib.colors import LinearSegmentedColormap
##Function visualise the segmentation given 
def choose_tresh_h(threshold,frame,load_image,h):
    
    img_denoised=denoise_NL_single_frame(load_image=load_image,frame=frame,threshold=threshold,h=h)

    mask=get_segmentation_frame(img_denoised,frame=None,volume_min_pixel=50)
    fig, ax = plt.subplots()
    yellow_cmap = LinearSegmentedColormap.from_list("yellow_cmap", [(1, 1, 0), (1, 1, 0)], N=2)
    alpha_mask = np.where(mask != 0, 0.5, 0)
    nb_labels=np.max(mask)
    for i in range(1,int(nb_labels)+1):
        ax.scatter(get_center_label(mask,i)[1],get_center_label(mask,i)[0])
    

    ax.imshow(load_image[:,:,frame],cmap='gray')
    ax.imshow(mask,cmap=yellow_cmap,alpha=alpha_mask)
    plt.show()
    

In [None]:
#Don't choose more than 100 frames as this function can take a long time to be executed##
interact(choose_tresh_h, load_image=fixed(load_image),frame=(0,100,1),threshold=(200,300,1),h=(5,20,1));

In [None]:
##Put the value of h and threshold choosen think to the cell above
threshold=275
h=20
##

###Denoising of the CT scan, this can take between 5 and 10 mins
masks_image=np.zeros(load_image.shape)
masks_bones=np.zeros(load_image.shape)
img_after_thresh=denoise_NL(load_image=load_image,threshold=threshold,h=20)#threshold_on_image(load_image,threshold=threshold)

print("Just finish denoising")



for k in range(0,load_image.shape[2]):
    #masks_bones[:,:,k]= get_segmentation_all_bones(img_after_thresh,k,100,sigma_x=sigma_x,sigma_y=sigma_y,threshold=threshold_all_bones)
    masks_image[:,:,k]=get_segmentation_frame(img_after_thresh,k,50)
   

In [None]:
from matplotlib.colors import LinearSegmentedColormap
#Function to visualise to segmentation on all the Ct_scan
def visu_frame_mask(load_image,mask,frame):
    fig, ax = plt.subplots()
    yellow_cmap = LinearSegmentedColormap.from_list("yellow_cmap", [(1, 1, 0), (1, 1, 0)], N=2)
    alpha_mask = np.where(mask[:, :, frame] != 0, 0.5, 0)

    nb_labels=np.max(mask[:, :, frame])
    for i in range(1,int(nb_labels)+1):
        ax.scatter(get_center_label(mask[:, :, frame],i)[1],get_center_label(mask[:, :, frame],i)[0])

    ax.imshow(load_image[:,:,frame],cmap='gray')
    ax.imshow(mask[:,:,frame],cmap=yellow_cmap,alpha=alpha_mask)
    plt.show()

In [None]:
#Visualisation of the segmentation of the segmentation on the CT_scan
interact(visu_frame_mask, load_image=fixed(load_image),mask=fixed(masks_image) ,frame=(1,tranche_number,1));

# IV/ Tracking 


In [None]:
## Creation of function that will be used
from skimage.segmentation import active_contour
from skimage import io, color, draw 
def get_snake_contours(previous_clusters,masks_previous_frame,load_image_current_frame):
        masks_current_frame=np.zeros(masks_previous_frame.shape).astype(np.uint8)
        print()
        label_assignement={}

        for previous_cluster in previous_clusters:
                label,_,_=previous_cluster
                binary_imged=np.where(masks_previous_frame==label,250,0).astype(np.uint8)
                contours, _ = cv2.findContours(binary_imged.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                contour=contours[0]
                snake = active_contour(load_image_current_frame,contour.squeeze().reshape(-1, 2), alpha=0.002, beta=0.1, gamma=1)
                print(snake.shape)

                rr, cc = draw.polygon(snake[:, 1], snake[:, 0], masks_current_frame.shape)
                masks_current_frame[rr,cc]=label

                
                #cv2.drawContours(masks_current_frame, contour, -1, (int(label), int(label), int(label)),thickness=cv2.FILLED)
                
                label_assignement[label]=label
                
        return masks_current_frame,label_assignement


def defusion(previous_clusters,masks_previous_frame,load_image_current_frame):
        masks_current_frame=np.zeros(masks_previous_frame.shape).astype(np.uint8)
        print()
        label_assignement={}

        for previous_cluster in previous_clusters:
                label,_,_=previous_cluster
                binary_imged=np.where(masks_previous_frame==label,250,0).astype(np.uint8)
                contours, _ = cv2.findContours(binary_imged.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                contour=contours[0]
                snake = active_contour(load_image_current_frame,contour.squeeze().reshape(-1, 2), alpha=0.002, beta=0.1, gamma=1)
                print(snake.shape)

                rr, cc = draw.polygon(snake[:, 1], snake[:, 0], masks_current_frame.shape)
                masks_current_frame[rr,cc]=label

                
                #cv2.drawContours(masks_current_frame, contour, -1, (int(label), int(label), int(label)),thickness=cv2.FILLED)
                
                label_assignement[label]=label
                
        return masks_current_frame,label_assignement

def get_volumes(masks,label,frame):
        return np.sum(masks[:,:,frame]==label)

def get_cluster_asignement(previous_clusters,current_clusters,masks_previous_frame,masks_current_frame,distance_max):
        label_assignement={}
        for cluster in current_clusters:
                overlapping_max=0
                number_of_over_lapping=0
                label,x,y=cluster
                
                label_assignement[label]=None
                for previous_cluster in previous_clusters:
                      
                      label_previous,x_previous,y_previous=previous_cluster
                      distance=np.sqrt((x-x_previous)**2+(y-y_previous)**2)
                
                      if distance<distance_max:
                                overlapping=np.sum( np.where(masks_previous_frame==label_previous,1,0)*np.where(masks_current_frame==label,1,0))/np.sum(np.where(masks_current_frame==label,1,0))
                                
                                if overlapping>0.2:
                                       number_of_over_lapping+=1


                                if overlapping_max<overlapping:
                                       overlapping_max=overlapping
                                       label_assignement[label]=label_previous

                if number_of_over_lapping>1:
                       print("detected one fusion")
        return label_assignement


def get_cluster_asignement_no_fusion(previous_clusters,current_clusters,masks_previous_frame,masks_current_frame,load_image_previous_frame,load_image_current_frame,distance_max):
        label_assignement={}
        for cluster in current_clusters:
                overlapping_max=0
                number_of_over_lapping=0
                label,x,y=cluster
                
                label_assignement[label]=None
                for previous_cluster in previous_clusters:
                      
                      label_previous,x_previous,y_previous=previous_cluster
                      distance=np.sqrt((x-x_previous)**2+(y-y_previous)**2)
                
                      if distance<distance_max:
                                overlapping=np.sum( np.where(masks_previous_frame==label_previous,1,0)*np.where(masks_current_frame==label,1,0))/np.sum(np.where(masks_current_frame==label,1,0))
                                
                                if overlapping>0.2:
                                       number_of_over_lapping+=1


                                if overlapping_max<overlapping:
                                       overlapping_max=overlapping
                                       label_assignement[label]=label_previous

                if number_of_over_lapping>1:
                       print("Using defusion")
                       masks_current_frame,label_assignement=defusion(previous_clusters,masks_previous_frame,load_image_current_frame)
                       break
        return label_assignement, masks_current_frame

                                       

              

#get_snake_contours(previous_clusters,masks_previous_frame,load_image_current_frame)
        


In [None]:
# Tracking using naif association
def tracking(masks):
    masks_tracking=np.zeros(masks.shape)
    masks_tracking[:,:,1]=masks[:,:,1]
    previous_clusters=[]#[(label,volume,x,y)] with x,y the center of the label
    labels=np.unique(masks[:,:,1][masks[:,:,1]!=0])
    #print(labels)
    max_label=np.max(labels)+1

    for label in labels:
        #volume=get_volumes(masks=masks,label=label,frame=0)
        x,y=get_center_label(masks[:,:,1],label)
        previous_clusters.append((label,x,y))

    for frame in range(2,masks.shape[2]):
        print(f"Iteration {frame} on {masks.shape[2]}")
        current_clusters=[]
        labels=np.unique(masks[:,:,frame][masks[:,:,frame]!=0])
        #print(labels)
        previous_clusters=[]
        labels_previous=np.unique(masks_tracking[:,:,frame-1][masks_tracking[:,:,frame-1]!=0])
        for label in labels_previous:
        #volume=get_volumes(masks=masks,label=label,frame=0)
            x,y=get_center_label(masks_tracking[:,:,frame-1],label)
            previous_clusters.append((label,x,y))

        for label in labels:
            
            x,y=get_center_label(masks[:,:,frame],label)
            current_clusters.append((label,x,y))

        label_assignement=get_cluster_asignement(previous_clusters,current_clusters,masks_tracking[:,:,frame-1],masks[:,:,frame],distance_max=300)
        

        #previous_clusters=current_clusters
         
        for current_label in label_assignement.keys():
            if label_assignement[current_label]!=None:
                indexs_current_label=np.where(masks[:,:,frame]==current_label)
                masks_tracking[:,:,frame][indexs_current_label]=label_assignement[current_label]
            else:
                max_label+=1
                indexs_current_label=np.where(masks[:,:,frame]==current_label)
                masks_tracking[:,:,frame][indexs_current_label]=max_label
    return masks_tracking


#Try to correct error from the naif tracking algo

#Using the snake algo to detect contours when an error has benn detected
def tracking_no_fusion(masks,load_image_denoised):
    masks_tracking=np.zeros(masks.shape)
    masks_tracking[:,:,1]=masks[:,:,1]
    previous_clusters=[]#[(label,volume,x,y)] with x,y the center of the label
    labels=np.unique(masks[:,:,1][masks[:,:,1]!=0])
    print(labels)
    if len(labels)>=1:
        max_label=np.max(labels)+1
    else:
        max_label=1

    for label in labels:
        #volume=get_volumes(masks=masks,label=label,frame=0)
        x,y=get_center_label(masks[:,:,1],label)
        previous_clusters.append((label,x,y))

    for frame in range(2,masks.shape[2]):
        print(f"Iteration {frame} on {masks.shape[2]}")
        current_clusters=[]
        labels=np.unique(masks[:,:,frame][masks[:,:,frame]!=0])
        #print(labels)
        previous_clusters=[]
        labels_previous=np.unique(masks_tracking[:,:,frame-1][masks_tracking[:,:,frame-1]!=0])
        for label in labels_previous:
        #volume=get_volumes(masks=masks,label=label,frame=0)
            x,y=get_center_label(masks_tracking[:,:,frame-1],label)
            previous_clusters.append((label,x,y))

        for label in labels:
            
            x,y=get_center_label(masks[:,:,frame],label)
            current_clusters.append((label,x,y))

        label_assignement,mask_current_frame=get_cluster_asignement_no_fusion(previous_clusters,current_clusters,masks_tracking[:,:,frame-1],masks[:,:,frame],load_image_denoised[:,:,frame-1],load_image_denoised[:,:,frame],distance_max=300)
        

        #previous_clusters=current_clusters
         
        for current_label in label_assignement.keys():
            if label_assignement[current_label]!=None:
                indexs_current_label=np.where(mask_current_frame==current_label)
                masks_tracking[:,:,frame][indexs_current_label]=label_assignement[current_label]
            else:
                max_label+=1
                indexs_current_label=np.where(mask_current_frame==current_label)
                masks_tracking[:,:,frame][indexs_current_label]=max_label
    return masks_tracking


masks_tracking=tracking_no_fusion(masks=masks_image,load_image_denoised=img_after_thresh)





In [None]:
def visu_tracking(load_image,mask,frame):
        fig, ax = plt.subplots()
       
        num_labels = 6#len(np.unique(mask))
        print(num_labels)
        colors = plt.cm.tab20c(np.linspace(0, 1, num_labels))  # Generate colors from the 'tab10' colormap
        custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=num_labels)

        alpha_mask = np.where(mask[:, :, frame] != 0, 0.5, 0)

        nb_labels=np.max(mask[:, :, frame])
        for i in range(1,int(nb_labels)+1):
            ax.scatter(get_center_label(mask[:, :, frame],i)[1],get_center_label(mask[:, :, frame],i)[0])

        ax.imshow(load_image[:,:,frame],cmap='gray')
        ax.imshow(mask[:,:,frame],cmap=custom_cmap,alpha=alpha_mask)
        plt.show()



In [None]:
interact(visu_tracking, load_image=fixed(load_image),mask=fixed(masks_tracking) ,frame=(0,740,1));
#interact(visu_frame_mask, load_image=fixed(load_image),mask=fixed(masks_tracking) ,frame=(1,tranche_number,1));

In [None]:
def save_mask(mask,filename):
    nifti_img = nib.Nifti1Image(mask, np.eye(4))
    nib.save(nifti_img, filename)
masks_tracking=masks_tracking.swapaxes(1,2)
save_mask(masks_tracking,'./masks_from_python/hand_2_bis_new.nii.gz')

In [None]:
def convert_mask_header(original_file,to_convert_file):
    header=nib.load(f"train-images/CT_{original_file}.nii.gz").header
    created_mask = nib.load(to_convert_file).get_fdata()
    nifti_img = nib.Nifti1Image(created_mask, np.eye(4),header=header)
    nib.save(nifti_img, "./masks_from_python/hand_2_bis_hey.nii.gz")




convert_mask_header(2,'./masks_from_python/hand_2_bis.nii.gz')