## Imports

In [None]:
# !pip install --upgrade pip
# !pip install scikit-multilearn
# !pip install visualkeras

In [None]:
# !pip install netCDF4
# !pip install zarr
# !pip install xarray
# !pip install tensorflow_addons
# !pip install h5netcdf
# !pip install tensorflow[and-cuda]
# !pip install tbparse

In [None]:
# ! gsutil -m cp -r dir gs://tropos_2/limassol /content
# !pip install tensorboard pandas


In [None]:
import tensorflow as tf

In [None]:
import netCDF4 as nc
import matplotlib.pyplot as plt
from google.cloud import storage
import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.colors as colors
import matplotlib.cbook as cbook
import math
import zarr
import xarray as xr
import datetime

In [None]:
from tensorflow import keras
from tensorflow.keras import layers
# import tensorflow_addons as tfa
from tensorflow.keras.utils import to_categorical
import tensorflow.keras.backend as K

In [None]:
import fsspec
import concurrent.futures
from tqdm.notebook import tqdm_notebook, trange, tqdm
from sklearn.metrics import confusion_matrix, balanced_accuracy_score
import seaborn as sns
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve
import plotly.express as px
import matplotlib.colors as colors

In [None]:
from typing import Union, List, Callable
import matplotlib.colors as mcolors

## Initializations

### Test tf and gpu compatibility

In [None]:
print("TensorFlow Version:", tf.__version__)
gpus = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(gpus))
if gpus:
    try:
        # Print details for each GPU
        for gpu in gpus:
            print("Found GPU:", gpu)
            tf.config.experimental.set_memory_growth(gpu, True)
            print("  Memory growth enabled.")
        # Try a simple GPU operation
        print("\nAttempting simple GPU operation...")
        with tf.device('/GPU:0'):
            a = tf.constant([[1.0, 2.0], [3.0, 4.0]])
            b = tf.constant([[1.0, 1.0], [0.0, 1.0]])
            c = tf.matmul(a, b)
        print("Simple GPU operation successful. Result tensor:\n", c.numpy())
    except RuntimeError as e:
        print("!!! Runtime Error during GPU setup or test:", e)
    except Exception as e:
        print("!!! An unexpected error occurred:", e)
else:
    print("!!! TensorFlow cannot find any GPUs.")
    print("!!! Ensure you have 'tensorflow' (GPU version), not 'tensorflow-cpu' installed.")
    print("!!! Check CUDA/cuDNN installation and compatibility.")

### Model Parameters

In [None]:
model_runtime = None # @param
if not model_runtime:
  model_runtime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_run_name = '4_layer_large_filter_no_class0_mean1_group_and_dice_penalty1_data_globalnorm' # @param
model_name = f'{model_run_name}_{model_runtime}'

## Data

### Get Data

In [None]:
generate_data = False # @param
use_mask = True # @param
dataset_name = 'data_w_mask_global_norm_3' # @param
train = False # @param

In [None]:
local_filepath = "limassol"
files_glob = tf.io.gfile.glob(local_filepath + "/*")
print(len(files_glob))

In [None]:
def load_nc_file(filename, engine="h5netcdf", *args, **kwargs) -> xr.Dataset:
    """Load a NetCDF dataset from local file system or cloud bucket."""
    with fsspec.open(filename, mode="rb") as file:
        dataset = xr.load_dataset(file, engine=engine, *args, **kwargs)
    return dataset

# load_nc_file('limassol/20170107_regridded_data_for_limassol.nc')

In [None]:
# Dataset attributes
_HEADS = features = [
    'polly_bsc_532',
    'polly_bsc_1064',
    'polly_att_bsc_532',
    'polly_att_bsc_1064',
    'polly_pardepol_532',
    'polly_voldepol_532',
    'polly_ang_532_1064',
    'model_pressure',
    'model_temperature'
]

_QUALITY_FLAGS = [
    'polly_bsc_532_quality_flag',
    'polly_bsc_1064_quality_flag',
    'polly_voldepol_532_quality_flag',
]

_CLASS_NAME = 'combined_target_classification'
_MODEL_CLASS_NAME = 'polly_target_classification'

# Define dataset sizes
times_min_dim = 960
heights_min_dim = 600

num_classes = 12

In [None]:
def filter_files(file, times_min_dim=960, heights_min_dim=600):
  # Get file list with data meeting required lengths
    ds = load_nc_file(file)
    time_len = len(ds.coords['time'].values)
    height_len = len(ds.coords['height'].values)
    if time_len != times_min_dim:
        print(file, time_len, height_len)
        return file
    if height_len != heights_min_dim:
        print(file, time_len, height_len)
        return file
    else:
      return None

In [None]:
# if generate_data:
#   files_to_remove = []
#   with concurrent.futures.ThreadPoolExecutor(max_workers=64) as executor:
#     all_data = executor.map(filter_files, files_glob)
#     files_to_remove = list(all_data)
#   print(len(files_to_remove))

files_to_remove = ['limassol/20171016_regridded_data_for_limassol.nc',
             'limassol/20170508_regridded_data_for_limassol.nc',
             'limassol/20170303_regridded_data_for_limassol.nc',
             'limassol/20170308_regridded_data_for_limassol.nc',
             'limassol/20170527_regridded_data_for_limassol.nc',
             'limassol/20171019_regridded_data_for_limassol.nc',
             'limassol/20180222_regridded_data_for_limassol.nc',
             'limassol/20161225_regridded_data_for_limassol.nc',
             'limassol/20170413_regridded_data_for_limassol.nc',
             'limassol/20170304_regridded_data_for_limassol.nc',
             'limassol/20170215_regridded_data_for_limassol.nc',
             'limassol/20170514_regridded_data_for_limassol.nc',
             'limassol/20170310_regridded_data_for_limassol.nc',
             'limassol/20170524_regridded_data_for_limassol.nc',
             'limassol/20170906_regridded_data_for_limassol.nc',
             'limassol/20180121_regridded_data_for_limassol.nc',
             'limassol/20171020_regridded_data_for_limassol.nc',
             'limassol/20170401_regridded_data_for_limassol.nc',
             'limassol/20170419_regridded_data_for_limassol.nc',
             'limassol/20170511_regridded_data_for_limassol.nc',
             'limassol/20170522_regridded_data_for_limassol.nc',
             'limassol/20170101_regridded_data_for_limassol.nc',
             'limassol/20161213_regridded_data_for_limassol.nc',
             'limassol/20170320_regridded_data_for_limassol.nc',
             'limassol/20180325_regridded_data_for_limassol.nc',
             'limassol/20170210_regridded_data_for_limassol.nc',
             'limassol/20170225_regridded_data_for_limassol.nc',
             'limassol/20170331_regridded_data_for_limassol.nc',
             'limassol/20180219_regridded_data_for_limassol.nc',
             'limassol/20170213_regridded_data_for_limassol.nc',
             'limassol/20170402_regridded_data_for_limassol.nc']

Seems that 29 files have less than 960 times. We can either get rid of them or to fill nans for the missing times.

In [None]:
# Remove problematic files
if generate_data:
  for file in files_to_remove:
    if file:
      files_glob.remove(file)
len(files_glob)

In [None]:
# Dataset parameters
_TIMES =  times_min_dim
_HEIGHTS = heights_min_dim

_N_FEATURES = len(_HEADS)
_N_SAMPLES = len(files_glob)

_INPUT_SHAPE = (_TIMES, _HEIGHTS, _N_FEATURES*2)
_INPUT_SHAPE

In [None]:

#Generate input data tensor, shape = [n_samples, times, heights, n_features]:
def get_mask(ds, feature_name):
    """Mask the dataset"""
    if feature_name in ['polly_bsc_532', 'polly_att_bsc_532']:
        feature_mask = np.array(ds.variables['polly_bsc_532_quality_flag'][:].values, dtype=bool)
    elif feature_name in ['polly_bsc_1064', 'polly_att_bsc_1064']:
        feature_mask = np.array(ds.variables['polly_bsc_1064_quality_flag'][:].values, dtype=bool)
    elif feature_name == 'polly_voldepol_532':
        feature_mask = np.array(ds.variables['polly_voldepol_532_quality_flag'][:].values, dtype=bool)
    elif feature_name == 'polly_pardepol_532':
        feature_mask = np.array(ds.variables['polly_bsc_532_quality_flag'][:].values, dtype=bool) | np.array(ds.variables['polly_voldepol_532_quality_flag'][:].values, dtype=bool)
    elif feature_name == 'polly_ang_532_1064':
        feature_mask = np.array(ds.variables['polly_bsc_532_quality_flag'][:].values, dtype=bool) | np.array(ds.variables['polly_bsc_1064_quality_flag'][:].values, dtype=bool)
    else:
        feature_mask = np.array(np.zeros(ds.variables[feature_name].values.shape), dtype=bool)
    return feature_mask

