In [2]:
# Install the required packages
! pip install tensorflow scikit-image tqdm matplotlib CairoSVG svglib reportlab keras --quiet

In [6]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from skimage import io, transform
import matplotlib.pyplot as plt


In [7]:
# Data Ingestion

def load_and_preprocess_image(file_path, target_size=(32, 32)):
    """
    Load and preprocess a single satellite image.

    Args:
        file_path (str): Path to the image file.
        target_size (tuple): Desired size of the output image (height, width).

    Returns:
        numpy.ndarray: Preprocessed image as a numpy array.

    This function loads an image from the given file path, resizes it to the specified
    target size, and normalizes the pixel values to the range [0, 1].
    """
    img = io.imread(file_path)
    img = transform.resize(img, target_size, anti_aliasing=True)
    img = img.astype(np.float32) / 255.0
    return img

def load_dataset(data_dir, target_size=(32, 32)):
    """
    Load all images and labels from the data directory.

    Args:
        data_dir (str): Path to the directory containing the dataset.
        target_size (tuple): Desired size of the output images (height, width).

    Returns:
        tuple: Two numpy arrays, (images, labels).
            - images: 4D array of shape (n_samples, height, width, channels)
            - labels: 3D array of shape (n_samples, height, width)

    This function loads all image channels and corresponding labels from the specified
    directory. It handles both single timestamp and multiple timestamp directory structures.
    """
    images = []
    labels = []
    
    print(f"Searching for data in: {data_dir}")
    
    if not os.path.exists(data_dir):
        print(f"Error: Directory {data_dir} does not exist.")
        return np.array(images), np.array(labels)
    
    # Check if data_dir is a single timestamp directory or contains multiple timestamp directories
    if all(f.startswith('channel_') or f == 'label.png' for f in os.listdir(data_dir)):
        # Single timestamp directory
        directories = [data_dir]
    else:
        # Multiple timestamp directories
        directories = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    for directory in directories:
        timestamp_images = []
        for i in range(12):
            img_path = os.path.join(directory, f'channel_{i}.png')
            if os.path.exists(img_path):
                img = load_and_preprocess_image(img_path, target_size)
                timestamp_images.append(img)
                print(f"Loaded channel {i}: {img_path}")
            else:
                print(f"Missing image: {img_path}")
        
        if len(timestamp_images) == 12:
            image_stack = np.stack(timestamp_images, axis=-1)
            images.append(image_stack)
            print(f"Successfully created image stack with shape: {image_stack.shape}")
            
            label_path = os.path.join(directory, 'label.png')
            if os.path.exists(label_path):
                label = io.imread(label_path)
                label = load_and_preprocess_image(label_path, target_size)
                labels.append(label)
                print(f"Loaded label: {label_path}")
            else:
                print(f"Missing label: {label_path}")
        else:
            print(f"Incomplete channel set in {directory}. Found {len(timestamp_images)} channels instead of 12.")
    
    print(f"Total image stacks loaded: {len(images)}")
    print(f"Total labels loaded: {len(labels)}")
    
    return np.array(images), np.array(labels)

def prepare_data_for_model(images, labels, test_size=0.2, validation_split=0.2):
    """
    Prepare the data for model training, including train/val/test split.

    Args:
        images (numpy.ndarray): 4D array of input images.
        labels (numpy.ndarray): 3D array of corresponding labels.
        test_size (float): Proportion of the dataset to include in the test split.
        validation_split (float): Proportion of the training data to include in the validation split.

    Returns:
        tuple: Six numpy arrays, (X_train, X_val, X_test, y_train, y_val, y_test).

    This function splits the dataset into training, validation, and test sets. It also
    performs some data integrity checks and reshapes the labels if necessary.
    """
    print(f"Images shape: {images.shape}")
    print(f"Labels shape before reshape: {labels.shape}")
    print(f"Labels unique values: {np.unique(labels)}")

    if labels.size == 0:
        raise ValueError("Labels array is empty. Please check the data loading process.")

    labels = labels.reshape(labels.shape[0], -1)
    print(f"Labels shape after reshape: {labels.shape}")

    n_samples = len(images)
    if n_samples < 3:
        print(f"Warning: Only {n_samples} samples found. Splitting may not be possible.")
        return images, images, images, labels, labels, labels

    if n_samples < 10:
        print(f"Warning: Only {n_samples} samples found. Using a 60/20/20 split.")
        test_size = 0.2
        validation_split = 0.25  # 25% of 80% is 20% of the total

    X_train_val, X_test, y_train_val, y_test = train_test_split(
        images, labels, test_size=test_size, random_state=42)

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=validation_split, random_state=42)

    print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
    print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
    print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

    return X_train, X_val, X_test, y_train, y_val, y_test

