<a href="https://colab.research.google.com/github/Divak-ar/floorData/blob/master/UnetModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import tensorflow as tf
from google.colab import files
import json
import shutil
import glob

# Set parameters
IMG_HEIGHT = 512
IMG_WIDTH = 512
BATCH_SIZE = 4
EPOCHS = 60

# Check if running in Colab
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

# Create directories for the project
def setup_directories():
    """Create directories for the project"""
    os.makedirs('dataset', exist_ok=True)
    os.makedirs('dataset/walls', exist_ok=True)
    os.makedirs('dataset/model', exist_ok=True)
    os.makedirs('dataset/test', exist_ok=True)
    os.makedirs('dataset/results', exist_ok=True)

    print("Directory structure created:")
    print("- dataset/walls: Upload your wall images here")
    print("- dataset/test: Upload your test images here")
    print("- dataset/model: Trained models will be saved here")
    print("- dataset/results: Results will be saved here")

# Upload files to Google Colab
def upload_files(target_dir='dataset/walls'):
    """Upload files to the target directory in Google Colab"""
    print(f"Please upload your images to {target_dir}...")
    uploaded = files.upload()

    # Move uploaded files to target directory
    for filename in uploaded.keys():
        dest_path = os.path.join(target_dir, filename)
        # Create the target directory if it doesn't exist
        os.makedirs(os.path.dirname(dest_path), exist_ok=True)
        # Move the file
        shutil.move(filename, dest_path)

    print(f"Uploaded {len(uploaded)} files to {target_dir}")
    return list(uploaded.keys())

