# Sudoku Detector

Author: [Egor Makarenko](https://github.com/egormkn)

Neural sudoku detection and extraction tool based on Tensorflow [image segmentation tutorial](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb).

In [1]:
try:
  from google.colab import drive
  drive.mount('/content/drive')
except ImportError as e:
  print(f'{e}, skipping Google Drive mount')

No module named 'google.colab', skipping Google Drive mount


In [2]:
%cd drive/My\ Drive/sudoku

[Errno 2] No such file or directory: 'drive/My Drive/sudoku'
/home/egor/Documents/sudoku


In [None]:
!ls

In [None]:
import cv2
import glob
import math
import random
import skimage
import sklearn
import itertools
import collections
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from pathlib import Path
from os import listdir, path
from skimage import draw, io
from more_itertools import grouper
from IPython.display import clear_output
from sklearn.preprocessing import minmax_scale
from PIL import Image, ImageColor, ImageDraw, ImageFont, ImageOps

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(f'Physical GPUs: {len(gpus)}, Logical GPUs: {len(logical_gpus)}')
  except RuntimeError as e:
    print(e)

In [None]:
def show_images(images, figsize=(15, 15)):
  plt.figure(figsize=figsize)
  for index, (title, image) in enumerate(images.items()):
    if len(image.shape) == 3 and image.shape[2] == 1:
      image = np.squeeze(image, axis=2)
    plt.subplot(1, len(images), index + 1)
    plt.title(title)
    plt.imshow(image, cmap='gray')
    plt.axis('off')
  plt.show()

### Sudoku puzzle

In [None]:
class Sudoku:

  known = 0
  puzzle = None
  solution = None
    
  def pluck(self, puzzle, n=0):
    def canBeA(puz, i, j, c):
      i, j = int(i), int(j)
      v = puz[c // 9][c % 9]
      if puz[i][j] == v: return True
      if puz[i][j] in range(1, 10): return False
      for m in range(9):
        if not (m == c // 9 and j == c % 9) and puz[m][j] == v: return False
        if not (i == c // 9 and m == c % 9) and puz[i][m] == v: return False
        if not ((i // 3) * 3 + m // 3 == c // 9 and (j // 3) * 3 + m % 3 == c % 9) \
          and puz[(i // 3) * 3 + m // 3][(j // 3) * 3 + m % 3] == v:
          return False
      return True

    cells     = set(range(81))
    cellsleft = cells.copy()
    while len(cells) > n and len(cellsleft):
      cell = random.choice(list(cellsleft))
      cellsleft.discard(cell)
      row = col = square = False

      for i in range(9):
        if i != cell / 9:
          if canBeA(puzzle, i, cell%9, cell): row = True
        if i != cell % 9:
          if canBeA(puzzle, cell/9, i, cell): col = True
        if not (((cell // 9) // 3) * 3 + i // 3 == cell // 9 and ((cell // 9) % 3) * 3 + i % 3 == cell % 9):
          if canBeA(puzzle, ((cell // 9) // 3) * 3 + i // 3, ((cell // 9) % 3) * 3 + i % 3, cell): square = True

      if row and col and square:
        continue
      else:
        puzzle[cell // 9][cell % 9] = 0
        cells.discard(cell)

    return len(cells), puzzle
  
  def generate_puzzle(self, n=40, iterations=100):
    results = {}
    for i in range(iterations):
      puzzle = [row.copy() for row in self.solution]
      cells, puzzle = self.pluck(puzzle, n)
      results.setdefault(cells, []).append(puzzle)
      if cells <= n: break
    least_known = min(results.keys())
    return least_known, results[least_known][0]
  
  def generate_solution(self):
    while True:
      solution = [[0] * 9 for _ in range(9)]
      error    = False
      rows     = [set(range(1, 10)) for _ in range(9)]
      columns  = [set(range(1, 10)) for _ in range(9)]
      squares  = [set(range(1, 10)) for _ in range(9)]
      for i in range(9):
        for j in range(9):
          row, column, square = rows[i], columns[j], squares[(i // 3) * 3 + j // 3]
          choices = row.intersection(column).intersection(square)
          error = not(choices)
          if error: break
          choice  = random.choice(list(choices))
          solution[i][j] = choice
          row.discard(choice)
          column.discard(choice)
          square.discard(choice)
        if error: break
      if not error: return solution
  
  def __init__(self, known=40, iterations=100):
    self.solution = self.generate_solution()
    self.known, self.puzzle = self.generate_puzzle(known, iterations)

  def image(self, size=(512, 512), padding=12, bgcolor='white', linecolor='green', 
           linewidth=(1, 3), font=None, fontsize=30, fontcolor='red'):
    image = Image.new('RGBA', size, color='#00000000')
    font  = ImageFont.truetype(font, fontsize) if font else ImageFont.load_default()
    draw  = ImageDraw.Draw(image)
    
    width, height = size
    
    cell_width, cell_height = (np.asarray(size) - 2 * padding) / 9
    
    draw.rectangle((padding, padding, width - padding, height - padding), fill=bgcolor)
    
    for i in range(10):
      x_pos = padding + cell_width * i
      y_pos = padding + cell_height * i
      draw.line((padding, y_pos, width - padding, y_pos),  fill=linecolor, width=linewidth[i % 3 == 0])
      draw.line((x_pos, padding, x_pos, height - padding), fill=linecolor, width=linewidth[i % 3 == 0])
    
    for x in range(9):
      for y in range(9):
        if not self.puzzle[y][x]: continue
        x_pos = padding + cell_width * x
        y_pos = padding + cell_height * y
        text = str(self.puzzle[y][x])
        text_width, text_height = draw.textsize(text, font)
        text_pos = (x_pos + (cell_width - text_width) / 2, y_pos + (cell_height - text_height) / 2)
        draw.text(text_pos, text, fontcolor, font)
    
    return image


In [None]:
sudoku = Sudoku(20)
print(f'Known values: {sudoku.known}')
sudoku.puzzle

In [None]:
sudoku.solution

In [None]:
sudoku.image(font='fonts/arial.ttf')

### Sudoku generator

In [None]:
def sudoku_generator(include_data=True, **kwargs):
  fonts = glob.glob(path.join('fonts', '*.ttf'))
  backgrounds = glob.glob(path.join('backgrounds', '*.jpg'))
  width, height = 512, 512
  padding = 12
  shift = lambda x: int(random.triangular(0, (x - 2 * padding), 0) / 9)
  
  while True:
    bgcolor = tuple(int(random.triangular(220, 255, 255))  for _ in range(3))
    linecolor = tuple(random.randint(0, 50) for _ in range(3))
    linewidth = random.randint(1, 5)
    linewidth = (linewidth, linewidth + random.randint(0, 5))
    font = random.choice(fonts)
    fontsize = random.randint(25, 35)
    fontcolor = f'hsl({random.randint(0, 360)}, 100%, {int(random.triangular(0, 50, 10))}%)'
    fill = random.random() * 0.8
    bg_image = random.choice(backgrounds)
    bg_image = Image.open(bg_image).convert('RGBA')
    bg_width, bg_height = bg_image.size

    sudoku = Sudoku(known=random.randint(40, 60))
    data = sudoku.puzzle
    image = sudoku.image(size=(width, height), padding=padding, bgcolor=bgcolor, linecolor=linecolor, 
                         linewidth=linewidth, font=font, fontcolor=fontcolor, fontsize=fontsize, **kwargs)

    polygon = np.asarray([
      [         padding,         padding], 
      [         padding, width - padding], 
      [height - padding, width - padding], 
      [height - padding,         padding]
    ])

    distortion = np.asarray([
      [ shift(height),  shift(width)],
      [ shift(height), -shift(width)],
      [-shift(height), -shift(width)],
      [-shift(height),  shift(width)]
    ])
    
    coeffs = cv2.getPerspectiveTransform(
      np.float32(polygon + distortion), 
      np.float32(polygon)
    ).flatten()[:8]
    
    bg_coeffs = cv2.getPerspectiveTransform(
      np.float32((polygon + distortion) / [height, width] * [bg_height, bg_width]), 
      np.float32([[0, 0], [0, bg_width - 1], [bg_height - 1, bg_width - 1], [bg_height - 1, 0]])
    ).flatten()[:8]
    
    extend = random.randint(0, max(width, height))
    crop_w, crop_h = random.randint(0, extend), random.randint(0, extend)
    
    image = image.transform(image.size, Image.PERSPECTIVE, coeffs, Image.BICUBIC)
    bg_image = bg_image.transform(bg_image.size, Image.PERSPECTIVE, bg_coeffs, Image.BICUBIC)
    image = image.crop((-crop_w, -crop_h, width + (extend - crop_w), height + (extend - crop_h)))
    image = image.resize((width, height))
    polygon = (polygon + distortion + [crop_w, crop_h]) / [width + extend, height + extend] * [width, height]
    bg_image = bg_image.crop((bg_width * 0.2, bg_height * 0.2, bg_width * 0.8, bg_height * 0.8))
    bg_image = bg_image.resize(image.size)
    
    image = Image.alpha_composite(bg_image, image).convert('RGB')

    image = np.uint8(image)
    polygon = np.flip(polygon, axis=1).astype(np.int32)

    yield (image, polygon, data) if include_data else (image, polygon)


In [None]:
def is_grayscale(image):
  return len(image.shape) < 3 or image.shape[2] == 1

def draw_polygon(image, polygon, color='#FF00FF', radius=None):
  color = ImageColor.getrgb(color)
  if image.dtype == np.float32:
    color = np.asarray(color) / 255.0
  if not radius:
    radius = max(*image.shape[:2]) // 50
  if is_grayscale(image):
    image = np.concatenate((image,)*3, axis=-1)
  else:
    image = np.copy(image)
  rr, cc = skimage.draw.polygon_perimeter(polygon[:, 0], polygon[:, 1], shape=image.shape, clip=True)
  image[rr, cc] = color
  for r, c in polygon:
    rr, cc = skimage.draw.circle(r, c, radius)
    rr = np.clip(rr, 0, image.shape[0] - 1)
    cc = np.clip(cc, 0, image.shape[1] - 1)
    image[rr, cc] = color
  return image

### Sudoku photos
Source: https://github.com/wichtounet/sudoku_dataset

In [None]:
def convert_row(row):
  filepath = path.normpath(path.join('sudoku_dataset', row.filepath))
  polygon = [
    [row.p1_y, row.p1_x], 
    [row.p2_y, row.p2_x], 
    [row.p3_y, row.p3_x], 
    [row.p4_y, row.p4_x]
  ]
  m_y = (row.p1_y + row.p2_y + row.p3_y + row.p4_y) / 4
  m_x = (row.p1_x + row.p2_x + row.p3_x + row.p4_x) / 4
  assert row.p1_y < m_y and row.p1_x < m_x
  assert row.p2_y < m_y and row.p2_x > m_x
  assert row.p3_y > m_y and row.p3_x > m_x
  assert row.p4_y > m_y and row.p4_x < m_x
  return pd.Series([filepath, polygon], index=['filepath', 'polygon'])

df = pd.read_csv(path.join('sudoku_dataset', 'outlines_sorted.csv')).apply(convert_row, axis=1)
df

In [None]:
paths, polygons = df['filepath'].tolist(), df['polygon'].tolist()

### Dataset

In [None]:
@tf.function
def preprocess_github(path, polygon):
  image = tf.io.read_file(path)
  image = tf.io.decode_image(image, channels=3, expand_animations=False, dtype=tf.uint8)
  image.set_shape((None, None, 3))
  return image, polygon

github_dataset = tf.data.Dataset.from_tensor_slices((paths, polygons)).map(preprocess_github, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
github_dataset

In [None]:
@tf.function
def preprocess_generated(image, polygon):
  image.set_shape((None, None, 3))
  polygon.set_shape((4, 2))
  return image, polygon

generated_dataset = tf.data.Dataset.from_generator(sudoku_generator, output_types=(tf.uint8, tf.int32), args=[False]).map(preprocess_generated, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
generated_dataset

In [None]:
for (i1, p1), (i2, p2) in tf.data.Dataset.zip((github_dataset, generated_dataset)).take(1):
  show_images({
    'Github dataset': draw_polygon(i1, p1),
    'Generated dataset': draw_polygon(i2, p2)
  })

### Data preprocessing

In [None]:
APPLY_THRESHOLD = True
IMAGE_SIZE = 224
IMAGE_CHANNELS = 3

def preprocess_image(image):
  colored = not(is_grayscale(image))
  image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if colored else np.squeeze(image, axis=2)
  image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 21, 3)
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) if colored else np.expand_dims(image, axis=2)
  return image

def preprocess_polygon(image, polygon):
  height, width = image.shape[:2]
  scale = IMAGE_SIZE / max(height, width)
  shift = abs(height - width) / 2 * scale * np.array([0, 1] if height > width else [1, 0])
  return (polygon * scale + shift).astype(np.int32)

@tf.function
def preprocess(image, polygon):
  polygon = tf.numpy_function(preprocess_polygon, [image, polygon], Tout=tf.int32)
  image = tf.cast(tf.image.resize_with_pad(image, IMAGE_SIZE, IMAGE_SIZE), tf.uint8)
  
  if IMAGE_CHANNELS == 1:
    image = tf.image.rgb_to_grayscale(image)
    if APPLY_THRESHOLD:
      image = tf.numpy_function(preprocess_image, [image], Tout=tf.uint8)
  
  image.set_shape((IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS))
  polygon.set_shape((4, 2))
  
  return image, polygon

polygon_dataset = tf.data.experimental.sample_from_datasets([github_dataset, generated_dataset]).map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
polygon_dataset

In [None]:
show_images({ 
  f'Image {i}': draw_polygon(image, polygon) for i, (image, polygon) in enumerate(polygon_dataset.take(3)) 
})

### Convert polygons to masks

In [None]:
def create_mask(image, polygon):
  shape = (*image.shape[:2], 1)
  mask = np.zeros(shape, dtype=np.bool)
  rr, cc = skimage.draw.polygon(polygon[:, 0], polygon[:, 1], shape)
  mask[rr, cc] = True
  return mask

def draw_mask(image, mask):
  return np.where(mask, image, 0)

@tf.function
def apply_mask(image, polygon):
  mask = tf.numpy_function(create_mask, [image, polygon], Tout=tf.bool)
  mask.set_shape(((IMAGE_SIZE, IMAGE_SIZE, 1)))
  return image, mask

mask_dataset = polygon_dataset.map(apply_mask, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
mask_dataset

In [None]:
for image, mask in mask_dataset.take(2):
  show_images({ 
    'Image': image,
    'Image + Mask': draw_mask(image, mask)  
  })

### Preparing train and test datasets

In [None]:
TRAIN_SIZE = 5000
TEST_SIZE = 100
BATCH_SIZE = 50
BUFFER_SIZE = 50
USE_MASKS = True

def draw_output(image, output):
  return draw_polygon(image, output) if output.shape == (4, 2) else draw_mask(image, output)

def augment_mask(image, mask):
  mask.set_shape((IMAGE_SIZE, IMAGE_SIZE, 1))
  return image, mask

def augment_polygon(image, polygon):
  polygon.set_shape((4, 2))
  return image, polygon

@tf.function
def augment(image, output):
  image = tf.image.random_jpeg_quality(image, 70, 100)
  image, output = augment_mask(image, output) if USE_MASKS else augment_polygon(image, output)
  image.set_shape((IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS))
  return image, output

dataset = mask_dataset if USE_MASKS else polygon_dataset

train_dataset = dataset.take(TRAIN_SIZE).map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = dataset.skip(TRAIN_SIZE).take(TEST_SIZE)

for image, output in train_dataset.take(5):
  sample_image, sample_output = image, output
  show_images({ 'Image': image, 'Output': draw_output(image, output) })

train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat().prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
print(f'Train dataset: {train_dataset}')
print(f'Test dataset: {test_dataset}')

### Model structure

In [None]:
"""
Upsamples an input.
Conv2DTranspose => Batchnorm => Dropout => Relu
Args:
filters: number of filters
size: filter size
norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
apply_dropout: If True, adds the dropout layer
Returns:
Upsample Sequential Model
"""
def pix2pix_upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  elif norm_type.lower() == 'instancenorm':
    result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
mobilenet = tf.keras.applications.MobileNetV2(input_shape=[IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS], include_top=False)

In [None]:
def unet_model(output_channels):
  # Use the activations of these layers
  layer_names = [
    'block_1_expand_relu',   # 112x112
    'block_3_expand_relu',   # 56x56
    'block_6_expand_relu',   # 28x28
    'block_13_expand_relu',  # 14x14
    'block_16_project',      # 7x7
  ]
  layers = [mobilenet.get_layer(name).output for name in layer_names]

  # Create the feature extraction model
  down_stack = tf.keras.Model(inputs=mobilenet.input, outputs=layers)
  down_stack.trainable = False

  # Create the upsampling model
  up_stack = [
      pix2pix_upsample(512, 3),  # 7x7   -> 14x14
      pix2pix_upsample(256, 3),  # 14x14 -> 28x28
      pix2pix_upsample(128, 3),  # 28x28 -> 56x56
      pix2pix_upsample(64, 3),   # 56x56 -> 112x112
  ]
  
  inputs = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], dtype = tf.uint8)
  x = tf.cast(inputs, tf.float32)
  x = tf.keras.applications.mobilenet.preprocess_input(x)
  
  # Downsampling through the model
  skips = down_stack(x)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

mask_model = unet_model(OUTPUT_CHANNELS)
mask_model.compile(optimizer='adam',
                   loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                   metrics=['accuracy'])
mask_model.summary()
tf.keras.utils.plot_model(mask_model, show_shapes=True)

In [None]:
def custom_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3], dtype = tf.uint8)
  x = tf.cast(inputs, tf.float32)
  x = tf.keras.applications.mobilenet.preprocess_input(x)

  # Downsampling through the model
  skips = down_stack(x)
  x = skips[-1]
  
  x = tf.keras.layers.Conv2D(10, (1, 1), activation='relu')(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(128)(x)
  x = tf.keras.layers.Dense(8)(x)
  x = tf.keras.layers.Reshape((4, 2))(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

polygon_model = unet_model(OUTPUT_CHANNELS)
polygon_model.compile(optimizer='adam',
                   loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                   metrics=['accuracy'])
polygon_model.summary()
tf.keras.utils.plot_model(polygon_model, show_shapes=True)

In [None]:
model = mask_model if USE_MASKS else polygon_model

### Model fitting

In [None]:
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, output in dataset.take(num):
      prediction = model.predict(image)
      show_images({ 'Image': image, 'Mask': output, 'Prediction': draw_output(image, prediction[0]) })
  else:
    prediction = model.predict(sample_image[tf.newaxis, ...])
    show_images({ 'Image': sample_image, 'Mask': sample_output, 'Prediction': draw_output(image, prediction[0]) })

In [None]:
show_predictions()

In [None]:
class InfoCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    print(f'\nEpoch {epoch+1} ended. Preparing prediction')
    # clear_output(wait=True)
    show_predictions()
    
class EnoughCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('accuracy') > 0.9999):
      print('\nReached 99% accuracy so stopping training!')
      self.model.stop_training = True

info_callback = InfoCallback()
enough_callback = EnoughCallback()

model_weights_dir = path.join('checkpoints', 'mask' if USE_MASKS else 'polygon')
model_weights = path.join(model_weights_dir, 'model-{epoch:04d}.ckpt')
save_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_weights, save_weights_only=True, verbose=1)


In [None]:
EPOCHS = 30
STEPS_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE
VALIDATION_STEPS = TEST_SIZE // BATCH_SIZE

print(f'Train size: {TRAIN_SIZE}, test size: {TEST_SIZE}')
print(f'BATCH_SIZE = {BATCH_SIZE}, VALIDATION_STEPS = {VALIDATION_STEPS}, STEPS_PER_EPOCH = {STEPS_PER_EPOCH}')

In [None]:
if path.exists(model_weights_dir):
  checkpoint = tf.train.latest_checkpoint(model_weights_dir)
  model.load_weights(checkpoint)
  print(f'Loaded checkpoint: {checkpoint}')

In [None]:
model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[save_callback, info_callback, enough_callback])

In [None]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
accuracy = model_history.history['accuracy']
val_accuracy = model_history.history['val_accuracy']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, accuracy, 'r', label='Training accuracy')
plt.plot(epochs, val_accuracy, 'bo', label='Validation accuracy')
plt.title('Training and Validation accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [None]:
show_predictions(test_dataset, 3)

In [None]:
from google.colab import files
from sklearn.preprocessing import minmax_scale

uploaded = files.upload()

for filename, content in uploaded.items():
  image = tf.constant(content)
  image = tf.io.decode_image(image, channels=3, expand_animations=False)
  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  image = tf.image.resize_with_pad(image, IMAGE_SIZE, IMAGE_SIZE, antialias=True)

  pred_mask = model.predict(image[tf.newaxis, ...])
  show_images({'Image': image, 'Prediction': pred_mask[0]})

  mask = minmax_scale(pred_mask[0].squeeze(), feature_range=(0, 255))

  plt.imshow(mask, cmap='gray')

In [None]:
sample_prediction = model.predict(sample_image)
sample_prediction = tf.keras.preprocessing.image.array_to_img(sample_prediction[0])
sample_prediction.save('prediction.png')
plt.imshow(sample_prediction, cmap='gray')