In [1]:
import pandas as pd
import numpy as np
from PIL import Image
import os, io, random

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Dropout, Activation, Add, Multiply, ReLU
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras.losses import BinaryFocalCrossentropy
import tensorflow.keras.backend as K

import datetime
import matplotlib.pyplot as plt

In [2]:
BATCH = 16
RESIZE = 128
EPS = 1e-7

In [3]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    """
    Batch-wise Dice coefficient: returns mean Dice across the batch.
    y_true, y_pred: tensors shaped [B, H, W, C] (C usually 1).
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Sum over spatial dims per sample
    axes = tf.range(1, tf.rank(y_pred))  # [H,W,C]
    intersection = tf.reduce_sum(y_true * y_pred, axis=axes)
    sums = tf.reduce_sum(y_true, axis=axes) + tf.reduce_sum(y_pred, axis=axes)
    dice = (2.0 * intersection + smooth) / (sums + smooth)
    # dice shape: [B, C] -> average over channels then batch
    dice = tf.reduce_mean(dice, axis=-1)
    return tf.reduce_mean(dice)


class CombinedSegmentationLoss(tf.keras.losses.Loss):
    """
    Combined loss:
      total_loss = w_dice * (1 - Dice)
                 + w_ft   * FocalTversky
                 + w_edge * EdgeLoss

    Implemented in a batch-vectorized and numerically stable way.
    """

    def __init__(
        self,
        alpha=0.7,        # Tversky alpha (FP weight)
        beta=0.3,         # Tversky beta (FN weight)  --> increase beta to penalize FN more
        ft_gamma=0.75,    # Focal Tversky focusing factor
        focal_alpha=0.95, # Internal Focal BCE (optional)
        focal_gamma=1.5,
        w_dice=0.6,
        w_ft=0.3,
        w_edge=0.1,
        from_logits=False,
        name="combined_segmentation_loss",
    ):
        super().__init__(name=name)
        # Tversky / Focal Tversky params
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.ft_gamma = float(ft_gamma)
        # optional focal BCE (not used in final sum by default; kept if you want)
        self.focal_bce = BinaryFocalCrossentropy(alpha=focal_alpha, gamma=focal_gamma, from_logits=from_logits)
        # weights
        self.w_dice = float(w_dice)
        self.w_ft = float(w_ft)
        self.w_edge = float(w_edge)
        self.from_logits = from_logits

    def _prepare_predictions(self, y_pred):
        """If logits provided, convert to probabilities."""
        if self.from_logits:
            y_pred = tf.nn.sigmoid(y_pred)
        # ensure float32
        return tf.cast(y_pred, tf.float32)

    def call(self, y_true, y_pred):
        """
        y_true, y_pred: tensors [B, H, W, C] (C=1 typically)
        returns scalar loss (mean over batch).
        """
        y_pred = self._prepare_predictions(y_pred)
        y_true = tf.cast(y_true, tf.float32)

        # ---- Dice Loss (batch-wise, stable) ----
        axes = tf.range(1, tf.rank(y_pred))  # spatial + channel dims
        intersection = tf.reduce_sum(y_true * y_pred, axis=axes)    # shape [B, C]
        sum_ytrue = tf.reduce_sum(y_true, axis=axes)
        sum_ypred = tf.reduce_sum(y_pred, axis=axes)
        dice_score = (2.0 * intersection + EPS) / (sum_ytrue + sum_ypred + EPS)  # [B, C]
        dice_score = tf.reduce_mean(dice_score, axis=-1)  # per-sample (avg over channels)
        dice_loss = 1.0 - dice_score                       # [B]

        # ---- Focal Tversky Loss (batch-wise) ----
        # compute TP, FP, FN per sample and channel
        TP = intersection
        FP = tf.reduce_sum((1.0 - y_true) * y_pred, axis=axes)
        FN = tf.reduce_sum(y_true * (1.0 - y_pred), axis=axes)

        # Tversky per sample/channel
        tversky = (TP + EPS) / (TP + self.alpha * FP + self.beta * FN + EPS)  # [B, C]
        tversky = tf.reduce_mean(tversky, axis=-1)  # [B]
        focal_tversky = tf.pow((1.0 - tversky), self.ft_gamma)  # [B]

        # ---- Edge loss (gradient-based) ----
        # tf.image.sobel_edges returns shape [B, H, W, C, 2]
        # compute gradient magnitude per-pixel and per-channel
        sobel_pred = tf.image.sobel_edges(y_pred)   # [B,H,W,C,2]
        sobel_true = tf.image.sobel_edges(y_true)
        # compute magnitude: sqrt(dx^2 + dy^2)
        # sobel[...,0] is dy, sobel[...,1] is dx (both shape [B,H,W,C])
        mag_pred = tf.sqrt(tf.square(sobel_pred[..., 0]) + tf.square(sobel_pred[..., 1]) + EPS)
        mag_true = tf.sqrt(tf.square(sobel_true[..., 0]) + tf.square(sobel_true[..., 1]) + EPS)
        # absolute difference and mean per sample
        edge_diff = tf.abs(mag_true - mag_pred)   # [B,H,W,C]
        edge_loss_per_sample = tf.reduce_mean(edge_diff, axis=axes)  # [B]
        # edge_loss_per_sample is mean difference across H,W,C

        # ---- combine losses per sample and then average ----
        total_per_sample = (
            self.w_dice * dice_loss
            + self.w_ft * focal_tversky
            + self.w_edge * edge_loss_per_sample
        )

        # final scalar
        return tf.reduce_mean(total_per_sample)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "alpha": self.alpha,
            "beta": self.beta,
            "ft_gamma": self.ft_gamma,
            "w_dice": self.w_dice,
            "w_ft": self.w_ft,
            "w_edge": self.w_edge,
            "from_logits": self.from_logits,
        })
        return cfg


In [4]:
# def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=0.75):
#     y_true_f = K.flatten(y_true)
#     y_pred_f = K.flatten(y_pred)
    
#     TP = K.sum(y_true_f * y_pred_f)
#     FP = K.sum((1 - y_true_f) * y_pred_f)
#     FN = K.sum(y_true_f * (1 - y_pred_f))
    
#     tversky = (TP + 1e-6) / (TP + alpha * FP + beta * FN + 1e-6)
#     return K.pow((1 - tversky), gamma)


# def dice_coefficient(y_true, y_pred, smooth=1e-6):
#     y_true_f = tf.cast(tf.reshape(y_true, [-1]), tf.float32)
#     y_pred_f = tf.cast(tf.reshape(y_pred, [-1]), tf.float32)
#     y_pred_f = tf.clip_by_value(y_pred_f, smooth, 1)
#     y_true_f = tf.clip_by_value(y_true_f, smooth, 1)
#     intersection = tf.reduce_sum(y_true_f * y_pred_f)
#     return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)


# Combined Dice + Focal Loss
# class DiceFocalLoss(tf.keras.losses.Loss):
#     def __init__(self, alpha=0.95, gamma=1.5):
#         super().__init__()
#         self.focal = BinaryFocalCrossentropy(alpha=alpha, gamma=gamma)

#     def call(self, y_true, y_pred):
#         y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
#         dice_loss = 1 - dice_coefficient(y_true, y_pred)
#         focal_loss = self.focal(y_true, y_pred)
#         return dice_loss + 0.01 * focal_loss


# def combined_loss(y_true, y_pred):
#     return 0.7 * dice_loss(y_true, y_pred) + 0.3 * focal_tversky_loss(y_true, y_pred)


# def edge_loss(y_true, y_pred):
#     sobel_x = tf.image.sobel_edges(y_pred)[..., 0]
#     sobel_y = tf.image.sobel_edges(y_pred)[..., 1]
#     edges_pred = tf.sqrt(tf.square(sobel_x) + tf.square(sobel_y))

#     sobel_x_t = tf.image.sobel_edges(y_true)[..., 0]
#     sobel_y_t = tf.image.sobel_edges(y_true)[..., 1]
#     edges_true = tf.sqrt(tf.square(sobel_x_t) + tf.square(sobel_y_t))
    
#     return tf.reduce_mean(tf.abs(edges_true - edges_pred))


# def total_loss(y_true, y_pred):
#     return 0.6 * combined_loss(y_true, y_pred) + 0.4 * edge_loss(y_true, y_pred)

In [5]:
def load_tif_data(data_dir):
    images, masks = [], []
    for patient_dir in os.listdir(data_dir):
        patient_path = os.path.join(data_dir, patient_dir)
        if not os.path.isdir(patient_path):
            continue
        # Get all .tif files in the patient directory
        tif_files = [f for f in os.listdir(patient_path) if f.endswith('.tif') and 'mask' not in f]
        for img_file in tif_files:
            img_path = os.path.join(patient_path, img_file)
            images.append(img_path)
            # Construct corresponding mask filename
            mask_file = img_file.replace('.tif', '_mask.tif')
            mask_path = os.path.join(patient_path, mask_file)
            masks.append(mask_path)
            if not os.path.exists(mask_path):
                continue  # Skip if mask file doesn’t exist
    
    return pd.DataFrame({'img': images, 'mask': masks})

In [6]:
df = pd.read_csv('data_addrs.csv')

In [7]:
df

Unnamed: 0,img,mask
0,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
1,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
2,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
3,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
4,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
...,...,...
3924,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
3925,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
3926,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
3927,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...


In [8]:
lbl = []

for each in df['mask'].tolist():
    im = Image.open(each)
    imarray = np.array(im)
    if np.sum(imarray) == 0:
        lbl.append(0)
    else:
        lbl.append(1)

df['label'] = pd.Series(lbl)
dff = df[df['label'] == 1].copy(deep=True)
dff.drop(['label'], axis=1, inplace=True)


In [9]:
dff.reset_index(drop=True).head()

Unnamed: 0,img,mask
0,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
1,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
2,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
3,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...
4,D:/Github/Brain Tumor MRI Segmentation/dataset...,D:/Github/Brain Tumor MRI Segmentation/dataset...


In [10]:
# df['label'].value_counts()

In [11]:
# idx = 58
# masks = dff['mask'].tolist()
# images = dff['img'].tolist()
# mask = masks[idx]
# image = images[idx]
# maskg = Image.open(mask)
# mask = np.array(maskg)
# # plt.imshow(maskg)
# imageg = Image.open(image)
# image = np.array(imageg)
# # plt.imshow(imageg)

# fig, ax = plt.subplots(1, 2, figsize=(10, 5))
# ax[0].imshow(maskg, cmap='gray')
# ax[0].set_title("Mask")

# ax[1].imshow(imageg, cmap='gray')
# ax[1].set_title("Image")

# plt.show()

In [12]:
def preprocess(image, mask):
    return random_crop_with_zoomout(image, mask, target_size=(RESIZE, RESIZE), p=0.3)


def random_crop_with_zoomout(image, mask, target_size=(RESIZE, RESIZE), p=0.5):
    """
    Apply crop+zoomout augmentation randomly with probability p.
    """
    # Generate random number
    do_augment = tf.random.uniform([], 0, 1.0) < p

    def augmented():
        return mask_aware_crop(image, mask, target_size)

    def identity():
        # Just resize without augmentation
        image_resized = tf.image.resize(image, target_size, method='bilinear')
        mask_resized = tf.image.resize(mask, target_size, method='nearest')
        return image_resized, mask_resized

    return tf.cond(do_augment, augmented, identity)


def mask_aware_crop(image, mask, target_size=(RESIZE, RESIZE)):
    # mask = tf.cast(mask, tf.int32)
    row, col = tf.shape(mask)[0], tf.shape(mask)[1]

    # Find nonzero indices (True positions)
    non_zero = tf.where(mask > 0)

    # Get bounding box coordinates
    r_min = tf.reduce_min(non_zero[:, 1])
    r_max = tf.reduce_max(non_zero[:, 1])
    c_min = tf.reduce_min(non_zero[:, 0])
    c_max = tf.reduce_max(non_zero[:, 0])

    # Compute distances to borders
    rmindist = tf.cast(r_min, tf.int32)
    cmindist = tf.cast(c_min, tf.int32)
    rmaxdist = tf.cast(col, tf.int32) - tf.cast(r_max, tf.int32)
    cmaxdist = tf.cast(row, tf.int32) - tf.cast(c_max, tf.int32)

    zoomout = tf.minimum(tf.minimum(rmindist, rmaxdist),
                        tf.minimum(cmindist, cmaxdist))

    zoomout_factor = tf.where(zoomout > 50, 0.3, 0.8)
    zoomout = tf.cast(tf.cast(zoomout, tf.float32) * zoomout_factor, tf.int32)

    # Expanded box coordinates (clamped to image size)
    rmin = tf.cast(tf.maximum(tf.cast(r_min, tf.int64) - tf.cast(zoomout, tf.int64), 0), tf.int32)
    cmin = tf.cast(tf.maximum(tf.cast(c_min, tf.int64) - tf.cast(zoomout, tf.int64), 0), tf.int32)
    rmax = tf.cast(tf.minimum(tf.cast(r_max, tf.int64) + tf.cast(zoomout, tf.int64), tf.cast(col, tf.int64)), tf.int32)
    cmax = tf.cast(tf.minimum(tf.cast(c_max, tf.int64) + tf.cast(zoomout, tf.int64), tf.cast(row, tf.int64)), tf.int32)

    # Crop image and mask
    cropped_mask = mask[cmin:cmax, rmin:rmax]
    cropped_img = image[cmin:cmax, rmin:rmax]

    resized_img = tf.image.resize(cropped_img, target_size, method='bilinear')
    # cropped_mask = tf.expand_dims(cropped_mask, axis=-1)
    resized_mask = tf.image.resize(cropped_mask, target_size, method='nearest')
    # resized_mask = tf.squeeze(resized_mask, axis=-1)
    
    return resized_img, resized_mask

In [13]:
x_tr, x_val = train_test_split(dff, test_size=0.3, random_state=42)
x_val, x_ts = train_test_split(x_val, test_size=0.5, random_state=42)

print(f"Train: {len(x_tr)}, Val: {len(x_val)}, Test: {len(x_ts)}")

Train: 961, Val: 206, Test: 206


In [14]:
def _load_tif_image(path, ch):
    img = Image.open(path.decode())
    
    if int(ch) == 3:
        img = img.convert('RGB')
    else:
        img = img.convert('L')
    
    img = img.resize((RESIZE, RESIZE))
    img = np.array(img).astype(np.float32) / 255.0

    # If grayscale (2D), expand to [H, W, 1]
    if img.ndim == 2:
        img = np.expand_dims(img, axis=-1)
    return img


def load_image_and_mask(image_path, mask_path):
    # Load image
    image = tf.numpy_function(_load_tif_image, [image_path, 3], tf.float32)
    image.set_shape([RESIZE, RESIZE, 3])  # or [RESIZE, RESIZE, 1]

    # Load mask (assume single channel)
    mask = tf.numpy_function(_load_tif_image, [mask_path, 1], tf.float32)
    mask.set_shape([RESIZE, RESIZE, 1])

    return image, mask

In [15]:
trds = tf.data.Dataset.from_tensor_slices((x_tr['img'], x_tr['mask']))
vlds = tf.data.Dataset.from_tensor_slices((x_val['img'], x_val['mask']))
tsds = tf.data.Dataset.from_tensor_slices((x_ts['img'], x_ts['mask']))


trds = (trds
	.shuffle(1024)
	.map(load_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
	.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
	.cache()
	.batch(BATCH)
	.prefetch(tf.data.AUTOTUNE)
)

vlds = (vlds
	.shuffle(1024)
	.map(load_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
	.cache()
	.batch(BATCH)
	.prefetch(tf.data.AUTOTUNE)
)

tsds = (tsds
	.shuffle(1024)
	.map(load_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
	.cache()
	.batch(BATCH)
	.prefetch(tf.data.AUTOTUNE)
)

In [16]:
for a, b in trds.take(1):
    print(a)
    print()
    print(b)
    

tf.Tensor(
[[[[0.23921569 0.         0.        ]
   [0.23921569 0.00392157 0.        ]
   [0.23921569 0.00784314 0.        ]
   ...
   [0.23921569 0.00784314 0.        ]
   [0.23921569 0.00392157 0.        ]
   [0.23921569 0.         0.        ]]

  [[0.23921569 0.00392157 0.        ]
   [0.23921569 0.00784314 0.        ]
   [0.23921569 0.00784314 0.        ]
   ...
   [0.23529412 0.00784314 0.        ]
   [0.23921569 0.00392157 0.        ]
   [0.23921569 0.00392157 0.        ]]

  [[0.23921569 0.00392157 0.        ]
   [0.23921569 0.00784314 0.        ]
   [0.23921569 0.00784314 0.        ]
   ...
   [0.23529412 0.00392157 0.        ]
   [0.23921569 0.00392157 0.        ]
   [0.23921569 0.00392157 0.        ]]

  ...

  [[0.23921569 0.         0.        ]
   [0.23921569 0.         0.        ]
   [0.23529412 0.00392157 0.        ]
   ...
   [0.23921569 0.00784314 0.        ]
   [0.23921569 0.00392157 0.        ]
   [0.23921569 0.         0.        ]]

  [[0.23921569 0.         0.      

In [17]:
# for img, mask in trds.take(3):
#     plt.subplot(1,2,1)
#     plt.imshow(img[0])
#     plt.subplot(1,2,2)
#     plt.imshow(mask[0,...,0], cmap='gray')
#     plt.show()

In [18]:
def attention_gate(x, g, inter_channels):
    """
    Attention gate: x -> encoder feature (skip), g -> decoder gating signal.
    Returns: x * attention_map (same shape as x).
    """
    # 1x1 conv to reduce channels
    theta_x = Conv2D(inter_channels, kernel_size=1, strides=1, padding='same')(x)
    phi_g = Conv2D(inter_channels, kernel_size=1, strides=1, padding='same')(g)

    add_xg = Add()([theta_x, phi_g])
    act = ReLU()(add_xg)

    psi = Conv2D(1, kernel_size=1, strides=1, padding='same')(act)
    psi = Activation('sigmoid')(psi)   # attention coefficients in [0,1]

    # broadcast multiply: attention map (H,W,1) * x (H,W,C)
    attn_out = Multiply()([x, psi])
    return attn_out


def TumorSegNet_withAttention(input_shape=(256, 256, 3)):
    inputs = Input(input_shape, name='Input_Layer')

    # Encoder
    c1 = Conv2D(64, 3, activation='relu', padding='same', name='Conv1_1')(inputs)
    c1 = Conv2D(64, 3, activation='relu', padding='same', name='Conv1_2')(c1)
    p1 = MaxPooling2D((2, 2), name='MaxPool1')(c1)

    c2 = Conv2D(128, 3, activation='relu', padding='same', name='Conv2_1')(p1)
    c2 = Conv2D(128, 3, activation='relu', padding='same', name='Conv2_2')(c2)
    p2 = MaxPooling2D((2, 2), name='MaxPool2')(c2)

    c3 = Conv2D(256, 3, activation='relu', padding='same', name='Conv3_1')(p2)
    c3 = Conv2D(256, 3, activation='relu', padding='same', name='Conv3_2')(c3)
    p3 = MaxPooling2D((2, 2), name='MaxPool3')(c3)

    # Bottleneck
    c4 = Conv2D(512, 3, activation='relu', padding='same', name='Bottleneck_Conv1')(p3)
    c4 = Conv2D(512, 3, activation='relu', padding='same', name='Bottleneck_Conv2')(c4)
    c4 = Dropout(0.5, name='Bottleneck_Dropout')(c4)

    # Decoder - level 3 (connects to c3)
    u3 = UpSampling2D((2, 2), name='UpSample3')(c4)
    # attention gate for c3 using gating signal u3
    att3 = attention_gate(c3, u3, inter_channels=128)  # inter_channels ~ half of skip channels
    u3 = concatenate([u3, att3], name='Concat3')
    c5 = Conv2D(256, 3, activation='relu', padding='same', name='Conv5_1')(u3)
    c5 = Conv2D(256, 3, activation='relu', padding='same', name='Conv5_2')(c5)

    # Decoder - level 2 (connects to c2)
    u2 = UpSampling2D((2, 2), name='UpSample2')(c5)
    att2 = attention_gate(c2, u2, inter_channels=64)
    u2 = concatenate([u2, att2], name='Concat2')
    c6 = Conv2D(128, 3, activation='relu', padding='same', name='Conv6_1')(u2)
    c6 = Conv2D(128, 3, activation='relu', padding='same', name='Conv6_2')(c6)

    # Decoder - level 1 (connects to c1)
    u1 = UpSampling2D((2, 2), name='UpSample1')(c6)
    att1 = attention_gate(c1, u1, inter_channels=32)
    u1 = concatenate([u1, att1], name='Concat1')
    c7 = Conv2D(64, 3, activation='relu', padding='same', name='Conv7_1')(u1)
    c7 = Conv2D(64, 3, activation='relu', padding='same', name='Conv7_2')(c7)

    outputs = Conv2D(1, 1, activation='sigmoid', name='Output_Layer')(c7)

    return Model(inputs, outputs, name='TumorSegNet_Attention')


In [19]:
# Instantiate and compile model
model = TumorSegNet_withAttention(input_shape=(RESIZE, RESIZE, 3))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
            loss=CombinedSegmentationLoss(from_logits=False),
            metrics=[dice_coefficient, 'accuracy'])

# TensorBoard callback
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=True,
    write_images=True,
    update_freq='epoch',
    profile_batch=0
)

# Custom callback to log sample predictions to TensorBoard
class PredictionLogger(tf.keras.callbacks.Callback):
    def __init__(self, val_images, val_masks):
        super(PredictionLogger, self).__init__()
        self.val_images = val_images[:5]  # Log first 5 validation images
        self.val_masks = val_masks[:5]

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.val_images)
        preds = (preds > 0.5).astype(np.uint8)
        
        # Create figure with images, true masks, and predictions
        fig = plt.figure(figsize=(15, 5 * len(self.val_images)))
        for i in range(len(self.val_images)):
            plt.subplot(len(self.val_images), 3, i * 3 + 1)
            plt.imshow(self.val_images[i].squeeze(), cmap='gray')
            plt.title('MRI Image')
            plt.axis('off')
            plt.subplot(len(self.val_images), 3, i * 3 + 2)
            plt.imshow(self.val_masks[i].squeeze(), cmap='gray')
            plt.title('True Mask')
            plt.axis('off')
            plt.subplot(len(self.val_images), 3, i * 3 + 3)
            plt.imshow(preds[i].squeeze(), cmap='gray')
            plt.title(f'Predicted Mask (Epoch {epoch + 1})')
            plt.axis('off')
        
        # Log figure to TensorBoard
        writer = tf.summary.create_file_writer(log_dir)
        with writer.as_default():
            tf.summary.image("Validation Predictions", plot_to_image(fig), step=epoch)
        plt.close(fig)

# Convert matplotlib figure to TensorBoard image format
def plot_to_image(figure):
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image

# Other callbacks
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True, monitor='val_dice_coefficient', mode='max'),
    ModelCheckpoint('TumorSegNet_attbest.keras', save_best_only=True, monitor='val_dice_coefficient'),
    tensorboard_callback,
    # PredictionLogger(x_val['img'], masks_val)
]

# Train the model
history = model.fit(
    trds,
    validation_data=vlds,
    epochs=20,
    callbacks=callbacks
)

Epoch 1/20
[1m61/61[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m352s[0m 6s/step - accuracy: 0.2134 - dice_coefficient: 0.0969 - loss: 0.8333 - val_accuracy: 0.8378 - val_dice_coefficient: 0.2222 - val_loss: 0.7393
Epoch 2/20
[1m61/61[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m367s[0m 6s/step - accuracy: 0.8077 - dice_coefficient: 0.2358 - loss: 0.7232 - val_accuracy: 0.9674 - val_dice_coefficient: 0.2052 - val_loss: 0.7299
Epoch 3/20
[1m61/61[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m364s[0m 6s/step - accuracy: 0.9230 - dice_coefficient: 0.3536 - loss: 0.6170 - val_accuracy: 0.9765 - val_dice_coefficient: 0.5033 - val_loss: 0.4771
Epoch 4/20
[1m61/61[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m359s[0m 6s/step - accuracy: 0.9571 - dice_coefficient: 0.5173 - loss: 0.4722 - val_accuracy: 0.9775 - val_dice_coefficient: 0.5419 - val_loss: 0.4401
Epoch 5/20
[1m61/61[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m365s[0m 6s/step - accuracy: 0.9632 - dice_coefficient: 

In [None]:
model.save('TumorSegNet_attention.keras')