In [None]:
def create_feature_matrix(
    files_list,
    use_mask: bool = False
    ):
    """Functions that creates the feature matrix."""
    input_data = np.zeros([_N_SAMPLES, _TIMES, _HEIGHTS, _N_FEATURES])
    label_data = np.zeros([_N_SAMPLES, _TIMES, _HEIGHTS, 1])
    curr_model_label_data = np.zeros([_N_SAMPLES, _TIMES, _HEIGHTS, 1])

    cnt_file = 0
    for i in trange(len(files_list)):
      try:
          file = files_list[i]
          ds = load_nc_file(file)
          cnt_features = 0
          for i in _HEADS:
              data = ds.variables[i][:].values
              if use_mask:
                  mask = get_mask(ds, i)
                  if mask.shape != data.shape:
                    print('Shape mismatch for: ', file, i, 'Shapes: ', mask.shape, data.shape)
                  data = np.where(~mask, data, np.nan)
                  if np.all(np.isnan(data)):
                    print('All values are nan for: ', file, i)
              #--------------------------------------------------
              clipped_data = data[:_TIMES,:_HEIGHTS]
              input_data[cnt_file,:,:,cnt_features] = clipped_data
              cnt_features += 1
      except:
          print(f'{file}, {i}')
          print(input_data.shape)
          continue
      class_values = ds.variables[_CLASS_NAME].values
      saturated_class = class_values[:_TIMES,:_HEIGHTS]
      saturated_class = saturated_class.reshape(_TIMES,_HEIGHTS,1)
      label_data[cnt_file,:,:] = saturated_class

      curr_model_class_values = ds.variables[_MODEL_CLASS_NAME].values
      curr_model_saturated_class = curr_model_class_values[:_TIMES,:_HEIGHTS]
      curr_model_saturated_class = curr_model_saturated_class.reshape(_TIMES,_HEIGHTS,1)
      curr_model_label_data[cnt_file,:,:] = curr_model_saturated_class

      cnt_file += 1
    return input_data, label_data, curr_model_label_data

In [None]:
if generate_data:
  input_data, label_data, curr_model_label_data = create_feature_matrix(files_list=files_glob, use_mask=use_mask)
  np.save(f'dataset_for_training/{dataset_name}_input.npy', input_data)
  np.save(f'dataset_for_training/{dataset_name}_label.npy', label_data)
  np.save(f'dataset_for_training/{dataset_name}_curr_model_label.npy', curr_model_label_data)
else:
  input_data = np.load(f'dataset_for_training/{dataset_name}_input.npy')
  label_data = np.load(f'dataset_for_training/{dataset_name}_label.npy')
  curr_model_label_data = np.load(f'dataset_for_training/{dataset_name}_curr_model_label.npy')

In [None]:
print(f'label data: {label_data.shape}')
print(f'input data: {input_data.shape}')


### Prepare datasets

In [None]:
# --- Configuration ---
SEED = 42  # Set a seed for reproducible random shuffling
TEST_RATIO = 0.2 # Proportion of the *total* data for the test set
VALID_RATIO = 0.1 # Proportion of the *total* data for the validation set

In [None]:
# --- Splitting Setup ---
n_samples = input_data.shape[0]

print(f'\nTotal number of samples: {n_samples}')
train_ratio = 1.0 - TEST_RATIO - VALID_RATIO
print(f'Splitting Ratios -> Train: {train_ratio:.2f}, Validation: {VALID_RATIO:.2f}, Test: {TEST_RATIO:.2f}')

# Calculate number of samples for each set based on total samples
# Use integer casting after multiplication to get sample counts
n_test = int(TEST_RATIO * n_samples)
n_valid = int(VALID_RATIO * n_samples)
# Train gets the remainder to avoid rounding errors losing samples
n_train = n_samples - n_test - n_valid

print(f'Calculated Samples -> Train: {n_train}, Validation: {n_valid}, Test: {n_test}')
print(f'Total allocated: {n_train + n_valid + n_test} (should equal {n_samples})')


In [None]:
# --- Random Splitting using Shuffled Indices ---
print(f"\nGenerating and shuffling indices with seed {SEED}...")
np.random.seed(SEED)
indices = np.arange(n_samples)
np.random.shuffle(indices)
print("Indices shuffled.")

In [None]:
# Determine split points in the shuffled indices array
test_indices = indices[:n_test]
valid_indices = indices[n_test : n_test + n_valid]
train_indices = indices[n_test + n_valid :] # The rest go to training

print(f"Indices used -> Test: {len(test_indices)}, Validation: {len(valid_indices)}, Train: {len(train_indices)}")

# Create the datasets using the shuffled indices
# This ensures that X, Y, and Y_curr_model samples stay paired correctly.
print("Creating data splits using shuffled indices...")
x_train = input_data[train_indices]
y_train_int = label_data[train_indices] # Keep original integer labels for now

x_valid = input_data[valid_indices]
y_valid_int = label_data[valid_indices]

x_test = input_data[test_indices]
y_test_int = label_data[test_indices]
# Only create the test split for the 'current model' labels
y_curr_model_test_int = curr_model_label_data[test_indices]
print("Data splits created.")


In [None]:
def class_variance(y_int, set='Train'):
  n_samples = y_int.shape[0]
  y_int_samples_squeezed = np.squeeze(y_int, axis=-1)

  all_unique_classes_per_sample = []
  for i in range(n_samples):
      sample_labels = y_int_samples_squeezed[i]
      unique_in_sample = np.unique(sample_labels.flatten())
      all_unique_classes_per_sample.extend(unique_in_sample)

  fig = px.histogram(all_unique_classes_per_sample)
  fig.show()

  class_presence_counts = np.bincount(all_unique_classes_per_sample, minlength=num_classes)
  return class_presence_counts/n_samples


In [None]:
def get_class_distribution(y_int):
    # Ensure the input is a numpy array
    y_int = np.array(y_int)

    # Squeeze the last dimension if it's 1 (e.g., for channel)
    if y_int.shape[-1] == 1:
        y_int_squeezed = np.squeeze(y_int, axis=-1)
    else:
        y_int_squeezed = y_int

    all_unique_classes_per_sample = []
    # Iterate over each sample (image/mask) in the dataset
    for i in range(y_int_squeezed.shape[0]):
        sample_labels = y_int_squeezed[i]
        # Find the unique classes present in the current sample
        unique_in_sample = np.unique(sample_labels.flatten())
        all_unique_classes_per_sample.extend(unique_in_sample)

    return all_unique_classes_per_sample

