# Deep Learning Model Architecture Exploration and Performance Evaluation

#### See the `data_EDA_and_CML_benchmarking.ipynb` notebook for parts 1 and 2, which include the deep learning dataset preparation and CML benchmarking, respectively

## 3.1 Model Architecture Exploration: Justification

##### Overall, the performances of the initial four deep learning models implemented in the `data_EDA_and_CML_benchmarking.ipynb` notebook, which included FCN, CNN, ResNet, and RNN, were poor. Among them, the CNN had the highest accuracy, exceeding 25%. While this value is still low, we will focus on implementing architectures that utilize CNNs, focusing on the three architectures listed below:  

1. VGG16 with Fine-Tuning (a deep CNN)
* *Why?* A VGG16 is a deep CNN with 16 layers that excels at deep feature extraction, effectively capturing complex visual features through small 3x3 convolutional filters. By using pre-trained weights on ImageNet and fine-tuning them on the `PHIPS_CrystalHabitAI_Dataset.nc` image dataset, VGG16 can adapt to our specific classification task, improving performance even with limited data, as the `PHIPS_CrystalHabitAI_Dataset.nc` image dataset is relatively small. The VGG16's depth and fine-tuning capabilities help overcome the low accuracy of initial models by learning more intricate patterns specific to our ice crystal images.

2. InceptionV3 (a different variation of a deep CNN)
* *Why?* This architecture excels at multi-scale feature learning, utilizing Inception modules to process multiple convolutional filter sizes in parallel, capturing visual information at different scales within the same layer. Despite its depth, InceptionV3 is computationally efficient due to techniques like factorized convolutions and dimension reductions, making it suitable for complex datasets without excessive computational cost. Its advanced architecture can extract richer and more diverse features than simpler models, potentially leading to significant improvements in classification accuracy on the `PHIPS_CrystalHabitAI_Dataset.nc` image dataset.

3. Convolutional Recurrent Neural Network (CRNN) with Attention Mechanism (a hyrbid of CNN and RNN)
* *Why?* CRNN integrates Convolutional Neural Networks for spatial feature extraction with Recurrent Neural Networks (like LSTM or GRU) to capture sequential or temporal dependencies in the data. Incorporating attention layers enables the model to focus on the most relevant parts of the input images, enhancing its ability to learn important features and improving classification results. Lastly, this architecture offers a novel solution that goes beyond standard models, potentially capturing complex patterns and relationships in our ice crystal images that previous models may have missed.

#### By using these DL architectures, we will address the low performance of the initial DL models by leveraging deeper networks, advanced feature extraction techniques, and innovative combinations of neural network types tailored to our image classification task. 

## 3.2 Imports and Environment Setup

In [8]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import time
import os

In [40]:
# TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D, 
                                     GlobalAveragePooling2D, Input, SimpleRNN, LSTM, TimeDistributed, 
                                     Bidirectional, Attention)
from tensorflow.keras.applications import VGG16, InceptionV3
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import Loss
from tensorflow.keras.preprocessing.image import smart_resize
from tensorflow.keras.layers import Concatenate, Resizing, Reshape, Permute, Multiply, Activation

In [10]:
# Sklearn for metrics
from sklearn.metrics import (classification_report, confusion_matrix, accuracy_score, 
                             f1_score, precision_score, recall_score, mean_squared_error, roc_curve, auc)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [11]:
# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## 3.3 Data Loading and Preprocessing
##### organized using a `DatasetLoader` class