# Create a simplified U-Net model specifically for wall detection
def simple_unet(input_shape=(512, 512, 1)):
    """A simplified U-Net for wall detection"""
    inputs = Input(input_shape)

    # Encoder (downsampling)
    conv1 = Conv2D(32, 3, padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Conv2D(32, 3, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, 3, padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Conv2D(64, 3, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Conv2D(128, 3, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)

    # Decoder (upsampling)
    up2 = UpSampling2D(size=(2, 2))(conv3)
    up2 = concatenate([conv2, up2], axis=3)

    conv4 = Conv2D(64, 3, padding='same')(up2)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Conv2D(64, 3, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)

    up1 = UpSampling2D(size=(2, 2))(conv4)
    up1 = concatenate([conv1, up1], axis=3)

    conv5 = Conv2D(32, 3, padding='same')(up1)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Conv2D(32, 3, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    # Output layer
    outputs = Conv2D(1, 1, activation='sigmoid')(conv5)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Load and preprocess data
def load_data_walls_only(walls_dir):
    """Load wall masks only for training"""
    masks = []

    # List all files in the directory
    wall_files = sorted(os.listdir(walls_dir))

    if not wall_files:
        print(f"No files found in {walls_dir}. Please upload some wall images.")
        return None

    print(f"Loading {len(wall_files)} files from {walls_dir}...")

    for wall_file in wall_files:
        # Skip hidden files
        if wall_file.startswith('.'):
            continue

        # Load wall mask
        wall_path = os.path.join(walls_dir, wall_file)
        mask = cv2.imread(wall_path, cv2.IMREAD_GRAYSCALE)

        # Skip if file couldn't be read
        if mask is None:
            print(f"Warning: Could not read file {wall_path}")
            continue

        mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT))
        mask = np.expand_dims(mask, axis=-1)
        mask = (mask > 127).astype(np.float32)  # Binarize

        # Use the mask as both input and output
        masks.append(mask)

    if len(masks) == 0:
        raise ValueError(f"No valid mask files found in {walls_dir}")

    return np.array(masks)

# Train the model
def train_simple_model(data_dir='dataset'):
    """Train the simple wall detection model"""
    # Load data
    walls_dir = os.path.join(data_dir, 'walls')

    # Create directories if they don't exist
    os.makedirs(os.path.join(data_dir, 'model'), exist_ok=True)

    # Check if walls directory has files
    if not os.listdir(walls_dir):
        print(f"No files found in {walls_dir}. Please upload wall images first.")
        return None, None

    print("Loading data...")
    masks = load_data_walls_only(walls_dir)

    if masks is None:
        return None, None

    print(f"Loaded {len(masks)} mask files")

    # Use masks as both input and output
    inputs = masks.copy()

    # Normalize inputs
    inputs = inputs.astype('float32') / 255.0

    # Split data
    X_train, X_val, Y_train, Y_val = train_test_split(
        inputs, masks, test_size=0.2, random_state=42
    )

    print(f"Training shapes: X={X_train.shape}, Y={Y_train.shape}")
    print(f"Validation shapes: X={X_val.shape}, Y={Y_val.shape}")

    # Create model
    print("Creating model...")
    model = simple_unet(input_shape=(IMG_HEIGHT, IMG_WIDTH, 1))

    # Print model summary
    model.summary()

    # Compile
    model.compile(
        optimizer=Adam(learning_rate=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.MeanIoU(num_classes=2)]
    )

    # Callbacks
    model_path = os.path.join(data_dir, 'model', 'simple_walls_best.h5')
    callbacks = [
        ModelCheckpoint(
            model_path,
            save_best_only=True,
            monitor='val_mean_io_u',
            mode='max'
        ),
        EarlyStopping(
            patience=10,
            monitor='val_mean_io_u',
            mode='max',
            restore_best_weights=True
        ),
        ReduceLROnPlateau(
            factor=0.2,
            patience=5,
            min_lr=1e-6,
            monitor='val_mean_io_u',
            mode='max'
        )
    ]

    # Train
    print("Training model...")
    history = model.fit(
        X_train, Y_train,
        validation_data=(X_val, Y_val),
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=callbacks
    )

    # Save final model
    final_model_path = os.path.join(data_dir, 'model', 'simple_walls_final.h5')
    model.save(final_model_path)
    print(f"Model saved to {final_model_path}")

    # Plot training history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')

    # Fix for the MeanIoU metric name issue
    iou_key = None
    val_iou_key = None

    # Find the correct keys for IoU metrics in history
    for key in history.history.keys():
        if 'io_u' in key.lower() and not key.startswith('val_'):
            iou_key = key
        elif 'io_u' in key.lower() and key.startswith('val_'):
            val_iou_key = key

    if iou_key and val_iou_key:
        plt.subplot(1, 2, 2)
        plt.plot(history.history[iou_key])
        plt.plot(history.history[val_iou_key])
        plt.title('Mean IoU')
        plt.ylabel('IoU')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='lower right')
    else:
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'])
        plt.plot(history.history['val_accuracy'])
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='lower right')

    plt.tight_layout()
    history_path = os.path.join(data_dir, 'training_history.png')
    plt.savefig(history_path)
    plt.show()

    print(f"Training history saved to {history_path}")

    return model, history

def get_walls_coordinates(model, image_path, min_length=20, threshold=0.5):
    """
    Extract wall coordinates in (x1, y1, x2, y2, length) format

    Args:
        model: Trained wall detection model
        image_path: Path to input floor plan image
        min_length: Minimum length of wall segments to detect
        threshold: Threshold for binary mask prediction

    Returns:
        numpy array of shape (n, 5) with each row containing (x1, y1, x2, y2, length)
    """
    # Load and preprocess image
    img = cv2.imread(image_path)

    # Check if image was loaded successfully
    if img is None:
        print(f"Error: Could not read image at {image_path}")
        return None

    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_gray = cv2.resize(img_gray, (IMG_WIDTH, IMG_HEIGHT))
    img_gray = np.expand_dims(img_gray, axis=-1)
    img_norm = img_gray.astype('float32') / 255.0

    # Predict mask
    pred_mask = model.predict(np.expand_dims(img_norm, axis=0))[0]
    binary_mask = (pred_mask > threshold).astype(np.uint8) * 255

    # Clean the mask
    kernel = np.ones((3, 3), np.uint8)
    binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)

    # Get the centerlines
    skeleton = skeletonize(binary_mask[:,:,0] > 0).astype(np.uint8) * 255

    # Use Hough transform to get line segments
    lines = cv2.HoughLinesP(
        skeleton,
        rho=1,
        theta=np.pi/180,
        threshold=10,
        minLineLength=min_length,
        maxLineGap=10
    )

    # Prepare output array
    wall_coords = []

    if lines is not None:
        for line in lines:
            x1, y1, x2, y2 = line[0]
            length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
            wall_coords.append([x1, y1, x2, y2, length])

    return np.array(wall_coords) if wall_coords else np.empty((0, 5))

