## Importing the required functions

In [5]:
from src.grabcut import *
from src.event_handler import *

## Defining the Main Loop

In [6]:
def run(filename: str, num_iters: int = 5, n_components: int = 5, refine: bool = True, should_output: bool = False, gamma: float = 50, iseightconn:bool = True):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_FG,            # drawing strokes initialized to mark foreground
    }

    img = cv2.imread(filename)
    img2 = img.copy()                                
    mask = np.zeros(img.shape[:2], dtype = np.uint8) # mask is a binary array with : 0 - background pixels
                                                     #                               1 - foreground pixels 
    output = np.zeros(img.shape, np.uint8)           # output image to be shown

    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    # cv2.namedWindow('Segmented output')
    
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)
    
    gc = GrabCut(0.3, gamma=gamma)
    gc.calculateBeta(img2)
    Es = []

    while(1):
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        cv2.imshow('Segmented image', output)
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        # key bindings
        if k == 27:
            # esc to exit
            cv2.destroyAllWindows()
            mask_final = gc.obtainFinalMask(mask)
            mask2 = np.where((mask_final == 1), 255, 0).astype('uint8')
            output = cv2.bitwise_and(img2, img2, mask = mask2)
            return output, img, Es
            break
        
        elif k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            # reset everything
            FLAGS['RECT'] = (0, 0, 1, 1)
            FLAGS['DRAW_STROKE'] = False
            FLAGS['DRAW_RECT'] = False
            FLAGS['rect_or_mask'] = -1
            FLAGS['rect_over'] = False
            FLAGS['value'] = DRAW_FG
            img = img2.copy()
            mask = np.zeros(img.shape[:2], dtype = np.uint8) 
            EventObj.image = img
            EventObj.mask = mask
            output = np.zeros(img.shape, np.uint8)
        
        elif k == 13: 
            # Press carriage return to initiate segmentation
        
            EventObj.flags = FLAGS
            should_fit = EventObj.should_fit

            leftW, upleftW, upW, uprightW = gc.calcNweights(img2)

            if should_fit:
                Es.clear()
                for i in tqdm.tqdm(range(num_iters)):
                    bgdGMM, fgdGMM = gc.assign_and_learn_GMM(img2, mask, n_components=n_components)
                    graph= gc.constructGCGraph(img2, mask, bgdGMM, fgdGMM, 9 * gc.gamma, leftW, upleftW, upW, uprightW, iseightconn=iseightconn)
                    mask = gc.estimateSegmentation(graph, mask)
                    E = gc.calcEs(img, mask, bgdGMM, fgdGMM, leftW, upleftW, upW, uprightW)
                    Es.append(E)
                EventObj.should_fit = False
                # gc.should_fit = False
            else:
                if refine:
                    bgdGMM, fgdGMM = gc.assign_and_learn_GMM(img2, mask, n_components=n_components)
                graph = gc.constructGCGraph(img2, mask, bgdGMM, fgdGMM, 9 * gc.gamma, leftW, upleftW, upW, uprightW, iseightconn=iseightconn)
                mask = gc.estimateSegmentation(graph, mask)
                E = gc.calcEs(img, mask, bgdGMM, fgdGMM, leftW, upleftW, upW, uprightW)
                Es.append(E)
            
            if should_output:
                plt.plot(Es)
                plt.show()
            mask_final = gc.obtainFinalMask(mask)
            print("Done")

            mask2 = np.where((mask_final == 1), 255, 0).astype('uint8')
            output = cv2.bitwise_and(img2, img2, mask = mask2)


## Running a test run

In [7]:
filename = 'llama.jpg'   # Path to image file
output, img, Es = run(filename)
cv2.destroyAllWindows()

100%|██████████| 5/5 [00:58<00:00, 11.63s/it]


Done
