In [None]:
import os
import glob
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

# ------------------------------
# Config
# ------------------------------
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 30
SEED = 42
NUM_SAMPLES = 5000   # total samples to use
BALANCE = True       # keep balanced positives & negatives

# Paths (update to your dataset paths)
IMAGE_DIR = "/content/drive/MyDrive/archive (7)/siim-acr-pneumothorax/png_images"
MASK_DIR  = "/content/drive/MyDrive/archive (7)/siim-acr-pneumothorax/png_masks"

MODEL_NAME = "unet_pneumo_grayscale_5000.keras"

# ------------------------------
# Collect file paths
# ------------------------------
x_all = sorted(glob.glob(os.path.join(IMAGE_DIR, "*.png")))
y_all = sorted(glob.glob(os.path.join(MASK_DIR, "*.png")))

print("Total images:", len(x_all))
print("Total masks:", len(y_all))

if len(x_all) == 0 or len(y_all) == 0:
    raise ValueError("❌ No images/masks found. Check your dataset paths.")

if len(x_all) != len(y_all):
    raise ValueError(f"❌ Mismatch: {len(x_all)} images vs {len(y_all)} masks")

# ------------------------------
# Balance positive/negative samples
# ------------------------------
if BALANCE:
    print("⚖️ Balancing dataset...")

    pos_indices, neg_indices = [], []
    for i, mask_path in enumerate(y_all):
        mask = tf.io.read_file(mask_path)
        mask = tf.image.decode_png(mask, channels=1)
        mask = tf.reduce_sum(mask).numpy()
        if mask > 0:
            pos_indices.append(i)   # positive mask
        else:
            neg_indices.append(i)   # negative mask

    print(f"Positives: {len(pos_indices)}, Negatives: {len(neg_indices)}")

    # Take 2500 from each for total of 5000
    half_samples = NUM_SAMPLES // 2

    if len(pos_indices) < half_samples:
        raise ValueError(f"❌ Not enough positive samples! Found {len(pos_indices)}, need {half_samples}")
    if len(neg_indices) < half_samples:
        raise ValueError(f"❌ Not enough negative samples! Found {len(neg_indices)}, need {half_samples}")

    np.random.seed(SEED)
    pos_selected = np.random.choice(pos_indices, size=half_samples, replace=False)
    neg_selected = np.random.choice(neg_indices, size=half_samples, replace=False)

    selected_indices = np.concatenate([pos_selected, neg_selected])
    np.random.shuffle(selected_indices)

    x_all = [x_all[i] for i in selected_indices]
    y_all = [y_all[i] for i in selected_indices]

else:
    np.random.seed(SEED)
    indices = np.random.choice(len(x_all), size=min(NUM_SAMPLES, len(x_all)), replace=False)
    x_all = [x_all[i] for i in indices]
    y_all = [y_all[i] for i in indices]

print(f"✅ Using {len(x_all)} samples (balanced={BALANCE})")
assert len(x_all) == NUM_SAMPLES, f"Expected {NUM_SAMPLES} images, got {len(x_all)}"

# ------------------------------
# Train/Val split
# ------------------------------
x_train, x_val, y_train, y_val = train_test_split(
    x_all, y_all, test_size=0.2, random_state=SEED
)

print("Train:", len(x_train), "Val:", len(x_val))

# ------------------------------
# Image parsing (GRAYSCALE)
# ------------------------------
def load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1)  # ✅ Grayscale input
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img

def load_mask(path):
    mask = tf.io.read_file(path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, IMAGE_SIZE, method="nearest")
    mask = tf.cast(mask > 127, tf.float32)  # binary mask
    return mask

def augment(img, mask):
    """Data augmentation"""
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_up_down(img)
        mask = tf.image.flip_up_down(mask)
    # Random rotation
    if tf.random.uniform(()) > 0.5:
        k = tf.random.uniform(shape=[], minval=1, maxval=4, dtype=tf.int32)
        img = tf.image.rot90(img, k=k)
        mask = tf.image.rot90(mask, k=k)
    return img, mask

def parse_pair(image_path, mask_path, augment_data=False):
    img = load_image(image_path)
    mask = load_mask(mask_path)
    if augment_data:
        img, mask = augment(img, mask)
    return img, mask

