In [2]:
import segmentation_models as sm
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.optimizers import Adam
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from cityscapesscripts.helpers.labels import id2label
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

Segmentation Models: using `keras` framework.


In [5]:
# Paths
CITYSCAPES_DIR = "./content/cityscapes"
IMG_DIR = os.path.join(CITYSCAPES_DIR, "leftImg8bit_trainvaltest/leftImg8bit")
MASK_DIR = os.path.join(CITYSCAPES_DIR, "gtFine_trainvaltest/gtFine")

IMG_SIZE = (512, 512)
NUM_CLASSES = 19  # Cityscapes has 19 semantic classes
BATCH_SIZE = 8  # Reduce batch size if kernel still crashes

# 🔹 **Function to Load Single Image**
def load_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    img = cv2.resize(img, IMG_SIZE)
    img = img / 255.0  # Normalize
    return img

# 🔹 **Function to Load Single Mask**
def load_mask(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Read as grayscale
    mask_resized = cv2.resize(mask, IMG_SIZE, interpolation=cv2.INTER_NEAREST)

    # Convert 255 (ignore index) → -1 (to mask out in loss function)
    mask_resized[mask_resized == 255] = -1  

    return mask_resized.astype(np.int32)  # Keep as integers (not one-hot)

# 🔹 **Function to Load Data Efficiently Using a Generator**
def cityscapes_generator(split="train"):
    img_path = os.path.join(IMG_DIR, split)
    mask_path = os.path.join(MASK_DIR, split)

    for city in os.listdir(img_path):  # Cities contain images
        img_city_dir = os.path.join(img_path, city)
        mask_city_dir = os.path.join(mask_path, city)

        if not os.path.isdir(img_city_dir):  # Skip non-directories (e.g., .DS_Store)
            continue

        for img_file in os.listdir(img_city_dir):
            if img_file.endswith("_leftImg8bit.png"):
                img = load_image(os.path.join(img_city_dir, img_file))

                # Get corresponding mask
                mask_file = img_file.replace("_leftImg8bit.png", "_gtFine_labelTrainIds.png")
                mask = load_mask(os.path.join(mask_city_dir, mask_file))

                yield img, mask  # Use yield instead of storing everything in a list

# 🔹 **Convert Generator to `tf.data.Dataset`**
def get_cityscapes_dataset(split="train", batch_size=BATCH_SIZE):
    dataset = tf.data.Dataset.from_generator(
        lambda: cityscapes_generator(split),
        output_signature=(
            tf.TensorSpec(shape=(512, 512, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(512, 512), dtype=tf.int32)
        )
    )

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)  # Optimize performance
    return dataset

# 🔹 **Create Train, Validation & Test Datasets**
train_dataset = get_cityscapes_dataset("train")
val_dataset = get_cityscapes_dataset("val")
test_dataset = get_cityscapes_dataset("test")

# 🔹 **Check the Dataset**
for img_batch, mask_batch in train_dataset.take(1):  # Check first batch
    print(f"Batch Image Shape: {img_batch.shape}, Batch Mask Shape: {mask_batch.shape}")

Batch Image Shape: (8, 512, 512, 3), Batch Mask Shape: (8, 512, 512)


In [15]:
import tensorflow.keras.backend as K

def masked_sparse_categorical_crossentropy(y_true, y_pred):
    """
    Custom loss function that ignores pixels with label -1 (255 in Cityscapes).
    """
    # 🔹 Create a mask: 1 for valid pixels, 0 for ignored pixels
    mask = K.cast(K.not_equal(y_true, -1), K.floatx())

    # 🔹 Compute sparse categorical crossentropy
    loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=False)

    # 🔹 Apply the mask (zero out ignored pixels)
    loss *= mask  # Ignore 255-labeled pixels

    # 🔹 Normalize by dividing only by the number of valid pixels
    return K.sum(loss) / K.sum(mask)

# 🔹 **Register the loss function before compiling**
from tensorflow.keras.utils import get_custom_objects
get_custom_objects().update({"masked_sparse_categorical_crossentropy": masked_sparse_categorical_crossentropy})


In [26]:
# Ensure compatibility with TensorFlow 2.x
sm.set_framework("tf.keras")

# Set backbone (EfficientNet or ResNet)
BACKBONE = "efficientnetb3"  # Try 'resnet50' for ResNet backbone

# Load Pretrained U-Net (or PSPNet, FPN)
model = sm.Unet(BACKBONE, 
                encoder_weights="imagenet", 
                classes=19, 
                activation="softmax",
               input_shape=(512, 512, 3)
               )

