In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
import os
import shapely
import math
# import geopandas as gpd

import skimage as ski
from skimage.morphology import dilation, square
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.restoration import inpaint

In [2]:
#run = "labeled"
run = "pred"

In [3]:
if run == "labeled":
    dirpath = os.getcwd()
    kp_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r"udder_labels\labels\keypoints")
    sg_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r"udder_labels\labels\segments")
    im_dir = os.path.join(os.path.normpath(dirpath + os.sep + os.pardir), r"udder_labels\frames_tolabel_depth")
    out_dir = r"results_watershed_labeled2"
    cow_list = os.listdir(im_dir)
else: 
    kp_dir = r"pred_labels\keypoints"
    sg_dir = r"pred_labels\segments"
    im_dir = r"depth_images"
    out_dir = r"results_watershed_predicted"
    cow_list = os.listdir(im_dir)
    
rc = {"axes.spines.left" : False,
      "axes.spines.right" : False,
      "axes.spines.bottom" : False,
      "axes.spines.top" : False,
      "xtick.bottom" : False,
      "xtick.labelbottom" : False,
      "ytick.labelleft" : False,
      "ytick.left" : False}

plt.rcParams.update(rc)
dil_factor = 30
ratio_limit = 4
iter_limit = 10
cols = 6
col = ".r"

In [4]:
def mk_dir(dirpath):
    if not os.path.exists(dirpath):
        os.mkdir(dirpath)
        
def area_ratio(labels):
    values = np.max(labels)
    areas = []
    for value in range(values):
        area = len(labels[labels==value+1])
        areas.append(area)
    return max(areas)/min(areas)

def get_angle(right_kp, left_kp):
    angle = np.arctan2(right_kp[1]-left_kp[1], right_kp[0]-left_kp[0])
    return angle
def get_center(right_kp, left_kp):
    return shapely.centroid(MultiPoint([right_kp, left_kp])).coords[0] 
def get_orientation(right_kp, left_kp):
    if right_kp[0] < left_kp[0]:
        orientation = -1 # up 
    else: 
        orientation = 1 # down
    return orientation

def sep_points(right_kp, left_kp, udder_shp, box, limit):
    global im_width, im_height
    wdist = np.linalg.norm(right_kp-left_kp)
    cnt = 0
    k = get_orientation(right_kp,left_kp)
    while (wdist < min(box[1, 0],  box[1, 1])/2) & (cnt < limit):
        angle = get_angle(right_kp, left_kp)
        nrb_point = [right_kp[0] + 10*np.cos(-k*angle), right_kp[1] + 10*np.sin(-k*angle)]
        nlb_point = [left_kp[0] - 10*np.cos(-k*angle), left_kp[1] - 10*np.sin(-k*angle)]
        # make sure they are still inside the udder
        if (udder_shp.contains(shapely.Point(nrb_point))): # & (nrb_point[0]>0) & (nrb_point[0] <= im_width) & (nrb_point[1]>0) & (nrb_point[1] <= im_height):
            # update points 
            right_kp = np.array(nrb_point)
        if (udder_shp.contains(shapely.Point(nlb_point))):# & (nlb_point[0]>0) & (nlb_point[0] <= im_width) & (nlb_point[1]>0) & (nlb_point[1] <= im_height):
            left_kp = np.array(nlb_point)
        wdist = np.linalg.norm(right_kp-left_kp)
        cnt += 1
    return (np.floor(right_kp).astype(int), np.floor(left_kp).astype(int))