def save_wall_coordinates_csv(wall_coords, output_path):
    """
    Save wall coordinates to CSV file

    Args:
        wall_coords: numpy array of wall coordinates (x1, y1, x2, y2, length)
        output_path: Path to save the CSV file
    """
    # Create output directory if needed
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Save to CSV
    np.savetxt(
        output_path,
        wall_coords,
        delimiter=',',
        header='x1,y1,x2,y2,length',
        comments='',
        fmt='%d,%d,%d,%d,%.2f'
    )

    print(f"Wall coordinates saved to {output_path}")

    # If in Colab, download the file
    if is_colab():
        files.download(output_path)

def extract_wall_coordinates(pred_mask, min_length=20):
    """
    Extract wall coordinates as line segments from prediction mask

    Returns:
        List of wall segments with start/end points and length
    """
    # Ensure binary mask
    binary = (pred_mask > 0).astype(np.uint8)

    # Clean the mask
    kernel = np.ones((3,3), np.uint8)
    binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)

    # Get the centerlines
    skeleton = skeletonize(binary).astype(np.uint8) * 255

    # Use Hough transform to get line segments
    lines = cv2.HoughLinesP(
        skeleton,
        rho=1,
        theta=np.pi/180,
        threshold=10,
        minLineLength=min_length,
        maxLineGap=10
    )

    wall_segments = []
    if lines is not None:
        for i, line in enumerate(lines):
            x1, y1, x2, y2 = line[0]
            length = np.sqrt((x2-x1)**2 + (y2-y1)**2)

            wall_segments.append({
                "id": i,
                "points": [
                    {"x": int(x1), "y": int(y1)},
                    {"x": int(x2), "y": int(y2)}
                ],
                "length": float(length)
            })

    return wall_segments, skeleton

