In [1]:
# import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import cv2
import math
import torch
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import numpy as np
import matplotlib.colorbar as colorbar

import requests
from io import BytesIO
from PIL import Image
import numpy as np
from maskrcnn_benchmark.utils import cv2_util
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

from moviepy.editor import VideoFileClip

# this makes our figures bigger
pylab.rcParams['figure.figsize'] = 20*1.5, 12*1.5

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.structures.keypoint import keypoints_to_heat_map
from maskrcnn_benchmark.modeling.roi_heads.keypoint_head.loss import project_keypoints_to_heatmap
from predictor import COCODemo

In [2]:
def load(img_path):
    """
    Given an url of an image, downloads the image and
    returns a PIL image
    """
    
    pil_image = Image.open(img_path).convert("RGB")
    # convert to BGR format
    image = np.array(pil_image)[:, :, [2, 1, 0]]
    return image

def load_video(video_path,frame_id=[0]):
    """
    Given an url of an image, downloads the image and
    returns a PIL image
    """
    
    video_clip = VideoFileClip(video_path)
    if frame_id is None:
        nframes = int(video_clip.fps * video_clip.duration)
        frame_id = range(nframes)
    images = []
    for i in frame_id:
        image = video_clip.get_frame(i/video_clip.fps)
        # convert to BGR format
        image = np.array(image)[:, :, [2, 1, 0]]
        images += [image]
    video_clip.close()
    
    return images

def imshow(img):
    plt.gca().invert_yaxis()
    plt.imshow(img[:, :, [2, 1, 0]])
    plt.axis("off")
    
def sigmoid(x):
  return 1 / (1 + math.exp(-x))
# define vectorized sigmoid
sigmoid_v = np.vectorize(sigmoid)

# num keypoints per animal, heatmap
def calculate_peaks(numparts, heatmap_avg):
    score = np.ones((numparts,)) * 0.000001
    all_peaks = []
    peak_counter = 0
    if len(score) < numparts:
        score = score[:numparts]
        print("score", score)
        ##logger.ERROR(‘Not enough scores provided for number of parts’)
        # return
    # threshold_detection = params[‘thre1’]
    # tic_localmax=time.time()
    for part in range(numparts):
        map_ori = heatmap_avg[part, :, :]
        map = map_ori
        map_left = np.zeros(map.shape)
        map_left[1:, :] = map[:-1, :]
        map_right = np.zeros(map.shape)
        map_right[:-1, :] = map[1:, :]
        map_up = np.zeros(map.shape)
        map_up[:, 1:] = map[:, :-1]
        map_down = np.zeros(map.shape)
        map_down[:, :-1] = map[:, 1:]
        #peaks_binary = np.logical_and(np.logical_and(np.logical_and(map >= map_left, map >= map_right),
                                                     #np.logical_and(map >= map_up, map >= map_down)), map > score[part])
        peaks_binary = (sigmoid_v(map_ori) > .2)
        #print("pb shap", np.shape(peaks_binary))
        peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]))  # note reverse
        peaks_with_score_and_id = [x + (map_ori[x[1], x[0]], i + peak_counter,) for i, x in
                                   enumerate(peaks)]  # if x[0]>0 and x[1]>0 ]
        all_peaks.append(peaks_with_score_and_id)
        peak_counter += len(peaks)
    return all_peaks



In [3]:
params = {
    'text.latex.preamble': ['\\usepackage{gensymb}'],
    'image.origin': 'lower',
    'image.interpolation': 'nearest',
    'image.cmap': 'jet',
    'axes.grid': False,
    'savefig.dpi': 150,  # to adjust notebook inline plot size
    'axes.labelsize': 10, # fontsize for x and y labels (was 10)
    'axes.titlesize': 12,
    'font.size': 12, # was 10
    'legend.fontsize': 10, # was 10
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    #'text.usetex': True,
    'figure.figsize': [20, 12],
    'font.family': 'serif',
}
matplotlib.rcParams.update(params)

In [4]:
def overlay_keypoints(image, predictions):
        #keypoints = predictions.get_field("keypoints")
        kps = predictions#keypoints.keypoints
        scores = kps.new_ones((kps.size(0), kps.size(1)))
        kps = torch.cat((kps[:, :, 0:2], scores[:, :, None]), dim=2).numpy()
        for region in kps:
            kfun = BeeKeypoints
            image = vis_keypoints_others(
                    image,
                    region.transpose((1, 0)),
                    kp_thresh=0,
                    kfun=kfun)
                
        return image