In [18]:
# Define a class for loading and preprocessing the dataset
class DatasetLoader:
    def __init__(self, file_path):
        self.file_path = file_path

    def load_data(self):
        # Load the dataset using xarray
        ds = xr.open_dataset(self.file_path)
        images = ds['image_array'].values  # Shape: (samples, height, width)
        labels = ds['label'].values        # Shape: (samples,)
        temps = ds['temperature'].values   # Shape: (samples,)
        return images, labels, temps

    def preprocess_data(self, images, labels):
        # Encode string labels into integers
        label_encoder = LabelEncoder()
        labels_encoded = label_encoder.fit_transform(labels)
        num_classes = len(np.unique(labels_encoded))

        # One-hot encode the labels
        labels_one_hot = to_categorical(labels_encoded, num_classes)

        # Expand dimensions of images for channels (grayscale images)
        images_expanded = np.expand_dims(images, axis=-1)  # Shape: (samples, height, width, 1)

        # Normalize images to [0, 1]
        images_normalized = images_expanded / 255.0

        return images_normalized, labels_one_hot, labels_encoded, num_classes, label_encoder

    def split_data(self, images, labels_encoded, labels_one_hot):
        # Split data into training, validation, and test sets
        X_train, X_temp, y_train_encoded, y_temp_encoded, y_train_one_hot, y_temp_one_hot = train_test_split(
            images, labels_encoded, labels_one_hot, test_size=0.2, random_state=42, stratify=labels_encoded)
        X_val, X_test, y_val_encoded, y_test_encoded, y_val_one_hot, y_test_one_hot = train_test_split(
            X_temp, y_temp_encoded, y_temp_one_hot, test_size=0.5, random_state=42, stratify=y_temp_encoded)

        return (X_train, y_train_encoded, y_train_one_hot), \
               (X_val, y_val_encoded, y_val_one_hot), \
               (X_test, y_test_encoded, y_test_one_hot)

# Instantiate the DatasetLoader and load the data
data_loader = DatasetLoader('/Users/valeriagarcia/Desktop/ESS569_Snowflake_Classification/PHIPS_CrystalHabitAI_Dataset.nc')
images, labels, temps = data_loader.load_data()
images, labels_one_hot, labels_encoded, num_classes, label_encoder = data_loader.preprocess_data(images, labels)
(X_train, y_train_encoded, y_train_one_hot), \
(X_val, y_val_encoded, y_val_one_hot), \
(X_test, y_test_encoded, y_test_one_hot) = data_loader.split_data(images, labels_encoded, labels_one_hot)

## 3.4 Data Augmentation
##### Here, we create a data augmentation generator (`train_datagen`) for the training data that applies random transformations—including rotations up to 20 degrees, horizontal and vertical shifts up to 10% of the image size, horizontal and vertical flips, zooms up to 10%—to enhance the diversity of the dataset during training.

In [19]:
# Define data augmentation for training data
train_datagen = ImageDataGenerator(
    rescale=1.0,  # Images are already normalized
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)

# No augmentation for validation and test data, only rescaling
val_datagen = ImageDataGenerator(rescale=1.0)
test_datagen = ImageDataGenerator(rescale=1.0)

# Create data generators
train_generator = train_datagen.flow(X_train, y_train_one_hot, batch_size=32)
val_generator = val_datagen.flow(X_val, y_val_one_hot, batch_size=32)
test_generator = test_datagen.flow(X_test, y_test_one_hot, batch_size=32, shuffle=False)


## 3.5 Physics-Informed Loss Function with Probabilistic Class Likelihoods

##### In the cloud microphysics community, it is well-understood from laboratory studies that different ice crystal habits have a tendency to grow within a specific range of temperatures and relative humidity conditions. An example of the different temperature regimes is provided in Varcie et al. 2024:
* *polycrystalline growth layer* (growth of polycrstals) -  may occur when the ambient temperature is below -18˚C
* *dendritic growth layer* (growth of dendrites) - may occur when temperature is warmer than or equal to -18˚C and less than or equal to -12˚C
* *plate growth layer* (growth of plates) - may occur where temperature is warmer than -12˚C and less than -8˚C
* *needle growth layer* (growth of needles) - may occur where temperatures are warmer than or equal to -8˚C and less than -3˚C