# All layers are trainable by default - FIX this
for layer in model.layers:
    layer.trainable = False  # ❄️ Freeze all layers

# Unfreeze only the final classification layers (fine-tuning)
for layer in model.layers[-10:]:  # Adjust this number if needed
    layer.trainable = True

# Compile Model
model.compile(optimizer=Adam(learning_rate=0.0001), 
              loss="masked_sparse_categorical_crossentropy", 
              metrics=["accuracy"])

model.summary()



Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 stem_conv (Conv2D)             (None, 256, 256, 40  1080        ['input_6[0][0]']                
                                )                                                                 
                                                                                                  
 stem_bn (BatchNormalization)   (None, 256, 256, 40  160         ['stem_conv[0][0]']              
                                )                                                           

                                                                                                  
 block2a_dwconv (DepthwiseConv2  (None, 128, 128, 14  1296       ['block2a_expand_activation[0][0]
 D)                             4)                               ']                               
                                                                                                  
 block2a_bn (BatchNormalization  (None, 128, 128, 14  576        ['block2a_dwconv[0][0]']         
 )                              4)                                                                
                                                                                                  
 block2a_activation (Activation  (None, 128, 128, 14  0          ['block2a_bn[0][0]']             
 )                              4)                                                                
                                                                                                  
 block2a_s

 )                              2)                                                                
                                                                                                  
 block2c_activation (Activation  (None, 128, 128, 19  0          ['block2c_bn[0][0]']             
 )                              2)                                                                
                                                                                                  
 block2c_se_squeeze (GlobalAver  (None, 192)         0           ['block2c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block2c_se_reshape (Reshape)   (None, 1, 1, 192)    0           ['block2c_se_squeeze[0][0]']     
                                                                                                  
 block2c_s

                                                                                                  
 block3b_se_reshape (Reshape)   (None, 1, 1, 288)    0           ['block3b_se_squeeze[0][0]']     
                                                                                                  
 block3b_se_reduce (Conv2D)     (None, 1, 1, 12)     3468        ['block3b_se_reshape[0][0]']     
                                                                                                  
 block3b_se_expand (Conv2D)     (None, 1, 1, 288)    3744        ['block3b_se_reduce[0][0]']      
                                                                                                  
 block3b_se_excite (Multiply)   (None, 64, 64, 288)  0           ['block3b_activation[0][0]',     
                                                                  'block3b_se_expand[0][0]']      
                                                                                                  
 block3b_p

 block4a_se_expand (Conv2D)     (None, 1, 1, 288)    3744        ['block4a_se_reduce[0][0]']      
                                                                                                  
 block4a_se_excite (Multiply)   (None, 32, 32, 288)  0           ['block4a_activation[0][0]',     
                                                                  'block4a_se_expand[0][0]']      
                                                                                                  
 block4a_project_conv (Conv2D)  (None, 32, 32, 96)   27648       ['block4a_se_excite[0][0]']      
                                                                                                  
 block4a_project_bn (BatchNorma  (None, 32, 32, 96)  384         ['block4a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4b_e

 block4c_drop (FixedDropout)    (None, 32, 32, 96)   0           ['block4c_project_bn[0][0]']     
                                                                                                  
 block4c_add (Add)              (None, 32, 32, 96)   0           ['block4c_drop[0][0]',           
                                                                  'block4b_add[0][0]']            
                                                                                                  
 block4d_expand_conv (Conv2D)   (None, 32, 32, 576)  55296       ['block4c_add[0][0]']            
                                                                                                  
 block4d_expand_bn (BatchNormal  (None, 32, 32, 576)  2304       ['block4d_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block4d_e

 block5a_expand_conv (Conv2D)   (None, 32, 32, 576)  55296       ['block4e_add[0][0]']            
                                                                                                  
 block5a_expand_bn (BatchNormal  (None, 32, 32, 576)  2304       ['block5a_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5a_expand_activation (Act  (None, 32, 32, 576)  0          ['block5a_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5a_dwconv (DepthwiseConv2  (None, 32, 32, 576)  14400      ['block5a_expand_activation[0][0]
 D)                                                              ']                               
          

                                                                                                  
 block5c_bn (BatchNormalization  (None, 32, 32, 816)  3264       ['block5c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5c_activation (Activation  (None, 32, 32, 816)  0          ['block5c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5c_se_squeeze (GlobalAver  (None, 816)         0           ['block5c_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5c_s

 )                                                                                                
                                                                                                  
 block5e_se_squeeze (GlobalAver  (None, 816)         0           ['block5e_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5e_se_reshape (Reshape)   (None, 1, 1, 816)    0           ['block5e_se_squeeze[0][0]']     
                                                                                                  
 block5e_se_reduce (Conv2D)     (None, 1, 1, 34)     27778       ['block5e_se_reshape[0][0]']     
                                                                                                  
 block5e_se_expand (Conv2D)     (None, 1, 1, 816)    28560       ['block5e_se_reduce[0][0]']      
          

 block6b_se_expand (Conv2D)     (None, 1, 1, 1392)   82128       ['block6b_se_reduce[0][0]']      
                                                                                                  
 block6b_se_excite (Multiply)   (None, 16, 16, 1392  0           ['block6b_activation[0][0]',     
                                )                                 'block6b_se_expand[0][0]']      
                                                                                                  
 block6b_project_conv (Conv2D)  (None, 16, 16, 232)  322944      ['block6b_se_excite[0][0]']      
                                                                                                  
 block6b_project_bn (BatchNorma  (None, 16, 16, 232)  928        ['block6b_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6b_d

                                )                                 'block6d_se_expand[0][0]']      
                                                                                                  
 block6d_project_conv (Conv2D)  (None, 16, 16, 232)  322944      ['block6d_se_excite[0][0]']      
                                                                                                  
 block6d_project_bn (BatchNorma  (None, 16, 16, 232)  928        ['block6d_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6d_drop (FixedDropout)    (None, 16, 16, 232)  0           ['block6d_project_bn[0][0]']     
                                                                                                  
 block6d_add (Add)              (None, 16, 16, 232)  0           ['block6d_drop[0][0]',           
          

                                                                                                  
 block6f_project_bn (BatchNorma  (None, 16, 16, 232)  928        ['block6f_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block6f_drop (FixedDropout)    (None, 16, 16, 232)  0           ['block6f_project_bn[0][0]']     
                                                                                                  
 block6f_add (Add)              (None, 16, 16, 232)  0           ['block6f_drop[0][0]',           
                                                                  'block6e_add[0][0]']            
                                                                                                  
 block7a_expand_conv (Conv2D)   (None, 16, 16, 1392  322944      ['block6f_add[0][0]']            
          

                                                                                                  
 top_conv (Conv2D)              (None, 16, 16, 1536  589824      ['block7b_add[0][0]']            
                                )                                                                 
                                                                                                  
 top_bn (BatchNormalization)    (None, 16, 16, 1536  6144        ['top_conv[0][0]']               
                                )                                                                 
                                                                                                  
 top_activation (Activation)    (None, 16, 16, 1536  0           ['top_bn[0][0]']                 
                                )                                                                 
                                                                                                  
 decoder_s

                                                                                                  
 decoder_stage3_upsampling (UpS  (None, 256, 256, 64  0          ['decoder_stage2b_relu[0][0]']   
 ampling2D)                     )                                                                 
                                                                                                  
 decoder_stage3_concat (Concate  (None, 256, 256, 20  0          ['decoder_stage3_upsampling[0][0]
 nate)                          8)                               ',                               
                                                                  'block2a_expand_activation[0][0]
                                                                 ']                               
                                                                                                  
 decoder_stage3a_conv (Conv2D)  (None, 256, 256, 32  59904       ['decoder_stage3_concat[0][0]']  
          

In [34]:
tf.config.run_functions_eagerly(True)  # Forces eager execution


In [None]:
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=30
)


Epoch 1/10
     59/Unknown - 1099s 19s/step - loss: 2.7452 - accuracy: 0.0296

In [None]:
# Evaluate on Validation Set
val_loss, val_acc = model.evaluate(val_dataset)
print(f"Validation Accuracy: {val_acc:.4f}")

# Test on a Sample Image
for test_img, test_mask in test_dataset.take(1):  # Take first test batch
    pred_mask = model.predict(test_img)
    pred_mask = np.argmax(pred_mask, axis=-1)  # Convert from one-hot to class indices
    break  # Only process one batch

# Display Results
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1); plt.imshow(test_img[0]); plt.title("Input Image")
plt.subplot(1, 3, 2); plt.imshow(test_mask[0]); plt.title("Ground Truth")
plt.subplot(1, 3, 3); plt.imshow(pred_mask[0]); plt.title("Predicted Mask")
plt.show()
