In [None]:
import cv2
from matplotlib import pyplot as plt

import numpy as np

from fastai.vision import get_image_files

from google.colab.patches import cv2_imshow

from glob import glob

import os

In [None]:
def segment_image(img):
  '''Returns the segmented image and position data'''
  grey = 255 - cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  _, thresh = cv2.threshold(grey, 0, 255, cv2.THRESH_OTSU)
  _, labeled, stats, _ = cv2.connectedComponentsWithStats(thresh)

  return labeled, stats

In [None]:
def get_digit_mask(img, n):
  '''Returns an image mask of a given number from an image'''
  labeled, stats = segment_image(img)
  bg_label = np.argmax(stats[:, 4])
  
  masks = [[labeled==i, left] for i, left in enumerate(stats[:, 0]) if i != bg_label]
  masks = sorted(masks, key=lambda k: k[1])

  final_mask = (labeled==bg_label) * 10

  for i, mask in zip(str(n), masks):
    final_mask += mask[0] * int(i)

  return final_mask

In [None]:
def filename_to_info(filename):
  name = filename.split('/')[-1].split('.')[0]
  n = int(name.split('.')[-1])

  return n, name

In [None]:
def create_dataset_from_images(name, in_path, out_path=None, fname_fn=None, size=None):
  '''Creates a segmentation dataset given a directory and a filename function'''
  if not out_path:
    out_path = f'{in_path}/{name}'
  if not fname_fn:
    fname_fn = filename_to_info
  if not size:
    size = (250, 150)

  images, i = get_image_files(in_path), 0

  for filename in images:
    try:
      n, name = fname_fn(filename)

      img = cv2.imread(filename)
      img = cv2.resize(img, size)

      mask = get_digit_mask(img, n)

      cv2.imwrite(f'{out_path}/{name}.png', img)
      cv2.imwrite(f'{out_path}/{name}_mask.png', mask)

    except Exception as e:
      i += 1
      print(f'Skipped: {filename} due to {e}')

  print(f'Total images skipped: {i}')