### Import Required Libraries

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

!pip install tensorflow
import tensorflow as tf
from tensorflow.keras.utils import Sequence

# Define Dataset Paths and Classes
dataset_path = '../AI4MARS/msl/'  # Update this to your dataset location
images_path = os.path.join(dataset_path, 'images', 'edr')  # Raw Mars images
labels_path = os.path.join(dataset_path, 'labels', 'train')  # Crowdsourced training labels

# Define terrain classes from dataset info
class_colors = {
    0: (0, 0, 0),       # Soil
    1: (255, 255, 255), # Bedrock
    2: (128, 128, 128), # Sand
    3: (255, 0, 0),     # Big Rock
    255: (0, 0, 0)      # NULL
}
class_names = {
    0: "Soil",
    1: "Bedrock",
    2: "Sand",
    3: "Big Rock",
    255: "Unlabeled"
}


### Data Load

In [None]:
def load_data(images_path, labels_path):
    """
    Load image and label paths, ensuring matching basenames.
    """
    image_files = sorted([f for f in os.listdir(images_path) if f.lower().endswith(('.jpg', '.png'))])
    label_files = sorted([f for f in os.listdir(labels_path) if f.lower().endswith('.png')])

    # Match basenames
    image_basenames = {os.path.splitext(f)[0] for f in image_files}
    label_basenames = {os.path.splitext(f)[0] for f in label_files}
    matching_basenames = image_basenames & label_basenames

    image_files = [f for f in image_files if os.path.splitext(f)[0] in matching_basenames]
    label_files = [f for f in label_files if os.path.splitext(f)[0] in matching_basenames]

    print(f"Loaded {len(image_files)} matching image-label pairs.")
    return image_files, label_files

def preprocess_image(image):
    """
    Preprocess the input image (e.g., resize, normalize).
    """
    image_resized = cv2.resize(image, (512, 512))  # Resize to a standard size
    image_normalized = image_resized / 255.0  # Normalize pixel values
    return image_normalized

def preprocess_label(label):
    """
    Preprocess the label (e.g., resize, map classes).
    """
    label_resized = cv2.resize(label, (512, 512), interpolation=cv2.INTER_NEAREST)
    return label_resized

def map_labels(label):
    """
    Map RGB label values to class IDs.
    """
    label_mapped = np.zeros(label.shape[:2], dtype=np.uint8)  # Initialize empty mask

    # Map RGB values to class IDs
    for class_id, color in class_colors.items():
        mask = np.all(label == color, axis=-1)  # Check where all RGB values match
        label_mapped[mask] = class_id

    return label_mapped

### Visualization

In [None]:
def visualize_terrain(image, label):
    """
    Visualize the original image, segmentation mask, and blended overlay.
    """
    # Create an overlay with terrain colors
    overlay = np.zeros_like(image, dtype=np.uint8)
    for class_id, color in class_colors.items():
        overlay[label == class_id] = color

    # Blend overlay with the original image
    blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)

    # Plot results
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(label, cmap="tab10")
    plt.title("Segmentation Mask")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB))
    plt.title("Blended Overlay")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

### Obstacle Detection

In [None]:
def detect_obstacles(image, label, obstacle_classes, min_area=20):
    """
    Detect and highlight obstacles in the image based on the label mask.
    """
    # Ensure label is mapped to class IDs
    print(f"Mapped label shape: {label.shape}")
    print(f"Unique mapped label values: {np.unique(label)}")

    # Create obstacle mask
    obstacle_mask = np.isin(label, obstacle_classes).astype(np.uint8)

    # Debug: Check obstacle mask
    if np.count_nonzero(obstacle_mask) == 0:
        print("Obstacle mask is empty. Skipping detection.")
        return
    plt.imshow(obstacle_mask, cmap="gray")
    plt.title("Obstacle Mask")
    plt.axis("off")
    plt.show()

    # Smooth the mask to reduce noise
    blurred_mask = cv2.GaussianBlur(obstacle_mask, (5, 5), 0)

    # Find contours
    contours, _ = cv2.findContours(blurred_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Draw bounding boxes on the original image
    result_image = image.copy()
    for cnt in contours:
        area = cv2.contourArea(cnt)
        if area > min_area:  # Filter small regions
            x, y, w, h = cv2.boundingRect(cnt)
            print(f"Bounding Box: x={x}, y={y}, w={w}, h={h}, area={area}")
            cv2.rectangle(result_image, (x, y), (x + w, y + h), (255, 0, 0), 2)

    # Show final result
    plt.figure(figsize=(8, 8))
    plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
    plt.title("Detected Obstacles with Bounding Boxes")
    plt.axis("off")
    plt.show()

### Limit Processing for Debugging

In [None]:
def main_process(images_path, labels_path, max_files=10):
    image_files, label_files = load_data(images_path, labels_path)

    for i, (image_file, label_file) in enumerate(zip(image_files, label_files)):
        # Load and map label
        # Load image and label
        image = cv2.imread(os.path.join(images_path, image_file))
        label = cv2.imread(os.path.join(labels_path, label_file), cv2.IMREAD_COLOR)  # Load in RGB
        label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
    
        # Map labels to class IDs
        label_mapped = map_labels(label)
        print(f"Unique values in mapped label: {np.unique(label_mapped)}")

        if image is None or label is None:
            print(f"Error loading {image_file} or {label_file}. Skipping...")
            continue

        # Preprocess image and label
        label_mapped = map_labels(label)
        image_preprocessed = preprocess_image(image)

        # Visualize and detect obstacles
        visualize_terrain(image, label_mapped)
        detect_obstacles(image, label_mapped, obstacle_classes=[0, 1, 2, 3])

### Model Training

In [None]:

class MarsDataGenerator(Sequence):
    def __init__(self, images_list, labels_list, batch_size=8):
        self.images = images_list
        self.labels = labels_list
        self.batch_size = batch_size

    def __len__(self):
        return len(self.images) // self.batch_size

    def __getitem__(self, idx):
        batch_images = self.images[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_labels = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

        X, Y = [], []
        for img_path, lbl_path in zip(batch_images, batch_labels):
            image = preprocess_image(cv2.imread(img_path))
            label = map_labels(cv2.imread(lbl_path, cv2.IMREAD_COLOR))
            X.append(image)
            Y.append(label)

        return np.array(X), np.array(Y)

### Process Dataset

In [None]:
# Execute the Code
if __name__ == "__main__":
    main_process(images_path, labels_path, max_files=100)