In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from collections import defaultdict
from functools import  partial
from queue import PriorityQueue, Queue
import warnings

import torch
import sam2
from sam2.build_sam import build_sam2, build_sam2_hf
from sam2.sam2_image_predictor import SAM2ImagePredictor as SAM

import cv2
import numpy as np
from PIL import  Image
import scipy
import skimage

import FastGeodis as geo

In [2]:
import sys
print(sys.getrecursionlimit())
sys.setrecursionlimit(1000000000)

3000


# Inference preparation

In [3]:
sigmoid = lambda x: 1 / (1 + np.exp(-x))
gaussian = lambda x,mean,std: np.exp(-((x - mean) / std) ** 2 / 2) / (2.5 * std)

In [4]:
def get_bbox(bound, pt, patch_size):
    x = min(bound[1] - patch_size[1], max(0, pt[1] - patch_size[1] // 2))
    y = min(bound[0] - patch_size[0], max(0, pt[0] - patch_size[0] // 2))
    return x, y, patch_size[0], patch_size[1]

In [5]:
def valid_pts(pts, bbox):
    if isinstance(pts, list):
        pts = np.array(pts)
    if pts.shape[0] == 0:
        return pts
        
    pts -= bbox[:2][::-1]
    check = np.ones(pts.shape[0])
    for i in range(pts.shape[1]):
        check *= np.where((pts[:, i] > 0) & (pts[:, i] < bbox[i + 2]), 1, 0)
    return pts[np.where(check > 0)]

In [6]:
def compose_prompts(positive, negative):
    pos_label = np.ones(positive.shape[0])
    neg_label = np.zeros(negative.shape[0])
    if negative.shape[0] == 0:
        return positive[:, ::-1].copy(), pos_label
    pts = np.concatenate([positive, negative], axis=0)
    labels = np.concatenate((pos_label, neg_label))
    return pts[:, ::-1].copy(), labels

In [7]:
# A wrapper for Sam inference
class SamInferer:
    def __init__(self, cfg = "", 
                    ckpt: str = "",
                    model_id: str | None = None,
                    patch_size = [512, 512],
                    roi=[128, 128],
                    root_area=500,
                    max_roots=3,
                    patience = 5,
                    min_dis = 10,
                    back_off = 5,
                    alpha=0.1, 
                    beta=0.5, 
                    post_act=True, 
                    min_length=5,
                    kernel_size=3,
                    fill_kernel_size=7,
                    neg_dis=15,
                    pos_sampling_grid=3,
                    neg_sampling_grid=6,
                    thresh=0.75,
                    decay=0.5,
                    ):
        if model_id is None:
            self.predictor = SAM(build_sam2(cfg, ckpt))
        else: 
            self.predictor = SAM(build_sam2_hf(model_id))
        # Guidance and queries
        self.pos = np.zeros([0, 2], dtype=int)
        self.pos_sampling_grid = pos_sampling_grid
        self.hist = np.zeros([0, 2], dtype=float)
        self.patience = patience
        self.min_dis = min_dis
        self.back_off = 5
        self.queue = PriorityQueue()

        # Negative sampling
        self.neg = np.zeros([0, 2], dtype=int)
        self.neg_dis = neg_dis
        self.neg_sampling_grid = neg_sampling_grid
        # We gonna prioritize long flow over short ones
        self.root = None

        # Context related
        self.step = 0
        self.a_mask = None
        self.b_mask = None
        self.image = None 
        self.label = None
        self.alpha = alpha
        self.beta = beta
        self.decay = decay
        self.post_act = post_act
        self.weight = None
        self.logits = None # Post-sigmoid or pre-sigmoid dependent
        self.var = None
        # kernel configuration
        self.patch_size = np.array(patch_size)
        self.w_kernel = [cv2.getGaussianKernel(patch_size[0], roi[0]), cv2.getGaussianKernel(patch_size[1], roi[1])]
        self.w_kernel = (self.w_kernel[0] / self.w_kernel[0][0, 0]) * (self.w_kernel[1] / self.w_kernel[1][0, 0]).T
        self.w_kernel /= self.w_kernel.sum() 
        # TO prevent vanishing
        self.w_kernel /= self.w_kernel.min()
        # Uncertainty modelling
        
        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        fill_kernel_size = (fill_kernel_size, fill_kernel_size) if isinstance(fill_kernel_size, int) else fill_kernel_size
        self.close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size)
        self.fill_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, fill_kernel_size)
        self.stable_thresh = thresh
        
        # Flow skeleton
        self.atlas = None
        self.graph = defaultdict(list)
        self.min_length = min_length
        self.parent = dict()
        self.roi = None
        self.graph_root = np.zeros([0, 2], dtype=int)
        self.root_area = root_area
        self.max_roots = max_roots

    def pop(self):
        if self.queue.qsize() == 0:
            print("Empty queue !!!")
            return None
        return self.queue.get()[1]
        
    def compose_prompts(self):
        def valid_pts(pts):
            pts -= self.root[None, :]
            check = np.ones(pts.shape[0])
            for i in range(pts.shape[1]):
                check *= np.where((pts[:, i] > 0) & (pts[:, i] < self.patch_size[i]), 1, 0)
            return pts[np.where(check > 0)]
        
        valid_pos = valid_pts(self.pos.copy())
        pos_label = np.ones(valid_pos.shape[0])
        if self.neg.shape[0] == 0:
            return valid_pos[:, ::-1].copy(), pos_label
        # Generated locally so no need for projection
        valid_neg = self.neg
        neg_label = np.zeros(valid_neg.shape[0])
        # Must be in xy format
        return np.concatenate([valid_pos, valid_neg], axis=0)[:, ::-1].copy(), np.concatenate([pos_label, neg_label], axis=0)

    def read(self, image_path, channels=3):
        img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        
        if img.shape[-1] > channels: 
            self.label = img[..., channels]
        # Sharpen for better sense of boundary
        # hsv_image = cv2.cvtColor(img[..., :channels], cv2.COLOR_BGR2HSV_FULL)
        # unsharp = cv2.GaussianBlur(hsv_image, (3, 3), 0)
        # hsv_image = 2 * hsv_image - unsharp
        # self.image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB_FULL)
        self.image = img[..., :channels][..., ::-1]
        self.a_mask = np.zeros(self.image.shape[:2], dtype=np.uint8)
        self.b_mask = np.zeros(self.image.shape[:2], dtype=np.uint8)
        self.hist = np.zeros_like(self.b_mask)
        self.logits = np.zeros(self.image.shape[:2])
        self.var = np.zeros_like(self.logits)
        self.weight = np.full(self.image.shape[:2], 1e-6)
        self.beta = np.full(self.image.shape[:2], self.beta) * self.weight
        self.output = np.zeros_like(self.a_mask)
        self.box = self.image.shape[:2]

        self.atlas = np.zeros_like(self.b_mask)
    
    # Generate negative prior
    def negative_sampling(self, debug = False):
        # Prepare
        pts = (self.pos - self.root).round().astype(int)
        dst = self.root + self.patch_size
        a_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.neg_dis, self.neg_dis))
        grad_kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (5, 5))
        
        mask = self.b_mask[self.root[0]:dst[0], self.root[1]:dst[1]].copy()        
        gradient = cv2.morphologyEx(cv2.dilate(mask, a_kernel, iterations=5), cv2.MORPH_GRADIENT, grad_kernel)
        for pt in pts:
            cv2.circle(gradient, pt[::-1], self.neg_dis * 3, 0, -1)

        alpha_mask = scipy.ndimage.binary_fill_holes(self.a_mask[self.root[0]: dst[0], self.root[1]:dst[1]]).astype(int).astype(np.uint8)
        alpha_mask = cv2.dilate(alpha_mask,  a_kernel, 3)

        negative_field = gradient * (1 - alpha_mask)
        # Discretize to 5 bin
        output = grid_sampling(negative_field.astype(float), grid=self.neg_sampling_grid, alpha=0.01)
        dis = np.triu(np.linalg.norm(output[None, :] - output[:, None, :], axis=2))
        dis[dis == 0] = 1e5
        drop = np.where(dis < self.neg_dis)[0]
        accepted_neg = [i for i in range(output.shape[0]) if i not in drop]
        output = output[accepted_neg]
        
        return {'pts': output} if not debug else {'pts': output, 'field': negative_field, 'b': alpha_mask, 'a': gradient}


    def roi_(self, x, y):
        if self.roi is None:
            self.roi = np.array([y, x, y + self.patch_size[0], x + self.patch_size[1]])
            return 
        self.roi[2] = max(self.roi[2], y + self.patch_size[0])
        self.roi[0] = min(self.roi[0], y)
        self.roi[3] = max(self.roi[3], x + self.patch_size[1])
        self.roi[1] = min(self.roi[1], x)
        
    def infer(self, debug=False):
        prompt = self.pop()
        if prompt is None:
            return {'ret': False}

        # if self.check_hist(prompt): 
        #     print(f"Same position got infered for too many times, Nothing changed")
        #     return {'ret': False}
        
        
        x, y, _, _ = get_bbox(self.box, prompt, self.patch_size)
        self.roi_(x, y)
        # Change base for referencing and post-processing.
        self.root = np.array([y, x])
        dst = self.root + self.patch_size
        
        # Check for validity
        
        # if self.check_hist(prompt):
        #     print("Same spot has been inferenced too many times, skipping")
        #     return {'ret': False}
            
        # Image in RGB format
        patch = self.image[self.root[0]:dst[0], self.root[1]:dst[1]].copy()
        self.predictor.set_image(patch)
        # Negative sampling
        neg = self.negative_sampling(debug=debug)
        self.neg = neg['pts']
        # Positive prompting
        input_mask = cv2.resize(self.b_mask[self.root[0]: dst[0], self.root[1]:dst[1]], (256, 256))
        self.pos = np.array(prompt)[None, :]
        # self.pos = np.concatenate([self.graph_root, self.pos], axis=0)
        
        annotation, a_label = self.compose_prompts()
        # print(annotation)
        # a_mask = self.logits[self.root[0]: dst[0], self.root[1]:dst[1]] / self.weight[self.root[0]: dst[0], self.root[1]:dst[1]]
        # if not self.post_act:
        #     a_mask = sigmoid(a_mask)
        # # Quantile for recall
        # # Always lower bound it for measure
        # score = max(0.4, np.quantile(a_mask, 0.98))
        # a_mask = np.where(a_mask > score, 1, 0).astype(np.float32)
        
        # a_mask = cv2.resize(a_mask, (256, 256), interpolation=cv2.INTER_LINEAR)
        # print(a_mask.max())
        # Mind that mask input must halve the size, for size matching
        print(input_mask.shape, a_label.shape, annotation.shape )
        masks, scores, logits = self.predictor.predict(point_coords=annotation, 
                                                        point_labels=a_label, 
                                                        mask_input=input_mask[None, :] if input_mask.max() > 0 else None, 
                                                        multimask_output=False)
        # cv2.resize(logits[0], self.patch_size, interpolation=cv2.INTER_LINEAR)
        return {'ret': True,
                'input': patch,
                'mask': masks[0], 
                'score': scores[0], 
                'logit': cv2.resize(logits[0], self.patch_size, interpolation=cv2.INTER_LINEAR), 
                'pts': annotation, 
                'label': a_label,
                'inp_mask': self.b_mask[self.root[0]: dst[0], self.root[1]:dst[1]].copy(),
                'prompt': prompt,
                'negative': neg}
    
    # Allow pipeline injection
    def morphology_centers(self, segmentation, weight, minArea=2400, minW=7.):
        # From an unknown respected lad
        def get_center_of_mass(cnt):
            M = cv2.moments(cnt)
            cx = int(M['m10']/M['m00'])
            cy = int(M['m01']/M['m00'])
            return cy, cx
        label_map = skimage.measure.label(segmentation, connectivity=2)
        labels, counts = np.unique(label_map, return_counts=True)
        # print(counts)
        centers = []
        output_segmentation=
        for label, count in zip(labels[1:min(len(labels), 1 + self.max_roots)], counts[1:min(len(labels), 1 + self.max_roots)]): 
            if count > minArea: 
                mask = (label_map == label).astype(np.uint8)
                # Force the isle to be stable
                if weight[mask == 1].mean() <= minW:
                    continue
                cnt = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0][0]
                centers.append((get_center_of_mass(cnt)))
        return {'centers': np.array(centers), 
                'mask': }
    
    def iter(self, seg_res=None, debug=False):
        self.step += 1
        if seg_res is None:
            seg_res = self.infer(debug=debug) 
        if seg_res['ret'] is False:
            return seg_res
        dst = self.root + self.patch_size

        # Updating primitives
        print(self.patch_size, seg_res['mask'].shape, seg_res['mask'].dtype)
        # Never let it lower than alpha
        score = max(seg_res['score'], self.alpha) 
        if score < self.alpha:
            warnings.warn(f"Model confidence {score} is hazardous, please make prompt to escape uncertainty")
        print(f"Score {score}")
        print(f"With {seg_res['label'].sum()} positives and {seg_res['label'].shape[0] - seg_res['label'].sum()} negatives")
        # Clean background from static gain
        prob_map = sigmoid(seg_res['logit']) 
        quantiles = [0.8, 1]
        sv = [0, ]
        dv = [0, 0, 1]
        for val in quantiles:
            sv.append(np.quantile(prob_map[prob_map > sv[-1]], val))
        print(dict(zip(sv, dv)))
        prob_map = np.interp(prob_map, sv, dv)

        # Weight ensemble
        weight = score * self.w_kernel
        # When weight is low, skip it to the next pred
        ensemble_kernel = self.w_kernel / self.w_kernel.max() * (1 - self.decay) * (self.weight[self.root[0]: dst[0], self.root[1]:dst[1]] > 1)
        # Well, there is a case where confidence score is too low, so add tolerance 
        self.beta[self.root[0]: dst[0], self.root[1]:dst[1]] += score * weight

        # Update logits
        self.var = ensemble_kernel * self.var[self.root[0]: dst[0], self.root[1]:dst[1]] + (1 - ensemble_kernel) * np.maximum((prob_map - self.logits[self.root[0]: dst[0], self.root[1]:dst[1]]) ** 2, 0.09)
        
        if self.post_act:
            self.logits[self.root[0]: dst[0], self.root[1]:dst[1]] += prob_map * weight
        else: 
            self.logits[self.root[0]: dst[0], self.root[1]:dst[1]] += seg_res['logit'] * weight
        self.weight[self.root[0]: dst[0], self.root[1]:dst[1]] += weight
        
        # Subtract to get gain properties
        root = self.roi[:2]
        dst = self.roi[2:]
        beta = self.beta[root[0]: dst[0], root[1]:dst[1]] / self.weight[root[0]: dst[0], root[1]:dst[1]]
        std = np.sqrt(self.var[root[0]: dst[0], root[1]:dst[1]])
        prob_map = self.logits[root[0]: dst[0], root[1]:dst[1]] / self.weight[root[0]: dst[0], root[1]:dst[1]]
        prob_map = sigmoid(prob_map) if not self.post_act else prob_map
        confidence = 1 - gaussian(beta, prob_map, std)
        # Get beta_mask

        quantized_mask = np.round(np.power(prob_map, 1 - (prob_map - beta) / (1 - beta)) * 200)
        quantized_mask[quantized_mask > 255] = 255
        beta_mask = cv2.adaptiveThreshold(quantized_mask.astype(np.uint8), 
                                            1, 
                                            cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                            cv2.THRESH_BINARY, 
                                            151, 
                                            -10)
        beta_mask = prune(beta_mask, min_size=25)
        
        self.b_mask[root[0]: dst[0], root[1]:dst[1]] = np.maximum(self.b_mask[root[0]: dst[0], root[1]:dst[1]], beta_mask) 
        
        possible = (prob_map >= self.alpha).astype(np.uint8)
        # possible -= possible * self.b_mask[root[0]: dst[0], root[1]:dst[1]]
        possible = cv2.morphologyEx(possible, cv2.MORPH_OPEN, self.close_kernel, 2)
        possible = scipy.ndimage.binary_fill_holes(possible).astype(np.uint8)
        possible = prune(possible, min_size=50)
        possible = smooth_mask(possible, 7, 3)
        self.a_mask[root[0]: dst[0], root[1]:dst[1]] = cv2.bitwise_or(self.a_mask[root[0]: dst[0], root[1]:dst[1]], possible)
        possible = np.maximum(possible, self.b_mask[root[0]: dst[0], root[1]:dst[1]])
        possible = cv2.morphologyEx(possible, cv2.MORPH_CLOSE, self.fill_kernel, 2)
        # Stabilize skeleton
        thin, thin_w = morpholgy_thinning(possible, return_weight=True)
        print("Thin:", possible.shape, thin_w.shape, thin.shape, possible.max())
        
        stable = cv2.morphologyEx(thin_w + self.atlas[root[0]: dst[0], root[1]:dst[1]], cv2.MORPH_DILATE + cv2.MORPH_CLOSE, self.fill_kernel, 5)
        stable = prune(stable, min_size=25)
        # stable = smooth_mask(stable, 7, 3)
        # stable = getLargestCC(stable).astype(int).astype(np.uint8)
        # Skeletonize pruned mask
        
        print("Stable:", stable.shape)
        skeleton = skimage.morphology.skeletonize(stable)
        # skeleton = getLargestCC(skeleton).astype(int).astype(np.uint8)
        # Flow map generation output['dist']
        print(score, skeleton.shape, prob_map.shape)
        w_map = (skeleton * prob_map) / score
        w_map[self.b_mask[root[0]: dst[0], root[1]:dst[1]] * skeleton == 1] = 1
        w_map[w_map > 1] = 1 

        # Update root to center of mass
        self.graph_root = self.morphology_centers(self.b_mask[root[0]: dst[0], root[1]:dst[1]], minArea=self.root_area) + root[None, :]
        
        # import IPython; IPython.embed()
        # Given the fact that the prompting 
        total_branches = []
        canvas = np.zeros_like(thin_w)
        for idx in range(self.graph_root.shape[0]):
            w_pts = np.array(np.where(w_map == 1)).T
            # print(w_pts.shape[0)
            dist = np.linalg.norm(w_pts - (self.graph_root[idx] - root), axis=1)
            nn = w_pts[np.argmin(dist)]
            # Flow porting
            # Update available mask
            ref_mask = self.b_mask[root[0]: dst[0], root[1]:dst[1]].copy()
            tree = dfs_tree(w_map.copy(), 
                            ref_mask, 
                            tuple(nn), 
                            alpha=0.01, 
                            thresh=0.8)
            
            # Ordering from leaves to roots
            branches = longest_path_branching(tree['dfs_tree'], tuple(nn))
            valid_branches = [branches[i] for i in range(len(branches)) if len(branches[i]) >= self.min_length]
            total_branches += valid_branches
            for branch in valid_branches:
                x, y = branch[self.back_off]
                # Only get outgoing vertexes, 
                # if it loops into the main stream
                # Then its a hole and already solved by fill holes
                # if ref_mask[x, y] == 0: 
                #     # Roll back to prevent overflow
                self.add_queue(root + branch[self.back_off], prior = prob_map[x, y])
                # else: 
                #     print("Point in mask already")
                # Draw on canvas for mask extraction later
                for node in branch[3:]:
                    # canvas[node[0], node[1]] = thin_w[node[0], node[1]]
                    canvas[node[0], node[1]] = 1
            
            self.b_mask[root[0]: dst[0], root[1]:dst[1]] = np.maximum(self.b_mask[root[0]: dst[0], root[1]:dst[1]], canvas)
            roi = (skimage.segmentation.flood_fill(self.b_mask[root[0]: dst[0], root[1]:dst[1]], 
                                                   tuple(nn), 
                                                   new_value=2,
                                                   connectivity=2) == 2).astype(np.uint8)
            self.hist[root[0]: dst[0], root[1]:dst[1]] += cv2.dilate(roi, self.fill_kernel, iterations=5).astype(np.uint8)
        self.atlas[root[0]: dst[0], root[1]:dst[1]] = np.maximum(self.atlas[root[0]: dst[0], root[1]:dst[1]], canvas)
        # Prepare for distance transform, the mask is from softmax, what can it possibly be ? 
        # image = torch.from_numpy(canvas).float().cuda()[None, None, :]
        # mask = torch.from_numpy(prob_map).float().cuda()[None, None, :]
        
        # dist = geo.GSF2d(image, mask, theta=1., v=8, lamb=.8, iter = 4).cpu().numpy()[0, 0]
        # dist[possible == 0] = dist.min()
        # accepted_mask = (dist > np.quantile(dist[dist > dist.min()], 0.2)).astype(np.uint8)
        # accepted_mask = getLargestCC(accepted_mask).astype(np.uint8)
        metrics = eval(self.b_mask[root[0]: dst[0], root[1]:dst[1]].copy(), self.label[root[0]: dst[0], root[1]:dst[1]])
        # del image
        # del mask
                # self.update_graph(branch)

        if debug:
            return {'ret': True,
                    'infer': seg_res,
                    'confidence': confidence,
                    'beta': self.b_mask[root[0]: dst[0], root[1]:dst[1]].copy(),
                    'prob_map': prob_map,
                    'stable': stable,
                    'thin': thin_w, 
                    'branches': total_branches,
                    'possible': possible,
                    'dist': dist,
                    'canvas': canvas,
                    'metrics': metrics
                    }
        else: 
            return {'ret': True,
                    'infer': seg_res,
                    'branches': valid_branches,
                    'metrics': metrics}

    def update_graph(self, path):
        # Inverse sampling from root to leaves
        for i in range(len(path) - 1, 0, -1):
            self.bi_add(tuple(self.root + path[i-1]), tuple(self.root + path[i]))

    def bi_add(self, src, dst):
        self.graph[src].append(dst)
        self.graph[dst].append(src)

    def check_hist(self, pt):
        return self.hist[pt[0], pt[1]] >= self.patience
    
    def add_queue(self, pt: list, prior: float = 1, isroot: bool =False):
        if isroot:
            # self.pos = np.concatenate([self.pos, np.array([pt])], axis=0)
            self.queue.put((prior, pt))
            self.graph_root = np.array(pt)[None, :]
            return 
        
        if self.check_hist(pt): 
            print(f"Same position got infered for too many times, skipping {pt}")
            return
            
        if pt[0] < 400:
            print(f"{pt} goes out of ROI")
            return
            
        score = float(2 + prior + self.step * 0.1 + np.random.randn() * 0.01)
        entry = (score, pt)
        try:
            self.queue.put(entry)
        except:
            pass


