In [1]:
import cv2
import numpy as np

class Edge:
    def __init__(self, to, capacity, rev):
        self.to = to
        self.capacity = capacity
        self.rev = rev

class Graph:
    def __init__(self, N):
        self.size = N
        self.graph = [[] for _ in range(N)]

    def add_edge(self, fr, to, capacity):
        if capacity <= 0:
            return
        if not (0 <= fr < self.size) or not (0 <= to < self.size):
            raise IndexError(f"Edge from {fr} to {to} out of bounds (size {self.size})")
        fwd = Edge(to, capacity, len(self.graph[to]))
        rev = Edge(fr, 0, len(self.graph[fr]))
        self.graph[fr].append(fwd)
        self.graph[to].append(rev)

In [2]:
class Dinic:
    def __init__(self, graph, s, t):
        self.graph = graph
        self.s = s
        self.t = t
        self.level = [0] * graph.size
        self.progress = [0] * graph.size

    def bfs(self):
        self.level = [-1] * self.graph.size
        queue = [self.s]
        self.level[self.s] = 0
        for u in queue:
            for e in self.graph.graph[u]:
                if e.capacity > 0 and self.level[e.to] < 0:
                    self.level[e.to] = self.level[u] + 1
                    queue.append(e.to)
        return self.level[self.t] >= 0

    def dfs(self, u, flow):
        if u == self.t:
            return flow
        for i in range(self.progress[u], len(self.graph.graph[u])):
            e = self.graph.graph[u][i]
            if e.capacity > 0 and self.level[u] < self.level[e.to]:
                pushed = self.dfs(e.to, min(flow, e.capacity))
                if pushed > 0:
                    e.capacity -= pushed
                    self.graph.graph[e.to][e.rev].capacity += pushed
                    return pushed
            self.progress[u] += 1
        return 0

    def max_flow(self, eps=1e-6):
        flow = 0
        INF = float('inf')
        while self.bfs():
            self.progress = [0] * self.graph.size
            pushed = self.dfs(self.s, INF)
            while pushed > eps:
                flow += pushed
                pushed = self.dfs(self.s, INF)
        return flow

In [3]:
class BinaryGraphCutAlpha:
    def __init__(self, image, lam=1.0, sigma=1.0):
        self.img = image.astype(np.float32)
        self.gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY) if self.img.ndim == 3 else self.img.copy()
        self.gray = cv2.normalize(self.gray, None, 0, 255, cv2.NORM_MINMAX)
        self.h, self.w = self.gray.shape
        self.N = self.h * self.w + 2
        self.S = self.h * self.w
        self.T = self.h * self.w + 1
        self.lam = lam
        self.sigma = sigma
        self._compute_beta()
        self._build_nlinks()
        self.fg_seeds = []
        self.bg_seeds = []
        self.labeling = np.zeros(self.h * self.w, dtype=np.int32)

    def _idx(self, y, x):
        return y * self.w + x

    def _compute_beta(self):
        diffs = []
        for y in range(self.h):
            for x in range(self.w):
                if x+1 < self.w:
                    diffs.append((self.gray[y,x] - self.gray[y,x+1])**2)
                if y+1 < self.h:
                    diffs.append((self.gray[y,x] - self.gray[y+1,x])**2)
        self.beta = 1.0 / (2 * np.mean(diffs) + 1e-10)

    def _build_nlinks(self):
        self.nlinks = []
        for y in range(self.h):
            for x in range(self.w):
                u = self._idx(y, x)
                Ip = float(self.gray[y, x])
                for dy, dx in [(0,1),(1,0),(-1,0),(0,-1)]:
                    ny, nx = y+dy, x+dx
                    if 0 <= ny < self.h and 0 <= nx < self.w:
                        v = self._idx(ny, nx)
                        Iq = float(self.gray[ny, nx])
                        wgt = self.lam * np.exp(-self.beta * (Ip - Iq)**2 / (self.sigma**2))
                        self.nlinks.append((u, v, wgt))

    def set_seeds(self, fg_seeds, bg_seeds):
        self.fg_seeds = list(fg_seeds)
        self.bg_seeds = list(bg_seeds)
        for (y,x) in self.fg_seeds:
            self.labeling[self._idx(y,x)] = 1
        for (y,x) in self.bg_seeds:
            self.labeling[self._idx(y,x)] = 0

    def add_seeds(self, new_fg, new_bg):
        for (y,x) in new_fg:
            if (y,x) not in self.fg_seeds:
                self.fg_seeds.append((y,x))
                self.labeling[self._idx(y,x)] = 1
        for (y,x) in new_bg:
            if (y,x) not in self.bg_seeds:
                self.bg_seeds.append((y,x))
                self.labeling[self._idx(y,x)] = 0

    def _unary_costs(self):
        fg_vals = [self.gray[y,x] for (y,x) in self.fg_seeds]
        bg_vals = [self.gray[y,x] for (y,x) in self.bg_seeds]
        mu_fg = np.mean(fg_vals) if fg_vals else np.mean(self.gray)
        mu_bg = np.mean(bg_vals) if bg_vals else np.mean(self.gray)
        var_fg = np.var(fg_vals) + 1e-6
        var_bg = np.var(bg_vals) + 1e-6
        self.D0 = np.zeros(self.h*self.w, np.float32)
        self.D1 = np.zeros(self.h*self.w, np.float32)
        for p in range(self.h*self.w):
            y, x = divmod(p, self.w)
            I = float(self.gray[y,x])
            self.D0[p] = ((I - mu_bg)**2)/(2*var_bg) + 0.5*np.log(2*np.pi*var_bg)
            self.D1[p] = ((I - mu_fg)**2)/(2*var_fg) + 0.5*np.log(2*np.pi*var_fg)

    def _build_graph_alpha(self, alpha):
        G = Graph(self.N)
        for u, v, wgt in self.nlinks:
            G.add_edge(u, v, wgt)
            G.add_edge(v, u, wgt)
        for p in range(self.h*self.w):
            D0p, D1p = self.D0[p], self.D1[p]
            if alpha == 1:
                G.add_edge(self.S, p, D0p)
                G.add_edge(p, self.T, D1p)
                if p in [self._idx(y,x) for (y,x) in self.fg_seeds]:
                    G.add_edge(self.S, p, 1e9)
            else:
                G.add_edge(self.S, p, D1p)
                G.add_edge(p, self.T, D0p)
                if p in [self._idx(y,x) for (y,x) in self.bg_seeds]:
                    G.add_edge(self.S, p, 1e9)
        return G

    def alpha_expansion(self, max_iter=10):
        self._unary_costs()
        for _ in range(max_iter):
            changed = False
            for alpha in [1, 0]:
                G = self._build_graph_alpha(alpha)
                solver = Dinic(G, self.S, self.T)
                solver.max_flow()
                visited = [False]*self.N
                stack = [self.S]
                visited[self.S] = True
                while stack:
                    u = stack.pop()
                    for e in G.graph[u]:
                        if e.capacity > 0 and not visited[e.to]:
                            visited[e.to] = True
                            stack.append(e.to)
                for p in range(self.h*self.w):
                    if visited[p] and self.labeling[p] != alpha:
                        self.labeling[p] = alpha
                        changed = True
            if not changed:
                break
        return self.labeling.reshape((self.h, self.w))

