# Training a simple CNN model in `keras` (>3.0) for Tornado Detection

This notebook steps through how to train a simple CNN model using a subset of TorNet.

This will not produce a model with any skill, but simply provides a working end-to-end example of how to set up a data loader, build, and fit a model


In [2]:
import os
os.environ['KERAS_BACKEND']='tensorflow' # set to 'tensorflow', 'torch' or 'jax' (installs required)

In [3]:
import sys
# Uncomment if tornet isn't installed in your environment or in your path already
#sys.path.append('../')  

import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import keras

from tornet.data.tf.loader import create_tf_dataset 
from tornet.data.constants import ALL_VARIABLES

In [32]:
# keras accepts most data loaders (tensorflow, torch).
# A pure keras data loader, with necessary preprocessing steps for the cnn baseline, is provided
from tornet.data.keras.loader import KerasDataLoader
data_root = "C:/Users/mjhig/tornet_2013"
ds = KerasDataLoader(data_root=data_root,
                     data_type='train',
                     years=[2013,2014,2015,2016,2017],
                     workers = 4,
                     batch_size=8,
                    select_keys= ALL_VARIABLES,
                     use_multiprocessing = True)

ds_val = KerasDataLoader(data_root=data_root,
                     data_type='train',
                     years=[2018],
                     workers = 4,
                     batch_size=8,
                    select_keys= ALL_VARIABLES,
                     use_multiprocessing = True)


In [5]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv2D

class CoordConv2D(Layer):
    """ CoordConv2D: Adds coordinate channels before applying Conv2D. """
    def __init__(self, out_channels, kernel_size, padding='same', strides=1, activation='relu', **kwargs):
        super(CoordConv2D, self).__init__(**kwargs)
        self.out_channels = out_channels  # ✅ Properly initializing this variable
        self.kernel_size = kernel_size
        self.padding = padding
        self.strides = strides
        self.activation = activation
        self.conv = None  # Will be initialized in build()

    def build(self, input_shape):
        """ Create Conv2D layer after getting input shape dynamically. """
        _, height, width, channels = input_shape  # Get input shape
        self.conv = Conv2D(self.out_channels, self.kernel_size, strides=self.strides, 
                           padding=self.padding, activation=self.activation)
        self.height = height  # ✅ Store height/width for later use
        self.width = width

    def call(self, inputs):
        """ Adds coordinate channels and applies Conv2D. """
        batch_size = tf.shape(inputs)[0]  # Dynamic batch size

        # Generate coordinate grids
        x_range = tf.linspace(-1.0, 1.0, self.width)
        y_range = tf.linspace(-1.0, 1.0, self.height)
        X, Y = tf.meshgrid(x_range, y_range)
        r = tf.sqrt(X**2 + Y**2)
        r_inv = 1.0 / (r + 1e-6)  # Avoid division by zero

        # Expand dims to match batch size and add channels
        r = tf.expand_dims(r, axis=-1)  # (H, W, 1)
        r_inv = tf.expand_dims(r_inv, axis=-1)  # (H, W, 1)
        r = tf.tile(tf.expand_dims(r, axis=0), [batch_size, 1, 1, 1])  # (batch, H, W, 1)
        r_inv = tf.tile(tf.expand_dims(r_inv, axis=0), [batch_size, 1, 1, 1])  # (batch, H, W, 1)

        # Concatenate coordinate channels with input image
        x = tf.concat([inputs, r, r_inv], axis=-1)  # Shape: (batch, H, W, channels + 2)

        return self.conv(x)  # Apply the convolution

    def get_config(self):
        """ Enables saving/loading of custom layer. """
        config = super(CoordConv2D, self).get_config()
        config.update({
            "out_channels": self.out_channels,
            "kernel_size": self.kernel_size,
            "padding": self.padding,
            "strides": self.strides,
            "activation": self.activation,
        })
        return config


In [33]:
(ds.num_batches * ds.batch_size) / ds.batch_size

9548.0

In [34]:
(ds_val.num_batches * ds.batch_size) / ds.batch_size

1920.0

In [6]:
# Create a simple CNN model
# This normalizes data, concatenates along channel, and applies a Conv2D
from tornet.data.constants import CHANNEL_MIN_MAX
from tornet.models.keras.layers import FillNaNs
from tensorflow.keras.layers import Dropout, BatchNormalization

input_vars = ALL_VARIABLES # which variables to use