# ------------------------------
# Datasets
# ------------------------------
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(lambda x, y: parse_pair(x, y, True), num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(500).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.map(lambda x, y: parse_pair(x, y, False), num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

for imgs, masks in train_ds.take(1):
    print("✅ Image batch:", imgs.shape, "Mask batch:", masks.shape)
    assert imgs.shape[-1] == 1, f"Expected 1 channel, got {imgs.shape[-1]}"

# ------------------------------
# U-Net Model (Grayscale Input)
# ------------------------------
def conv_block(x, filters):
    x = tf.keras.layers.Conv2D(filters, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    x = tf.keras.layers.Conv2D(filters, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = tf.keras.layers.MaxPooling2D((2, 2))(c)
    return c, p

def decoder_block(x, skip, filters):
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = tf.keras.layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x

def build_unet(input_shape=(256, 256, 1)):
    """Standard U-Net architecture for grayscale images"""
    inputs = tf.keras.Input(shape=input_shape)

    # Encoder
    c1, p1 = encoder_block(inputs, 64)
    c2, p2 = encoder_block(p1, 128)
    c3, p3 = encoder_block(p2, 256)
    c4, p4 = encoder_block(p3, 512)

    # Bridge
    b = conv_block(p4, 1024)

    # Decoder
    d1 = decoder_block(b, c4, 512)
    d2 = decoder_block(d1, c3, 256)
    d3 = decoder_block(d2, c2, 128)
    d4 = decoder_block(d3, c1, 64)

    # Output
    outputs = tf.keras.layers.Conv2D(1, 1, activation="sigmoid")(d4)

    return tf.keras.Model(inputs, outputs, name="UNet_Grayscale")

print("\n🏗️ Building U-Net model for grayscale images...")
model = build_unet(input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1))
model.summary()

# ------------------------------
# Loss Functions
# ------------------------------
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - ((2. * intersection + smooth) /
                (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth))

def bce_dice_loss(y_true, y_pred):
    return tf.keras.losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

# ------------------------------
# Compile & Train
# ------------------------------
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=bce_dice_loss,
    metrics=["accuracy", dice_coefficient]
)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        MODEL_NAME,
        save_best_only=True,
        monitor="val_dice_coefficient",
        mode="max",
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        patience=10,
        restore_best_weights=True,
        monitor="val_dice_coefficient",
        mode="max",
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    tf.keras.callbacks.CSVLogger('training_log.csv')
]

# ------------------------------
# Train
# ------------------------------
print("\n" + "="*60)
print("🚀 Starting Training on 5000 Grayscale Images (Balanced)")
print("="*60)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Input shape: (256, 256, 1) - Grayscale")
print(f"Epochs: 60")
print("="*60 + "\n")

history = model.fit(
    train_ds,
    epochs=60,
    validation_data=val_ds,
    callbacks=callbacks,
    verbose=1
)

print("\n" + "="*60)
print("✅ Training completed! Model saved as:", MODEL_NAME)
print("="*60)

# ------------------------------
# Training Summary
# ------------------------------
print("\n📊 Training Summary:")
print(f"   Final Train Loss: {history.history['loss'][-1]:.4f}")
print(f"   Final Val Loss: {history.history['val_loss'][-1]:.4f}")
print(f"   Final Train Dice: {history.history['dice_coefficient'][-1]:.4f}")
print(f"   Final Val Dice: {history.history['val_dice_coefficient'][-1]:.4f}")
print(f"   Best Val Dice: {max(history.history['val_dice_coefficient']):.4f}")
print("\n📁 Saved files:")
print(f"   - {MODEL_NAME}")
print(f"   - training_log.csv")
print("="*60)

Total images: 12047
Total masks: 12047
⚖️ Balancing dataset...
Positives: 2664, Negatives: 9383
✅ Using 5000 samples (balanced=True)
Train: 4000 Val: 1000
✅ Image batch: (30, 256, 256, 1) Mask batch: (30, 256, 256, 1)

🏗️ Building U-Net model for grayscale images...



🚀 Starting Training on 5000 Grayscale Images (Balanced)
Training samples: 4000
Validation samples: 1000
Batch size: 30
Input shape: (256, 256, 1) - Grayscale
Epochs: 60

Epoch 1/60
[1m134/134[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - accuracy: 0.9523 - dice_coefficient: 0.0154 - loss: 1.2653
Epoch 1: val_dice_coefficient improved from -inf to 0.01245, saving model to unet_pneumo_grayscale_5000.keras
[1m134/134[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1442s[0m 10s/step - accuracy: 0.9525 - dice_coefficient: 0.0154 - loss: 1.2646 - val_accuracy: 0.9940 - val_dice_coefficient: 0.0125 - val_loss: 1.2582 - learning_rate: 1.0000e-04
Epoch 2/60
[1m134/134[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.9916 - dice_coefficient: 0.0234 - loss: 1.0938
Epoch 2: val_dice_coefficient did not improve from 0.01245
[1m134/134[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m249s[0m 2s/step - accuracy: 0.9916 - dice_coefficient: 0.0234 - loss