In [None]:
# 1. Imports
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from glob import glob

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses
from tensorflow.keras.callbacks import ModelCheckpoint

# ------------------------------
# 2. Paths
# ------------------------------
IMAGE_FOLDER = "kidney/images/"    # e.g. dataset/images/*.png
LABEL_FOLDER = "kidney/labels/"    # e.g. dataset/labels/*.txt

# ------------------------------
# 3. Load Images and Labels
# ------------------------------
def load_image(path, size=(128,128)):
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    img = cv2.resize(img, size)
    img = img / 255.0
    return img

def load_mask(path, size=(128,128)):
    """
    Assumes label.txt contains flattened 0/1 binary mask (e.g. 256x256 = 65536 values).
    Modify reshape() if your dataset is different size.
    """
    mask = np.loadtxt(path, dtype=np.uint8)
    mask = mask.reshape((256,256))   # Change if your txt labels are different shape
    mask = cv2.resize(mask, size, interpolation=cv2.INTER_NEAREST)
    mask = np.expand_dims(mask, axis=-1)
    return mask

images, masks = [], []
for img_path in sorted(glob(os.path.join(IMAGE_FOLDER, "*.png"))):
    filename = os.path.splitext(os.path.basename(img_path))[0]
    label_path = os.path.join(LABEL_FOLDER, filename + ".txt")
    if os.path.exists(label_path):
        images.append(load_image(img_path))
        masks.append(load_mask(label_path))

images = np.array(images, dtype=np.float32)
masks = np.array(masks, dtype=np.float32)

print("Images:", images.shape)
print("Masks:", masks.shape)

# ------------------------------
# 4. Train-Test Split
# ------------------------------
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(images, masks, test_size=0.2, random_state=42)

# ------------------------------
# 5. GM-UNet Model
# ------------------------------
def conv_block(x, filters):
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

def attention_gate(x, g, filters):
    g1 = layers.Conv2D(filters, 1)(g)
    x1 = layers.Conv2D(filters, 1)(x)
    psi = layers.Activation("relu")(layers.add([g1, x1]))
    psi = layers.Conv2D(1, 1, activation="sigmoid")(psi)
    return layers.multiply([x, psi])

def gm_unet(input_shape=(128,128,3), num_classes=1):
    inputs = layers.Input(input_shape)

    # Encoder
    c1 = conv_block(inputs, 64)
    p1 = layers.MaxPooling2D()(c1)

    c2 = conv_block(p1, 128)
    p2 = layers.MaxPooling2D()(c2)

    c3 = conv_block(p2, 256)
    p3 = layers.MaxPooling2D()(c3)

    c4 = conv_block(p3, 512)
    p4 = layers.MaxPooling2D()(c4)

    # Bottleneck
    bn = conv_block(p4, 1024)

    # Decoder + Attention (GM block)
    g4 = attention_gate(c4, bn, 512)
    u4 = layers.Conv2DTranspose(512, 2, strides=2, padding="same")(bn)
    u4 = layers.concatenate([u4, g4])
    c5 = conv_block(u4, 512)

    g3 = attention_gate(c3, c5, 256)
    u3 = layers.Conv2DTranspose(256, 2, strides=2, padding="same")(c5)
    u3 = layers.concatenate([u3, g3])
    c6 = conv_block(u3, 256)

    g2 = attention_gate(c2, c6, 128)
    u2 = layers.Conv2DTranspose(128, 2, strides=2, padding="same")(c6)
    u2 = layers.concatenate([u2, g2])
    c7 = conv_block(u2, 128)

    g1 = attention_gate(c1, c7, 64)
    u1 = layers.Conv2DTranspose(64, 2, strides=2, padding="same")(c7)
    u1 = layers.concatenate([u1, g1])
    c8 = conv_block(u1, 64)

    # Output
    outputs = layers.Conv2D(num_classes, 1, activation="sigmoid")(c8)

    model = models.Model(inputs, outputs)
    return model

model = gm_unet()
model.compile(optimizer=optimizers.Adam(1e-4),
              loss=losses.BinaryCrossentropy(),
              metrics=["accuracy"])

model.summary()

# ------------------------------
# 6. Training with Checkpoint
# ------------------------------
checkpoint = ModelCheckpoint(
    "gm_unet_best.h5",
    monitor="val_loss",
    save_best_only=True,
    mode="min"
)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=8,
    epochs=20,
    callbacks=[checkpoint]
)

# Save final model
model.save("gm_unet_final.h5")

# ------------------------------
# 7. Prediction + Reloading Model
# ------------------------------
from tensorflow.keras.models import load_model

# Load best saved model
best_model = load_model("gm_unet_best.h5", compile=False)

def predict_and_show(image, model):
    pred = model.predict(np.expand_dims(image, axis=0))[0]
    pred_mask = (pred > 0.5).astype(np.uint8)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(image)
    plt.title("Input")
    plt.subplot(1,3,2)
    plt.imshow(pred_mask[:,:,0], cmap="gray")
    plt.title("Prediction")
    plt.subplot(1,3,3)
    plt.imshow(image)
    plt.imshow(pred_mask[:,:,0], cmap="jet", alpha=0.5)
    plt.title("Overlay")
    plt.show()

# Example prediction
predict_and_show(X_val[0], best_model)


In [None]:
from tensorflow.keras.models import load_model
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Load the trained model
model = load_model("gm_unet_best.h5", compile=False)

def preprocess_image(image_path, size=(128,128)):
    img = cv2.imread(image_path, cv2.IMREAD_COLOR)
    img = cv2.resize(img, size)
    img = img / 255.0
    return img

def predict_single_image(image_path, model, size=(128,128)):
    # Preprocess
    img = preprocess_image(image_path, size)
    
    # Predict
    pred = model.predict(np.expand_dims(img, axis=0))[0]
    pred_mask = (pred > 0.5).astype(np.uint8)

    # Show
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(img)
    plt.title("Input Image")
    
    plt.subplot(1,3,2)
    plt.imshow(pred_mask[:,:,0], cmap="gray")
    plt.title("Predicted Mask")
    
    plt.subplot(1,3,3)
    plt.imshow(img)
    plt.imshow(pred_mask[:,:,0], cmap="jet", alpha=0.5)
    plt.title("Overlay")
    plt.show()

# Example usage:
predict_single_image("1.PNG", model)