def predict_walls(model, image_path, result_dir='dataset/results'):
    """Predict walls and extract line segments"""
    # Create results directory if it doesn't exist
    os.makedirs(result_dir, exist_ok=True)

    # Get base filename without extension
    base_filename = os.path.splitext(os.path.basename(image_path))[0]

    # Load and preprocess image
    img = cv2.imread(image_path)
    # Check if image was loaded successfully
    if img is None:
        print(f"Error: Could not read image at {image_path}")
        return None, None

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))

    # For wall-to-wall model, we need a grayscale version
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img_gray = np.expand_dims(img_gray, axis=-1)
    img_norm = img_gray.astype('float32') / 255.0

    # Predict mask
    pred_mask = model.predict(np.expand_dims(img_norm, axis=0))[0]
    binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255

    # Extract wall coordinates
    wall_segments, skeleton = extract_wall_coordinates(binary_mask)

    # Visualize
    plt.figure(figsize=(15, 10))

    # Original image
    plt.subplot(2, 2, 1)
    plt.imshow(img)
    plt.title("Original Floor Plan")
    plt.axis('off')

    # Predicted mask
    plt.subplot(2, 2, 2)
    plt.imshow(binary_mask[:,:,0], cmap='gray')
    plt.title("Predicted Wall Mask")
    plt.axis('off')

    # Skeleton
    plt.subplot(2, 2, 3)
    plt.imshow(skeleton, cmap='gray')
    plt.title("Wall Skeleton")
    plt.axis('off')

    # Wall lines
    plt.subplot(2, 2, 4)
    line_img = img.copy()
    for wall in wall_segments:
        p1 = wall['points'][0]
        p2 = wall['points'][1]
        cv2.line(line_img, (p1['x'], p1['y']), (p2['x'], p2['y']), (0, 0, 255), 2)

    plt.imshow(line_img)
    plt.title(f"Detected Wall Lines ({len(wall_segments)} segments)")
    plt.axis('off')

    plt.tight_layout()

    # Save visualization
    viz_path = os.path.join(result_dir, f"{base_filename}_visualization.png")
    plt.savefig(viz_path)
    plt.show()

    # Save wall data to JSON
    wall_data = {
        "walls": wall_segments,
        "imageWidth": IMG_WIDTH,
        "imageHeight": IMG_HEIGHT,
        "scale": 0.02  # 1 pixel = 2cm (example scale)
    }

    json_path = os.path.join(result_dir, f"{base_filename}_wall_coordinates.json")
    with open(json_path, 'w') as f:
        json.dump(wall_data, f, indent=2)

    print(f"Found {len(wall_segments)} wall segments")
    print(f"Wall coordinates saved to {json_path}")
    print(f"Visualization saved to {viz_path}")

    # If in Colab, download the files
    if is_colab():
        print("Downloading results...")
        files.download(json_path)
        files.download(viz_path)

    return wall_segments, binary_mask