##### As our dataset only contains temperature in the metadata, we will focus on leveraging temperature and the temperature regimes above to create a custom loss function. Note the above temperature ranges refer to temperature layers over which certain ice crystals *may* grow, assuming other conditions, such as high ice/water supersaturations, are met. Moreover, particles that growth at cooler temperatures may still be observed at warmer temperatures due to sedimentation of the particles. 

##### Objective: Incorporate temperature-dependent class probabilities into the loss function to guide the model based on physical principles while allowing for natural variability.

In [32]:
#### Create a mapping from class labels to their corresponding temperatures ####

# Get unique class labels
unique_classes = np.unique(labels_encoded)

# Initialize a dictionary to hold temperatures for each class
class_temperatures = {class_idx: [] for class_idx in unique_classes}

# Populate the dictionary
for idx, class_idx in enumerate(labels_encoded):
    temp = temps[idx]
    class_temperatures[class_idx].append(temp)

# Print the mapping of integer labels to original labels
for class_idx, class_label in enumerate(label_encoder.classes_):
    print(f"Class {class_idx}: {class_label}")

print()

#### Calculate the mean and standard deviation for the temperatures in each class ####

# Initialize the dictionary to hold temperature statistics for each class
class_temperature_stats = {}

# Compute mean and standard deviation for each class, ignoring NaNs
for class_idx in unique_classes:
    temps_ = np.array(class_temperatures[class_idx])
    
    # Compute mean and standard deviation while ignoring NaNs
    mean_temp = np.nanmean(temps_)
    std_temp = np.nanstd(temps_)
    
    # Handle case where all temps are NaN
    if np.isnan(mean_temp) or np.isnan(std_temp):
        print(f"Class {class_idx}: All temperature values are NaN. Cannot compute mean and std.")
        class_temperature_stats[class_idx] = {'mean': np.nan, 'std': np.nan}
    else:
        class_temperature_stats[class_idx] = {'mean': mean_temp, 'std': std_temp}
        print(f"Class {class_idx}: Mean Temp = {mean_temp:.2f}, Std Temp = {std_temp:.2f}")

Class 0: aggregate
Class 1: bullet_rosette
Class 2: capped_column
Class 3: column
Class 4: dendrite
Class 5: graupel
Class 6: needle
Class 7: plate
Class 8: polycrystal
Class 9: side_plane
Class 10: tiny

Class 0: Mean Temp = -11.28, Std Temp = 4.97
Class 1: All temperature values are NaN. Cannot compute mean and std.
Class 2: Mean Temp = -11.40, Std Temp = 2.89
Class 3: Mean Temp = -13.08, Std Temp = 4.77
Class 4: Mean Temp = -14.17, Std Temp = 2.99
Class 5: Mean Temp = -11.17, Std Temp = 4.74
Class 6: Mean Temp = -3.14, Std Temp = 1.18
Class 7: Mean Temp = -9.33, Std Temp = 4.09
Class 8: Mean Temp = -13.55, Std Temp = 4.28
Class 9: Mean Temp = -9.21, Std Temp = 4.76
Class 10: Mean Temp = -8.36, Std Temp = 3.61


  mean_temp = np.nanmean(temps_)
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


In [35]:
# The get_expected_probs() function calculates the expected class probabilities for each sample in a batch based on its temperature, using Gaussian distributions derived from the class temperature statistics.
# This function will be called within the physics-informed loss function to obtain the expected probabilities based on temperature, which are then used to compute the physics term (e.g., KL divergence).

