In [6]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pickle
%matplotlib inline

In [2]:
sealion_types = ["adult_males", 
    "subadult_males",
    "adult_females",
    "juveniles",
    "pups"]

In [27]:
def count_inside(rect, dots, margin=12):
    x_start, y_start, x_end, y_end = rect
    counts = []
    for sealion in sealion_types:
        count = 0
        cdots = dots[sealion] if sealion in dots else []
        for x, y in cdots:
            if x_start + margin <= x < x_end - margin and \
            y_start + margin <= y < y_end - margin:
                count += 1
        counts.append(count)
    return counts

In [28]:
def generate_train_counts(im_train, im_mask, dots, block_shape=(224, 224), stride=(160, 160)):
    im_train_masked = cv2.bitwise_and(im_train, im_train, mask=im_mask)
    
    threshold_mask = 0.9  # if less than 90% of the pixels on the patch are useful (not masked), then we don't return it
    
    h, w, c = im_train.shape
    ys = list(range(0, h - block_shape[0], stride[0]))
    if ys[-1] != h - block_shape[0] - 1:
        ys.append(h - block_shape[0] - 1)
        
    xs = list(range(0, w - block_shape[1], stride[1]))
    if xs[-1] != w - block_shape[1] - 1:
        xs.append(w - block_shape[1] - 1)
        
    for y in ys:
        y_end = y + block_shape[0]
        for x in xs:
            x_end = x + block_shape[1]
            im_mask_patch = im_mask[y:y_end, x:x_end]
            n_white_pixels = np.sum(im_mask_patch)/255
            percentage_inside = n_white_pixels/(block_shape[0] * block_shape[1])
            if percentage_inside > threshold_mask:
                im_train_patch = im_train_masked[y:y_end, x:x_end, :]
                counts = count_inside((x, y, x_end, y_end), dots)
                yield im_train_patch, counts

In [29]:
im_train = cv2.imread("../data/sealion/Train/101.jpg")
im_mask = cv2.imread("../data/sealion/TrainMask/101.jpg", cv2.IMREAD_GRAYSCALE)
with open("../data/sealion/TrainDots/101.pkl", "rb") as ifile:
    dots = pickle.load(ifile)

In [30]:
for i, (patch, counts) in enumerate(generate_train_counts(im_train, im_mask, dots)):
    if 5 < np.sum(counts) < 10:
        cv2.imwrite("../data/gg/{p}.jpg".format(p=i), patch)