def data_generator(images, labels, class_weights, batch_size=32):
    """
    Generate batches of data with class weights for training.

    Args:
        images (numpy.ndarray): 4D array of input images.
        labels (numpy.ndarray): 2D array of corresponding labels.
        class_weights (dict): Dictionary mapping class indices to class weights.
        batch_size (int): Number of samples per batch.

    Yields:
        tuple: (batch_images, batch_labels, batch_weights)
            - batch_images: numpy array of shape (batch_size, height, width, channels)
            - batch_labels: numpy array of shape (batch_size, num_classes)
            - batch_weights: numpy array of shape (batch_size,) containing sample weights

    This generator function creates balanced batches of data for training, applying
    class weights to handle imbalanced datasets.
    """
    num_samples = len(images)
    while True:
        indices = np.random.permutation(num_samples)
        for start in range(0, num_samples, batch_size):
            end = min(start + batch_size, num_samples)
            batch_indices = indices[start:end]
            batch_images = images[batch_indices]
            batch_labels = labels[batch_indices]
            print("Batch images shape:", batch_images.shape)
            print("Batch labels shape:", batch_labels.shape)
            batch_weights = np.array([class_weights[label] for label in batch_labels.flatten()])
            yield batch_images, batch_labels, batch_weights

# Model Architecture

