In [None]:
import os

import cv2
import numpy as np

from fastai.vision import *
from fastai.vision.all import *
from glob import glob
from matplotlib import pyplot as plt

In [None]:
def find_compontents(img):
  '''Returns the positions of each image component'''
  grey = 255 - cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  _, thresh = cv2.threshold(grey, 0, 255, cv2.THRESH_OTSU)
  stats = cv2.connectedComponentsWithStats(thresh)[2]
  return stats

In [None]:
def annotations(img, n):
  '''Returns an array of annotations (in format:
  [[x1, y1, x2, y2], digit]) given an image of a number.'''
  stats = find_compontents(img)
  bg_label = np.argmax(stats[:, 4])

  bboxes = [data[1][:4] for data in enumerate(stats) if data[0] != bg_label]
  bboxes = np.array(sorted(bboxes, key=lambda k: k[1])).astype(int)

  new_format = bboxes

  new_format[:, 0] = bboxes[:, 0]
  new_format[:, 1] = bboxes[:, 1]
  new_format[:, 2] = bboxes[:, 0] + bboxes[:, 2]
  new_format[:, 3] = bboxes[:, 1] + bboxes[:, 3]

  return np.array(list(zip(new_format, str(n))))

In [None]:
def filename_to_info(filename):
  name = filename.split('/')[-1].split('.')[0]
  n = int(name.split('.')[-1])

  return n, name

In [None]:
def create_dataset_from_images(name, in_path, out_path=None, fname_fn=None, size=None):
  '''Creates an object detection dataset given a directory and a filename function'''
  if not out_path:
    out_path = f'{in_path}/{name}'
  if not fname_fn:
    fname_fn = filename_to_info
  if not size:
    size = (250, 150)

  images, i = get_image_files(in_path), 0

  for filename in images:
    try:
      n, name = fname_fn(filename)

      img = cv2.imread(filename)
      img = cv2.resize(img, (250, 150))
      data = annotations(img, n)

      cv2.imwrite(f'{out_path}/{name}.png', img)
      np.save(f'{out_path}/{name}.npy', data)

    except Exception as e:
      i += 1
      print(f'Skipped: {filename} due to {e}')

  print(f'Total images skipped: {i}')