In [None]:
import matplotlib.pyplot as plt
import PIL
import numpy as np

import torch

In [None]:
############### Matplotlib config
plt.rc('image', cmap='gray')
plt.rc('grid', linewidth=0)
plt.rc('xtick', top=False, bottom=False, labelsize='large')
plt.rc('ytick', left=False, right=False, labelsize='large')
plt.rc('axes', facecolor='F8F8F8', titlesize="large", edgecolor='white')
plt.rc('text', color='a8151a')
plt.rc('figure', facecolor='F0F0F0')# Matplotlib fonts
MATPLOTLIB_FONT_DIR = os.path.join(os.path.dirname(plt.__file__), "mpl-data/fonts/ttf")
################################################################################

In [None]:
def draw_ONE_bounding_box_on_image(image, ymin:int, xmin:int, ymax:int, xmax:int, 
                               color:str='red', thickness:int=1, display_str:bool=None):
    """
    Adds a bounding box to an image.
    Bounding box coordinates are absolute.

    Args:
    image: a PIL.Image object.
    ymin: ymin of bounding box.
    xmin: xmin of bounding box.
    ymax: ymax of bounding box.
    xmax: xmax of bounding box.
    color: color to draw bounding box. Default is red.
    thickness: line thickness. Default value is 1.
    display_str_list: string to display in box
    """
    draw = PIL.ImageDraw.Draw(image)
    im_width, im_height = image.size

    left, right, top, bottom = xmin, xmax, ymin, ymax

    draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color)


def draw_bounding_boxes_on_image(image, boxes_dict:dict, color_list:list=[]):
  """
  Draws bounding boxes on image.

  Args:
    image: PIL.Image.
    boxes_dict: dict
      Contains true and predicted bboxes coordinates ??(ymin, xmin, ymax, xmax)??. 
      The coordinates are absolute.
    color: list, default is empty
      Color to draw bounding box.
                           
  Raises:
    ValueError: if boxes is not a [N, 4] array
  """
  for key in boxes_dict.keys():
    if key == "true_bbox":
      color = color_list[0]
      thickness = 2
    else : 
      color = color_list[1]
      thickness = 1

    boxes = np.asarray(boxes_dict[key])
    boxes_shape = boxes.shape
    if not boxes_shape:
      return
    if len(boxes_shape) != 2 or boxes_shape[1] != 4:
      raise ValueError('Input must be of size [N, 4]')
    
    for i in range(boxes_shape[0]):
      draw_ONE_bounding_box_on_image(image, 
                                    boxes[i, 1], boxes[i, 0], 
                                    boxes[i, 3], boxes[i, 2], 
                                    color=color, thickness=thickness)

def draw_bounding_boxes_on_image_array(image:np.ndarray, boxes:dict, color:list=[]):
  """
  Creates a new RGB image of size 75x75 and calls the
  draw_bounding_boxes_on_image function.

  Args:
    image: numpy nd array of shape (N,75,75)
    boxes: dict
        Dict containing predicted and true bboxex (absolute coords)
    color: list
        Color to draw bounding box.

  Return :
    rgbimg: numpy nd array of shape (N,75,75)
  """
  image_pil = PIL.Image.fromarray(image)
  rgbimg = PIL.Image.new("RGBA", image_pil.size)
  rgbimg.paste(image_pil)
  draw_bounding_boxes_on_image(rgbimg, boxes, color)
  return np.array(rgbimg)

In [None]:
################################################################################
def display_digits_with_boxes(digits, predictions, labels, pred_bboxes, bboxes, title, nb_sample=10):
  """Utility to display a row of digits with their predictions.

  Args:
    digits : np.ndarray of shape (N,1,75,75)
        Raw image with normalized pixel values (from 0 to 1)
    predictions : np.ndarray of shape (N,S,S,10)
        Predicted label with the same shape as labels
    labels : np.ndarray of shape (N,10)
        One-hot label vectors of the digits (from 0 to 9)
    pred_bboxes : np.ndarray of shape (N,S,S,5)
        Predicted bboxes locations (relative to cell coordinates)
    bboxes : np.ndarray of shape (N, 4)
        Ground true bboxe locations (relative to cell coordinates)
  """
  iou_threshold = 0.6
  nb_sample = 10
  indexes = np.random.choice(len(predictions), size=nb_sample)
  
  n_digits = digits[indexes].numpy()
  # Rescale pixel values to un-normed values (from 0 -black- to 255 -white-)
  n_digits = n_digits * 255.0
  n_digits = n_digits.reshape(nb_sample, 75, 75)
  
  n_predictions = predictions[indexes]
  # Argmax of one-hot vectors. Shape : (N,S,S,10) -> (N,S,S)
  n_predictions = torch.argmax(torch.softmax(n_predictions, dim=-1), dim=-1).numpy()
  
  ### shape : (N, S, S, 5)
  n_pred_bboxes = pred_bboxes[indexes]

  ### shape : (N, 4)
  n_bboxes_rel = bboxes[indexes]
  n_bboxes = relative2absolute(torch.as_tensor(n_bboxes_rel)).numpy()
  # n_bboxes = n_bboxes_rel/75

  # Set plot config
  fig = plt.figure(figsize=(20, 4))
  plt.title(title)
  plt.yticks([])
  plt.xticks([])
  
  bboxes_to_plot = {"true_bbox":[], "pred_bbox":[]}
  for i in range(nb_sample):
    bboxes_to_plot["pred_bbox"] = []
    bboxes_to_plot["true_bbox"] = []
    
    for cell_i in range(6):
      for cell_j in range(6):
        n_pred_bboxes_ij = n_pred_bboxes[:, cell_i, cell_j, :4]
        n_pred_bboxes_ij = relative2absolute(n_pred_bboxes_ij).numpy()
        bboxes_to_plot["pred_bbox"].append(n_pred_bboxes_ij[i])
    
    bboxes_to_plot["true_bbox"].append(n_bboxes[i])
    
    ax = fig.add_subplot(1, nb_sample, i+1)
    img_to_draw = draw_bounding_boxes_on_image_array(image=n_digits[i], boxes = bboxes_to_plot, 
        color=["white", "red"])

    plt.xticks([])
    plt.yticks([])

    plt.imshow(img_to_draw)