# TF convention is B,L,W,H
inputs = {v:keras.Input(shape=(120,240,2),name=v) for v in input_vars}

# Normalize inputs
norm_layers = []
for v in input_vars:
    min_max = np.array(CHANNEL_MIN_MAX[v]) # [2,]

    # choose mean,var to get approximate [-1,1] scaling
    var=((min_max[1]-min_max[0])/2)**2 # scalar
    var=np.array(2*[var,])    # [n_sweeps,]
    offset=(min_max[0]+min_max[1])/2    # scalar
    offset=np.array(2*[offset,]) # [n_sweeps,]
    
    norm_layers.append(
        keras.layers.Normalization(mean=offset, variance=var,
                                   name='Normalized_%s' % v)
    )

# Concatenate normed inputs along channel dimension
x=keras.layers.Concatenate(axis=-1,name='Concatenate1')(
        [l(inputs[v]) for l,v in zip(norm_layers,input_vars)]
        )

# Replace background (nan) with -3
x = FillNaNs(fill_val=-3,name='ReplaceNan')(x)
# Processing
x = CoordConv2D(32, (3, 3), padding="same", activation="relu")(x)
x = Dropout(0.1)(x)
x = CoordConv2D(64, (3, 3), padding="same", activation="relu")(x)
x = Dropout(0.1)(x)
x = CoordConv2D(128,(3,3),strides=2,padding='same',activation='relu')(x)
x = Dropout(0.1)(x)
x = CoordConv2D(256,(3,3),strides=2,padding='same',activation='relu')(x)
x=  BatchNormalization()(x)  # It normalizes the output of a previous activation layer
x = Dropout(0.1)(x)
x = keras.layers.Conv2D(1,1,padding='same',activation=None, name='TornadoLikelihood')(x)
x = keras.layers.GlobalAveragePooling2D(name='GlobalMaxPool')(x)
x = keras.layers.Dense(64, activation='relu')(x)
y = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=inputs,outputs=y,name='TornadoDetector')

model.summary()




In [None]:
# ==================== 🌟 Imports 🌟 ==================== #
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.losses import BinaryFocalCrossentropy
from tensorflow.keras.metrics import AUC, BinaryAccuracy, Precision, Recall
from tensorflow.keras.optimizers import AdamW

monitor='pr_auc'  # Monitor Precision-Recall AUC for imbalanced data



# ==================== 🏃 Callbacks 🏃 ==================== #
# 🛑 Early stopping to prevent overfitting and restore the best weights
early_stopping = EarlyStopping(
    monitor=monitor,
    patience=4,               # Allow some epochs for PR AUC improvement
    mode='max',               # Because higher pr_auc is better
    restore_best_weights=True # Reload the best weights when stopping
)

# 📉 Dynamic LR scheduler to adjust learning rate when PR AUC stalls
reduce_lr = ReduceLROnPlateau(
    monitor=monitor,
    factor=0.5,               # Halve the LR if performance plateaus
    patience=2,               # Give more time before reducing LR
    min_lr=1e-6,              # Minimum LR to prevent over-reduction
    mode='max',
    verbose=1
)

# ==================== 🎯 Custom Loss 🎯 ==================== #
# ⚖️ Weighted binary cross-entropy to handle class imbalance
def weighted_binary_crossentropy(pos_weight):
    def loss(y_true, y_pred):
        bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        weight = y_true * pos_weight + (1 - y_true)  # Heavier penalty for false negatives
        return tf.reduce_mean(bce * weight)
    return loss

import tensorflow as tf
import keras.backend as K



# ==================== 🚀 Optimizer & Metrics 🚀 ==================== #
# ⚡ AdamW optimizer with weight decay for regularization
opt = AdamW(learning_rate=1e-4, weight_decay=1e-4)

# 📊 Metrics to evaluate model performance beyond accuracy
metrics = [
    AUC(curve='PR', name='pr_auc'),  # Precision-Recall AUC for imbalanced data
    AUC(name='AUC'),   
    BinaryAccuracy(name='accuracy'), # Overall accuracy
    Precision(name='precision'),     # Precision for positive class
    Recall(name='recall'),           # Recall for positive class    
              # ROC AUC
]

# ==================== 🔧 Compile Model 🔧 ==================== #
# 🛠️ Compile the model with custom loss and chosen metrics
model.compile(
    optimizer=opt,
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # Adjust pos_weight as needed (e.g., 10, 20)
    metrics=metrics
)