# Bounding box Processing

In [8]:
def compute_iou(box1, box2):
    """Compute Intersection over Union (IoU) between two xyxy boxes."""
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2

    # Compute intersection
    xi1, yi1 = max(x1, x1_p), max(y1, y1_p)
    xi2, yi2 = min(x2, x2_p), min(y2, y2_p)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    # Compute union
    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x2_p - x1_p) * (y2_p - y1_p)
    union_area = box1_area + box2_area - inter_area

    return inter_area / union_area if union_area > 0 else 0

def compute_liou(box1, box2): 
    """Compute Intersection over Union (IoU) between two xyxy boxes."""
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2

    # Compute intersection
    xi1, yi1 = max(x1, x1_p), max(y1, y1_p)
    xi2, yi2 = min(x2, x2_p), min(y2, y2_p)
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)

    # Compute union
    box1_area = (x2 - x1) * (y2 - y1)

    return inter_area / box1_area if box1_area > 0 else 0
    
def non_max_suppression(boxes, iou_threshold=0.5):
    """Prune overlapping bounding boxes using Non-Maximum Suppression (NMS)."""
    if len(boxes) == 0:
        return []

    # Sort boxes by area (largest first)
    boxes = sorted(boxes, key=lambda b: (b[2] - b[0]) * (b[3] - b[1]), reverse=True)
    selected_boxes = []

    while boxes:
        best_box = boxes.pop(0)
        selected_boxes.append(best_box)
        
        boxes = [box for box in boxes if compute_iou(best_box, box) < iou_threshold]

    return selected_boxes