In [5]:
for cow in cow_list:
    filenames = os.listdir(os.path.join(im_dir, cow))
    num = len(filenames)
    rows = int(math.ceil(num/cols))
    col_idx = list(range(0,cols))*rows
    fig, axs = plt.subplots(rows, cols, figsize=(12, 8), tight_layout=True, frameon = False)
    axs = axs.flat
    fig.suptitle(f"cowID: {cow}")
    # out_path = os.path.join(out_dir, cow)
    # mk_dir(out_path)
    
    for ax, file in zip(axs, filenames):
        
        label = file.replace(".tif", ".txt")
        udder = ski.io.imread(os.path.join(im_dir, cow, file))
        # image shape is in y,x oder
        im_size = udder.shape
        im_height = im_size[0]
        im_width = im_size[1]
        
        miss_mask = udder.copy()
        miss_mask[: :] = 0
        miss_mask[udder == 0] = 1
        inp_udder = inpaint.inpaint_biharmonic(udder, miss_mask)
        
        # segments are in x,y oder 
        with open(os.path.join(sg_dir, label), "r") as f:
            mask = np.array([float(point) for point in f.read().split(" ")][1:])
        mask = mask.reshape((int(len(mask)//2),2))

        # keypoints are in x,y oder
        with open(os.path.join(kp_dir, label), "r") as f:
            data =  [float(point) for point in f.read().split(" ")]
            box = np.array(data[1:5])
            points = np.array(data[5:])
        points = points.reshape((4,3))
        box = box.reshape((2,2))
        box[:, 0] = box[:, 0] * im_width
        box[:, 1] = box[:, 1] * im_height
        box[0, 0] = box[0, 0] - box[1, 0]/2
        box[0, 1] = box[0, 1] - box[1, 1]/2

        points[:, 0] = points[:, 0] * im_width
        points[:, 1] = points[:, 1] * im_height

        # polygon requires y,x order
        polygon = [[coord[1] * im_height, coord[0]*im_width] for coord in mask]
        polygon2 = [[coord[0]*im_width, coord[1] * im_height] for coord in mask]
        mask2 = ski.draw.polygon2mask(im_size, polygon)
        masked_udder = inp_udder*mask2
        
        mask1 = np.zeros(im_size)
        points2 =np.round(points,0).astype(int)
        
        lf_kp = points[0, :2]
        rf_kp = points[1, :2]
        lb_kp = points[2, :2]
        rb_kp = points[3, :2]
        
        # location of teats
        lf_point = shapely.Point(points[0, :2])
        rf_point = shapely.Point(points[1, :2])
        lb_point = shapely.Point(points[2, :2])
        rb_point = shapely.Point(points[3, :2])

        udder_shp = shapely.Polygon(polygon2)

#         wd_front = shapely.distance(rf_point, lf_point)
#         wd_back = shapely.distance(rb_point, lb_point)

#         rf_ud = shapely.distance(udder_shp.exterior, rf_point)
#         lf_ud = shapely.distance(udder_shp.exterior, lf_point)
#         rb_ud = shapely.distance(udder_shp.exterior, rb_point)
#         lb_ud = shapely.distance(udder_shp.exterior, lb_point)

        # distances = {"wd_front": wd_front, "wd_back": wd_back, "rf_ud": rf_ud, "lf_ud":lf_ud, "rb_ud": rb_ud, "lb_ud": lb_ud}
        
        new_front = sep_points(rf_kp, lf_kp, udder_shp, box, iter_limit)
        points2[0, :2] = new_front[0]
        points2[1, :2] = new_front[1]

        new_back = sep_points(rb_kp, lb_kp, udder_shp, box, iter_limit)
        points2[2, :2] = new_back[0]
        points2[3, :2] = new_back[1]

        # if they are too close to the edge of the box move 10 units inside


        # marker locations
        mask1[points2[0, 1], points2[0,0]] = True
        mask1[points2[1, 1], points2[1,0]] = True
        mask1[points2[2, 1], points2[2,0]] = True
        mask1[points2[3, 1], points2[3,0]] = True

        mask1 = dilation(mask1,  square(dil_factor))
        markers, _ = ndi.label(mask1)
        # find segments
        labels = watershed(masked_udder, markers = markers, mask = mask2, watershed_line=True)

        # area of labels
        ratio = area_ratio(labels)

        cnt = 0
        # print(f"{cnt} cow: {cow}, ratio: {ratio}")
        while (ratio > ratio_limit) & (cnt < iter_limit): # and the number of segements is 4
            mask1 = dilation(mask1,  square(10))
            markers, _ = ndi.label(mask1)
            labels2 = watershed(masked_udder, markers = markers, mask = mask2, watershed_line=True)
            num_segments = np.max(labels2)
            if num_segments < 4:
                break
            else:
                labels = labels2
            ratio = area_ratio(labels)
            num_segments = np.max(labels2)
            cnt+= 1
            # print(f"\t {cnt} cow: {cow}, ratio: {ratio}")

        # make a nice plot!
        # fig, ax = plt.subplots()
        ax.imshow(labels, cmap=plt.cm.nipy_spectral)
        ax.plot(points2[0, 0], points2[0, 1], col) # right front
        ax.plot(points2[1, 0], points2[1, 1], col) # left front
        ax.plot(points2[2, 0], points2[2, 1], col) # right back
        ax.plot(points2[3, 0], points2[3, 1], col) # left back
        # fig.savefig(os.path.join(out_path, file.replace("tif", "png")))
        
    fig.savefig(os.path.join(out_dir, str(cow)+ ".png"))
    plt.close(fig)   

In [6]:
for cow in cow_list:
    filenames = os.listdir(os.path.join(im_dir, cow))
    num = len(filenames)
    rows = int(math.ceil(num/cols))
    frame_nums = [file.split("_")[-1].replace(".tif", "") for file in filenames]
    col_idx = list(range(0,cols))*rows
    fig, axs = plt.subplots(rows, cols, figsize=(12, 8), tight_layout=True, frameon = False)
    axs = axs.flat
    fig.suptitle(f"cowID: {cow}")
    
    
    for ax, file in zip(axs, filenames):
        label = file.replace(".tif", ".txt")
        udder = ski.io.imread(os.path.join(im_dir, cow, file))
        # image shape is in y,x oder
        im_size = udder.shape
        im_height = im_size[0]
        im_width = im_size[1]

        # segments are in x,y oder 
        with open(os.path.join(sg_dir, label), "r") as f:
            mask = np.array([float(point) for point in f.read().split(" ")][1:])
        mask = mask.reshape((int(len(mask)//2),2))
        polygon = np.zeros(mask.shape)
        polygon[:, 0] = mask[:, 0]*im_width
        polygon[:, 1] = mask[:, 1]*im_height
        
        # keypoints are in x,y oder
        with open(os.path.join(kp_dir, label), "r") as f:
            data =  [float(point) for point in f.read().split(" ")]
            box = np.array(data[1:5])
            points = np.array(data[5:])
        points = points.reshape((4,3))
        points[:, 0] = points[:, 0] * im_width
        points[:, 1] = points[:, 1] * im_height
        # ax.set_axis_off()
        ax.imshow(udder, cmap=plt.cm.nipy_spectral)
        ax.plot(points[0, 0], points[0, 1], col) # right front
        ax.plot(points[1, 0], points[1, 1], col) # left front
        ax.plot(points[2, 0], points[2, 1], col) # right back
        ax.plot(points[3, 0], points[3, 1], col) # left back
        ax.plot(polygon[:, 0], polygon[:, 1])
        
    
    # fig.savefig(os.path.join(out_dir, str(cow)+ ".png"))
    fig.savefig(os.path.join(out_dir, str(cow)+ "_2.png"))
    plt.close(fig) 