In [5]:
import cv2
import torch
from torchvision import transforms as T

from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark import layers as L
from maskrcnn_benchmark.utils import cv2_util

import numpy as np
import matplotlib.pyplot as plt
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints, BeeKeypoints, FlyKeypoints


def vis_keypoints_others(img, kps, kp_thresh=2, alpha=0.7, kfun=PersonKeypoints):
    """Visualizes keypoints (adapted from vis_one_image).
    kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob).
    """
    img = img.copy()
    dataset_keypoints = kfun.NAMES
    kp_lines = kfun.CONNECTIONS

    # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
    colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

    # Perform the drawing on a copy of the image, to allow for blending.
    kp_mask = np.copy(img)
    
    # Draw the keypoints.
    for l in [0]:#range(len(kp_lines)):
        i1 = kp_lines[l][0]
        i2 = kp_lines[l][1]
        p1 = kps[0, i1], kps[1, i1]
        p2 = kps[0, i2], kps[1, i2]
        #if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
         #   cv2.line(
          #      kp_mask, p1, p2,
           #     color=colors[l], thickness=2, lineType=cv2.LINE_AA)
        if kps[2, i1] > kp_thresh:
            cv2.circle(
                kp_mask, p1,
                radius=16, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)
        #if kps[2, i2] > kp_thresh:
         #   cv2.circle(
          #      kp_mask, p2,
           #     radius=16, color=colors[l], thickness=-1, lineType=cv2.LINE_AA)

    # Blend the keypoints.
    return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0)

def bounding_box(points):
    x_coordinates, y_coordinates = zip(*points)
    return [min(x_coordinates), min(y_coordinates), max(x_coordinates), max(y_coordinates)]

def get_centroid(coord):
    return [(coord[0]+coord[2])/2, (coord[1]+coord[3])/2]

def get_dist(p1, p2):
    dist = math.hypot(p2[0] - p1[0], p2[1] - p1[1])
    return dist


In [6]:
#test_file = 'bee_test.png'
#image = load(test_file)
base = "../tools/datasets/bee/validation/"
#base = "../tools/datasets/bee/train/"
base_val = "../tools/datasets/bee/annotations/validation.json"
#base_val = "../tools/datasets/bee/annotations/train_bee_annotations2018_nondup.json"

import json
with open(base_val) as f:
    data_an = json.load(f)
#print(data_an["annotations"][0])

test_files = []
for p in data_an['images'][:]:
    test_files.append(str(p['file_name']))

print(test_files)
print(len(test_files))
test_images = []
for file in test_files:
    test_images.append(load(base + file))



['000000051501.jpg', '000000051602.jpg', '000000051656.jpg', '000000051770.jpg', '000000052209.jpg', '000000052239.jpg', '000000052248.jpg', '000000052271.jpg', '000000052328.jpg', '000000052361.jpg', '000000052413.jpg', '000000052563.jpg', '000000052583.jpg', '000000052587.jpg', '000000052738.jpg', '000000052799.jpg', '000000052861.jpg', '000000053184.jpg', '000000053247.jpg', '000000053341.jpg', '000000053573.jpg', '000000053614.jpg', '000000053776.jpg', '000000053802.jpg', '000000053827.jpg', '000000053842.jpg', '000000053857.jpg', '000000053866.jpg', '000000053890.jpg', '000000053916.jpg']
30


In [7]:
#MSE sleap
import h5py
filename = '../tools/sleap_bottomup_5.slp'
with h5py.File(filename, 'r') as f:
    # List all groups
    #print('Keys: %s' % f.keys())
    a_group_key = list(f.keys())
    #print('a_group_key:', a_group_key)
    # Get the data
    frames = f['frames'][()]
    instances = f['instances'][()]
    points = f['points'][()]
    pred_points = f['pred_points'][()]