# Simple direct wall extraction without deep learning
def extract_walls_directly(image_path, result_dir='dataset/results'):
    """Extract walls directly from binary wall image (no deep learning)"""
    # Create results directory if it doesn't exist
    os.makedirs(result_dir, exist_ok=True)

    # Get base filename without extension
    base_filename = os.path.splitext(os.path.basename(image_path))[0]

    # Read image (assuming walls are white on black background)
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    # Check if image was loaded successfully
    if img is None:
        print(f"Error: Could not read image at {image_path}")
        return None

    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))

    # Binarize if needed
    _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)

    # Skeletonize
    skeleton = skeletonize(binary > 0).astype(np.uint8) * 255

    # Detect lines
    lines = cv2.HoughLinesP(
        skeleton,
        rho=1,
        theta=np.pi/180,
        threshold=15,
        minLineLength=30,
        maxLineGap=10
    )

    # Process lines
    wall_segments = []
    if lines is not None:
        for i, line in enumerate(lines):
            x1, y1, x2, y2 = line[0]
            length = np.sqrt((x2-x1)**2 + (y2-y1)**2)

            wall_segments.append({
                "id": i,
                "points": [
                    {"x": int(x1), "y": int(y1)},
                    {"x": int(x2), "y": int(y2)}
                ],
                "length": float(length)
            })

    # Visualize
    rgb_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(img, cmap='gray')
    plt.title("Binary Wall Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(skeleton, cmap='gray')
    plt.title("Wall Skeleton")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    line_img = rgb_img.copy()
    for wall in wall_segments:
        p1 = wall['points'][0]
        p2 = wall['points'][1]
        cv2.line(line_img, (p1['x'], p1['y']), (p2['x'], p2['y']), (0, 0, 255), 2)

    plt.imshow(line_img)
    plt.title(f"Detected Wall Lines ({len(wall_segments)} segments)")
    plt.axis('off')

    plt.tight_layout()

    # Save visualization
    viz_path = os.path.join(result_dir, f"{base_filename}_direct_visualization.png")
    plt.savefig(viz_path)
    plt.show()

    # Save wall data
    wall_data = {
        "walls": wall_segments,
        "imageWidth": IMG_WIDTH,
        "imageHeight": IMG_HEIGHT,
        "scale": 0.02
    }

    json_path = os.path.join(result_dir, f"{base_filename}_direct_wall_coordinates.json")
    with open(json_path, 'w') as f:
        json.dump(wall_data, f, indent=2)

    print(f"Found {len(wall_segments)} wall segments")
    print(f"Wall coordinates saved to {json_path}")
    print(f"Visualization saved to {viz_path}")

    # If in Colab, download the files
    if is_colab():
        print("Downloading results...")
        files.download(json_path)
        files.download(viz_path)

    return wall_segments

# Test the model on all images in the test directory
def test_model(model, test_dir='dataset/test', result_dir='dataset/results'):
    """Test the model on all images in the test directory"""
    # Create results directory if it doesn't exist
    os.makedirs(result_dir, exist_ok=True)

    # Get list of test images
    test_images = glob.glob(os.path.join(test_dir, '*.*'))

    # Check if there are any test images
    if not test_images:
        print(f"No test images found in {test_dir}. Please upload some test images first.")
        return

    print(f"Testing model on {len(test_images)} images...")

    # Process each test image
    all_results = {}

    for img_path in test_images:
        print(f"Processing {os.path.basename(img_path)}...")
        wall_segments, _ = predict_walls(model, img_path, result_dir)

        if wall_segments:
            all_results[os.path.basename(img_path)] = {
                "segments_count": len(wall_segments),
                "segments": wall_segments
            }

    # Save summary of results
    summary_path = os.path.join(result_dir, 'test_results_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(all_results, f, indent=2)

    print(f"Test results saved to {summary_path}")

    # If in Colab, download the summary
    if is_colab():
        print("Downloading test results summary...")
        files.download(summary_path)

    return all_results

# Example usage
def get_wall_coordinates_simple_format(model, image_path, result_dir='dataset/results'):
    """
    Extract wall coordinates in simple (x1,y1,x2,y2,length) format
    and save to CSV

    Uses the existing extract_wall_coordinates function
    """
    # Create results directory if it doesn't exist
    os.makedirs(result_dir, exist_ok=True)

    # Get base filename without extension
    base_filename = os.path.splitext(os.path.basename(image_path))[0]

    # Load and preprocess image
    img = cv2.imread(image_path)
    # Check if image was loaded successfully
    if img is None:
        print(f"Error: Could not read image at {image_path}")
        return None

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))

    # For wall-to-wall model, we need a grayscale version
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    img_gray = np.expand_dims(img_gray, axis=-1)
    img_norm = img_gray.astype('float32') / 255.0

    # Predict mask
    pred_mask = model.predict(np.expand_dims(img_norm, axis=0))[0]
    binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255

    # Extract wall coordinates using existing function
    wall_segments, skeleton = extract_wall_coordinates(binary_mask)

    # Convert to simple format (x1,y1,x2,y2,length)
    wall_coords = []
    for wall in wall_segments:
        x1 = wall['points'][0]['x']
        y1 = wall['points'][0]['y']
        x2 = wall['points'][1]['x']
        y2 = wall['points'][1]['y']
        length = wall['length']
        wall_coords.append([x1, y1, x2, y2, length])

    wall_coords = np.array(wall_coords) if wall_coords else np.empty((0, 5))

    # Save to CSV
    csv_path = os.path.join(result_dir, f"{base_filename}_wall_coordinates.csv")
    np.savetxt(
        csv_path,
        wall_coords,
        delimiter=',',
        header='x1,y1,x2,y2,length',
        comments='',
        fmt='%d,%d,%d,%d,%.2f'
    )

    print(f"Found {len(wall_coords)} wall segments")
    print(f"Wall coordinates saved to {csv_path}")

    # If in Colab, download the file
    if is_colab():
        print("Downloading CSV file...")
        files.download(csv_path)

    return wall_coords