def get_expected_probs(temperature_batch, num_classes):
    # Initialize expected probabilities array
    expected_probs = np.zeros((len(temperature_batch), num_classes), dtype=np.float32)
    
    for i, temp in enumerate(temperature_batch):
        total_prob = 0.0
        probs = np.zeros(num_classes, dtype=np.float32)
        
        # Handle NaN temperatures by assigning uniform probabilities
        if np.isnan(temp):
            # Assign uniform probability if temperature is NaN
            probs[:] = 1.0 / num_classes
        else:
            for class_idx in range(num_classes):
                mean = class_temperature_stats[class_idx]['mean']
                std = class_temperature_stats[class_idx]['std']
                
                # Handle classes with NaN mean or std by assigning uniform probability
                if np.isnan(mean) or np.isnan(std):
                    prob = 1.0 / num_classes
                else:
                    # Gaussian probability density function
                    prob = np.exp(-0.5 * ((temp - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
                probs[class_idx] = prob
                total_prob += prob
            
            # Normalize probabilities to sum to 1
            if total_prob > 0:
                probs /= total_prob
            else:
                # If total_prob is zero (unlikely), assign uniform probabilities
                probs[:] = 1.0 / num_classes
        
        expected_probs[i] = probs
    
    return expected_probs

In [34]:
# The temperature statistics are all NaN for the bullet rosette category (class 1)
# The physics-informed loss function will be modified to exclude classes with missing temperature data (e.g., Class 1) from the physics term.

# Define the physics-informed loss function
def physics_informed_loss(y_true, y_pred, temperature):
    # Standard categorical cross-entropy loss
    cce = tf.keras.losses.CategoricalCrossentropy()
    loss = cce(y_true, y_pred)
    
    # Identify samples not belonging to classes with missing temperature stats
    class_indices = tf.argmax(y_true, axis=1)
    valid_class_mask = tf.constant([
        not np.isnan(class_temperature_stats[i]['mean']) for i in range(num_classes)
    ], dtype=tf.bool)
    sample_mask = tf.gather(valid_class_mask, class_indices)
    sample_mask = tf.cast(sample_mask, tf.float32)
    
    # Compute expected probabilities based on temperature
    expected_probs = tf.numpy_function(
        func=get_expected_probs,
        inp=[temperature, num_classes],
        Tout=tf.float32
    )
    
    # Compute the physics term (KL divergence)
    kl_divergence = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
    physics_term = kl_divergence(expected_probs, y_pred)
    
    # Apply the mask to exclude invalid samples
    physics_term = physics_term * sample_mask
    
    # Compute mean physics term over valid samples
    total_valid_samples = tf.reduce_sum(sample_mask) + 1e-7  # Avoid division by zero
    physics_term = tf.reduce_sum(physics_term) / total_valid_samples
    
    # Total loss with weighting factor
    lambda_weight = 0.1  # Adjust as needed
    total_loss = loss + lambda_weight * physics_term
    
    return total_loss

## 3.6 Model Definitions

##### Here, we will define the DL models to be used (e.g., VGG16 with fine-tuning, InceptionV3, CRNN with Attention) with necessary adjustments to accept temperature data where needed.

#### A. VGG16 with Fine-Tuning
**Implementation details:**
* Pre-trained VGG16 Model: Utilize the VGG16 model pre-trained on ImageNet.
* Input Adjustments: Convert grayscale images to RGB by repeating the single channel three times.
* Output Layer: Adjust the final dense layer to match the number of classes.
* Temperature Handling: Temperature data is not fed into the model but provided to the loss function during training.

In [37]:
class VGG16Model:
    def __init__(self, input_shape, num_classes):
        self.input_shape = input_shape  # Shape: (height, width, channels)
        self.num_classes = num_classes
        self.model = self.build_model()

    def build_model(self):
        # Input layer for images
        inputs = Input(shape=self.input_shape, name='image_input')

        # Convert grayscale to RGB by repeating channels
        x = Concatenate(axis=-1)([inputs, inputs, inputs])  # Shape: (height, width, 3)

        # Load pre-trained VGG16 model without the top layer and wiht pre-trained weights
        base_model = VGG16(weights='imagenet', include_top=False, input_tensor=x)

        # Freeze base model layers for initial training (to keep pre-trained features during training)
        for layer in base_model.layers:
            layer.trainable = False

        # Add custom layers on top (more specifically, adds GlobalAveragePooling2D, a dense layer with 256 units and ReLU activation, and an output layer matching the number of classes with softmax activation)
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(256, activation='relu')(x)
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        # Construct the model
        model = Model(inputs=inputs, outputs=outputs, name='VGG16Model')

        return model

    def compile_model(self):
        # Compilation will be handled in the training step using a custom training loop
        pass  # No action needed here

#### B. InceptionV3
**Implementation details:**
* Pre-trained InceptionV3 Model: Utilize the InceptionV3 model pre-trained on ImageNet.
* Input Adjustments: Convert grayscale images to RGB. Also resize images to the expected input size for InceptionV3 (e.g., 299x299).
* Output Layer: Adjust the final dense layer to match the number of classes.
* Temperature Handling: Temperature data is not fed into the model but provided to the loss function during training.

In [39]:
class InceptionV3Model:
    def __init__(self, input_shape, num_classes):
        self.input_shape = input_shape  # Original image shape
        self.num_classes = num_classes
        self.model = self.build_model()

    def build_model(self):
        # Input layer for images
        inputs = Input(shape=self.input_shape, name='image_input')

        # Resize images to 299x299 as expected by InceptionV3
        x = Resizing(299, 299)(inputs)

        # Convert grayscale to RGB
        x = Concatenate(axis=-1)([x, x, x])  # Shape: (299, 299, 3)

        # Load pre-trained InceptionV3 model without the top layer
        base_model = InceptionV3(weights='imagenet', include_top=False, input_tensor=x)

        # Freeze base model layers for initial training
        for layer in base_model.layers:
            layer.trainable = False

        # Add custom layers on top
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(256, activation='relu')(x)
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        # Construct the model
        model = Model(inputs=inputs, outputs=outputs, name='InceptionV3Model')

        return model

    def compile_model(self):
        # Compilation will be handled during training with the custom loss
        pass

#### C. CRNN with Attention Mechanism
**Implementation details:**
* Convolutional Layers: Extract spatial features from images.
* Recurrent Layers (LSTM): Capture sequential dependencies in the extracted features.
* Attention Mechanism: Enhance the model's focus on relevant features.
* Input Adjustments: Use the grayscale images directly.
* Output Layer: Adjust the final dense layer to match the number of classes.
* Temperature Handling: Temperature data is not fed into the model but provided to the loss function during training.

In [None]:
class CRNNModel:
    def __init__(self, input_shape, num_classes, lstm_units=64):
        self.input_shape = input_shape  # Shape: (height, width, channels)
        self.num_classes = num_classes
        self.lstm_units = lstm_units
        self.model = self.build_model()

    def build_model(self):
        # Input layer for images
        inputs = Input(shape=self.input_shape, name='image_input')

        # Convolutional layers
        x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
        x = MaxPooling2D(pool_size=(2, 2))(x)

        x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)

        # Prepare data for LSTM
        shape = x.shape
        x = Reshape((shape[1] * shape[2], shape[3]))(x)  # Shape: (batch_size, timesteps, features)

        # LSTM layer
        x = LSTM(self.lstm_units, return_sequences=True)(x)

        # Attention mechanism
        attention = Dense(1, activation='tanh')(x)
        attention = Flatten()(attention)
        attention = Activation('softmax')(attention)
        attention = RepeatVector(self.lstm_units)(attention)
        attention = Permute([2, 1])(attention)
        x = Multiply()([x, attention])
        x = Lambda(lambda xin: K.sum(xin, axis=1))(x)

        # Output layer
        outputs = Dense(self.num_classes, activation='softmax', name='output')(x)

        # Construct the model
        model = Model(inputs=inputs, outputs=outputs, name='CRNNModel')

        return model

    def compile_model(self):
        # Compilation will be handled during training with the custom loss
        pass