# Setup Section

In [None]:
!tar -xf /content/screw.tar

In [None]:
!pip install --upgrade matplotlib

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
image_path = "/content/gdrive/MyDrive/colab_output/images"

In [None]:
import os
import pathlib
import pprint as pp
from glob import glob, iglob
from PIL import Image, ImageFilter
from typing import List, Union, Tuple, BinaryIO

from keras.api._v2.keras.layers import Conv3DTranspose

pp.PrettyPrinter(indent=4)
import pickle

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential, layers, losses, metrics
from tensorflow.keras.models import Model
from tensorflow.keras.utils import image_dataset_from_directory

# Data init

In [None]:
train_dir = "/content/train/"
train_dir = pathlib.Path(train_dir)
test_dir = "/content/test/"
test_dir = pathlib.Path(test_dir)

In [None]:
batch_size, img_height, img_width = (320, 256, 256)
x_train = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                      seed=142,
                                                      image_size=(img_height, img_width),
                                                      color_mode='grayscale',
                                                      batch_size=None,
                                                      shuffle=False,
                                                      )

In [None]:
# test_batch = 160
x_test = tf.keras.utils.image_dataset_from_directory(test_dir,
                                                     seed=142,
                                                     image_size=(img_height, img_width),
                                                     batch_size=None,
                                                     shuffle=False,
                                                     color_mode='grayscale',
                                                    )

In [None]:
train_classes = x_train.class_names
train_classes

In [None]:
x_2_train_np = np.array(list(map(lambda x : x[0], x_train.as_numpy_iterator())), 'float16')
x_2_train_final = x_2_train_np.astype('float16') / 255
x_2_train_final = x_2_train_final.reshape(320, 256, 256, 1)
x_2_train_final.shape

In [None]:
x_2_test_np = np.array(list(map(lambda x : x[0], x_test.as_numpy_iterator())), 'float16')
x_2_test_final = x_2_test_np.astype('float16') / 255
x_2_test_final = x_2_test_final.reshape(160, 256, 256, 1)
x_2_test_final.shape

# Custom Metrics

In [None]:
class ssim_metric(tf.keras.metrics.Metric):
  def __init__(self, name="ssim_metric", **kwargs):
    super().__init__(name=name, **kwargs)
    self.ssim = self.add_weight(name="ssim", initializer="zeros")
    self.total_samples = self.add_weight(name="total_samples",
                                         initializer="zeros",
                                         dtype="int32")
  def update_state(self, y_true, y_pred, sample_weight=None):
    ssim_metric = tf.image.ssim(y_true, y_pred, max_val=1)
    ssim_metric = ssim[0].numpy()
    self.ssim.assign_add(ssim_loss)

  def result(self):
    return tf.subtract(self.ssim, 1)

  def reset_state(self):
    self.ssim.assign(0.)

In [None]:
class ssim_loss(tf.keras.losses.Loss):
  @tf.function
  def call(self, y_true, y_pred):
    ssim_loss = tf.image.ssim(y_true, y_pred, max_val=1)
    return tf.subtract(1., ssim_loss)


# Model Definition

## Metrics, Loss

In [None]:
def conv2d_block(x, filters, kernel_size=3, reps:int=2, pooling:bool=False, **kwargs):
  residual = x
  options = {}
  if kwargs:
    options.update(**kwargs)
  for rep in range(reps):
    if not rep:
      options.update({'strides': 2})
    else:
      options['strides'] = 1
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.SeparableConv2D(filters, kernel_size, padding="same", use_bias=False, **options)(x)
  
  if pooling:
    x = layers.MaxPooling2D(kernel_size, strides=2, padding="same")(x)
    # residual = layers.Conv2D(filters, 1, strides=2)(residual)
  # elif filters != residual.shape[-1]:
  #   residual = layers.Conv2D(filters, 1)(residual)
  
  # x = layers.add([x, residual])
  return x

In [None]:
def conv2d_T_block(x, filters, kernel_size=3, reps:int=2, **kwargs):
  residual = x
  options = {'strides': 2}
  if kwargs:
    options.update(**kwargs)
  for rep in range(reps):
    if not rep:
      options.update({'strides': 2})
    else:
      options['strides'] = 1
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2DTranspose(filters, kernel_size, padding="same", use_bias=False, **options)(x)
  
  # residual = layers.Conv2D(filters, 1)(residual)
  
  # x = layers.add([x, residual])
  return x

## Model Definition (clf_model init)

In [None]:
def get_model(input_shape, filter_blocks:List, rescaling:bool=False, **kwargs):
  inputs = tf.keras.Input(shape=input_shape)
  if rescaling:
    x = layers.Rescaling(1./255)(inputs)
    x = layers.Conv2D(filter_blocks[0], kernel_size=5, padding='same', use_bias=False)(x)
  else:
    x = layers.Conv2D(filter_blocks[0], kernel_size=5, padding='same', use_bias=False)(inputs)
  
  for block in filter_blocks:
    x = conv2d_block(x, block, **kwargs)
  
  r_filter_blocks = reversed(filter_blocks)
  for t_block in r_filter_blocks:
    x = conv2d_T_block(x, t_block, **kwargs)

  outputs = layers.Conv2D(1, 3, activation='sigmoid', padding='same')(x)

  model = tf.keras.Model(inputs, outputs)
  return model