def create_lightning_cnn(input_shape=(32, 32, 12), num_classes=1024):
    """
    Create a CNN model for lightning detection.

    Args:
        input_shape (tuple): Shape of the input images (height, width, channels).
        num_classes (int): Number of output classes.

    Returns:
        tensorflow.keras.Model: Compiled Keras model.

    This function creates a simple CNN architecture for lightning detection in satellite imagery.
    """
    model = models.Sequential([
        layers.Conv2D(16, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPooling2D((2, 2)),
        
        layers.Flatten(),
        
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        
        layers.Dense(num_classes, activation='sigmoid')
    ])
    
    return model

def compile_model(model, learning_rate=1e-3, decay=0.9):
    """
    Compile the Keras model with appropriate optimizer and loss function.

    Args:
        model (tensorflow.keras.Model): The model to compile.
        learning_rate (float): Initial learning rate for the optimizer.
        decay (float): Decay rate for the learning rate schedule.

    Returns:
        tensorflow.keras.Model: Compiled Keras model.

    This function sets up the model with an Adam optimizer with learning rate decay,
    binary crossentropy loss, and various metrics.
    """
    lr_schedule = optimizers.schedules.InverseTimeDecay(
        initial_learning_rate=learning_rate,
        decay_steps=1000,
        decay_rate=decay
    )
    optimizer = optimizers.Adam(learning_rate=lr_schedule)
    
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
    
    return model

# Model Training

def train_model(model, X_train, y_train, X_val, y_val, batch_size=128, epochs=10):
    """
    Train the model on the provided data.

    Args:
        model (tensorflow.keras.Model): The model to train.
        X_train (numpy.ndarray): Training data.
        y_train (numpy.ndarray): Training labels.
        X_val (numpy.ndarray): Validation data.
        y_val (numpy.ndarray): Validation labels.
        batch_size (int): Number of samples per gradient update.
        epochs (int): Number of epochs to train the model.

    Returns:
        tensorflow.keras.callbacks.History: Training history object.

    This function trains the model using the provided data, implements early stopping
    and learning rate reduction, and uses class weights to handle imbalanced data.
    """
    # Calculate class weights
    unique, counts = np.unique(y_train.flatten(), return_counts=True)
    class_weights = dict(zip(unique, len(y_train.flatten()) / (len(unique) * counts)))

    early_stopping = EarlyStopping(patience=50, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(factor=0.2, patience=20)

    train_gen = data_generator(X_train, y_train, class_weights, batch_size)
    val_gen = data_generator(X_val, y_val, class_weights, batch_size)

    history = model.fit(
        train_gen,
        steps_per_epoch=len(X_train) // batch_size,
        epochs=epochs,
        validation_data=val_gen,
        validation_steps=len(X_val) // batch_size,
        callbacks=[early_stopping, reduce_lr]
    )
    
    return history

def create_custom_cnn(input_shape=(32, 12, 4, 12), num_classes=1024):
    """
    Create a custom CNN model for lightning detection.

    Args:
        input_shape (tuple): Shape of the input images (height, width, time, channels).
        num_classes (int): Number of output classes.

    Returns:
        tensorflow.keras.Model: Compiled Keras model.

    This function creates a more complex CNN architecture for lightning detection,
    designed to handle multi-channel, multi-temporal satellite imagery.
    """
    inputs = layers.Input(shape=input_shape)
    
    # Reshape the input to combine the 12 and 4 dimensions
    x = layers.Reshape((32, 12, 48))(inputs)
    
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)
    
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)
    
    x = layers.Dense(512, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Performance Visualization

def plot_training_history(history):
    """
    Plot the training history of the model.

    Args:
        history (tensorflow.keras.callbacks.History): History object returned by model.fit().

    This function creates two subplots: one for accuracy and one for loss,
    showing both training and validation metrics over epochs.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
    
    ax1.plot(history.history['accuracy'], label='Training Accuracy')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    
    ax2.plot(history.history['loss'], label='Training Loss')
    ax2.plot(history.history['val_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

def plot_precision_recall_curve(y_true, y_pred):
    """
    Plot the precision-recall curve.

    Args:
        y_true (numpy.ndarray): True labels.
        y_pred (numpy.ndarray): Predicted probabilities or scores.

    This function calculates and plots the precision-recall curve, which shows the
    trade-off between precision and recall for different threshold values.
    """
    precision, recall, _ = precision_recall_curve(y_true.ravel(), y_pred.ravel())
    
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, label='Precision-Recall Curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.show()

def plot_roc_curve(y_true, y_pred):
    """
    Plot the Receiver Operating Characteristic (ROC) curve.

    Args:
        y_true (numpy.ndarray): True labels.
        y_pred (numpy.ndarray): Predicted probabilities or scores.

    This function calculates and plots the ROC curve, which illustrates the diagnostic
    ability of a binary classifier system as its discrimination threshold is varied.
    It also calculates and displays the Area Under the Curve (AUC) score.
    """
    fpr, tpr, _ = roc_curve(y_true.ravel(), y_pred.ravel())
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend()
    plt.show()

def visualize_predictions(y_true, y_pred, num_samples=5):
    """
    Visualize the true labels and model predictions.

    Args:
        y_true (numpy.ndarray): True labels.
        y_pred (numpy.ndarray): Predicted labels or probabilities.
        num_samples (int): Number of random samples to visualize.

    This function randomly selects a specified number of samples and displays
    their true labels alongside the model's predictions for visual comparison.
    """
    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4*num_samples))
    
    for i in range(num_samples):
        idx = np.random.randint(0, y_true.shape[0])
        
        axes[i, 0].imshow(y_true[idx].reshape(32, 32), cmap='binary')
        axes[i, 0].set_title(f'True - Sample {i+1}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(y_pred[idx].reshape(32, 32), cmap='binary')
        axes[i, 1].set_title(f'Predicted - Sample {i+1}')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Main Execution

if __name__ == "__main__":
    """
    Main execution block for the lightning detection model.

    This block demonstrates the full pipeline of the lightning detection model:
    1. Data loading and preprocessing
    2. Model creation and compilation
    3. Model training
    4. Model evaluation
    5. Visualization of results

    It serves as an example of how to use the various functions defined in this script
    to train and evaluate a lightning detection model on satellite imagery data.
    """
    # Data Ingestion
    data_dir = 'path/Data/Clean/timestamp_1'
    target_size = (32, 12)
    batch_size = 32
    
    print(f"Contents of {data_dir}:")
    for root, dirs, files in os.walk(data_dir):
        level = root.replace(data_dir, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{subindent}{f}")
    
    print("\nLoading dataset...")
    images, labels = load_dataset(data_dir, target_size)
    
    print(f"\nNumber of samples: {len(images)}")
    print(f"Shape of first image stack: {images[0].shape if len(images) > 0 else 'No images'}")
    print(f"Shape of first label: {labels[0].shape if len(labels) > 0 else 'No labels'}")

    if len(images) == 0 or len(labels) == 0:
        print("Error: No images or labels were loaded. Please check the data directory and file structure.")
        exit(1)
    
    print("\nPreparing data for model...")
    X_train, X_val, X_test, y_train, y_val, y_test = prepare_data_for_model(images, labels)
    
    print("X_train shape:", X_train.shape)
    print("y_train shape:", y_train.shape)
    print("X_val shape:", X_val.shape)
    print("y_val shape:", y_val.shape)

    print(f"\nTraining samples: {len(X_train)}")
    print(f"Validation samples: {len(X_val)}")
    print(f"Test samples: {len(X_test)}")
    
    # Model Creation and Training
    input_shape = X_train.shape[1:]  # Dynamically set the input shape
    num_classes = y_train.shape[1] 
    
    model = create_custom_cnn(input_shape, num_classes)
    model = compile_model(model)
    
    print("Training model...")
    history = train_model(model, X_train, y_train, X_val, y_val)
    
    # Model Evaluation
    print("Evaluating model...")
    test_loss, test_accuracy, test_precision, test_recall = model.evaluate(X_test, y_test)
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    
    # Generate predictions on test set
    y_pred = model.predict(X_test)
    
    # Visualize Results
    plot_training_history(history)
    plot_precision_recall_curve(y_test, y_pred)
    plot_roc_curve(y_test, y_pred)
    visualize_predictions(y_test, y_pred)
    
    # Print best outcomes
    print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
    print(f"Best validation loss: {min(history.history['val_loss']):.4f}")

Contents of path/Data/Clean/timestamp_1:
timestamp_1/
    channel_0.png
    channel_1.png
    channel_10.png
    channel_11.png
    channel_2.png
    channel_3.png
    channel_4.png
    channel_5.png
    channel_6.png
    channel_7.png
    channel_8.png
    channel_9.png
    label.png

Loading dataset...
Searching for data in: path/Data/Clean/timestamp_1
Loaded channel 0: path/Data/Clean/timestamp_1\channel_0.png
Loaded channel 1: path/Data/Clean/timestamp_1\channel_1.png
Loaded channel 2: path/Data/Clean/timestamp_1\channel_2.png
Loaded channel 3: path/Data/Clean/timestamp_1\channel_3.png
Loaded channel 4: path/Data/Clean/timestamp_1\channel_4.png
Loaded channel 5: path/Data/Clean/timestamp_1\channel_5.png
Loaded channel 6: path/Data/Clean/timestamp_1\channel_6.png
Loaded channel 7: path/Data/Clean/timestamp_1\channel_7.png
Loaded channel 8: path/Data/Clean/timestamp_1\channel_8.png
Loaded channel 9: path/Data/Clean/timestamp_1\channel_9.png
Loaded channel 10: path/Data/Clean/timestam

In [None]:
def post_process_predictions(predictions, threshold=0.5, min_area=10):
    """
    Post-process the model's predictions to identify significant lightning areas.

    Args:
        predictions (numpy.ndarray): Raw predictions from the model (probability maps).
        threshold (float): Probability threshold for considering a pixel as lightning.
        min_area (int): Minimum area (in pixels) for a region to be considered significant.

    Returns:
        numpy.ndarray: Binary mask of significant lightning areas.
    """
    # Apply threshold to get binary prediction
    binary_pred = (predictions > threshold).astype(np.uint8)

    # Label connected components
    labeled, num_features = ndimage.label(binary_pred)

    # Remove small areas
    for i in range(1, num_features + 1):
        area = np.sum(labeled == i)
        if area < min_area:
            binary_pred[labeled == i] = 0

    return binary_pred

def get_lightning_regions(binary_pred):
    """
    Get the bounding boxes of lightning regions.

    Args:
        binary_pred (numpy.ndarray): Binary mask of lightning areas.

    Returns:
        list: List of tuples, each containing (min_row, min_col, max_row, max_col) for a lightning region.
    """
    labeled, num_features = ndimage.label(binary_pred)
    regions = []

    for i in range(1, num_features + 1):
        rows, cols = np.where(labeled == i)
        min_row, max_row = np.min(rows), np.max(rows)
        min_col, max_col = np.min(cols), np.max(cols)
        regions.append((min_row, min_col, max_row, max_col))

    return regions

# Example usage
y_pred = model.predict(X_test)  # Assuming X_test is your test data

for i, pred in enumerate(y_pred):
    # Post-process the prediction
    processed_pred = post_process_predictions(pred, threshold=0.5, min_area=10)
    
    # Get lightning regions
    lightning_regions = get_lightning_regions(processed_pred)
    
    print(f"Sample {i + 1}:")
    if lightning_regions:
        print(f"  Found {len(lightning_regions)} significant lightning areas:")
        for j, region in enumerate(lightning_regions):
            print(f"    Region {j + 1}: {region}")
    else:
        print("  No significant lightning areas detected.")
    print()