# Training the UNet

## Mounting the Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
drive_path = '/content/drive/MyDrive'

### Import

In [16]:
import os
import numpy as np
import cv2
from glob import glob
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation ,MaxPool2D, Conv2DTranspose, Concatenate, Input, UpSampling2D, MultiHeadAttention, LayerNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger

### Seeding

In [5]:
os.environ['PYTHONHASHSEED'] = str(42)
np.random.seed(42)
tf.random.set_seed(42)

### Hyperparameters

In [6]:
batch_size=4
lr=1e-4
epochs=5
height=768
width=512

In [64]:
dataset_path = os.path.join(drive_path, "Human Face Segmentation", "Original_Data")

In [65]:
files_dir = os.path.join(drive_path, 'Colab Notebooks', 'Files', 'Original_Data')
model_file_unet = os.path.join(files_dir, "unet-org.h5")
model_file_basic_trans = os.path.join(files_dir, "basic-trans-org.h5")
model_file_adv_trans = os.path.join(files_dir, "adv-trans-org.h5")
log_file = os.path.join(files_dir, "log-aug.csv")

### Creating Folder

In [66]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [67]:
create_dir(files_dir)

## Building UNet Model

### Conv Block

In [17]:
def conv_block(inputs, num_filters, strides=1):
    x = Conv2D(num_filters, 3,strides=strides, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, strides=strides, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

### Encoder Block

In [18]:
def encoder_block(inputs, num_filters):
    x=conv_block(inputs, num_filters)
    p = MaxPool2D((2,2))(x)

    return x,p

### Decoder Block

In [19]:
def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs)
    x=Concatenate()([x,skip])
    x = conv_block(x,num_filters)
    return x

### UNet

In [20]:
def build_unet(input_shape):
    inputs = Input(input_shape)

    """ENCODER"""
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)

    """Bridge"""
    b1 = conv_block(p4, 1024)

    """Decoder"""
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    outputs = Conv2D(1,1,padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name="UNET")
    return model

### TransUNet Encoder Block

In [21]:
def transformer_encoder(inputs, num_heads=4, ff_dim=64, num_transformer_layers=12, dropout_rate=0.1):
    x = inputs
    for _ in range(num_transformer_layers):
        # Multi-Head Self-Attention
        attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=64)(x, x)
        x = LayerNormalization(epsilon=1e-6)(attn_output + x)

        # Feed Forward Network
        ffn = Conv2D(filters=ff_dim, kernel_size=1, activation='relu')(x)
        ffn = Conv2D(filters=inputs.shape[-1], kernel_size=1)(ffn)
        x = LayerNormalization(epsilon=1e-6)(ffn + x)

        # Dropout
        x = Dropout(dropout_rate)(x)

    return x

In [22]:
def trans_encoder(input_shape):
    inputs = Input(shape=input_shape)

    # Initial convolution layer
    x=conv_block(inputs, num_filters=64)

    # Encoder blocks with transformers
    encoder_blocks = [
        conv_block(x, 64),
        conv_block(x, 128, strides=2),  # Downsample
        transformer_encoder(x),  # 12 transformer layers
        conv_block(x, 256, strides=2),  # Downsample
        transformer_encoder(x),  # 12 transformer layers
        conv_block(x, 512,
                   strides=2),  # Downsample
        transformer_encoder(x),  # 12 transformer layers
    ]

    return inputs, encoder_blocks[-1]

### TransUNet Decoder Block

In [37]:
def trans_decoder(encoder_output, num_classes=2):
    x = encoder_output

    # Decoder blocks with upsampling
    decoder_blocks = [
        Conv2D(512, 3, activation='relu', padding='same')(x),
        UpSampling2D()(x),
        Conv2D(256, 3, activation='relu', padding='same')(x),
        UpSampling2D()(x),
        Conv2D(128, 3, activation='relu', padding='same')(x),
        UpSampling2D()(x),
        Conv2D(64, 3, activation='relu', padding='same')(x),
        UpSampling2D()(x),
    ]

    # Final segmentation output
    output = Conv2D(num_classes, 1, activation='softmax')(x)

    return output

## TransUNet Model

In [47]:
def basic_trans_unet(input_shape, num_classes):
    inputs, encoder_output = trans_encoder(input_shape)
    decoder_output = trans_decoder(encoder_output, num_classes)
    model = tf.keras.Model(inputs=inputs, outputs=decoder_output)
    return model

## Nikhil's TransUNet

In [48]:
def trans_unet_encoder(inputs, num_filters):
    x=conv_block(inputs, num_filters)
    x = transformer_encoder(x,num_transformer_layers=12)
    p = MaxPool2D((2,2))(x)
    return x,p