In [8]:
def interactive_segment(image_path):
    img_orig = cv2.imread(image_path)
    # resize to 128Ã—128
    img = cv2.resize(img_orig, (128, 128), cv2.INTER_AREA)
    bcut = BinaryGraphCutAlpha(img, lam=1.0, sigma=10.0)
    fg_seeds, bg_seeds = [], []
    display = img.copy()
    drawing = False
    mode = 'fg'
    h, w = img.shape[:2]

    def draw(event, x, y, flags, param):
        nonlocal drawing, mode
        if event == cv2.EVENT_LBUTTONDOWN:
            drawing = True
        elif event == cv2.EVENT_LBUTTONUP:
            drawing = False
        elif event == cv2.EVENT_MOUSEMOVE and drawing:
            if 0 <= x < w and 0 <= y < h:
                if mode == 'fg':
                    fg_seeds.append((y, x))
                    cv2.circle(display, (x, y), 3, (0, 255, 0), -1)
                else:
                    bg_seeds.append((y, x))
                    cv2.circle(display, (x, y), 3, (0, 0, 255), -1)

    # enforce fixed window size to match image dimensions
    cv2.namedWindow('seg', cv2.WINDOW_NORMAL)
    cv2.resizeWindow('seg', w, h)
    cv2.setMouseCallback('seg', draw)

    while True:
        cv2.imshow('seg', display)
        key = cv2.waitKey(1) & 0xFF
        if key == ord('f'):
            mode = 'fg'
        elif key == ord('b'):
            mode = 'bg'
        elif key == ord('r'):
            fg_seeds.clear()
            bg_seeds.clear()
            display = img.copy()
        elif key == ord('c'):
            print("Performing alpha-expansion segmentation...")
            bcut.set_seeds(fg_seeds, bg_seeds)
            seg = bcut.alpha_expansion(max_iter=5)
            seg8 = (seg * 255).astype(np.uint8)
            cv2.imshow('segmentation', seg8)
            mask_3ch = cv2.cvtColor(seg8, cv2.COLOR_GRAY2BGR)
            fg_extracted = cv2.bitwise_and(img, mask_3ch)
            cv2.imshow('foreground', fg_extracted)
        elif key == ord('a'):
            print("Adding new seeds and updating segmentation...")
            bcut.add_seeds(fg_seeds, bg_seeds)
            seg = bcut.alpha_expansion(max_iter=5)
            seg8 = (seg * 255).astype(np.uint8)
            cv2.imshow('segmentation', seg8)
            mask_3ch = cv2.cvtColor(seg8, cv2.COLOR_GRAY2BGR)
            fg_extracted = cv2.bitwise_and(img, mask_3ch)
            cv2.imshow('foreground', fg_extracted)
        elif key == ord('q'):
            break
    cv2.destroyAllWindows()

In [10]:
image_path = "chest_ct.jpg"  
interactive_segment(image_path)

Performing alpha-expansion segmentation...