In [8]:
kps_bu = []#np.zeros((30))
for img_idx in range(0,30):
    if img_idx == 0:
        num_before = 0
        num_bees_per_img = frames[0][4]
    else:
        num_before = frames[(img_idx-1)][4]
        num_bees_per_img = frames[img_idx][4]-frames[img_idx][3]
    print("nb", num_before)
    print("num kps", int(num_bees_per_img*5))
    print("range", int(num_before*5+num_bees_per_img*5) - int(num_before*5))
    #kps[img_idx] = np.zeros((num_bees_per_img,5,2))
    kps_bu.append(np.zeros((num_bees_per_img,5,2)))
    counter = 0
    current_idx = int(num_before*5)
    for pred_bee in range(0,num_bees_per_img):
        for kpt in range(0,5):
    #for point in pred_points[int(num_before*5):int(num_before*5+num_bees_per_img*5)]:
            #print("point", point)
            x = pred_points[current_idx][0]
            y = pred_points[current_idx][1]
            kps_bu[img_idx][pred_bee][kpt][0] = x
            kps_bu[img_idx][pred_bee][kpt][1] = y
            current_idx = current_idx + 1
        #while
#print("KPSSSS",kps_bu)


bboxes_bu = []

bees_in_imgs = []
for img_idx in range(0,30):
    if img_idx == 0:
        num_bees_per_img = frames[0][4]
        bees_in_imgs.append(num_bees_per_img)
    else:
        num_bees_per_img = frames[img_idx][4]-frames[img_idx][3]
        bees_in_imgs.append(num_bees_per_img)

for im_idx in range(0,30):
    bb_in_im = np.zeros((bees_in_imgs[im_idx],4))
    for bee in range(0, bees_in_imgs[im_idx]):
        bb_i = bounding_box(kps_bu[im_idx][bee])
        bb_in_im[bee] = bb_i
        #print("kps: ", kps_bu[im_idx][bee])
        #print("box: ", bb_i)
    bboxes_bu.append(bb_in_im)
    
print("BBOXESS: ",bboxes_bu)

nb 0
num kps 105
range 105
nb 21
num kps 70
range 70
nb 35
num kps 50
range 50
nb 45
num kps 60
range 60
nb 57
num kps 55
range 55
nb 68
num kps 50
range 50
nb 78
num kps 70
range 70
nb 92
num kps 85
range 85
nb 109
num kps 85
range 85
nb 126
num kps 75
range 75
nb 141
num kps 50
range 50
nb 151
num kps 50
range 50
nb 161
num kps 80
range 80
nb 177
num kps 65
range 65
nb 190
num kps 110
range 110
nb 212
num kps 90
range 90
nb 230
num kps 100
range 100
nb 250
num kps 80
range 80
nb 266
num kps 80
range 80
nb 282
num kps 55
range 55
nb 293
num kps 85
range 85
nb 310
num kps 95
range 95
nb 329
num kps 80
range 80
nb 345
num kps 65
range 65
nb 358
num kps 90
range 90
nb 376
num kps 100
range 100
nb 396
num kps 90
range 90
nb 414
num kps 80
range 80
nb 430
num kps 70
range 70
nb 444
num kps 50
range 50
BBOXESS:  [array([[ 780.84759521,   18.76969147,  874.33813477,  119.15652466],
       [ 950.36358643,   24.93842506, 1008.75592041,   26.1102829 ],
       [1249.74816895,   85.22297668, 1524

In [42]:
print("KPS !!", kps_bu[11])

KPS !! [[[1357.23669434   62.67543411]
  [1369.65063477   98.62140656]
  [1393.06140137   14.62785244]
  [          nan           nan]
  [          nan           nan]]

 [[1178.40905762  -46.32356644]
  [1178.50109863  193.53616333]
  [1178.38903809   86.04069519]
  [1142.53210449  229.6275177 ]
  [1202.27587891  206.25422668]]

 [[ 228.88000488  -34.41603088]
  [ 205.5124054   145.77064514]
  [ 240.63044739   85.27181244]
  [ 193.30725098  145.59761047]
  [ 241.18821716  181.5308075 ]]

 [[ 481.34341431  181.48057556]
  [ 289.48028564  253.09501648]
  [ 397.50704956  217.18490601]
  [ 265.33605957  229.59033203]
  [ 289.31848145  277.64782715]]

 [[1237.45825195  290.5697937 ]
  [1477.75427246  242.55613708]
  [1393.9753418   254.42765808]
  [1502.20092773  278.70819092]
  [1502.22058105  206.34121704]]

 [[2210.18359375  325.82595825]
  [2258.12060547  373.11312866]
  [2223.03857422  253.41899109]
  [2258.13378906  110.04924774]
  [2234.203125     97.27648926]]

 [[ 481.73999023  481

1 of 8: 100/0/standard 