# ==================== 🏃‍♂️ Train Model 🏃‍♂️ ==================== #
# 🚀 Training the model with callbacks for early stopping and adaptive LR

from tqdm import tqdm
import time
from tensorflow.keras.callbacks import Callback
import numpy as np

# Custom callback with progress updates and validation metrics per step
class ValidationProgressWithMetricsCallback(Callback):
    def __init__(self, validation_data, val_steps):
        super().__init__()
        self.validation_data = validation_data
        self.val_steps = val_steps

    def on_epoch_end(self, epoch, logs=None):
        print(f"\nStarting validation for epoch {epoch + 1}...")
        start_time = time.time()

        pbar = tqdm(total=self.val_steps, desc=f"Validation (Epoch {epoch + 1})", ncols=120)
        total_metrics = None  # To store cumulative metrics

        for i, (x_val, y_val) in enumerate(self.validation_data):
            if i >= self.val_steps:
                break

            # Evaluate the current batch and get the metrics
            metrics = self.model.test_on_batch(x_val, y_val, return_dict=True)

            # Initialize total_metrics on first iteration
            if total_metrics is None:
                total_metrics = {k: 0.0 for k in metrics.keys()}

            # Accumulate metrics for averaging later
            for k, v in metrics.items():
                total_metrics[k] += v

            # Compute average metrics so far
            avg_metrics = {k: total_metrics[k] / (i + 1) for k in total_metrics}

            # Update progress bar with current metrics
            metrics_display = " - ".join([f"{k}: {v:.4f}" for k, v in avg_metrics.items()])
            pbar.set_postfix_str(metrics_display)
            pbar.update(1)

        pbar.close()

        # Final averaged metrics after validation
        final_metrics = {k: v / self.val_steps for k, v in total_metrics.items()}
        final_metrics_display = " - ".join([f"{k}: {v:.4f}" for k, v in final_metrics.items()])

        print(f"\nValidation completed in {time.time() - start_time:.2f} seconds for epoch {epoch + 1}.")
        print(f"Final Validation Metrics: {final_metrics_display}\n")


# Example validation steps
val_steps = 100  # Adjust based on your validation dataset size

# Updated model.fit with validation progress and metrics callback
history = model.fit(
    ds,
    validation_data=None,  # Validation is handled by the custom callback
    epochs=20,
    steps_per_epoch=88,
    callbacks=[early_stopping, reduce_lr, ValidationProgressWithMetricsCallback(ds_val, val_steps)]
)

# ==================== 📈 Final Notes 📈 ==================== #
# - Consider tuning 'pos_weight' further for optimal PR AUC.
# - Experiment with 'BinaryFocalCrossentropy' if PR AUC remains low:
#     model.compile(optimizer=opt, loss=BinaryFocalCrossentropy(gamma=2.0), metrics=metrics)
# - Check PR AUC thresholds post-training for best classification cutoff.


Epoch 1/20


In [31]:
# Build a test set
ds_test = KerasDataLoader(data_root=data_root,
                         data_type='test',
                         years=[2013,2014,2015,2016,2017,2018],
                         batch_size = 8, 
                         workers = 4,
                         select_keys=['DBZ', 'VEL', 'KDP', 'RHOHV', 'ZDR', 'WIDTH'],
                         use_multiprocessing = True)


In [None]:
model=keras.models.load_model('tornado_detector_baseline.keras')

In [None]:
model=keras.models.load_model('tornado_detector_baseline.keras',compile=True)

In [32]:
# Evaluate
import tornet.metrics.keras.metrics as km
metrics = [keras.metrics.AUC(curve='pr',name='AUC'),
           keras.metrics.Precision(),
           keras.metrics.Recall()
           ]
from tensorflow.keras.optimizers import AdamW
model.compile(metrics=metrics)
#model.compile(optimizer=AdamW(learning_rate=.001),loss=BinaryFocalCrossentropy(),metrics=metrics)
# steps=10 for demo purposes
model.evaluate(ds_test,steps=88,return_dict=True)

[1m88/88[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 271ms/step - AUC: 0.1982 - loss: 1.4086 - precision_1: 0.0000e+00 - recall_1: 0.0000e+00


{'AUC': 0.18733206391334534,
 'loss': 1.072451114654541,
 'precision_1': 0.0,
 'recall_1': 0.0}

In [None]:
model.load()