# Output processing

In [9]:
def smooth_mask(mask, kernel_size, sigma):
    output = np.zeros_like(mask)
    kernel = cv2.getGaussianKernel(kernel_size, sigma).squeeze()
    cnts, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    refined = [np.stack([np.convolve(kernel, cnt[:, 0, i], mode='valid') for i in range(cnt.shape[-1])], axis=-1).round().astype(np.int32)[:, None, :] for cnt in cnts]
    cv2.drawContours(output, refined, -1, 1, -1)
    return output

In [10]:
def getLargestCC(segmentation):
    labels = skimage.measure.label(segmentation, connectivity=2)
    assert( labels.max() != 0 ) # assume at least 1 CC
    largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
    return largestCC

In [11]:
def prune(mask, min_size=5):
    output = np.zeros(mask.shape, np.uint8)
    _, label_im = cv2.connectedComponents(mask.astype(np.uint8), connectivity=8, ltype=cv2.CV_16U)
    labels, counts = np.unique(label_im, return_counts=True)
    for label, count in zip(labels[1:], counts[1:]):
        im = label_im == label
        if count >= min_size:
            output = cv2.bitwise_or(output, im.astype(np.uint8))
    return output

In [12]:
def morpholgy_thinning(mask, return_weight=False):
  #thinning word into a line
  # Structuring Element
  kernel = cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))
  close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3))
  weight = np.zeros(mask.shape, dtype=np.uint8)
  # early stopping
  if cv2.countNonZero(cv2.erode(mask,kernel)) == 0:
    if return_weight: 
      return mask, weight
    return mask

  # Create an empty output image to hold values
  thin = np.zeros(mask.shape,dtype='uint8')
  # Loop until erosion leads to an empty set
  while cv2.countNonZero(mask)!= 0:
    # Erosion
    erode = cv2.erode(mask,kernel)
    # Opening on eroded image
    opened = cv2.morphologyEx(erode,cv2.MORPH_OPEN,close_kernel)
    # Subtract these two
    subset = erode - opened
    # Union of all previous sets
    thin = cv2.bitwise_or(subset,thin)
    # Keep the cummulative for weighting
    weight += thin
    # Set the eroded image for next iteration
    mask = erode.copy()
  
  if not return_weight:
    return thin
  else:
    return thin, weight