# Main function to run in Google Colab
def run_wall_segmentation_tool_colab():
    """Main function to run the wall segmentation tool in Google Colab"""
    # Check if running in Colab
    if not is_colab():
        print("This function is designed to run in Google Colab.")
        print("If you're not in Colab, please use the other functions directly.")
        return

    # Setup directories
    setup_directories()

    # Display menu
    print("\n--- Wall Segmentation Tool for Google Colab ---")
    print("1. Upload Wall Images")
    print("2. Upload Test Images")
    print("3. Train Model")
    print("4. Test Model on Test Images")
    print("5. Extract Walls Directly (No Deep Learning)")
    print("6. Download Model")
    print("7. Upload and Load Existing Model")
    print("8. Extract Wall Coordinates (x1,y1,x2,y2,length) Format")
    print("9. Exit")

    while True:
        choice = input("\nEnter your choice (1-9): ")

        if choice == "1":
            # Upload wall images
            upload_files('dataset/walls')

        elif choice == "2":
            # Upload test images
            upload_files('dataset/test')

        elif choice == "3":
            # Train model
            print("Training model on uploaded wall images...")
            model, _ = train_simple_model()

        elif choice == "4":
            # Test model
            model_path = os.path.join('dataset', 'model', 'simple_walls_best.h5')
            if not os.path.exists(model_path):
                print("No trained model found. Please train the model first.")
                continue

            try:
                model = load_model(model_path)
                print("Model loaded successfully.")
                test_model(model)
            except Exception as e:
                print(f"Error loading model: {e}")

        elif choice == "5":
            # Extract walls directly
            # First, upload a binary wall image
            print("Please upload a binary wall image...")
            uploaded = upload_files('dataset/test')

            if uploaded:
                image_path = os.path.join('dataset', 'test', uploaded[0])
                print(f"Processing {image_path}...")
                extract_walls_directly(image_path)

        elif choice == "6":
            # Download model
            model_path = os.path.join('dataset', 'model', 'simple_walls_best.h5')
            if os.path.exists(model_path):
                print("Downloading trained model...")
                files.download(model_path)
            else:
                print("No trained model found. Please train the model first.")

        elif choice == "7":
            # Upload and load existing model
            print("Please upload your trained model (.h5 file)...")
            uploaded = upload_files('dataset/model')

            if uploaded:
                model_path = os.path.join('dataset', 'model', uploaded[0])
                try:
                    model = load_model(model_path)
                    print("Model loaded successfully.")

                    # Ask if they want to test it
                    test_choice = input("Do you want to test the model on your test images? (y/n): ")
                    if test_choice.lower() == 'y':
                        test_model(model)
                except Exception as e:
                    print(f"Error loading model: {e}")

        elif choice == "8":
        # Extract wall coordinates in (x1,y1,x2,y2,length) format
        # First check if model exists
          model_path = os.path.join('dataset', 'model', 'simple_walls_best.h5')

          if not os.path.exists(model_path):
              # Ask if they want to upload a model
              upload_choice = input("No trained model found. Do you want to upload one? (y/n): ")
              if upload_choice.lower() == 'y':
                  uploaded = upload_files('dataset/model')
                  if uploaded:
                      model_path = os.path.join('dataset', 'model', uploaded[0])
                  else:
                      print("No model uploaded. Returning to menu.")
                      continue
              else:
                  print("A model is required for wall coordinate extraction. Returning to menu.")
                  continue

          # Now check for test images or ask to upload one
          test_dir = 'dataset/test'
          test_images = glob.glob(os.path.join(test_dir, '*.*'))

          if not test_images:
              print("No test images found. Please upload an image to process.")
              uploaded = upload_files(test_dir)
              if not uploaded:
                  print("No images uploaded. Returning to menu.")
                  continue
              test_images = [os.path.join(test_dir, uploaded[0])]

          # Let user select an image if multiple are available
          if len(test_images) > 1:
              print("\nAvailable test images:")
              for i, img_path in enumerate(test_images):
                  print(f"{i+1}. {os.path.basename(img_path)}")

              img_choice = input(f"Select image to process (1-{len(test_images)}): ")
              try:
                  img_idx = int(img_choice) - 1
                  if 0 <= img_idx < len(test_images):
                      image_path = test_images[img_idx]
                  else:
                      print("Invalid selection. Using the first image.")
                      image_path = test_images[0]
              except ValueError:
                  print("Invalid input. Using the first image.")
                  image_path = test_images[0]
          else:
              image_path = test_images[0]

          # Load model
          try:
              model = load_model(model_path)
              print(f"Model loaded successfully from {model_path}")
              print(f"Processing image: {os.path.basename(image_path)}")

              # Extract and save coordinates
              wall_coords = get_wall_coordinates_simple_format(model, image_path)

              if wall_coords is not None and wall_coords.size > 0:
                  print("First 5 walls:")
                  for i, wall in enumerate(wall_coords[:5]):
                      print(f"Wall {i+1}: x1={int(wall[0])}, y1={int(wall[1])}, x2={int(wall[2])}, y2={int(wall[3])}, length={wall[4]:.2f}")
              else:
                  print("No walls detected in the image.")

          except Exception as e:
              print(f"Error processing image: {e}")

        elif choice == "9":
            # Exit
            print("Exiting...")
            break

        else:
            print("Invalid choice. Please try again.")