def plot_class_distribution_density(y_train, y_val, y_test, num_classes):
    """
    Calculates and plots the class distribution as overlaid density plots
    for train, validation, and test sets in a single figure.

    Args:
        y_train (np.ndarray): Training dataset labels.
        y_val (np.ndarray): Validation dataset labels.
        y_test (np.ndarray): Test dataset labels.
        num_classes (int): The total number of classes to set the x-axis range.
    """
    # Create a single figure and axis for the density plot
    fig, ax = plt.subplots(figsize=(12, 7), sharey=True)

    datasets = {
        'Train Set': y_train,
        'Validation Set': y_val,
        'Test Set': y_test
    }

    # Use seaborn's color palette for distinct colors
    colors = sns.color_palette('deep', n_colors=len(datasets))
    # colors = ["#3498db", "#e74c3c", "#2ecc71"]

    # Loop through each dataset to create and add its density plot
    pd_dists = []
    for i, (name, data) in enumerate(datasets.items()):
        # Get the class distribution data using the helper function
        dist_data_pd = pd.DataFrame({'values':get_class_distribution(data)})
        dist_data_pd['Dataset'] = name
        pd_dists.append(dist_data_pd)

    dist_data = pd.concat(pd_dists)
    sns.kdeplot(data=dist_data, x='values', hue='Dataset', ax=ax,
                 common_norm=False,
                 alpha=0.8,
                 linewidth=2,
                #  discrete=True,
                #  stat='percent',
                #  multiple="layer",
                #  shrink=.9,
                 cut=0,
                 palette=colors)


    # --- Customize the plot ---
    ax.set_xlabel("Class ID", fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.set_xticks(range(num_classes)) # Ensure ticks are at integer class labels
    ax.grid(axis='y', linestyle='--', alpha=0.9)


    # Adjust layout
    plt.tight_layout()

    plt.show()

plot_class_distribution_density(y_train_int, y_valid_int, y_test_int, 12)

In [None]:
def plot_class_distribution_density(y_train, y_val, y_test, num_classes):
    """
    Calculates and plots the class distribution as overlaid density plots
    for train, validation, and test sets in a single figure.

    Args:
        y_train (np.ndarray): Training dataset labels.
        y_val (np.ndarray): Validation dataset labels.
        y_test (np.ndarray): Test dataset labels.
        num_classes (int): The total number of classes to set the x-axis range.
    """
    # Create a single figure and axis for the density plot
    fig, ax = plt.subplots(figsize=(12, 7), sharey=True)

    datasets = {
        'Train Set': y_train,
        'Validation Set': y_val,
        'Test Set': y_test
    }

    # Use seaborn's color palette for distinct colors
    colors = sns.color_palette('deep', n_colors=len(datasets))
    # colors = ["#3498db", "#e74c3c", "#2ecc71"]

    # Loop through each dataset to create and add its density plot
    pd_dists = []
    for i, (name, data) in enumerate(datasets.items()):
        # 2. Calculate the histogram density and the bin edges
        density, edges = np.histogram(get_class_distribution(data), bins=num_classes, density=True)

        # 3. Calculate the center of each bin for the x-coordinate
        centers = (edges[:-1] + edges[1:]) / 2

        # 4. Plot the density as a point at the center of each bin
        ax.plot(centers, density,
                marker='o',        # Use a circle as the marker
                linestyle='-',      # No connecting line between markers
                ms=8,              # Set the marker size
                label=name,        # Add a label for the legend
                color=colors[i])

    # --- Customize the plot ---
    ax.set_xlabel("Class ID", fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.set_xticks(range(num_classes)) # Ensure ticks are at integer class labels
    ax.grid(axis='y', linestyle='--', alpha=0.9)
    ax.legend(title='Dataset') # Add a legend to identify the points

    plt.tight_layout()
    plt.show()

plot_class_distribution_density(y_train_int, y_valid_int, y_test_int, 12)

### Preprocess dataset

In [None]:
# Train data
x_train_clipped = x_train.copy()

for i in range(_N_FEATURES):
  x_train_clipped[...,i] = np.clip(
      x_train_clipped[...,i], 0, np.nanpercentile(
          x_train_clipped[...,i], 99))

RAW_MEANS = np.nanmean(x_train_clipped, axis=(0,1,2))
RAW_MEANS = RAW_MEANS[np.newaxis, np.newaxis, np.newaxis, :]

In [None]:
x_temp = x_test[6,...,6].copy()
x_temp = np.clip(x_temp, 0, None)
# replace nan
x_temp = np.nan_to_num(x_temp, nan=RAW_MEANS[...,6], posinf=None, copy=True)
values = np.concatenate([x_temp.flatten(), np.log1p(x_temp.flatten())])
# Create a corresponding label for each value
labels = ['Original'] * len(x_temp.flatten()) + ['Log-Transformed'] * len(np.log1p(x_temp.flatten()))
# Build the DataFrame
combined_df = pd.DataFrame({'Pixel Value': values, 'Data Type': labels})


# 3. Create the plot using hue
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes
fig.set_size_inches(8, 6)
sns.kdeplot(data=combined_df, x='Pixel Value', hue='Data Type', fill=True, common_norm=False, ax=ax1)
sns.kdeplot(data=combined_df, x='Pixel Value', hue='Data Type', fill=True, common_norm=False, ax=ax2, legend=False)

ax1.set_ylim(15, 18)  # outliers only
ax2.set_ylim(0, 0.4)  # most of the data

# hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
plt.xlim((-0.5,7))

# 4. Add titles and labels
ax1.set_ylabel('')
ax2.set_ylabel('')
fig.supylabel('Density')
plt.show()

In [None]:
def preprocess(data, means=None, stds=None, training=True, with_log=True):
  # clip negatives
  data = np.clip(data, 0, None)
  # get nan feature
  x_nans = np.isnan(data).astype(np.int16)
  # replace nan
  data = np.nan_to_num(data, nan=RAW_MEANS, posinf=None, copy=True)
  # log transform
  if with_log:
    data = np.log1p(data)
  # calc means and stds
  if training:
    means = np.nanmean(data, axis=(0,1,2))
    means = means[np.newaxis, np.newaxis, np.newaxis, :]
    stds = np.nanstd(data, axis=(0,1,2))
    stds = stds[np.newaxis, np.newaxis, np.newaxis, :]
  # Standardize
  data = (data-means)/stds
  # add nan feature
  data = np.concatenate([data, x_nans], axis=3)
  return data, means, stds

x_train, log_means, log_stds = preprocess(x_train)
x_valid, _, _ = preprocess(x_valid, log_means, log_stds, False)
x_test, _, _ = preprocess(x_test, log_means, log_stds, False)

In [None]:
print("Deleting original large arrays to free memory...")
del input_data
del label_data
del curr_model_label_data
import gc
gc.collect()

In [None]:
print(f"\nOne-hot encoding labels with num_classes={num_classes}...")
# Use float32 for compatibility with most models/losses
y_train = tf.keras.utils.to_categorical(y_train_int, num_classes=num_classes).astype(np.float32)
y_valid = tf.keras.utils.to_categorical(y_valid_int, num_classes=num_classes).astype(np.float32)
y_test = tf.keras.utils.to_categorical(y_test_int, num_classes=num_classes).astype(np.float32)
y_curr_model_test = tf.keras.utils.to_categorical(y_curr_model_test_int, num_classes=num_classes).astype(np.float32)
print("One-hot encoding complete.")

In [None]:
print('x_train:',np.shape(x_train))
print('y_train:',np.shape(y_train))

print('x_test:',np.shape(x_test))
print('x_test:',np.shape(x_test))

print('x_valid:',np.shape(x_valid))
print('y_valid:',np.shape(y_valid))

print('y_curr_model_test:',np.shape(y_curr_model_test))

In [None]:
# Visualize dataset

# Choose which image to visualize
y_valid_img = y_valid[2]
# -----

for i in range(_TIMES):
    for j in range(_HEIGHTS):
        indices = np.where(y_valid_img[i, j, :])[0]
        if len(indices) > 1:
            active_index = np.random.choice(indices)
            y_valid_img[i, j, :] = 0
            y_valid_img[i, j, active_index] = 1
        elif len(indices) == 0:
            y_valid_img[i, j, 0] = 1

# Define colors for your 12 classes
colors = plt.cm.get_cmap('viridis', num_classes)
class_colors = colors(np.linspace(0, 1, num_classes))

# Create an RGB image where each class is mapped to a color
colored_image = np.zeros((960, 600, 3), dtype=np.float32)

for i in range(12):
    mask = y_valid_img[:, :, i] == 1
    colored_image[mask] = class_colors[i, :3]

# Display the colored image using Matplotlib
fig, ax = plt.subplots(figsize=(10, 8))  # Create a figure and an axes object
image = ax.imshow(colored_image.transpose(1, 0, 2), origin='lower')
# ax.set_title("Segmentation Visualization")
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.axis('on')

# Create a custom colorbar
cmap = mcolors.ListedColormap(class_colors)
bounds = np.arange(13) - 0.5  # Create boundaries for each color
norm = mcolors.BoundaryNorm(bounds, cmap.N)

cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=cmap, norm=norm),
                    ax=ax,
                    ticks=np.arange(num_classes),
                    spacing='proportional',
                    label='Class Number',
                    shrink=0.6)

cbar.ax.set_yticklabels([f'Class {i}' for i in range(num_classes)])

plt.tight_layout()
plt.show()

In [None]:
def plot_xarray_channels_list(data_array, channels_list):
    """
    Plots xarray DataArrays with height and time dimensions for each channel
    specified in a list, displaying 3 plots per row.

    Args:
        data_array: The xarray DataArray containing the data.
        channels_list: A list of channel names to plot.
    """
    cols = 3
    num_channels = len(channels_list)
    num_rows = 3

    fig, axes = plt.subplots(num_rows, min(cols, num_channels), figsize=(20, 10))

    if num_rows == 1:
        axes = np.array([axes])

    data_array = data_array.copy()
    for i, channel in enumerate(channels_list):
        row = i // cols
        col = i % cols

        if num_channels == 1:
          ax = axes[0]
        else:
          ax = axes[row, col]
        channel_data = data_array[..., channel]
        image = ax.imshow(channel_data.T, origin='lower', cmap='viridis', vmin=np.percentile(channel_data, [10]), vmax=np.percentile(channel_data, [90]), aspect='equal')
        ax.set_title(_HEADS[i])
        ax.set_xlabel("X-axis")
        ax.set_ylabel("Y-axis")
        ax.axis('on')
        plt.colorbar(image, label='Pixel Value', shrink=1)

    plt.tight_layout()
    plt.show()