In [13]:
def dfs_tree(mask, start):
    rows, cols = mask.shape[:2]
    stack = [start]
    visited = set()
    parent = {}
    directions = [(-1, 0),(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)] 
    dfs_tree = defaultdict(list)
    
    while stack:
        x, y = stack.pop()

        if (x, y) in visited:
            continue
        
        visited.add((x, y))
        
        for dx, dy in directions:
            nx, ny = x + dx, y + dy
            if 0 <= nx < rows and 0 <= ny < cols and mask[nx, ny] == 1 and (nx, ny) not in visited:
                stack.append((nx, ny))
                parent[(nx, ny)] = (x, y)
                dfs_tree[(x, y)].append((nx, ny))
    
    return dfs_tree, parent

In [14]:
def longest_path(tree, start):
    def dfs(node, path):
        path.append(node)
        max_path = path[:]
        
        for neighbor in tree[node]:
            new_path = dfs(neighbor, path[:])
            if len(new_path) > len(max_path):
                max_path = new_path
        
        return max_path
    
    return dfs(start, [])

In [15]:
def negative_field(logit_map, distance=15, beta=0.5, alpha=0.05):
    if isinstance(distance, int):
        distance = (distance, distance)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, distance)
    dilated_mask = cv2.morphologyEx(cv2.dilate((logit_map > beta).astype(int).astype(np.uint8), kernel, iterations=5), cv2.MORPH_GRADIENT, kernel)
    possible = scipy.ndimage.binary_fill_holes(np.where(logit_map > alpha, 1, 0)).astype(int).astype(np.uint8)
    dilated_possible = cv2.dilate(possible, kernel, iterations=3)
    negative_field = dilated_mask - dilated_mask * dilated_possible
    return negative_field


