In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, datasets

#  Load MNIST
(x_train, _), (x_test, _) = datasets.mnist.load_data()
x_train = x_train[..., None]/255.0
x_test = x_test[..., None]/255.0

#  Make blurry inputs (simulate "bad images")
x_train_blur = tf.image.resize(x_train, (14,14))
x_train_blur = tf.image.resize(x_train_blur, (28,28))

x_test_blur = tf.image.resize(x_test, (14,14))
x_test_blur = tf.image.resize(x_test_blur, (28,28))

# Simple U-Net
def simple_unet(input_shape):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(x)
    x = layers.Conv2DTranspose(32, 3, strides=2, padding='same', activation='relu')(x)
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(x)
    return models.Model(inputs, outputs)

model = simple_unet((28,28,1))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

#  Train the U-Net
model.fit(x_train_blur, x_train, epochs=5, batch_size=64, validation_split=0.1)


predicted = model.predict(x_test_blur[:5])


import matplotlib.pyplot as plt
for i in range(5):
    plt.figure(figsize=(8,2))
    plt.subplot(1,3,1)
    plt.title("Blurry Input")
    plt.imshow(x_test_blur[i].squeeze(), cmap='gray')
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.title("Ground Truth")
    plt.imshow(x_test[i].squeeze(), cmap='gray')
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.title("Predicted Mask")
    plt.imshow(predicted[i].squeeze(), cmap='gray')
    plt.axis('off')
    plt.show()