temp_ds = x_valid[2,:,:,0:10]
plot_xarray_channels_list(temp_ds, range(0,9))


In [None]:
# Calculate class weights
if 'y_train' not in locals():
    raise NameError("y_train NumPy array must be loaded before calculating class weights.")

num_classes = y_train.shape[-1] # Should be 12
print(f"Calculating class weights for {num_classes} classes using the full y_train NumPy array...")


class_pixel_counts = np.sum(y_train, axis=(0, 1, 2))

print(f"Total pixel counts per class: {class_pixel_counts}")
print(f"Shape of counts: {class_pixel_counts.shape}") # Should be (12,)

# IMPORTANT: Add a small epsilon to prevent division by zero for classes potentially absent in the dataset
epsilon = np.finfo(float).eps # A very small positive float number
class_weights = 1.0 / (class_pixel_counts + epsilon)

class_weights = class_weights * 200.0

# Scaling weights to have a mean of 1:
class_weights = class_weights / np.mean(class_weights)
# print(f"Mean-normalized class weights: {class_weights}")

class_weights[0] = 0 # We don't care about class 0 (unknown)


print(f"\nCalculated class weights (inverse frequency * 200): {class_weights}")
print(f"Class weights shape: {class_weights.shape}") # Should be (12,)
print(f"Min weight: {np.min(class_weights)}, Max weight: {np.max(class_weights)}")

### Prepare patches

In [None]:
BATCH_SIZE = 16 # Example: Choose a value like 2, 4, 8, 16 based on GPU VRAM
SHUFFLE_BUFFER_SIZE_TRAIN = len(x_train) # Shuffle buffer size, often set to the dataset size
SHUFFLE_BUFFER_SIZE_VALID = len(x_valid)

In [None]:
def data_generator(x_data, y_data):
    """Yields one sample at a time."""
    num_samples = x_data.shape[0]
    for i in range(num_samples):
        # Ensure data types are correct for yielding
        yield x_data[i].astype(np.float32), y_data[i].astype(np.float32)

# Define the shape and type of *one* sample (output of the generator)
output_signature = (
    tf.TensorSpec(shape=(_TIMES, _HEIGHTS, _N_FEATURES*2), dtype=tf.float32),
    tf.TensorSpec(shape=(_TIMES, _HEIGHTS, num_classes), dtype=tf.float32)
)

print("Creating datasets using from_generator...")

In [None]:
# Training Dataset
train_dataset_gen = tf.data.Dataset.from_generator(
    lambda: data_generator(x_train, y_train), # Use lambda to pass args
    output_signature=output_signature
)
# Shuffle needs care with generators - apply *before* batching
# Use reshuffle_each_iteration=True for better shuffling across epochs
train_dataset_gen = train_dataset_gen.shuffle(SHUFFLE_BUFFER_SIZE_TRAIN, reshuffle_each_iteration=True)
train_dataset_gen = train_dataset_gen.batch(BATCH_SIZE)
train_dataset_gen = train_dataset_gen.repeat()
train_dataset_gen = train_dataset_gen.prefetch(tf.data.AUTOTUNE)
print("Train dataset created from generator.")

In [None]:
# Validation Dataset
valid_dataset_gen = tf.data.Dataset.from_generator(
    lambda: data_generator(x_valid, y_valid),
    output_signature=output_signature
)
valid_dataset_gen = valid_dataset_gen.batch(BATCH_SIZE)
train_dataset_gen = train_dataset_gen.repeat()
valid_dataset_gen = valid_dataset_gen.prefetch(tf.data.AUTOTUNE)
print("Validation dataset created from generator.")

In [None]:
# Test Dataset (No shuffling)
test_dataset_gen = tf.data.Dataset.from_generator(
    lambda: data_generator(x_test, y_test),
    output_signature=output_signature
)
test_dataset_gen = test_dataset_gen.batch(BATCH_SIZE)
test_dataset_gen = test_dataset_gen.prefetch(tf.data.AUTOTUNE)
print("Test dataset created from generator.")

In [None]:
print("\nChecking one batch from generator dataset:")
for batch in train_dataset_gen.take(1):
    print("Input batch shape:", batch[0].shape)
    print("Output batch shape:", batch[1].shape)

In [None]:
# Assuming valid_dataset_gen is your tf.data.Dataset for validation
print("Fetching a few batches from validation dataset to check diversity...")
fetched_info = []
num_to_check = 5 # Check first 5 batches

# Check cardinality (number of batches) if possible
val_cardinality = tf.data.experimental.cardinality(valid_dataset_gen)
num_batches_total = val_cardinality.numpy() if val_cardinality!=tf.data.experimental.UNKNOWN_CARDINALITY else -1
print(f"Validation dataset cardinality (num batches): {num_batches_total if num_batches_total >= 0 else 'Unknown'}")
if num_batches_total > 0 and num_batches_total < num_to_check:
     num_to_check = num_batches_total # Don't try to check more batches than exist

if num_batches_total == 0:
     print("!!! Warning: Validation dataset appears to have zero batches!")
else:
    try:
        for i, batch in enumerate(valid_dataset_gen.take(num_to_check)):
            x_batch, y_batch = batch
            info = (x_batch.shape, y_batch.shape, tf.reduce_mean(x_batch).numpy(), tf.reduce_mean(tf.cast(y_batch, tf.float32)).numpy())
            print(f"  Batch {i}: X shape {info[0]}, Y shape {info[1]}, X mean {info[2]:.4f}, Y mean {info[3]:.10f}")
            fetched_info.append(info[2:]) # Store mean/sum for comparison

            if np.all(tf.cast(y_batch, tf.float32)[0].numpy() == tf.cast(y_batch, tf.float32)[3].numpy()):
              print('!!! Warning: Samples within batch are identical')
        # Check if the fetched stats were all identical (simplistic check)
        if len(fetched_info) > 1 and len(set(fetched_info)) == 1:
             print("!!! Warning: Multiple validation batches fetched seem identical! Check generator logic.")
        elif len(fetched_info) > 0:
             print("Validation batches seem diverse (based on mean/sum).")
        else:
             print("Could not fetch validation batches to check diversity.")

    except Exception as e:
        print(f"!!! Error occurred while fetching validation batches: {e}")
        import traceback
        traceback.print_exc()

## Unet

In [None]:
from IPython.display import clear_output
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import classification_report
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

class DisplayAndSaveCallback(Callback):
    def __init__(self, epochs_num_to_print = 2):
        self.epochs_num_to_print = epochs_num_to_print

    def on_epoch_end(self, epoch, logs=None):
        if epoch%self.epochs_num_to_print==0:
            # print predicted test results:
            y_pred_1 = model.predict(x_test)
            y_pred_onehot_1 = tf.one_hot(tf.argmax(y_pred_1[:,:,:,:], axis=3), y_pred_1[:,:,:,:].shape[3])
            print(classification_report(y_test.reshape([-1,12]), y_pred_onehot_1.reshape([-1,12])))
            # save current model:
            model_save_path = f'{model_name}/{str(epoch)}'
            model.save(model_save_path)
            print("model saved:", model_save_path)


class DisplayAndSaveCallback_2(Callback):
    def __init__(self, epochs_num_to_print = 2, model_name='default_name', n_classes=12):
        self.epochs_num_to_print = epochs_num_to_print
        self.model_name = model_name
        self.n_classes = 12

    def on_epoch_end(self, epoch, logs=None):
        if epoch%self.epochs_num_to_print==0:
            # print predicted test results:
            y_pred_1 = model.predict(x_test)
            y_pred_onehot_1 = tf.one_hot(tf.argmax(y_pred_1[:,:,:,:], axis=3), y_pred_1[:,:,:,:].shape[3])
            print(classification_report(y_test.reshape([-1,12]), y_pred_onehot_1.reshape([-1,12])))
            # save current model:
            model_save_path = self.model_name+str(epoch)+'.keras'
            model.save(model_save_path)
            print("model saved:", model_save_path)

In [None]:
import tensorflow as tf
# Ensure necessary layers are imported
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Conv2DTranspose, concatenate, Rescaling, Lambda, Cropping2D, BatchNormalization, Activation
from tensorflow.keras.models import Model

