# MARS IMAGE CLASSIFICATION

##Importing Dependencies

In [None]:
import cv2
import numpy as np
import os
import pandas as pd
from scipy.fft import fft2, fftshift
from scipy.stats import skew, kurtosis
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


## Image And Labels Loading

In [None]:
def load_images_and_labels(image_directory):
    """Loads images from directory structure and returns patches and labels.

    Args:
        image_directory (str): Path to the directory containing image folders by class.

    Returns:
        tuple: Patches (list of image patches), labels (list of numerical labels), class_names (list of class names)
    """
    all_patches = []
    all_labels = []
    class_names = []
    class_label_map = {}

    terrain_classes = os.listdir(image_directory) # Example: get class names from subfolders
    class_names = terrain_classes # Store class names for evaluation output
    class_label_map = {terrain_classes[i]: i for i in range(len(terrain_classes))} # Map class names to numerical labels


    for terrain_class in terrain_classes:
        class_dir = os.path.join(image_directory, terrain_class)
        if not os.path.isdir(class_dir): # Skip if not a directory
            continue
        for image_file in os.listdir(class_dir):
            if image_file.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')): # Common image formats
                image_path = os.path.join(class_dir, image_file)
                patches, labels_placeholder = load_and_preprocess_image(image_path) # load_and_preprocess_image now returns labels_placeholder

                # **Correct label association here** -  Assuming class name is folder name
                class_label_numeric = class_label_map[terrain_class]
                image_labels = [class_label_numeric] * len(patches) # Assign class label to all patches from this image (simplified approach - adjust as needed)

                all_patches.extend(patches)
                all_labels.extend(image_labels) # Use the labels generated above
    return all_patches, all_labels, class_names

## Data Pre-processing

In [None]:
def load_and_preprocess_image(image_path, patch_size=(64, 64)): # Patch size from paper
    """Loads an image, converts to grayscale, and extracts patches.

    Args:
        image_path (str): Path to the image file.
        patch_size (tuple, optional): Size of patches (height, width). Defaults to (64, 64).

    Returns:
        tuple: patches (list of image patches), labels_placeholder (currently placeholder labels)
    """
    img_bgr = cv2.imread(image_path)
    if img_bgr is None:
        return None, None  # Handle cases where image loading fails

    img_gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)

    patches = []
    labels = [] # You'll need to get labels associated with each patch/image
    height, width = img_gray.shape
    patch_h, patch_w = patch_size

    # Extract patches - Non-overlapping for simplicity initially, consider overlapping later if desired
    for i in range(0, height - patch_h + 1, patch_h):
        for j in range(0, width - patch_w + 1, patch_w):
            patch = img_gray[i:i + patch_h, j:j + patch_w]
            patches.append(patch)
            # **IMPORTANT:** You'll need a way to associate labels with these patches.
            # This will depend on your dataset's structure. Placeholder label for now.
            labels.append("UnknownLabel") # Replace with actual label retrieval mechanism
    return patches, labels

## Extracting Features

In [None]:
def extract_fft_features(patch, feature_vector_length=200, band_strategy="radial"): # Added band_strategy parameter
    """Extracts FFT-based features from an image patch.

    Args:
        patch (np.array): 2D grayscale image patch.
        feature_vector_length (int, optional): Desired length of the FFT feature vector. Defaults to 200.
        band_strategy (str, optional): Strategy for band-based feature extraction ('radial', 'rectangular', 'none'). Defaults to "radial".

    Returns:
        np.array: FFT feature vector.
    """
    f_transform = fft2(patch)
    magnitude_spectrum = np.abs(fftshift(f_transform)) # Shift for visualization/band definition, processing can be done without shift

    patch_size = patch.shape[0] # Assumes square patch
    center_x, center_y = patch_size // 2, patch_size // 2

    fft_features = []

    if band_strategy == "radial":
        # Example Radial Band Strategy (Experiment with band definitions!)
        num_bands = feature_vector_length # or adjust num_bands based on desired feature vector length
        max_radius = patch_size // 2 # Max radius from center
        radii = np.linspace(0, max_radius, num_bands + 1) # Define band boundaries

        for i in range(num_bands):
            inner_radius = radii[i]
            outer_radius = radii[i+1]
            band_magnitudes = []
            for y in range(patch_size):
                for x in range(patch_size):
                    radius = np.sqrt((x - center_x)**2 + (y - center_y)**2)
                    if inner_radius <= radius < outer_radius:
                        band_magnitudes.append(magnitude_spectrum[y, x])
                if band_magnitudes: # Avoid error if band is empty
                    avg_magnitude = np.mean(band_magnitudes)
                else:
                    avg_magnitude = 0 # Or handle empty band differently if needed
                fft_features.append(avg_magnitude)

    elif band_strategy == "rectangular": # Example Rectangular Bands (Experiment!)
        # Define rectangular bands in frequency domain - example bands, adjust as needed
        num_bands_x = 10 # Example: 10 bands along x-frequency
        num_bands_y = 20 # Example: 20 bands along y-frequency, to get 10*20=200 features
        band_width_x = patch_size // num_bands_x
        band_height_y = patch_size // num_bands_y

        for i in range(num_bands_y):
            for j in range(num_bands_x):
                band_magnitudes = magnitude_spectrum[i*band_height_y:(i+1)*band_height_y, j*band_width_x:(j+1)*band_height_y,].flatten()
                avg_magnitude = np.mean(band_magnitudes) if band_magnitudes.size > 0 else 0
                fft_features.append(avg_magnitude)

    elif band_strategy == "none": # Simple Reshape and Truncate (as baseline - less recommended for final model)
        fft_features = magnitude_spectrum.reshape(-1)[:feature_vector_length]

    return np.array(fft_features[:feature_vector_length]) # Ensure exactly feature_vector_length features

