# Cloud Image Classification and Segmentation

This notebook demonstrates how to classify images and segment clouds in the classified cloudy images using pre-trained classifier and segmentation models.

## Imports and Setup

In [None]:
import jax
import jax.numpy as jnp
from jax import random, jit
from flax.training import train_state
import optax
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import glob

# Utility functions for loading models and data
from utilities import load_model, open_fits_with_mask
from visualization import plot_image_preds
from cloudsifier.model import load_model as load_classifier_model



# Function to load the segmentation model
def load_segmentation_model(model_path):
    return load_model(model_path)

# Load the pre-trained models
classifier_state = load_classifier_model('models/image_classification_VGG8.keras')
segmentation_state = load_segmentation_model('models/doubleconv_64_128_256_Flax')

# Set paths to the data
image_paths = sorted(glob.glob('path/to/test/images/*.fits'))

# Function to classify images
@jit
def classify_image(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    probs = jax.nn.sigmoid(logits)
    return probs > 0.5  # Assuming binary classification, return True if cloud is detected

# Function to segment images
@jit
def segment_image(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    probs = jax.nn.sigmoid(logits)
    return jnp.round(probs)

# Function to classify images and then segment if they contain clouds
def classify_and_segment(image_path, classifier_state, segmentation_state):
    # Open image file
    image, _ = open_fits_with_mask(image_path)
    image = jnp.expand_dims(image, axis=-1)  # Add channel dimension
    
    # Classify the image
    is_cloud = classify_image(classifier_state, jnp.expand_dims(image, axis=0))
    
    if is_cloud:
        # Segment the image
        mask_pred = segment_image(segmentation_state, jnp.expand_dims(image, axis=0))
        return image.squeeze(), mask_pred.squeeze(), True
    else:
        return image.squeeze(), None, False


## Classify and Segment Images


In [None]:
# List to store results
results = []

# Loop through images and process each
for image_path in tqdm(image_paths):
    image, mask_pred, is_cloud = classify_and_segment(image_path, classifier_state, segmentation_state)
    results.append((image, mask_pred, is_cloud))

# Plot results
for i, (image, mask_pred, is_cloud) in enumerate(results[:10]):
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap='gray')
    plt.title('Original Image')

    if is_cloud:
        plt.subplot(1, 3, 2)
        plt.imshow(mask_pred, cmap='jet')
        plt.title('Predicted Mask')
    
    plt.subplot(1, 3, 3)
    plt.imshow(image, cmap='gray')
    if is_cloud:
        plt.imshow(mask_pred, cmap='jet', alpha=0.5)
    plt.title('Overlay')
    
    plt.show()


## Save Results

In [None]:
# Save the results to files
output_dir = 'output/segmented_images'
os.makedirs(output_dir, exist_ok=True)

for i, (image, mask_pred, is_cloud) in enumerate(results):
    image_path = os.path.join(output_dir, f'image_{i}.png')
    plt.imsave(image_path, image, cmap='gray')
    
    if is_cloud:
        mask_path = os.path.join(output_dir, f'mask_{i}.png')
        plt.imsave(mask_path, mask_pred, cmap='jet')

print("Results saved successfully.")

## Conclusion

This notebook demonstrates how to classify images and segment clouds in the classified cloudy images using pre-trained classifier and segmentation models. The results are visualized and saved for further analysis.