In [16]:
def grid_sampling(mask, grid=8, alpha=0.1):
    if isinstance(grid, int):
        grid = (grid, grid)
    # We cannot sampling on grid
    x, y = np.linspace(0, 1, grid[0])[:-1], np.linspace(0, 1, grid[1])[:-1]
    patch_size = np.array(mask.shape[:2]) // grid
    mesh = np.floor(np.stack(np.meshgrid(x, y), axis=-1).reshape(-1, 2) * mask.shape[:2]).astype(int)
    def sample(src):
        dst = src + patch_size
        if np.mean(mask[src[0]:dst[0],src[1]:dst[1]]) < alpha:
            return [-1, -1]
        possible = np.array(np.where(mask[src[0]:dst[0],src[1]:dst[1]] > 0)).T
        return possible[np.random.randint(0, high=possible.shape[0])] + src
    output = np.apply_along_axis(sample, 1, mesh)
    return output[output[:, 0] > 0]

# Geometry Graph processing

In [None]:
directions = np.array([(-1, 0),(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]) 

In [None]:
def unilateral_dfs_tree(g_mask, inp_mask, start, weight = 1, alpha=0.1, thresh=0.95, context_size=3, dis_map=None):
    rows, cols = g_mask.shape[:2]
    stack = PriorityQueue()
    stack.put((0, start))

    # Dijikstra
    status = np.zeros(g_mask.shape, dtype=int)

    # This to indicate strictly incremental path
    g_mask = np.abs(g_mask - 0.001)
    cost = np.full_like(g_mask, 1e4)
    
    px, py = start
    cost[px, py] = 0
    leaves = set()
    border = set()
    begin = start
    parent = {}
    dfs_tree = defaultdict(list)
    
    while not stack.empty():
        
        state, (moves, (x, y)) = stack.get()
        i = 0
        prior = np.mean(moves, axis=0)
        accepted_dir = directions[ (directions * prior).sum(axis=-1) >= 0]
        for di in directions:
            (dx, dy) = di.tolist()
            if len(moves) >= context_size: 
                moves.pop(0)
            di_state = moves + [(dx, dy)]
            
            nx, ny = x + dx, y + dy
            
            if dis_map is not None:
                dis_map[nx, ny] = np.mean(di_state, axis=0)
            
            if 0 <= nx < rows and 0 <= ny < cols and g_mask[nx, ny] > alpha and status[nx, ny] <= status[x, y]:
                # Update on inverse confidence
                if cost[nx, ny] > state + (1 - g_mask[nx, ny]) * weight:
                    # Set cost as uncertainty gain 
                    i += 1
                    cost[nx, ny] = state + (1 - g_mask[nx, ny]) * weight
                    # Start by largest margin
                    stack.put((cost[nx, ny], (di_state, (nx, ny))))
                    # Erase entry from other branch
                    if (nx, ny) in parent.keys():
                        # print(dfs_tree[parent[(nx, ny)]])
                        dfs_tree[parent[(nx, ny)]].remove((nx, ny))

                    parent[(nx, ny)] = (x, y)
                    # Determine that they have gone out of beta mask
                    # Odd for out-going, Even for in-going.
                    if (g_mask[nx, ny] - thresh) * (g_mask[x, y] - thresh) < 0 :
                        if status[nx, ny] == 0:
                            border.add((nx, ny))
                        status[nx, ny] = status[x, y] + 1
                    else:
                        status[nx, ny] = status[x, y] 
                    dfs_tree[(x, y)].append((nx, ny))
            
        if i == 0:
            # Force leaves to be sink
            status[x, y] += inp_mask[x, y] % 2
            leaves.add((x, y))
            # Continual of flow
            if inp_mask[x, y] == 1:
                begin = (x, y)
    return {'dfs_tree': dfs_tree, 
            'parent': parent, 
            'cost': cost, 
            'border': border, 
            'leaves': leaves, 
            'status': status, 
            'begin': begin,
            'dis_map': dis_map}

In [17]:
def dfs_tree(g_mask, inp_mask, start, weight = 1, alpha=0.1, thresh=0.95):
    rows, cols = g_mask.shape[:2]
    stack = PriorityQueue()
    stack.put((0, start))

    # Dijikstra
    status = np.zeros(g_mask.shape, dtype=int)

    # This to indicate strictly incremental path
    g_mask = np.abs(g_mask - 0.001)
    cost = np.full_like(g_mask, 1e4)
    px, py = start
    cost[px, py] = 0
    leaves = set()
    border = set()
    begin = start
    parent = {}
    directions = [(-1, 0),(-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)] 
    dfs_tree = defaultdict(list)
    
    while not stack.empty():
        
        state, (x, y) = stack.get()
        i = 0
        for dx, dy in directions:
            nx, ny = x + dx, y + dy
            if 0 <= nx < rows and 0 <= ny < cols and g_mask[nx, ny] > alpha and status[nx, ny] <= status[x, y]:
                # Update on inverse confidence
                if cost[nx, ny] > state + (1 - g_mask[nx, ny]) * weight:
                    # Set cost as uncertainty gain 
                    i += 1
                    cost[nx, ny] = state + (1 - g_mask[nx, ny]) * weight
                    # Start by largest margin
                    stack.put((cost[nx, ny], (nx, ny)))
                    # Erase entry from other branch
                    if (nx, ny) in parent.keys():
                        # print(dfs_tree[parent[(nx, ny)]])
                        dfs_tree[parent[(nx, ny)]].remove((nx, ny))

                    parent[(nx, ny)] = (x, y)
                    # Determine that they have gone out of mask
                    # Odd for out-going, Even for in-going.
                    if (g_mask[nx, ny] - thresh) * (g_mask[x, y] - thresh) < 0 :
                        if status[nx, ny] == 0:
                            border.add((nx, ny))
                        status[nx, ny] = status[x, y] + 1
                    else:
                        status[nx, ny] = status[x, y] 
                    dfs_tree[(x, y)].append((nx, ny))
            
        if i == 0:
            # Force leaves to be sink
            status[x, y] += inp_mask[x, y] % 2
            leaves.add((x, y))
            # Continual of flow
            if inp_mask[x, y] == 1:
                begin = (x, y)
    return {'dfs_tree': dfs_tree, 
            'parent': parent, 
            'cost': cost, 
            'border': border, 
            'leaves': leaves, 
            'status': status, 
            'begin': begin}

In [18]:
# We do post 
def longest_path_branching(tree, start):
    visited = set()
    branches = []
    def dfs(node, path, visited, branches):
        if len(tree[node]) == 0:
            return [node]
        if node in visited:
            return path
        paths = []
        visited.add(node)
        for neighbor in tree[node]:
            if neighbor in visited:
                continue
            paths.append(dfs(neighbor, path[:], visited, branches) + [node])

        if len(paths) == 0:
            max_path = [node]
        else:
            paths = sorted(paths, key= lambda x: -len(x))
            max_path = paths[0]
            branches += paths[1:]
        return max_path
    output = dfs(start, [], visited, branches)
    return  branches + [output] 

# Visualization

In [19]:
np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

# Metrics

In [20]:
def eval(pred, label):
    total = label.shape[0] * label.shape[1]
    intersection = pred * label
    int_count = intersection.sum()
    union = cv2.bitwise_or(pred, label)
    u_count = union.sum()
    dice = 2 * int_count / (int_count + u_count + 1e-6)
    iou = int_count / u_count 
    acc = (pred == label).sum() / total
    recall = int_count / pred.sum()
    f1 = 2 * acc * recall / (acc + recall)
    return {'dice': dice,
            'iou': iou,
            'acc': acc,
            'recall': recall,
            'f1': f1}
    

# Testing

## Paths

In [21]:
%env CUDA_VISIBLE_DEVICES=3

env: CUDA_VISIBLE_DEVICES=3


In [22]:
%pwd

'/work/hpc/potato/airc/notebooks'

In [23]:
image_path = "../data/v2/2015.png"
checkpoint = "/work/hpc/potato/sam/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
model_id =  "facebook/sam2.1-hiera-small"
# y, x, h, w = 2400, 0, 1200, 1200
# pts = np.array([570, 1660])
first_pts = (1900, 2025)
patch_size = [512, 512]

In [24]:
%ls /work/hpc/potato/sam/sam2/checkpoints

[0m[01;32mdownload_ckpts.sh[0m*         sam2.1_hiera_large.pt  sam2.1_hiera_tiny.pt
sam2.1_hiera_base_plus.pt  sam2.1_hiera_small.pt


## First attempt

In [25]:
# # A wrapper for Sam inference
# class SamInferer:
#     def __init__(self, cfg = "", 
#                     ckpt: str = "", 
#                     patch_size = [512, 512],
#                     patience = 5,
#                     min_dis = 10,
#                     alpha=0.1, 
#                     beta=0.5, 
#                     post_act=True, 
#                     min_length=5,
#                     kernel_size=3,
#                     fill_kernel_size=7,
#                     neg_dis=15,
#                     sampling_grid=6,
#                     thresh=0.75,
#                     decay=0.5
#                     ):

In [26]:
param = SamInferer(cfg=model_cfg, 
                    ckpt=checkpoint,
                    roi=[256, 256],
                    root_area=1200,
                    max_roots=2,
                    min_dis=15,
                    patience=4,
                    beta=0.6,
                    alpha=0.2,
                    decay=0.3,
                    patch_size=patch_size,
                    fill_kernel_size=5,
                    thresh=0.75,
                    min_length=50,
                    back_off=5,
                    neg_dis=20,
                    )

In [27]:
iterations = 0

In [28]:
param.read(image_path)

In [29]:
param.add_queue(first_pts, isroot=True)

In [30]:
# plt.imshow(skimage.morphology.skeletonize(output['possible']))
iterations += 1
print(f"Iteration {iterations}")
output = param.iter(debug=True)
print(param.roi)
if output['ret'] is True:
    src, dst = param.roi[:2], param.roi[2:]
    image = param.image[src[0]:dst[0], src[1]:dst[1], ::-1]
    # logit = param.logits[src[0]:dst[0], src[1]:dst[1]] / param.weight[src[0]:dst[0], src[1]:dst[1]]
    logit = output['prob_map']
    p_mask = param.b_mask[src[0]:dst[0], src[1]:dst[1]].copy()
    # p_mask = (p_mask >= np.quantile(p_mask[p_mask > 0], 0.5)).astype(np.uint8)
    # Plotting
    plot = image[:, :, ::-1] * (1- logit[..., None]) / 255 + logit[..., None]
    plot = plot * (1 - p_mask[..., None])
    # plot = plot.astype(np.uint8)
    # plt.imshow(plot)
    # plt.colorbar()
    # 
    cv2.imwrite("/work/hpc/potato/airc/data/viz/closed.jpg", (np.repeat(output['thin'][:, :, None] / output['thin'].max(), 3, axis=-1) * 255).astype(int).astype(np.uint8))
    branches = [np.array(branch) for branch in output['branches']]
    annotation = output['infer']['pts'] + (param.root - param.roi[:2])[::-1]
    a_label = output['infer']['label'][:, None]
    color = [0, 1, 0, 0.5] * a_label + [1, 0, 0, 0.5] * (1 - a_label)
    # plt.scatter(annotation[:, 0], annotation[:, 1], s=50, c=color, marker='*')
    # plot = image.astype(float) / 255
    for pt, c in zip(annotation, color):
        cv2.circle(plot, pt, 5, c, -1)
    cv2.circle(plot, (np.array(output['infer']['prompt']) - src)[::-1], 10, [1, 1, 0, 0.5], -1)
    # plt.title(f"Score: {output['infer']['score']}")
    # pt = output['infer']['prompt']
    # plt.scatter(pt[1], pt[0])
    cmap = plt.get_cmap('hsv')
    for i, branch in enumerate(branches):
        c_val = cmap(float(i) / len(branches))
        # print(c_val)
        # rgb = (int(c_val[0] * 255), int(c_val[1] * 255), int(c_val[2] * 255))
        cv2.polylines(plot, [branch[:, ::-1]], False, c_val, 2)
    # "|
    cv2.imwrite(f"/work/hpc/potato/airc/data/viz/iteration_{iterations}_v4.jpg", (plot * 255)[..., ::-1])
    plt.imshow(plot)
    plt.title(f"Metrics: {output['metrics']}")
    plt.figure(figsize=(10,6))
else:
    print("Hehe")

Iteration 1
(256, 256) (1,) (1, 2)
[512 512] (512, 512) float32
Score 0.8727840781211853
With 1.0 positives and 0.0 negatives
{0: 0, np.float32(0.11066823): 0, np.float32(0.98896164): 1}


ValueError: operands could not be broadcast together with shapes (512,512) (0,0) 

In [None]:
# plt.imshow(cv2.threshold((logit * 255).astype(int).astype(np.uint8), 20, 1, cv2.THRESH_OTSU + cv2.THRESH_BINARY)[1])                                          `
# plt.colorbar()
gain = np.sqrt(np.pi)
std = self.beta.mean() * param.w_kernel.mean() / scipy.ndimage.gaussian_filter(param.weight[src[0]:dst[0], src[1]:dst[1]], 30)
dirac = lambda x, mean, std:  np.exp((x - mean) / (std * 1.44)) / (2.5 * std)
beta = param.beta[src[0]:dst[0], src[1]:dst[1]] / param.weight[src[0]:dst[0], src[1]:dst[1]]

p_logit = skimage.restoration.denoise_wavelet(logit[..., None], channel_axis=-1, rescale_sigma=True)
softmask = dirac(logit ** 1.5, beta, std) * (logit > param.alpha)
s = np.interp(softmask, (softmask.min(), softmask.max()), (0, 1))
pseudo_mask = (s > 0.4).astype(np.uint8)
cand = np.minimum(s * 255, 255).astype(np.uint8)
# softmask = sigmoid(cand)

# softmask[softmask > 255] = 255
mask = cv2.adaptiveThreshold(cand, 1, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 9, -5) * (softmask > param.alpha)
# most_confidence =
# mask = cv2.threshold(cand.astype(np.uint8), 200, 1, cv2.THRESH_OTSU + cv2.THRESH_BINARY)[1]
# radius = 15
# footprint = skimage.morphology.disk(radius)
# local_otsu = skimage.filters.rank.otsu(softmask.astype(np.uint8), footprint)
# mask = (softmask > local_otsu).astype(np.uint8)
# mask = prune(mask, min_size=50)
mask = np.maximum(mask, pseudo_mask)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones([3, 3], dtype=np.uint8), 2)

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=[8, 6])
im = ax1.imshow(image[..., ::-1])
ax1.set_title("Original Logit")
fig.colorbar(im, ax=ax1)

im=ax2.imshow(s)
ax2.set_title("Softmask")
fig.colorbar(im, ax=ax2)

im=ax3.imshow(pseudo_mask)
ax3.set_title("Adaptive Thresholding")
fig.colorbar(im, ax=ax3)

ax4.imshow(mask)
ax4.set_title("Beta Map")
plt.show()

In [None]:
index = np.argmax(softmask)
pos = np.unravel_index(index, softmask.shape, order='F')
print(pos)

In [None]:
pts = output['infer']['pts']
plt.imshow(output['infer']['input'])
plt.axis('off')
# plt.scatter(pts[:, 0], pts[:, 1], marker='s', s=5, color=(1, 0, 1))
# plt.colorbar()

In [None]:
output['infer']['negative'].keys()

In [None]:
x, y = param.root
plt.imshow(param.image[x:x+512, y:y+512])

In [None]:
cnts, _ = cv2.findContours(p_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
output = cv2.drawContours(image.copy(), cnts, -1, (0,0,255), 3)
plt.imshow(output[..., ::-1])
plt.axis('off')

In [None]:
mask = param.label[src[0]:dst[0], src[1]:dst[1], None] 
pred = param.b_mask[src[0]:dst[0], src[1]:dst[1], None]
# Red for FP, Blue for FN
# fp = pred * (1 - mask) * np.array([0, 0, .2])[None, None, :]
# fn = mask * (1 - pred)  * np.array([.2, 0, 0])[None, None, :]
# masks = pred
# save_img = image * (1 - masks) + 255 * masks


In [None]:
# print(pred.shape, image.shape)
cv2.imwrite("/work/hpc/potato/airc/data/viz/image.jpg", param.image[..., ::-1])

In [None]:
cv2.imwrite("/work/hpc/potato/airc/data/viz/seg_image.jpg", image)
# cv2.imwrite("/work/hpc/potato/airc/data/viz/seg_pred.jpg", pred * 255)
# cv2.imwrite("/work/hpc/potato/airc/data/viz/seg_label.jpg", mask * 255)

In [None]:
# param.b_mask[src[0]:dst[0], src[1]:dst[1]] = mask

In [None]:
hsv_image = cv2.cvtColor(param.image, cv2.COLOR_BGR2HSV)
unsharp = cv2.bilateralFilter(hsv_image,9,50,75)
hsv_image = 2 * hsv_image - unsharp 

In [None]:
# mask = logit > 0.2
# kernel = np.ones([3, 3], dtype=np.uint8)
# skeleton = skimage.morphology.skeletonize(param.label)
# plt.imshow(param.image[..., ::-1])
# plt.imshow(image[..., ::-1])
plt.imshow(hsv_image[src[0]:dst[0], src[1]:dst[1], 2])
# keypoints = np.argmax(logit, axis=1)
# print(keypoints.shape,  logit.shape, logit.max())
# pts = [logit[i, pt] for i, pt in enumerate(keypoints)]
# chosen = np.argmax(pts)
# pt = keypoints[chosen]
# print((chosen, pt) + src, logit[chosen, pt], src)
# plt.colorbar()
# plt.imshow(hsv_image[src[0]:dst[0], src[1]:dst[1], 0] )

In [None]:
param.graph_root = (453, 1488)

In [None]:
bgr_image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
plt.imshow(bgr_image[src[0]:dst[0], src[1]:dst[1]])

In [None]:
logit = param.logits / param.weight
param.b_mask = (logit ** 1.5 > param.beta).astype(np.uint8)

In [None]:
# x = torch.from_numpy(param.a_mask[src[0]:dst[0], src[1]:dst[1]].copy()).float().cuda()
x = torch.from_numpy(logit).float().cuda()
canvas = torch.from_numpy(output['canvas']).float().cuda()
canvas /= canvas.max()
dist = geo.GSF2d(canvas[None, None, :], x[None, None, :], theta=1., v=8, lamb=.8, iter = 4).cpu().numpy()[0, 0]
# plt.imshow(dist > np.quantile(dist[dist > 0], 0.2))
plt.imshow(dist)
plt.colorbar()
plt.title(f"{ np.quantile(dist[dist > dist.min()], 0.2)}")

In [None]:
# plt.imshow(skimage.morphology.skeletonize(output['possible']))

src, dst = param.root, param.root + param.patch_size
image = param.image[src[0]:dst[0], src[1]:dst[1]]
logit = param.logits[src[0]:dst[0], src[1]:dst[1]] / param.weight[src[0]:dst[0], src[1]:dst[1]]
# Plotting
plot = image[:, :, ::-1] * (1- logit[..., None]) / 255 + logit[..., None]
plot = plot * (1 - param.b_mask[src[0]:dst[0], src[1]:dst[1]][..., None])
plt.imshow(plot)
# plt.colorbar()
cv2.imwrite("/work/hpc/potato/airc/data/viz/closed.jpg", (np.repeat(output['thin'][:, :, None] / output['thin'].max(), 3, axis=-1) * 255).astype(int).astype(np.uint8))
branches = [np.array(branch) for branch in output['branches']]
prompt = output['infer']['prompt'] - param.root
annotation = output['infer']['pts']
a_label = output['infer']['label'][:, None]
color = [0, 0, 1, 0.5] * a_label + [1, 0, 0, 0.5] * (1 - a_label)
plt.title(f"Score: {output['infer']['score']}")
plt.scatter(annotation[:, 0], annotation[:, 1], s=50, c=color, marker='*')
plt.scatter(prompt[1], prompt[0], s=100, c=[1, 1, 0, 1], marker='X')
# pt = output['infer']['prompt']
# plt.scatter(pt[1], pt[0])
for branch in branches:
    plt.plot(branch[:, 1], branch[:, 0])

In [None]:
p = 0.98
score = np.quantile(prob_map, p)
prob_map = sigmoid(output['infer']['logit'])
plt.imshow(prob_map > score)
plt.title(f"{p}-quantile {score}")

In [None]:
param.queue

In [None]:
dist = param.pos[:, None, :] - param.pos[None, :, :]
loss = np.linalg.norm(dist, axis=2)

In [None]:
loss

In [None]:
print(param.pos)

In [None]:
plt.imshow(logit ** 2)
plt.title(f"Score {output['infer']['score']}")
plt.colorbar()

In [None]:
plt.imshow(logit > 0.458)

In [None]:
plt.imshow(logit > param.alpha)

In [None]:
print(logit.min(), logit.max())

In [None]:
ret = logit ** 0.3
plt.imshow(image[:, :, ::-1] * (1- ret[..., None]) / 255 + ret[..., None])

plt.scatter(pt[1], pt[0], s=0.5, c=(1, 1, 1))

In [None]:
output = param.iter()
src, dst = param.root, param.root + param.patch_size
image = param.image[src[0]:dst[0], src[1]:dst[1]]
logit = param.logits[src[0]:dst[0], src[1]:dst[1]] / param.weight[src[0]:dst[0], src[1]:dst[1]]
# Plotting
plot = image[:, :, ::-1] * (1- logit[..., None]) / 255 + logit[..., None]
plot = plot * (1 - param.b_mask[src[0]:dst[0], src[1]:dst[1]][..., None])
plt.imshow(plot)
# plt.colorbar()
cv2.imwrite("/work/hpc/potato/airc/data/viz/closed.jpg", (np.repeat(output['thin'][:, :, None] / output['thin'].max(), 3, axis=-1) * 255).astype(int).astype(np.uint8))
branches = [np.array(branch) for branch in output['branches']]
pt = output['infer']['prompt'] - param.root
plt.scatter(pt[1], pt[0])
for branch in branches: 
    print(branch.shape[0])
    plt.plot(branch[:, 1], branch[:, 0])

In [None]:
plt.imshow(mask)
plt.title(f"Score: {score}")

In [None]:
prob = sigmoid(logit)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
possible = cv2.morphologyEx(np.where(prob > param.alpha, 1, 0).astype(int).astype(np.uint8), cv2.MORPH_CLOSE, kernel, 1)
possible = scipy.ndimage.binary_fill_holes(possible)
possible = getLargestCC(possible).astype(int).astype(np.uint8)
plt.imshow(possible)

In [None]:
thin, thin_w = morpholgy_thinning(possible, return_weight=True)

In [None]:
plt.imshow(thin_w * (thin_w < thin_w.max() * 0.75))
plt.title(f"Max:{thin_w.max()} ")
plt.colorbar()

In [None]:
cv2.getStructuringElement(cv2.MORPH_CROSS, (5, 5))
stable = ((thin_w > 0) & (thin_w < thin_w.max() * 0.75)).astype(int).astype(np.uint8)
closed = cv2.morphologyEx(stable, cv2.MORPH_DILATE, kernel, 1)
cv2.imwrite("/work/hpc/potato/airc/data/viz/closed.jpg", np.repeat(closed[:, :, None], 3, axis=-1) * 255)
plt.imshow(closed)

In [None]:
skeleton = skimage.morphology.skeletonize(closed)
plt.imshow(skeleton)
main_branch = getLargestCC(skeleton).astype(int).astype(np.uint8)
plt.imshow(main_branch)

In [None]:
w_map = (main_branch * prob) / score
w_map[mask * main_branch == 1] = 1
cv2.imwrite("/work/hpc/potato/airc/data/viz/output.jpg", (np.repeat(w_map[:, :, None], 3, axis=2) * 255).astype(int).astype(np.uint8))
plt.imshow(w_map)
plt.colorbar()

Logit map is same shape to mask prompt, which is half the size of image input

In [None]:
# Always start at stable mask
w_pts = np.array(np.where(w_map == 1)).T
print(w_pts.shape)
dist = np.linalg.norm(w_pts - res['pts'][-1] / 2, axis=1)
nn = w_pts[np.argmin(dist)]
print(nn, res['pts'][-1] / 2)

In [None]:
dfs_res = dfs_tree(w_map.copy(), res['inp_mask'], tuple(nn), alpha=0.01, thresh=0.8)

In [None]:
path, branches = longest_path_branching(dfs_res['dfs_tree'], tuple(nn))

In [None]:
path = np.array(path)
branches = [np.array(branch) for branch in branches]

In [None]:
leaves = np.array(list(dfs_res['leaves']))
cost_map = dfs_res['cost']
cost_map[cost_map > 100] = 0
plt.plot(path[:, 1], path[:, 0])
for branch in branches: 
    plt.plot(branch[:, 1], branch[:, 0])
plt.imshow(cost_map ** 0.3)
plt.scatter(leaves[:, 1], leaves[:, 0], s=0.2)
plt.colorbar()

In [None]:
accepted_branch = [branches[i] for i in range(len(branches)) if len(branches[i]) > 10]

In [None]:
accepted_branch

In [None]:
len(accepted_branch[0])