def trans_unet_decoder(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs)
    x=Concatenate()([x,skip])
    x = conv_block(x,num_filters)
    return x

def adv_trans_unet(input_shape):
    inputs = Input(input_shape)

    """ENCODER"""
    s1, p1 = trans_unet_encoder(inputs, 64)
    s2, p2 = trans_unet_encoder(p1, 128)
    s3, p3 = trans_unet_encoder(p2, 256)
    s4, p4 = trans_unet_encoder(p3, 512)

    """Bridge"""
    b1 = conv_block(p4, 1024)

    """Decoder"""
    d1 = trans_unet_decoder(b1, s4, 512)
    d2 = trans_unet_decoder(d1, s3, 256)
    d3 = trans_unet_decoder(d2, s2, 128)
    d4 = trans_unet_decoder(d3, s1, 64)

    outputs = Conv2D(1,1,padding='same', activation='softmax')(d4)

    model = Model(inputs, outputs, name="TRANSUNET")
    return model

## Loss Function

In [60]:
def dice_coefficient(y_true, y_pred, smooth=1):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    return (2.0 * intersection + smooth) / (union + smooth)

### Dataset Pipeline

In [25]:
#Loading the training and validation dataset
def load_data(path):
    train_x = sorted(glob(os.path.join(path, "train", "images", "*")))
    train_y = sorted(glob(os.path.join(path, "train", "masks", "*")))


    valid_x = sorted(glob(os.path.join(path, "valid", "images", "*")))
    valid_y = sorted(glob(os.path.join(path, "valid", "masks", "*")))

    return (train_x, train_y), (valid_x, valid_y)

In [26]:
# Reading Images
def read_image(path):
    path = path.decode()
    x = cv2.imread(path,cv2.IMREAD_COLOR)
    x=x/255.0
    return x

In [27]:
# Reading Mask
def read_mask(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = x/255.0
    x = np.expand_dims(x, axis=-1)
    return x

In [28]:
# tf.data pipeline
def tf_parse(x, y):
    def _parse(x, y):
        x = read_image(x)
        y = read_mask(y)  # Use read_mask instead of read_image for masks
        return x, y

    x, y = tf.numpy_function(_parse, [x, y], [tf.float64, tf.float64])
    x.set_shape([height, width, 3])
    y.set_shape([height, width, 1])
    return x, y


def tf_dataset(x,y,batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((x,y))
    dataset = dataset.map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

## Training

In [68]:
(train_x, train_y), (valid_x, valid_y) = load_data(dataset_path)
print(f"Train : {len(train_x)} - {len(train_y)}")
print(f"Valid : {len(valid_x)} - {len(valid_y)}")

Train : 108 - 108
Valid : 13 - 13


In [69]:
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)

In [70]:
for x, y in valid_dataset:
    print(x.shape, y.shape)

(4, 768, 512, 3) (4, 768, 512, 1)
(4, 768, 512, 3) (4, 768, 512, 1)
(4, 768, 512, 3) (4, 768, 512, 1)
(1, 768, 512, 3) (1, 768, 512, 1)


In [71]:
input_shape = (height, width,3)
unet_model = build_unet(input_shape)

In [81]:
basic_trans_model = basic_trans_unet(input_shape, num_classes=1)

In [73]:
adv_trans_model = adv_trans_unet(input_shape)

In [83]:
opt = tf.keras.optimizers.Adam(lr)
unet_model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy', dice_coefficient])
basic_trans_model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy' , dice_coefficient])
adv_trans_model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy', dice_coefficient])

In [84]:
callbacks_unet=[
    ModelCheckpoint(model_file_unet, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    CSVLogger(log_file),
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
]
callbacks_basic_trans=[
    ModelCheckpoint(model_file_basic_trans, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    CSVLogger(log_file),
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
]
callbacks_adv_trans=[
    ModelCheckpoint(model_file_adv_trans, verbose=1, save_best_only=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4),
    CSVLogger(log_file),
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
]

In [None]:
unet_model.fit(train_dataset,
          validation_data=valid_dataset,
          epochs=epochs,
          callbacks=callbacks_unet
          )

In [None]:
basic_trans_model.fit(train_dataset,
          validation_data=valid_dataset,
          epochs=epochs,
          callbacks=callbacks_basic_trans
          )

Epoch 1/5


In [None]:
adv_trans_model.fit(train_dataset,
          validation_data=valid_dataset,
          epochs=epochs,
          callbacks=callbacks_adv_trans
          )