In [None]:
import keras_cv
import tensorflow as tf
import matplotlib.pyplot as plt

# Load image

In [None]:
image_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/segmentation_input.jpg"
image_path = tf.keras.utils.get_file("street.jpg", origin=image_url)
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
original_shape = tf.shape(image)[:2]

# Resize to 512x512 (as required by model)

In [None]:
input_image = tf.image.resize(image, (512, 512))
input_image = tf.cast(input_image, tf.float32) / 255.0  # Normalize to [0, 1]
input_image = tf.expand_dims(input_image, axis=0)       # Add batch dimension

# Load model

In [None]:
model = keras_cv.models.DeepLabV3Plus(
    backbone="resnet50",  # You can also use mobilenetv2, resnet101, etc.
    num_classes=21,       # For Pascal VOC
    weights="pascalvoc"   # Load pre-trained weights
)

# Run prediction

In [None]:
predictions = model.predict(input_image)
segmentation = tf.argmax(predictions[0], axis=-1)
segmentation = tf.image.resize(tf.expand_dims(segmentation, -1), original_shape, method="nearest")
segmentation = tf.squeeze(segmentation).numpy().astype("uint8")

# Display original + segmented

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image.numpy().astype("uint8"))
plt.title("Original Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(segmentation, cmap="jet")  # Simple colormap for visualization
plt.title("Segmented Output")
plt.axis("off")
plt.tight_layout()
plt.show()

# Some Extra codes

In [None]:
# import tensorflow as tf
# import tensorflow_hub as hub
# import numpy as np
# import matplotlib.pyplot as plt
 
# # Load the DeepLabV3+ model from TensorFlow Hub
# model = hub.load("https://tfhub.dev/tensorflow/deeplabv3/1")
 
# # Load and preprocess the input image
# image_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/segmentation_input.jpg"
# image_path = tf.keras.utils.get_file("street.jpg", origin=image_url)
 
# img_raw = tf.io.read_file(image_path)                           # Read image bytes
# img = tf.image.decode_jpeg(img_raw)                             # Decode JPEG
# original_size = tf.shape(img)[:2]                               # Save original size for resizing output
# img = tf.image.convert_image_dtype(img, tf.uint8)               # Ensure image is uint8
# img_resized = tf.image.resize(img, [513, 513])                  # Resize to model's expected input
# img_tensor = tf.expand_dims(img_resized, 0)                     # Add batch dimension
 
# # Run the model
# result = model(img_tensor)
# segmentation_map = tf.argmax(result['semantic_pred'], axis=3)[0]  # Get class prediction map
 
# # Define a color map (Pascal VOC colormap — simplified)
# def create_pascal_label_colormap():
#     colormap = np.zeros((256, 3), dtype=int)
#     for i in range(256):
#         r, g, b = 0, 0, 0
#         c = i
#         for j in range(8):
#             r |= (c & 1) << (7 - j)
#             g |= ((c >> 1) & 1) << (7 - j)
#             b |= ((c >> 2) & 1) << (7 - j)
#             c >>= 3
#         colormap[i] = [r, g, b]
#     return colormap
 
# # Convert class map to color map
# colormap = create_pascal_label_colormap()
# segmentation_color = tf.gather(colormap, segmentation_map)
 
# # Resize segmentation map back to original image size
# segmentation_color = tf.image.resize(segmentation_color, original_size, method='nearest')
 
# # Plot the original image and segmentation result
# plt.figure(figsize=(10, 5))
# plt.subplot(1, 2, 1)
# plt.imshow(tf.image.decode_jpeg(img_raw))
# plt.title("Original Image")
# plt.axis('off')
 
# plt.subplot(1, 2, 2)
# plt.imshow(tf.cast(segmentation_color, tf.uint8))
# plt.title("Segmented Image")
# plt.axis('off')
# plt.tight_layout()
# plt.show()