In [None]:
# ===============================
# LaneFusion-CV: U-Net Training
# ===============================

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

# -------------------------
# Paths
# -------------------------
IMAGE_DIR = r"C:/Users/pc/Desktop/LaneFusion-CV AI Project/road_line_images/road_line_images"
MASK_DIR  = r"C:/Users/pc/Desktop/LaneFusion-CV AI Project/masks_png"  # output from XML->mask script
IMG_SIZE = 256

# -------------------------
# Data Loader Functions
# -------------------------
def load_image(path):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    return img / 255.0

def load_mask(path):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (IMG_SIZE, IMG_SIZE))
    mask = mask / 255.0
    mask = np.expand_dims(mask, axis=-1)
    return mask

# -------------------------
# Pair Images + Masks
# -------------------------
image_files = sorted([os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith(('.jpg','.png'))])
mask_files  = sorted([os.path.join(MASK_DIR, f) for f in os.listdir(MASK_DIR) if f.endswith(('.jpg','.png'))])

paired_images = []
paired_masks = []

for img_path in image_files:
    mask_name = os.path.basename(img_path).replace(".jpg", ".png")
    mask_path = os.path.join(MASK_DIR, mask_name)
    if os.path.exists(mask_path):
        paired_images.append(img_path)
        paired_masks.append(mask_path)

print("Total usable images:", len(paired_images))

# -------------------------
# Load Data
# -------------------------
X = np.array([load_image(p) for p in paired_images])
y = np.array([load_mask(p) for p in paired_masks])

# Train/Validation Split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
print("Train:", X_train.shape, "Val:", X_val.shape)

# -------------------------
# U-Net Model Definition
# -------------------------
def unet_model():
    inputs = Input((IMG_SIZE, IMG_SIZE, 3))

    # Encoder
    c1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
    c1 = Conv2D(32, 3, activation='relu', padding='same')(c1)
    p1 = MaxPooling2D()(c1)

    c2 = Conv2D(64, 3, activation='relu', padding='same')(p1)
    c2 = Conv2D(64, 3, activation='relu', padding='same')(c2)
    p2 = MaxPooling2D()(c2)

    c3 = Conv2D(128, 3, activation='relu', padding='same')(p2)
    c3 = Conv2D(128, 3, activation='relu', padding='same')(c3)
    p3 = MaxPooling2D()(c3)

    # Bottleneck
    bn = Conv2D(256, 3, activation='relu', padding='same')(p3)
    bn = Conv2D(256, 3, activation='relu', padding='same')(bn)

    # Decoder
    u3 = UpSampling2D()(bn)
    u3 = Concatenate()([u3, c3])
    c4 = Conv2D(128, 3, activation='relu', padding='same')(u3)
    c4 = Conv2D(128, 3, activation='relu', padding='same')(c4)

    u2 = UpSampling2D()(c4)
    u2 = Concatenate()([u2, c2])
    c5 = Conv2D(64, 3, activation='relu', padding='same')(u2)
    c5 = Conv2D(64, 3, activation='relu', padding='same')(c5)

    u1 = UpSampling2D()(c5)
    u1 = Concatenate()([u1, c1])
    c6 = Conv2D(32, 3, activation='relu', padding='same')(u1)
    c6 = Conv2D(32, 3, activation='relu', padding='same')(c6)

    outputs = Conv2D(1, 1, activation='sigmoid')(c6)

    return Model(inputs, outputs)

# -------------------------
# Compile Model
# -------------------------
model = unet_model()
model.compile(optimizer=Adam(1e-4), loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

# -------------------------
# Train Model
# -------------------------
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=20,
    batch_size=8
)

# -------------------------
# Save Model
# -------------------------
model.save("LaneFusionCV_UNET.h5")
print("Model saved successfully!")

# -------------------------
# Plot Sample Predictions
# -------------------------
def predict_sample(idx=0):
    img = X_val[idx]
    mask = y_val[idx]

    pred = model.predict(np.expand_dims(img, axis=0))[0]
    pred = (pred > 0.5).astype(np.uint8)

    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.title("Original Image")
    plt.imshow(img)

    plt.subplot(1,3,2)
    plt.title("Ground Truth Mask")
    plt.imshow(mask[:,:,0], cmap='gray')

    plt.subplot(1,3,3)
    plt.title("Predicted Mask")
    plt.imshow(pred[:,:,0], cmap='gray')
    plt.show()

predict_sample(0)