In [None]:
filters = [32, 64, 128, 256, 512]
input_shape = x_2_train_final.shape[1:-1] + (1,)
print(input_shape)
clf_model = get_model(input_shape=input_shape, filter_blocks=filters)
clf_model.summary()

# Model Compile, Fit

In [None]:
clf_model.compile(optimizer="adam", loss=ssim_loss(), metrics=["MeanSquaredError", "Poisson"])

In [None]:
history = clf_model.fit(x_2_train_final, x_2_train_final, 
                    epochs=28,
                    batch_size=32,
                    # callbacks = callbacks,
                    validation_data=(x_2_test_final, x_2_test_final))

# History Graphing

In [None]:
num = len(history.history.keys()) / 2
metric = (key for key in history.history.keys())
fig, ax = plt.subplots(2, 2, figsize=(8, 1.5*num))
for j in range(2):
  for i in range(int(num/2)):
    this_metric = next(metric)
    ax[i, j].plot(history.history[this_metric])
    ax[i, j].plot(history.history[f"val_{this_metric}"])
    ax[i, j].set_title(f'{this_metric}'.title())
    ax[i, j].set(xlabel="Epochs", ylabel="Loss")
plt.savefig(f"{image_path}/ssim_output_graph_{run_count:03d}.png")
fig.tight_layout()
plt.show()

# Inspect Model Image Output 

In [None]:
def img_gen(path_list, rand_samp:bool=False):
  if rand_samp:
    ind = np.random.randint(0, len(path_list) - 1)
  path = path_list[ind]
  label = os.path.dirname(path)
  label = label.split('/')[-1]
  dense = Image.open(path)
  dense = np.asarray(dense.resize((256, 256)), dtype=np.float32)
  dense = dense / 255
  dense = dense.reshape((1, 256, 256, 1))
  yield dense, label


In [None]:
img_c = 9

In [None]:
fig, ax = plt.subplots(2, 10, figsize=(15, 4))
for i in range(10):
  img_paths = glob("/content/train/**/*.png")
  img_in, label = next(img_gen(img_paths, rand_samp=True))
  img_out = clf_model.predict(img_in)
  error = tf.image.ssim(tf.sqrt(img_in**2),tf.sqrt(img_out**2), max_val=1)
  print(error[0].numpy())
  ax1 = ax[0, i]
  ax1.imshow(img_in[0,:,:,0], cmap='gray')
  ax1.axis('off')
  ax2 = ax[1, i]
  ax2.imshow(img_out[0,:,:,0], cmap='gray')
  ax1.set_title(label)
  ax2.set_title(f"{1 - error[0].numpy():.4f}")
  ax2.axis('off')
fig.tight_layout()
plt.savefig(f"{image_path}/ssim_output_{img_c:03d}.png")
img_c += 1
plt.show()


# Define Prediction Funcions

In [None]:
def get_img(path):
  dense = Image.open(path)
  dense = np.asarray(dense.resize((256, 256)), dtype=np.float32)
  dense = dense / 255
  dense = dense.reshape((1, 256, 256, 1))
  return dense, label


def get_distributions(model):
  result = []
  for path in iglob("/content/test/**/*.png"):
    img, label = get_img(path)
    pred = model.predict(img)
    mse = np.abs(np.mean(img**2 - pred**2))
    ssim = tf.image.ssim(img, pred, max_val=1)
    ssim = 1 - ssim[0].numpy()
    result.append([ssim, mse, label])
  return result

# Get Model Predictions

In [None]:
loss_distributions = get_distributions(clf_model)

In [None]:
losses = np.asarray(loss_distributions)
print(losses.shape)
losses = losses[:,:2].astype(np.float32)
losses_df.info()


## Graphing Prediction Loss

In [None]:
fig , ax = plt.subplots(1, 2, figsize=(12, 5))
hist_x, hist_y = np.histogram(losses[:40, 0], bins=12)
ax[0].stairs(hist_x, hist_y, hatch=('...'))
hist_w, hist_v = np.histogram(losses[40:, 0], bins=20)
ax[0].stairs(hist_w, hist_v, hatch=('...'))
hist_x, hist_y = np.histogram(losses[:40, 1], bins=12)
ax[1].stairs(hist_x, hist_y, hatch=('...'))
hist_w, hist_v = np.histogram(losses[40:, 1], bins=20)
ax[1].stairs(hist_w, hist_v, hatch=('...'))

# plt.stairs(norm_y, norm_x, hatch=('...'), label="Normal")
