In [None]:
import numpy as np # extra
import os # extra
import glob
import cv2 as cv
import matplotlib.pyplot as plt # extra
from pathlib import Path
from typing import Tuple
from woundcompute import image_analysis as ia
from woundcompute import segmentation as seg
from woundcompute import compute_values as com
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, Matern, RationalQuadratic
from woundcompute import post_process as pp
from skimage.measure import regionprops
from skimage.io import imread
from sklearn.metrics import r2_score
import re


def find_yaml_folders(base_path):
    yaml_folders = set()
    
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.endswith('.yaml'):
                yaml_folders.add(root)
                break

    return list(yaml_folders)

# # Example usage
# base_path1 = '/projectnb/lejlab2/quan/wound_healing/wound_compute_env/woundcompute/quan_playground/data/Ellie'
# base_path2 = '/projectnb/lejlab2/quan/wound_healing/wound_compute_env/woundcompute/quan_playground/data/Varun'
# base_path3 = '/projectnb/lejlab2/quan/wound_healing/wound_compute_env/woundcompute/quan_playground/data/Filip'
# base_path = '/projectnb/lejlab2/quan/wound_healing/wound_compute_env/woundcompute/quan_playground/data/20240909_R7P1_Viability/Sorted2/tissue_ai'
base_path = '/projectnb/lejlab2/quan/wound_healing/wound_compute_env/woundcompute/quan_playground/data/20240913_R7P2_Viability2/Sorted2/tissue_ai1'

yaml_folders = find_yaml_folders(base_path)
yaml_folders.sort()

In [None]:
# functions

def load_raw_img_list(folders_list,sample_ind):
    file_name = folders_list[sample_ind][-8:]
    # print(f'file_name = {file_name}')
    input_file = Path(folders_list[sample_ind])
    input_dict, input_path_dict, output_path_dict = ia.input_info_to_dicts(input_file)
    zoom_fcn = com.select_zoom_function(input_dict)
    input_path = input_path_dict["ph1_images_path"]
    # output_path = output_path_dict["segment_ph1_path"]
    thresh_fcn = seg.select_threshold_function(input_dict, False, False, True)

    img_list = ia.read_all_tiff(input_path)
    return img_list,zoom_fcn,thresh_fcn,file_name

def load_pillar_masks_list(folders_list,sample_ind):
    segmented_folder = folders_list[sample_ind]+'/segment_ph1/'
    pillar_pattern = 'pillar_*.npy'
    pillar_paths = glob.glob(segmented_folder + pillar_pattern)
    pillar_paths.sort()
    pillar_masks_list = [np.load(pil_path,allow_pickle=True) for pil_path in pillar_paths]
    return pillar_masks_list

def load_pillars_locations(folders_list,sample_ind):
    pillars_loc_folder = folders_list[sample_ind]+'/track_pillars_ph1/'
    pillars_x_locations = np.loadtxt(pillars_loc_folder+'pillar_tracker_x.txt')
    pillars_y_locations = np.loadtxt(pillars_loc_folder+'pillar_tracker_y.txt')
    return pillars_x_locations,pillars_y_locations