# Entry point
if __name__ == "__main__":
    # If running in Colab, run the Colab-specific function
    if is_colab():
        run_wall_segmentation_tool_colab()
    else:
        # If running locally, provide a simple interface
        print("Wall Segmentation Tool")
        print("1. Train model")
        print("2. Extract walls directly")
        print("3. Predict with trained model")
        print("4. Exit")

        choice = input("Enter your choice (1-4): ")

        if choice == "1":
            data_dir = input("Enter data directory path (or press Enter for default): ")
            if data_dir:
                train_simple_model(data_dir)
            else:
                train_simple_model()

        elif choice == "2":
            image_path = input("Enter image path: ")
            if image_path:
                extract_walls_directly(image_path)
            else:
                print("Error: Image path required")

        elif choice == "3":
            model_path = input("Enter model path: ")
            image_path = input("Enter image path: ")

            if model_path and image_path:
                try:
                    model = load_model(model_path)
                    predict_walls(model, image_path)
                except Exception as e:
                    print(f"Error: {e}")
            else:
                print("Error: Both model and image paths required")

        elif choice == "4":
            print("Exiting...")
        else:
            print("Invalid choice. Please run again.")

Directory structure created:
- dataset/walls: Upload your wall images here
- dataset/test: Upload your test images here
- dataset/model: Trained models will be saved here
- dataset/results: Results will be saved here

--- Wall Segmentation Tool for Google Colab ---
1. Upload Wall Images
2. Upload Test Images
3. Train Model
4. Test Model on Test Images
5. Extract Walls Directly (No Deep Learning)
6. Download Model
7. Upload and Load Existing Model
8. Extract Wall Coordinates (x1,y1,x2,y2,length) Format
9. Exit

Enter your choice (1-9): 8

Available test images:
1. 00001000.jpg
2. 00010057.jpg
3. 00010062.jpg
4. 00010058.jpg
5. 00010063.jpg
6. 00010090.jpg
7. 00010059.jpg
Select image to process (1-7): 4




Model loaded successfully from dataset/model/simple_walls_best.h5
Processing image: 00010058.jpg




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 804ms/step
Found 44 wall segments
Wall coordinates saved to dataset/results/00010058_wall_coordinates.csv
Downloading CSV file...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

First 5 walls:
Wall 1: x1=256, y1=441, x2=410, y2=408, length=157.50
Wall 2: x1=340, y1=195, x2=440, y2=174, length=102.18
Wall 3: x1=74, y1=254, x2=225, y2=220, length=154.78
Wall 4: x1=425, y1=103, x2=471, y2=316, length=217.91
Wall 5: x1=222, y1=68, x2=407, y2=29, length=189.07

Enter your choice (1-9): 9
Exiting...