## Extracting Statistical Features

In [None]:
def extract_statistical_features(patch):
    """Extracts statistical features from an image patch in spatial domain.

    Args:
        patch (np.array): 2D grayscale image patch.

    Returns:
        list: List of statistical features [mean, std, skewness, kurtosis, energy, entropy].
    """
    mean_val = np.mean(patch)
    std_dev = np.std(patch)
    patch_skewness = skew(patch.flatten()) # Flatten to 1D for skew/kurtosis
    patch_kurtosis = kurtosis(patch.flatten())
    energy = np.sum(patch**2)
    # Entropy calculation (using normalized histogram - basic entropy)
    hist, _ = np.histogram(patch.flatten(), bins=256, range=[0, 256], density=True) # Normalized histogram
    entropy_val = -np.sum(hist * np.log2(hist + 1e-9)) # Add small epsilon to avoid log(0)

    return [mean_val, std_dev, patch_skewness, patch_kurtosis, energy, entropy_val]

## Extracting Features From The Patches

In [None]:
def extract_features_from_patches(patches, fft_feature_vector_length=200, fft_band_strategy="radial"):
    """Extracts both FFT-based and statistical features from a list of patches.

    Args:
        patches (list): List of image patches (np.array).
        fft_feature_vector_length (int, optional): Length of FFT feature vector. Defaults to 200.
        fft_band_strategy (str, optional): FFT band strategy ('radial', 'rectangular', 'none'). Defaults to "radial".

    Returns:
        np.array: Combined feature matrix (num_patches x total_features).
    """
    all_fft_features = []
    all_statistical_features = []

    for patch in patches:
        fft_features = extract_fft_features(patch, feature_vector_length=fft_feature_vector_length, band_strategy=fft_band_strategy)
        statistical_features = extract_statistical_features(patch)

        all_fft_features.append(fft_features)
        all_statistical_features.append(statistical_features)

    all_fft_features = np.array(all_fft_features)
    all_statistical_features = np.array(all_statistical_features)
    combined_features = np.concatenate((all_fft_features, all_statistical_features), axis=1)
    return combined_features

## Model Training

In [None]:
def train_random_forest(X_train, y_train, random_state=42):
    """Trains a Random Forest Classifier.

    Args:
        X_train (np.array): Training feature matrix.
        y_train (np.array): Training labels.
        random_state (int, optional): Random state for reproducibility. Defaults to 42.

    Returns:
        RandomForestClassifier: Trained Random Forest model.
    """
    rf_classifier = RandomForestClassifier(random_state=random_state)
    rf_classifier.fit(X_train, y_train)
    return rf_classifier

## Model Evaluation

In [None]:
def evaluate_model(model, X_test, y_test, class_names=None):
    """Evaluates the trained model and prints metrics.

    Args:
        model (RandomForestClassifier): Trained model.
        X_test (np.array): Test feature matrix.
        y_test (np.array): Test labels.
        class_names (list, optional): List of class names for confusion matrix labels. Defaults to None.

    Returns:
        tuple: (accuracy, classification_report_dict, confusion_matrix)
    """
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.4f}")
    print("\nClassification Report:")
    report = classification_report(y_test, y_pred, target_names=class_names, output_dict=True) # Get report as dict
    print(classification_report(y_test, y_pred, target_names=class_names)) # Print string report

    print("\nConfusion Matrix:")
    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names if class_names else sorted(np.unique(y_test)), # Class names for labels
                yticklabels=class_names if class_names else sorted(np.unique(y_test)))
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()
    return accuracy, report, cm # Return metrics for analysis


## Main

In [None]:
def main():
    """Main function to execute Mars terrain classification pipeline."""
    image_directory = "path/to/your/mars_images"  # **Replace with your actual image directory**
    patches, labels, class_names = load_images_and_labels(image_directory) # Modified load function to return class names

    if not patches:
        print("No patches loaded. Exiting.")
        return

    print("Extracting features...")
    combined_features = extract_features_from_patches(patches, fft_feature_vector_length=200, fft_band_strategy="radial") # Choose band strategy

    X = combined_features
    y = labels

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    print("Training Random Forest...")
    rf_model = train_random_forest(X_train, y_train)

    print("Evaluating Model...")
    evaluate_model(rf_model, X_test, y_test, class_names=class_names) # Pass class names for better evaluation output


if __name__ == "__main__":
    main()