In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
x_train = x_train / 255.0
x_test = x_test / 255.0

In [4]:
y_train_seg = (x_train.mean(axis=-1) > 0.5).astype(int)
y_test_seg = (x_test.mean(axis=-1) > 0.5).astype(int)

In [5]:
#y_train_seg.shape = (num_samples, height, width, 1)
y_train_seg = y_train_seg[:, :, :, np.newaxis]
y_test_seg = y_test_seg[:, :, :, np.newaxis]

In [9]:
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D,concatenate

In [13]:
def unet_model(input_size=(32, 32, 3)):
    inputs = Input(input_size)
# Downsampling
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    p1 = MaxPooling2D((2, 2))(c1)
    c2 = Conv2D(64, (3, 3), activation='relu', padding='same')(p1)
    p2 = MaxPooling2D((2, 2))(c2)
# Bottleneck
    c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(p2)
# Upsampling
    u1 = UpSampling2D((2, 2))(c3)
    m1 = concatenate([u1, c2])
    c4 = Conv2D(64, (3, 3), activation='relu', padding='same')(m1)
    u2 = UpSampling2D((2, 2))(c4)
    m2 = concatenate([u2, c1])
    c5 = Conv2D(32, (3, 3), activation='relu', padding='same')(m2)
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c5)
    return Model(inputs, outputs)

In [15]:
# Compile the model
model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])

In [None]:
# Train the model
model.fit(x_train, y_train_seg, validation_data=(x_test, y_test_seg),
epochs=10, batch_size=32)

Epoch 1/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m147s[0m 92ms/step - accuracy: 0.9358 - loss: 0.1491 - val_accuracy: 0.9886 - val_loss: 0.0313
Epoch 2/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 91ms/step - accuracy: 0.9881 - loss: 0.0301 - val_accuracy: 0.9931 - val_loss: 0.0194
Epoch 3/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 92ms/step - accuracy: 0.9923 - loss: 0.0196 - val_accuracy: 0.9949 - val_loss: 0.0142
Epoch 4/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 92ms/step - accuracy: 0.9933 - loss: 0.0165 - val_accuracy: 0.9955 - val_loss: 0.0121
Epoch 5/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m423s[0m 271ms/step - accuracy: 0.9951 - loss: 0.0125 - val_accuracy: 0.9967 - val_loss: 0.0096
Epoch 6/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 143ms/step - accuracy: 0.9956 - loss: 0.0113 - val_accuracy: 0.9966 - val_loss: 0.0

In [None]:
import matplotlib.pyplot as plt
pred = model.predict(x_test[:5])
# Display images and masks
for i in range(5):
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(x_test[i])
plt.subplot(1, 3, 2)
plt.title("Ground Truth Mask")
plt.imshow(y_test_seg[i].squeeze(), cmap='gray')
plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(pred[i].squeeze(), cmap='gray')
plt.show()