def build_unet(img_shape, n_classes=12):
    # input layer shape is equal to patch image size
    inputs = Input(shape=img_shape)

    # rescale images from (0, 255) to (0, 1)
    previous_block_activation = inputs  # Set aside residual
    encoder_filters = [64, 128, 256, 512]

    contraction = {}
    # Contraction path: Blocks 1 through 5 are identical apart from the feature depth
    for f in encoder_filters:
        # First convolution
        f_in = previous_block_activation.shape[-1] if previous_block_activation.shape[-1] is not None else f # Approx
        x = tf.keras.layers.Conv2D(f, (3, 3), kernel_initializer='he_normal', padding='same')(previous_block_activation)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)
        x = tf.keras.layers.Dropout(0.1)(x)

        # Second convolution
        x = tf.keras.layers.Conv2D(f, (3, 3), kernel_initializer='he_normal', padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)

        contraction[f'conv{f}'] = x
        x_pool = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
        previous_block_activation = x_pool

    # Bottleneck layer
    # First convolution
    c5 = tf.keras.layers.Conv2D(encoder_filters[-1]*2, (3, 3), kernel_initializer='he_normal', padding='same')(previous_block_activation)
    c5 = tf.keras.layers.BatchNormalization()(c5)
    c5 = tf.keras.layers.Activation('relu')(c5)
    c5 = tf.keras.layers.Dropout(0.2)(c5)

    # Second convolution
    c5 = tf.keras.layers.Conv2D(encoder_filters[-1]*2, (3, 3), kernel_initializer='he_normal', padding='same')(c5)
    c5 = tf.keras.layers.BatchNormalization()(c5)
    c5 = tf.keras.layers.Activation('relu')(c5)
    previous_block_activation = c5

    # Expansive path: Second half of the network: upsampling inputs
    for f in reversed(encoder_filters): # e.g., [256, 128, 64, 32]
        x_upsampled = Conv2DTranspose(f, (2, 2), strides=(2, 2), padding='same')(previous_block_activation)
        x_upsampled = BatchNormalization()(x_upsampled)
        x_upsampled = tf.keras.layers.Activation('relu')(x_upsampled)
        skip_connection = contraction[f'conv{f}']

        # Cropping
        upsampled_h, upsampled_w = x_upsampled.shape[1], x_upsampled.shape[2]
        skip_h, skip_w = skip_connection.shape[1], skip_connection.shape[2]

        # Calculate how much the upsampled tensor is larger than the skip tensor
        crop_h = max(0, upsampled_h - skip_h)
        crop_w = max(0, upsampled_w - skip_w)

        if crop_h > 0 or crop_w > 0:
            x_upsampled = Cropping2D(cropping=(
                (crop_h // 2, crop_h - crop_h // 2), # Top, Bottom crop
                (crop_w // 2, crop_w - crop_w // 2)  # Left, Right crop
            ), name=f'crop_upsampled_to_skip_{f}')(x_upsampled)

        # END cropping
        x = concatenate([x_upsampled, skip_connection])
        x = Conv2D(f, (3, 3), kernel_initializer='he_normal', padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Dropout(0.2)(x)
        x = Conv2D(f, (3, 3), kernel_initializer='he_normal', padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        previous_block_activation = x

    outputs = tf.keras.layers.Conv2D(filters=n_classes, kernel_size=(1, 1), activation="softmax")(previous_block_activation)
    return Model(inputs=inputs, outputs=outputs)

In [None]:
# build model
img_height = _HEIGHTS
img_width = _TIMES
img_channels = _N_FEATURES*2

strategy = tf.distribute.MirroredStrategy()
print(f'Number of devices: {strategy.num_replicas_in_sync}') # Should print 4

In [None]:
def jaccard_index(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)

In [None]:
def multiclass_weighted_squared_dice_loss(class_weights: Union[list, np.ndarray, tf.Tensor]):
    """
    Weighted squared Dice loss.
    Used as loss function for multi-class image segmentation with one-hot encoded masks.
    :param class_weights: Class weight coefficients (Union[list, np.ndarray, tf.Tensor], len=<N_CLASSES>)
    :return: Weighted squared Dice loss function (Callable[[tf.Tensor, tf.Tensor], tf.Tensor])
    """
    if not isinstance(class_weights, tf.Tensor):
        class_weights = tf.constant(class_weights)

    def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """
        Compute weighted squared Dice loss.
        :param y_true: True masks (tf.Tensor, shape=(<BATCH_SIZE>, <IMAGE_HEIGHT>, <IMAGE_WIDTH>, <N_CLASSES>))
        :param y_pred: Predicted masks (tf.Tensor, shape=(<BATCH_SIZE>, <IMAGE_HEIGHT>, <IMAGE_WIDTH>, <N_CLASSES>))
        :return: Weighted squared Dice loss (tf.Tensor, shape=(None,))
        """
        axis_to_reduce = range(1, K.ndim(y_pred))  # Reduce all axis but first (batch)
        numerator = y_true * y_pred * class_weights  # Broadcasting
        numerator = 2. * K.sum(numerator, axis=axis_to_reduce)

        denominator = (y_true**2 + y_pred**2) * class_weights  # Broadcasting
        denominator = K.sum(denominator, axis=axis_to_reduce)

        return 1 - numerator / denominator

    return loss

In [None]:
# Define your class groups (make these accessible where you define the loss)
AEROSOL_INDICES = tf.constant([3, 4, 5, 6], dtype=tf.int32)
CLOUD_INDICES = tf.constant([7, 8, 9, 10, 11], dtype=tf.int32)

def Dice_plus_GroupConfusion_Loss(
  class_weights: Union[List, np.ndarray, tf.Tensor],
  group_penalty_factor: float = 0.5, # Hyperparameter to weigh the group penalty
  epsilon: float = 1e-6 # Small constant to prevent division by zero
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
  """
  Combined loss: Weighted squared Dice loss + Group Confusion Penalty.
  Penalizes misclassifications between aerosol and cloud groups.

  :param class_weights: Class weight coefficients for Dice loss.
  :param group_penalty_factor: Weighting factor for the group confusion penalty.
  :param epsilon: Small constant for numerical stability.
  :return: Combined loss function.
  """
  if not isinstance(class_weights, tf.Tensor):
    class_weights = tf.constant(class_weights, dtype=tf.float32)

  def loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    """
    Compute combined loss.
    :param y_true: True masks (tf.Tensor, shape=(<BATCH_SIZE>, <H>, <W>, <N_CLASSES>), one-hot)
    :param y_pred: Predicted masks (tf.Tensor, shape=(<BATCH_SIZE>, <H>, <W>, <N_CLASSES>), softmax probabilities)
    :return: Combined loss (tf.Tensor, shape=(None,))
    """

    y_pred = tf.cast(y_pred, tf.float32)
    y_true = tf.cast(y_true, tf.float32)

    # --- Dice loss component ---
    axis_to_reduce = tuple(range(1, K.ndim(y_pred)))  # Reduce spatial and class axes (H, W, C)
    # Weights are broadcasted across batch, H, W
    numerator = y_true * y_pred * class_weights
    numerator = 2. * K.sum(numerator, axis=axis_to_reduce)

    # Broadcasting class_weights
    denominator = (y_true + y_pred**2) * class_weights # Using y_true instead of y_true**2
    denominator = K.sum(denominator, axis=axis_to_reduce)

    dice_loss = 1.0 - (numerator + epsilon) / (denominator + epsilon) # Epsilon added for stability

    # --- Group Confusion Penalty component ---
    # Get probabilities predicted for aerosol and cloud groups for each pixel
    y_pred_aerosol_probs = K.sum(tf.gather(y_pred, AEROSOL_INDICES, axis=-1), axis=-1)
    y_pred_cloud_probs = K.sum(tf.gather(y_pred, CLOUD_INDICES, axis=-1), axis=-1)

    # Identify if the true class belongs to aerosol or cloud group for each pixel
    y_true_is_aerosol = K.sum(tf.gather(y_true, AEROSOL_INDICES, axis=-1), axis=-1)
    y_true_is_cloud = K.sum(tf.gather(y_true, CLOUD_INDICES, axis=-1), axis=-1)

    # Calculate pixel-wise penalty:
    pixel_group_penalty = (y_true_is_aerosol * y_pred_cloud_probs) + (y_true_is_cloud * y_pred_aerosol_probs)

    # Average the pixel-wise penalty across spatial dimensions (H, W) to get a per-sample penalty
    sample_group_penalty = K.mean(pixel_group_penalty, axis=tuple(range(1, K.ndim(pixel_group_penalty))))

    # --- 3. Combine losses ---
    # The dice_loss and sample_group_penalty are now per-sample losses (shape [BATCH_SIZE,])
    combined_loss = dice_loss + (group_penalty_factor * sample_group_penalty)

    return combined_loss

  return loss

In [None]:
def aerosol_cloud_confusion(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
  """
  Calculates a metric for confusion between aerosol and cloud groups.
  A lower value indicates less confusion.

  The metric computes the average probability assigned to the wrong group
  when the true class belongs to either aerosols or clouds.

  :param y_true: True masks (tf.Tensor, shape=(<BATCH_SIZE>, <H>, <W>, <N_CLASSES>), one-hot encoded).
  :param y_pred: Predicted masks (tf.Tensor, shape=(<BATCH_SIZE>, <H>, <W>, <N_CLASSES>), softmax probabilities).
  :return: Per-sample group confusion score (tf.Tensor, shape=(<BATCH_SIZE>,)).
            Keras will average this over the batch and then over the epoch.
  """
  y_pred = tf.cast(y_pred, tf.float32)
  y_true = tf.cast(y_true, tf.float32)

  # Sum predicted probabilities for aerosol and cloud groups for each pixel
  y_pred_aerosol_probs = K.sum(tf.gather(y_pred, AEROSOL_INDICES, axis=-1), axis=-1)
  y_pred_cloud_probs = K.sum(tf.gather(y_pred, CLOUD_INDICES, axis=-1), axis=-1)

  # Identify if the true class belongs to aerosol or cloud group for each pixel
  y_true_is_aerosol = K.sum(tf.gather(y_true, AEROSOL_INDICES, axis=-1), axis=-1)
  y_true_is_cloud = K.sum(tf.gather(y_true, CLOUD_INDICES, axis=-1), axis=-1)

  # Calculate pixel-wise confusion:
  pixel_group_confusion = (y_true_is_aerosol * y_pred_cloud_probs) + \
                          (y_true_is_cloud * y_pred_aerosol_probs)

  # Average the pixel-wise confusion across dimensions to get a per-sample confusion score.
  per_sample_confusion_score = K.mean(pixel_group_confusion, axis=tuple(range(1, K.ndim(pixel_group_confusion))))

  return per_sample_confusion_score

## Train

In [None]:
if train:
  import tensorflow.python.keras.backend as Kr
  %load_ext tensorboard
  # PARAMETERS
  group_penalty_factor = 1

In [None]:
if train:
  print(model_name)

  if not os.path.exists(model_run_name):
      try:
          os.makedirs(model_run_name)  # Use makedirs for nested directories
          print(f"Directory '{model_run_name}' created successfully.")
      except OSError as e:
          print(f"Error creating directory '{model_run_name}': {e}")
  else:
      print(f"Directory '{model_run_name}' already exists.")

In [None]:
if train:
  %tensorboard --logdir logs/{model_name}

In [None]:
if train:
  with strategy.scope():
    model = build_unet(img_shape=(img_width, img_height, img_channels), n_classes = num_classes)
    print(model.summary())
    model_checkpoint_filepath = '/Saved_Model.keras'
    checkpoint = tf.keras.callbacks.ModelCheckpoint(model_checkpoint_filepath, monitor="val_accuracy", verbose=1, save_best_only=True, mode="max")

    # stop model training early if validation loss doesn't continue to decrease over 10 iterations
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=20, verbose=1, mode="min", restore_best_weights=True, start_from_epoch=50)

    # log training console output to csv

    csv_logger = tf.keras.callbacks.CSVLogger('csv_logger', separator=",", append=False)

    # create list of callbacks
    log_dir = f"logs/{model_name}/fit/"
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    epochs_num_to_print_pred = 20  # print reconstruction every n epochs
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                                  patience=10, min_lr=5e-7, verbose=1)
    callbacks_list = [checkpoint, csv_logger, tensorboard_callback,
                      DisplayAndSaveCallback_2(epochs_num_to_print_pred, f'{model_run_name}/model_'),
                      reduce_lr, early_stopping]  # early_stopping

    # steps
    steps_per_epoch = len(x_train) // BATCH_SIZE
    if len(x_train) % BATCH_SIZE != 0: # If there's a remainder
        steps_per_epoch +=1
    print(f'steps per epoch = {steps_per_epoch}')
    # compile model
    optimizer = keras.optimizers.Adam(learning_rate=5e-3)
    model.compile(optimizer=optimizer,loss=Dice_plus_GroupConfusion_Loss(class_weights, group_penalty_factor),metrics=[jaccard_index, aerosol_cloud_confusion])
    # Training phase
    history = model.fit(train_dataset_gen, epochs=500, validation_data=valid_dataset_gen,
                        callbacks=callbacks_list, verbose=1, validation_steps=4,
                        steps_per_epoch=steps_per_epoch)

    # Save model
    model_save_path = f'{model_run_name}/unet_last_epoch.keras'
    model.save(model_save_path)
    print("model saved:", model_save_path)


In [None]:
if train:
  # classification validation report
  with strategy.scope():
    y_pred = model.predict(valid_dataset_gen)
    y_pred_onehot = tf.one_hot(tf.argmax(y_pred[:,:,:,:], axis=3), y_pred[:,:,:,:].shape[3])
    print('Classification report for validation dataset on model')
    print(classification_report(y_valid.reshape([-1,12]), y_pred_onehot.reshape([-1,12])))

## Results Analysis

### Import model

In [None]:
labels1=['0-No Class', '1-Clean atmosphere', '2-Non-typed particles/low conc',
         '3-Aerosol: small', '4-Aerosol: large,spherical','5-Aerosol: mixture, partly non-spherical',
         '6-Aerosol: large, non-spherical', '7-Cloud: non-typed', '8-Cloud: water droplets',
         '9-Cloud: likely water droplets', '10-Cloud: ice crystals', '11-Cloud: likely ice crystals']

In [None]:
import keras
model = keras.saving.load_model(
    f'{model_run_name}/unet_last_epoch.keras',
    custom_objects={
        'loss':Dice_plus_GroupConfusion_Loss,
        'jaccard_index':jaccard_index,
        'aerosol_cloud_confusion':aerosol_cloud_confusion}, compile=True, safe_mode=False)

In [None]:
%load_ext tensorboard
get_model_run = files_glob = tf.io.gfile.glob('logs/' + model_run_name + "*")
latest_run = np.sort(get_model_run)[-1]
%tensorboard --logdir logs/{'4_layer_large_filter_no_class0_mean1_group_and_dice_penalty1_data_globalnorm_20250522-133149'} --port=6010

In [None]:
import matplotlib.pyplot as plt
from tbparse import SummaryReader

# Specify the path to your TensorBoard log directory
log_dir = 'logs/4_layer_large_filter_no_class0_mean1_group_and_dice_penalty1_data_globalnorm_20250522-133149'

# Read the scalar data from the log directory
reader = SummaryReader(log_dir)

In [None]:
log_dir = 'logs/4_layer_large_filter_no_class0_mean1_group_and_dice_penalty1_data_globalnorm_20250522-133149'
reader = SummaryReader(log_dir)
df_tensors = reader.tensors

TARGET_TAG = 'epoch_loss'
df_filtered = df_tensors[df_tensors['tag'] == TARGET_TAG].sort_values('step').reset_index(drop=True)

df_set1 = df_filtered.iloc[0::2].copy()
df_set2 = df_filtered.iloc[1::2].copy()

smoothing_weight = 0.3

df_set1['smoothed_value'] = df_set1['value'].ewm(alpha=smoothing_weight).mean()
df_set2['smoothed_value'] = df_set2['value'].ewm(alpha=smoothing_weight).mean()

target_step = 146
val1_at_step = df_set1.loc[df_set1['step'] == target_step, 'value'].iloc[0]
val2_at_step = df_set2.loc[df_set2['step'] == target_step, 'value'].iloc[0]


plt.figure(figsize=(7, 4))

plt.plot(df_set1['step'], df_set1['smoothed_value'], label='Training', color='#01153E')
plt.plot(df_set2['step'], df_set2['smoothed_value'], label='Validation', color='#40E0D0')

plt.axvline(x=target_step, color='red', linestyle=':', linewidth=2, label=f'Step {target_step}')

plt.ylim(top=0.5)
plt.xlabel('Step')
plt.legend(['Training','Validation'],fontsize=8)
plt.show()

In [None]:

TARGET_TAG = 'epoch_aerosol_cloud_confusion'
df_filtered = df_tensors[df_tensors['tag'] == TARGET_TAG].sort_values('step').reset_index(drop=True)

df_set1 = df_filtered.iloc[0::2].copy()
df_set2 = df_filtered.iloc[1::2].copy()

smoothing_weight = 0.4

df_set1['smoothed_value'] = df_set1['value'].ewm(alpha=smoothing_weight).mean()
df_set2['smoothed_value'] = df_set2['value'].ewm(alpha=smoothing_weight).mean()

target_step = 146
val1_at_step = df_set1.loc[df_set1['step'] == target_step, 'value'].iloc[0]
val2_at_step = df_set2.loc[df_set2['step'] == target_step, 'value'].iloc[0]


plt.figure(figsize=(7, 4))

plt.plot(df_set1['step'], df_set1['smoothed_value'], label='Training', color='#01153E')
plt.plot(df_set2['step'], df_set2['smoothed_value'], label='Validation', color='#40E0D0')

# Add the vertical line at step 146
plt.axvline(x=target_step, color='red', linestyle=':', linewidth=2, label=f'Step {target_step}')

plt.ylim(top=0.008, bottom=0.001)
plt.xlabel('Step')
plt.legend(['Training','Validation'],fontsize=8)
plt.show()

In [None]:

TARGET_TAG = 'epoch_jaccard_index'
df_filtered = df_tensors[df_tensors['tag'] == TARGET_TAG].sort_values('step').reset_index(drop=True)

df_set1 = df_filtered.iloc[0::2].copy()
df_set2 = df_filtered.iloc[1::2].copy()

smoothing_weight = 0.01

df_set1['smoothed_value'] = df_set1['value'].ewm(alpha=smoothing_weight).mean()
df_set2['smoothed_value'] = df_set2['value'].ewm(alpha=smoothing_weight).mean()

target_step = 146
val1_at_step = df_set1.loc[df_set1['step'] == target_step, 'value'].iloc[0]
val2_at_step = df_set2.loc[df_set2['step'] == target_step, 'value'].iloc[0]


plt.figure(figsize=(7, 4))

plt.plot(df_set1['step'], df_set1['smoothed_value'], label='Training', color='#01153E')
plt.plot(df_set2['step'], df_set2['smoothed_value'], label='Validation', color='#40E0D0')

# Add the vertical line at step 146
plt.axvline(x=target_step, color='red', linestyle=':', linewidth=2, label=f'Step {target_step}')

# plt.ylim(top=0.008)
plt.xlabel('Step')
plt.legend(['Training','Validation'],fontsize=8)
plt.show()

In [None]:
y_pred = model.predict(test_dataset_gen)
y_pred_onehot = tf.one_hot(tf.argmax(y_pred[:,:,:,:], axis=3), y_pred[:,:,:,:].shape[3])
print('Classification mreport test dataset')
print(classification_report(y_test.reshape([-1,12]), y_pred_onehot.reshape([-1,12]), target_names=labels1))

In [None]:
Y_TEST = np.argmax(y_test, axis=3)
Y_PRED = np.argmax(y_pred_onehot, axis=3)

results = confusion_matrix(Y_TEST.reshape([-1,1]), Y_PRED.reshape([-1,1]), normalize='true')

In [None]:
def plot_confusion_matrix(results, labels1):
  plt.figure(figsize=[16,8])
  ax = sns.heatmap(results*100, annot=True, cmap='Blues')

  # Show all ticks and label them with the respective list entries
  ax.set_xticks(np.arange(len(labels1))+0.5)
  ax.set_yticks(np.arange(len(labels1))+0.5)
  ax.set_xticklabels(labels1, rotation=45, ha="right",
                      rotation_mode="anchor", fontsize=10)
  ax.set_yticklabels(labels1, rotation=0, ha="right",
                      rotation_mode="anchor", fontsize=10)

  ax.set_title('Confusion Matrix [%]\n\n', fontsize=15)
  ax.set_xlabel('\nPredicted Values', fontsize=10)
  ax.set_ylabel('Actual Values ', fontsize=10)

  plt.show()

plot_confusion_matrix(results, labels1)

### Analyze results based on height

In [None]:
# Calculate mean iou score per row (height)
mean_iou_per_row = np.zeros(heights_min_dim)
non_0_y_test = np.zeros(heights_min_dim)

for n in range(heights_min_dim):
  y_test_row = Y_TEST[:,:,n].copy()
  y_pred_row = Y_PRED[:,:,n].copy()

  y_test_non_0_count = np.sum(np.not_equal(y_test_row, 0))
  non_0_y_test[n] = y_test_non_0_count

  cm_row = confusion_matrix(y_test_row.reshape([-1,1]), y_pred_row.reshape([-1,1]), normalize='true', labels=np.arange(num_classes))
  tp = np.diag(cm_row)
  fp = np.sum(cm_row, axis=0) - tp
  fn = np.sum(cm_row, axis=1) - tp
  iou_per_class = tp / (tp + fp + fn + 1e-7)
  relevant_iou_scores = iou_per_class[1:] # Exclude class 0
  mean_iou_per_row[n] = np.nanmean(relevant_iou_scores)


In [None]:
# --- Plotting ---
x_axis_coords = np.arange(heights_min_dim)

# Plot: Mean IoU

fig = px.scatter(x=x_axis_coords, y=mean_iou_per_row,
                 width=1000, height=500,
                 color=non_0_y_test,
                 color_continuous_scale=px.colors.sequential.Viridis,
                 labels={'color': 'Count'}
                 )
fig.update_yaxes(title_text="Jaccard Index")
fig.update_xaxes(title_text="Height Index")

fig.show()


### Analyze actual images

In [None]:
class MidpointNormalize(mpl.colors.Normalize):
    def __init__(self, vmin=None, vmax=None, vcenter=None, clip=False):
        self.vcenter = vcenter
        super().__init__(vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1.]
        return np.ma.masked_array(np.interp(value, x, y,
                                            left=-np.inf, right=np.inf))

    def inverse(self, value):
        y, x = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
        return np.interp(value, x, y, left=-np.inf, right=np.inf)

In [None]:
custom_colors = [
        # Background / Clear Air
        '#f0f0f0',  # 0: No Class (Very Light Gray)
        '#d9d9d9',  # 1: Clean Atmosphere (Light Gray)

        # Non-typed
        '#969696',  # 2: Non-typed particles (Medium Gray)

        # Aerosol Group (Warm Colors)
        '#fee08b',  # 3: Aerosol: small (Bright Yellow)
        '#fdae61',  # 4: Aerosol: large, spherical (Strong Orange)
        '#f46d43',  # 5: Aerosol: mixture (Red-Orange)
        '#d73027',  # 6: Aerosol: large, non-spherical (Strong Red)

        # Cloud Group (Cool Colors)
        '#66c2a5',  # 7: Cloud: non-typed (Teal)
        '#3288bd',  # 8: Cloud: water droplets (Strong Blue)
        '#4393c3',  # 9: Cloud: likely water droplets (Slightly Lighter Blue)
        '#5e4fa2',  # 10: Cloud: ice crystals (Deep Purple)
        '#9e9ac8',  # 11: Cloud: likely ice crystals (Lighter Purple/Lavender)
    ]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.ndimage import gaussian_filter # For heatmap density

def plot_segmentation_results(
    y_true_all: np.ndarray,
    y_pred_all: np.ndarray,
    image_index: int,
    class_labels: list = range(0, num_classes),
    figsize: tuple = (34, 8),
    group_confusion_sigma: float = 1.0, # Sigma for Gaussian filter for density heatmap
    confusion_sigma: float = 0.5, # Sigma for Gaussian filter for density heatmap
    with_suptitle = True
):
    """
    Plots a side-by-side comparison of the ground truth and predicted segmentation maps.

    Args:
        y_true_all: NumPy array of ground truth integer labels for all samples.
                    Shape: (num_samples, image_height, image_width).
        y_pred_all: NumPy array of predicted integer class labels for all samples.
                    Shape: (num_samples, image_height, image_width).
        image_index: The index of the image to plot from the batch of samples.
        class_labels: A list of strings with the names for each class index.
                      The length should be equal to the number of classes.
        figsize: A tuple specifying the figure size for the plot.
    """
    if not (0 <= image_index < y_true_all.shape[0]):
        print(f"Error: image_index {image_index} is out of bounds. Please choose an index between 0 and {y_true_all.shape[0] - 1}.")
        return


    # Extract the specific image slices to plot
    y_true_image = y_true_all[image_index,...].copy()
    y_pred_image = y_pred_all[image_index,...].copy()

    group_confusion_error_map = np.zeros_like(y_true_image, dtype=np.uint8)
    confusion_error_map = np.zeros_like(y_true_image, dtype=np.uint8)
    num_classes = len(class_labels)

    # Create boolean masks for aerosol and cloud classes for y_true and y_pred
    is_true_aerosol = np.isin(y_true_image, AEROSOL_INDICES)
    is_true_cloud = np.isin(y_true_image, CLOUD_INDICES)

    is_pred_aerosol = np.isin(y_pred_image, AEROSOL_INDICES)
    is_pred_cloud = np.isin(y_pred_image, CLOUD_INDICES)

    is_pred_typed = np.isin(y_pred_image, tf.concat([AEROSOL_INDICES, CLOUD_INDICES], 0))
    is_true_typed = np.isin(y_true_image, tf.concat([AEROSOL_INDICES, CLOUD_INDICES], 0))
    is_typed = np.logical_and(is_pred_typed, is_true_typed)
    # Condition 1: True is Aerosol, Predicted is Cloud
    true_aero_pred_cloud = np.logical_and(is_true_aerosol, is_pred_cloud)

    # Condition 2: True is Cloud, Predicted is Aerosol
    true_cloud_pred_aero = np.logical_and(is_true_cloud, is_pred_aerosol)

    # Combine conditions: Mark 1 where either confusion type occurs
    group_confusion_error_map[np.logical_or(true_aero_pred_cloud, true_cloud_pred_aero)] = 1

    # General onfusion
    confusion_error_map[np.logical_and(np.not_equal(y_pred_image, y_true_image), is_typed)] = 1

    # group confusion density
    # Higher sigma = more smoothing, larger "bright" areas.
    group_confusion_density_map = gaussian_filter(group_confusion_error_map.astype(float), sigma=group_confusion_sigma)
    confusion_density_map = gaussian_filter(confusion_error_map.astype(float), sigma=confusion_sigma)

    # --- Set up colormaps and normalizations ---
    cmap_labels = plt.get_cmap('nipy_spectral', num_classes)
    bounds_labels = np.arange(-0.5, num_classes, 1)
    norm_labels = mcolors.BoundaryNorm(bounds_labels, cmap_labels.N)

    cmap_binary_error = mcolors.ListedColormap(['lightgray', 'darkorange'])
    bounds_binary_error = [-0.5, 0.5, 1.5]
    norm_binary_error = mcolors.BoundaryNorm(bounds_binary_error, cmap_binary_error.N)

    cmap_density_heatmap = plt.get_cmap('hot') # 'hot', 'inferno', 'magma' are good for heatmaps

    # --- Set up colormap and normalization for discrete classes ---
    # cmap = plt.get_cmap('tab20', num_classes)
    cmap = mcolors.ListedColormap(custom_colors)
    cmap_error = mcolors.ListedColormap(['black', 'red']) # 0: Match, 1: Mismatch
    bounds_error = [-0.5, 0.5, 1.5]
    norm_error = mcolors.BoundaryNorm(bounds_error, cmap_error.N)

    # Create a normalization object to map integer class values to the colormap.
    # Boundaries are set between the integers.
    bounds = np.arange(-0.5, num_classes, 1)
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # --- Plotting ---
    fig, axs = plt.subplots(1, 4, figsize=figsize)
    if with_suptitle:
      fig.suptitle(f"Segmentation Result for Image #{image_index}", fontsize=15, y=0.8)

    # Plot Group Confusion Error Map
    im3 = axs[0].imshow(group_confusion_density_map.T, cmap=cmap_density_heatmap, aspect='equal')
    axs[0].set_title("Aerosol-Cloud Group Confusion Density Heatmap")
    axs[0].set_xlabel("Time (X-axis / Width)")
    axs[0].set_ylabel("Height (Y-axis)")

    # Plot general confusion heatmap
    im4 = axs[1].imshow(confusion_density_map.T, cmap=cmap_density_heatmap, aspect='equal')
    axs[1].set_title("Confusion heatmap")
    axs[1].set_xlabel("Time (X-axis)")
    axs[1].set_ylabel("Height (Y-axis)")

    # Plot Ground Truth
    im1 = axs[2].imshow(y_true_image.T, cmap=cmap, norm=norm, aspect='equal')
    axs[2].set_title("Ground Truth (Test Labels)")
    axs[2].set_xlabel("Time (X-axis)")
    axs[2].set_ylabel("Height (Y-axis)")

    # Plot Prediction
    im2 = axs[3].imshow(y_pred_image.T, cmap=cmap, norm=norm, aspect='equal')
    axs[3].set_title("Model Prediction")
    axs[3].set_xlabel("Time (X-axis)")
    axs[3].set_ylabel("Height (Y-axis)")

    # Invert y-axis so that height 0 is at the top
    for ax in axs:
        ax.invert_yaxis()

    # --- Create and configure the colorbar ---
    ax_position = axs[1].get_position()

    fig.subplots_adjust(right=0.85)

    cbar_ax = fig.add_axes([
        0.87,
        ax_position.y0 + 0.01,
        0.02,
        ax_position.height - 0.02
    ])

    cbar = fig.colorbar(im1, cax=cbar_ax)

    # Set the ticks to be in the middle of each color segment
    tick_locs = np.arange(num_classes)
    cbar.set_ticks(tick_locs)

    # Set the tick labels to your class labels
    cbar.set_ticklabels(class_labels)
    cbar.ax.tick_params(labelsize=10)
    cbar.set_label("Class Labels", rotation=270, labelpad=15)

    plt.show()

plot_segmentation_results(Y_TEST, Y_PRED, 10)

### Find case studies

In [None]:
mean_iou_per_row = np.zeros(heights_min_dim)
non_0_y_test = np.zeros(heights_min_dim)

def get_height_analysis_graph(y_test, y_pred, image_num):
  for n in range(heights_min_dim):
    y_test_row = y_test[image_num,:,n].copy()
    y_pred_row = y_pred[image_num,:,n].copy()

    y_test_non_0_count = np.sum(np.not_equal(y_test_row, 0))
    non_0_y_test[n] = y_test_non_0_count

    cm_row = confusion_matrix(y_test_row.reshape([-1,1]), y_pred_row.reshape([-1,1]), normalize='true', labels=np.arange(num_classes))
    tp = np.diag(cm_row)
    fp = np.sum(cm_row, axis=0) - tp
    fn = np.sum(cm_row, axis=1) - tp
    iou_per_class = tp / (tp + fp + fn + 1e-7)
    relevant_iou_scores = iou_per_class[1:] # Exclude class 0
    mean_iou_per_row[n] = np.nanmean(relevant_iou_scores)

  # --- Plotting ---
  y_axis_coords = np.arange(heights_min_dim)

  # Plot: Mean IoU

  fig = px.scatter(x=mean_iou_per_row, y=y_axis_coords,
                  title='Mean IoU vs. Height (Excl Class 0)',
                  width=1000, height=500,
                  color=non_0_y_test,
                  color_continuous_scale=px.colors.sequential.Viridis,
                  labels={'color': 'Count'}
                  )
  fig.update_xaxes(title_text="Jaccard Value (iou)")
  fig.update_yaxes(title_text="Height Index")

  fig.show()

In [None]:
image_num = 36
plot_segmentation_results(Y_TEST, Y_PRED, image_num, with_suptitle=False)

results_36 = confusion_matrix(Y_TEST[image_num,...].reshape([-1,1]), Y_PRED[image_num,...].reshape([-1,1]), normalize='true')
plot_confusion_matrix(results_36, labels1[:7])
print(classification_report(y_test[image_num,...].reshape([-1,12])[...,:7], y_pred_onehot[image_num,...].reshape([-1,12])[...,:7], target_names=labels1[:7]))

In [None]:
image_num = 12
plot_segmentation_results(Y_TEST, Y_PRED, image_num, with_suptitle=False)

results_12 = confusion_matrix(Y_TEST[image_num,...].reshape([-1,1]), Y_PRED[image_num,...].reshape([-1,1]), normalize='true')
plot_confusion_matrix(results_12, labels1)
print(classification_report(y_test[image_num,...].reshape([-1,12]), y_pred_onehot[image_num,...].reshape([-1,12]), target_names=labels1))

In [None]:
# Lidar signal attenuation with thick cloud
image_num = 70
plot_segmentation_results(Y_TEST, Y_PRED, image_num, with_suptitle=False)
results_70 = confusion_matrix(Y_TEST[image_num,...].reshape([-1,1]), Y_PRED[image_num,...].reshape([-1,1]), normalize='true')
plot_confusion_matrix(results_70, labels1)
print(classification_report(y_test[image_num,...].reshape([-1,12]), y_pred_onehot[image_num,...].reshape([-1,12]), target_names=labels1))

In [None]:
# Mixed phase clouds