def extract_mask_image(binary_image):
    # Find contours in the binary image
    contours, _ = cv.findContours(binary_image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    # Assuming the mask is the largest contour
    if contours:
        largest_contour = max(contours, key=cv.contourArea)
        # Get the bounding box of the largest contour
        x, y, w, h = cv.boundingRect(largest_contour)
        # Extract the region of interest (ROI) from the binary image
        mask_image = binary_image[y:y+h, x:x+w]
        return mask_image
    else:
        raise ValueError("No contours found in the binary image.")

def determine_quadrant(point:np.ndarray,img_w:float,img_h:float)->int:
    x,y=point
    center_x = img_w/2
    center_y = img_h/2

    if x >= center_x and y < center_y:
        return 0  # Top-right quadrant
    elif x < center_x and y < center_y:
        return 1  # Top-left quadrant
    elif x < center_x and y >= center_y:
        return 2  # Bottom-left quadrant
    else:
        return 3  # Bottom-right quadrant

def rearrange_pillars_indexing(pillar_masks,pillars_x_pos:np.ndarray,pillars_y_pos:np.ndarray):
    img_h,img_w = pillar_masks[0].shape
    px_f0 = pillars_x_pos[0,:]
    py_f0 = pillars_y_pos[0,:]
    pillar_centers = np.concatenate((px_f0.reshape(-1,1),py_f0.reshape(-1,1)),axis=1)

    rearranged_px_pos = np.zeros_like(pillars_x_pos)
    rearranged_py_pos = np.zeros_like(pillars_y_pos)
    quad_indices=[]
    for og_ind,pc in enumerate(pillar_centers):
        quad_ind = determine_quadrant(pc,img_w,img_h)
        rearranged_px_pos[:,quad_ind] = pillars_x_pos[:,og_ind]
        rearranged_py_pos[:,quad_ind] = pillars_y_pos[:,og_ind]
        quad_indices.append(quad_ind)
    
    new_indices=np.argsort(np.array(quad_indices))
    rearranged_pillar_masks = [pillar_masks[new_indices[0]],pillar_masks[new_indices[1]],pillar_masks[new_indices[2]],pillar_masks[new_indices[3]]]
    return rearranged_pillar_masks,rearranged_px_pos,rearranged_py_pos

def smooth_with_GPR(s: np.ndarray) -> np.ndarray:
    num_frames = s.shape[0]
    # kernel = 1.0 * RBF(length_scale=1.0, length_scale_bounds=(1e-5, 1e1)) + 1.0 * WhiteKernel(noise_level=1.0, noise_level_bounds=(1e-5, 1e1))
    kernel = 1.0 * Matern(length_scale=1.0,length_scale_bounds=(1e-5,1e1),nu=2.5) + 1.0 * WhiteKernel(noise_level=1.0, noise_level_bounds=(1e-5,1e1))
    # kernel = 1.0 * RationalQuadratic(length_scale=1.0, alpha=1.0, length_scale_bounds=(1e-5, 1e1), alpha_bounds=(1e-05, 100000.0)) + 1.0 * WhiteKernel(noise_level=1.0, noise_level_bounds=(1e-5,1e1))
    model = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
    xdata = np.arange(num_frames).reshape(-1, 1)
    xhat = xdata
    ydata = s.reshape(-1, 1)
    remove_indices = np.where(np.isnan(ydata.reshape(-1)))[0]
    xdata = np.delete(xdata, remove_indices, axis=0)
    ydata = np.delete(ydata, remove_indices, axis=0)
    model.fit(xdata, ydata)
    yhat = model.predict(xhat)
    posterior_kernel = model.kernel_
    return yhat.reshape(-1),posterior_kernel

def compute_relative_pillars_dist(all_pillars_x_pos:np.ndarray,all_pillars_y_pos:np.ndarray):
    p0x,p1x,p2x,p3x = all_pillars_x_pos[:,0],all_pillars_x_pos[:,1],all_pillars_x_pos[:,2],all_pillars_x_pos[:,3]
    p0y,p1y,p2y,p3y = all_pillars_y_pos[:,0],all_pillars_y_pos[:,1],all_pillars_y_pos[:,2],all_pillars_y_pos[:,3]
    p0p1x = p1x-p0x
    p0p1y = p1y-p0y
    p0p1_dist = np.sqrt(p0p1x**2+p0p1y**2)
    p2p3x = p2x-p3x
    p2p3y = p2y-p3y
    p2p3_dist = np.sqrt(p2p3x**2+p2p3y**2)
    p1p2x = p2x-p1x
    p1p2y = p2y-p1y
    p1p2_dist = np.sqrt(p1p2x**2+p1p2y**2)
    p0p3x = p3x-p0x
    p0p3y = p3y-p0y
    p0p3_dist = np.sqrt(p0p3x**2+p0p3y**2)
    p1p3_x = p1x-p3x
    p1p3_y = p1y-p3y
    p1p3_dist = np.sqrt(p1p3_x**2+p1p3_y**2)
    p0p2_x = p2x-p0x
    p0p2_y = p2y-p0y
    p0p2_dist = np.sqrt(p0p2_x**2+p0p2_y**2)
    return p0p1_dist,p2p3_dist,p1p2_dist,p0p3_dist,p1p3_dist,p0p2_dist

In [None]:
# load in all samples + pillar masks + pillar locations

img_list_all=[]
all_p_x_loc = [] # pillar x locations of all samples
all_p_y_loc = [] # pillar y locations of all samples
file_names_all=[]
all_pillar_masks_lists=[] # pillar masks of all samples
for file_ind in range(31):
    if file_ind in [0,3,5,23]:
        continue
    img_list,zoom_fcn,thresh_fcn,fn = load_raw_img_list(yaml_folders,file_ind)
    file_names_all.append(fn)

    pillars_x_locations,pillars_y_locations=load_pillars_locations(yaml_folders,file_ind)
    pillars_mask_list=load_pillar_masks_list(yaml_folders,file_ind)

    img_list_all.append(img_list)
    all_p_x_loc.append(pillars_x_locations)
    all_p_y_loc.append(pillars_y_locations)
    all_pillar_masks_lists.append(pillars_mask_list)

In [None]:
samp_ind = 14 # pick the sample you want to look at

# pick out the data for the sample in questions
print(f'file_names={file_names_all[samp_ind]}')
samp_im_list = img_list_all[samp_ind]
samp_px_loc = all_p_x_loc[samp_ind]
samp_py_loc = all_p_y_loc[samp_ind]
samp_pil_masks = all_pillar_masks_lists[samp_ind]

# rearrange the pillars indexing so that:
# p0 is top right, p1 is top left, p2 is bot left, p3 is bot right
samp_pil_masks,samp_px_loc,samp_py_loc = rearrange_pillars_indexing(samp_pil_masks,samp_px_loc,samp_py_loc)

In [None]:
# this block of code exists with the purpose to confirm the pillar indexing via visualization
# if you don't run this, it's fine

# extract the "template" of the pillars using the pillar masks + frame
# template: the part of the image that is the pillar
pillar_masks_only = []
pillar_templates_shape = []
for pm in samp_pil_masks:
    pm_only = extract_mask_image(pm.astype(np.uint8))
    pillar_masks_only.append(pm_only)
    pillar_templates_shape.append(pm_only.shape)

# extracting pillar templates for all frames
templates_all_frames = []
for img_ind in range(len(samp_im_list)):
    cur_frame = samp_im_list[img_ind]
    cur_pil_x_pos = samp_px_loc[img_ind,:]
    cur_pil_y_pos = samp_py_loc[img_ind,:]
    
    templates_cur_frame = []
    num_pillars = len(cur_pil_x_pos)
    for pil_ind in range(num_pillars):
        center_x = cur_pil_x_pos[pil_ind]
        center_y = cur_pil_y_pos[pil_ind]
        x_dist,y_dist = pillar_templates_shape[pil_ind]

        left = int(center_x - x_dist/2)
        right = int(center_x + x_dist/2)
        top = int(center_y - y_dist/2)
        bot = int(center_y + y_dist/2)

        template_p = cur_frame[top:bot,left:right]
        templates_cur_frame.append(template_p)
    
    templates_all_frames.append(templates_cur_frame)

In [None]:
# this block of code exists with the purpose to confirm the pillar indexing via visualization
# if you don't run this, it's fine

frame_ind = 0

cur_frame_templates = templates_all_frames[frame_ind]
frame0_img = samp_im_list[0]

fig,ax = plt.subplots(nrows=3,ncols=2,figsize=(7,10))

ax[0,0].imshow(frame0_img,cmap='gist_gray')
ax[0,0].axis('off')
ax[0,0].set_title('frame0 img')

ax[0,1].imshow(samp_im_list[frame_ind],cmap='gist_gray')
ax[0,1].axis('off')
ax[0,1].set_title(f'frame{frame_ind} image')

ax[1,0].imshow(cur_frame_templates[0],cmap='gist_gray')
ax[1,0].axis('off')
ax[1,0].set_title('pillar 0')

ax[1,1].imshow(cur_frame_templates[1],cmap='gist_gray')
ax[1,1].axis('off')
ax[1,1].set_title('pillar 1')

ax[2,0].imshow(cur_frame_templates[2],cmap='gist_gray')
ax[2,0].axis('off')
ax[2,0].set_title('pillar 2')

ax[2,1].imshow(cur_frame_templates[3],cmap='gist_gray')
ax[2,1].axis('off')
ax[2,1].set_title('pillar 3')

plt.suptitle(f'sample: {file_names_all[samp_ind]}')
plt.tight_layout()

## relative pillar positions

In [None]:
# compute the relative distance
p0p1_dist,p2p3_dist,p1p2_dist,p0p3_dist,p1p3_dist,p0p2_dist = compute_relative_pillars_dist(samp_px_loc,samp_py_loc)

cur_rel_dist = [p0p1_dist,p2p3_dist,p1p2_dist,p0p3_dist,p1p3_dist,p0p2_dist]

In [None]:
# plot the raw relative distance between pillar per frame and the GP result
fig,ax = plt.subplots(ncols=2,nrows=3,figsize=(10,12))

title_names=['p0p1 distance','p2p3 distance','p1p2 distance','p0p3 distance','p1p3 distance','p0p2 distance']
for rd_ind,rel_pillars_dist in enumerate(cur_rel_dist):
    frame_ind_steps = np.linspace(0,len(rel_pillars_dist)-1,len(rel_pillars_dist),dtype=int)

    smoothed_dist,post_kern = smooth_with_GPR(rel_pillars_dist)
    frame_ind_smoothed = np.linspace(0,len(smoothed_dist)-1,len(smoothed_dist),dtype=int)
    print(post_kern)
    
    if rd_ind<=1:
        x_ind = 0
    elif rd_ind>1 and rd_ind<=3:
        x_ind = 1
    else:
        x_ind = 2
    
    if rd_ind%2==0:
        ax[x_ind,0].scatter(frame_ind_steps,rel_pillars_dist,s=8)
        ax[x_ind,0].set_title(title_names[rd_ind])
        ax[x_ind,0].plot(frame_ind_smoothed,smoothed_dist,c='red',label='GPR')
        ax[x_ind,0].grid('on',ls=':')
    else:
        ax[x_ind,1].scatter(frame_ind_steps,rel_pillars_dist,s=8)
        ax[x_ind,1].set_title(title_names[rd_ind])
        ax[x_ind,1].plot(frame_ind_smoothed,smoothed_dist,c='red',label='GPR')
        ax[x_ind,1].grid('on',ls=':')

## smooth position first then relative dist - emma's cool with not doing this

In [None]:
# sm_px_loc_all = np.zeros_like(samp_px_loc)
# sm_py_loc_all = np.zeros_like(samp_py_loc)
# for ind in range(samp_px_loc.shape[1]):
#     cur_x_pos = samp_px_loc[:,ind]
#     cur_y_pos = samp_py_loc[:,ind]

#     sm_x_pos,_=smooth_with_GPR(cur_x_pos)
#     sm_y_pos,_=smooth_with_GPR(cur_y_pos)

#     sm_px_loc_all[:,ind] = sm_x_pos
#     sm_py_loc_all[:,ind] = sm_y_pos

# p0p1_sm_dist,p2p3_sm_dist,p1p2_sm_dist,p0p3_sm_dist,p1p3_sm_dist,p0p2_sm_dist = compute_relative_pillars_dist(sm_px_loc_all,sm_py_loc_all)

# cur_rel_sm_dist = [p0p1_sm_dist,p2p3_sm_dist,p1p2_sm_dist,p0p3_sm_dist,p1p3_sm_dist,p0p2_sm_dist]

# fig,ax = plt.subplots(ncols=2,nrows=3,figsize=(10,12))

# title_names=['p0p1 dist w/ smoothed pos','p2p3 dist w/ smoothed pos','p1p2 dist w/ smoothed pos',\
#              'p0p3 dist w/ smoothed pos','p1p3 dist w/ smoothed pos','p0p2 dist w/ smoothed pos']
# for rd_sm_ind,rel_pillars_sm_dist in enumerate(cur_rel_sm_dist):
#     frame_ind_steps = np.linspace(0,len(rel_pillars_sm_dist)-1,len(rel_pillars_sm_dist),dtype=int)

#     smoothed_dist_sm,post_kern = smooth_with_GPR(rel_pillars_sm_dist)
#     frame_ind_smoothed = np.linspace(0,len(smoothed_dist_sm)-1,len(smoothed_dist_sm),dtype=int)
#     print(post_kern)
    
#     if rd_sm_ind<=1:
#         x_ind = 0
#     elif rd_sm_ind>1 and rd_sm_ind<=3:
#         x_ind = 1
#     else:
#         x_ind = 2
    
#     if rd_sm_ind%2==0:
#         ax[x_ind,0].scatter(frame_ind_steps,rel_pillars_sm_dist,s=8)
#         ax[x_ind,0].set_title(title_names[rd_sm_ind])
#         ax[x_ind,0].plot(frame_ind_smoothed,smoothed_dist_sm,c='red',label='GPR')
#         ax[x_ind,0].grid('on',ls=':')
#     else:
#         ax[x_ind,1].scatter(frame_ind_steps,rel_pillars_sm_dist,s=8)
#         ax[x_ind,1].set_title(title_names[rd_sm_ind])
#         ax[x_ind,1].plot(frame_ind_smoothed,smoothed_dist_sm,c='red',label='GPR')
#         ax[x_ind,1].grid('on',ls=':')