# Training a simple CNN model in Tensorflow for Tornado Detection

This notebook steps through how to train a simple CNN model using a subset of TorNet.

This will not produce a model with any skill, but simply provides a working end-to-end example of how to set up a data loader, build, and fit a model


In [1]:
import sys
# Uncomment if tornet isn't installed in your environment or in your path already
#sys.path.append('../')  

import os
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from tornet.data.tf.loader import create_tf_dataset 
from tornet.data.constants import ALL_VARIABLES

In [2]:
# Create basic dataloader
# This option loads directly from netcdf files, and will be slow and IO bound
# To speed up training, either
#     build as a tensorflow_dataset , (see tornet/data/tfds/tornet/README.md)
#     cache dataset first , or
#     use tf.data.Dataset.load on a pre-saved dataset

# Location of tornet
data_root = 'C:/Users/mjhig/tornet_2013'

# Get training data from 2018
data_type='train'
years = [2013,2014,2015,2016,2017,2018]

catalog_path = os.path.join(data_root,'catalog.csv')
if not os.path.exists(catalog_path):
    raise RuntimeError('Unable to find catalog.csv at '+data_root)
        
catalog = pd.read_csv(catalog_path,parse_dates=['start_time','end_time'])
catalog = catalog[catalog['type']==data_type]
catalog = catalog[catalog.start_time.dt.year.isin(years)]
catalog = catalog.sample(frac=1,random_state=1234)
file_list = [os.path.join(data_root,f) for f in catalog.filename]

ds = create_tf_dataset(file_list,variables=ALL_VARIABLES,n_frames=1) 

# (Optional) Save data for faster reloads (makes copy of data!)
#ds.save('tornet_sample.tfdataset') 


In [3]:
# If saved with ds.save(...), just load that model
#ds = tf.data.Dataset.load('tornet_sample.tfdataset')

In [4]:
# If data was registered in tensorflow_dataset, use that
# env variable TFDS_DATA_DIR should point to location of this resaved dataset
#import tensorflow_datasets as tfds
#import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'

#data_type='train'
#years = [2018,]
#ds = tfds.load('tornet',split='+'.join(['%s-%d' % (data_type,y) for y in years]))

In [5]:
import tornet.data.preprocess as pp
from tornet.data import preprocess as tfpp

# Preprocess

# add 'coordinates' variable used by CoordConv layers
#ds = ds.map(lambda d: pp.add_coordinates(d,include_az=False,backend=tf))
     
# Take only last time frame
ds = ds.map(pp.remove_time_dim)

# Split sample into inputs,label
ds = ds.map(tfpp.split_x_y)

# (Optional) add sample weights
# weights={'wN':1.0,'w0':1.0,'w1':1.0,'w2':2.0,'wW':0.5}
# ds = ds.map(lambda x,y:  tfpp.compute_sample_weight(x,y,**weights) )

ds = ds.prefetch(tf.data.AUTOTUNE)
ds=ds.batch(32)

In [9]:
# Create a simple CNN model
# This normalizes data, concatenates along channel, and applies a Conv2D
import keras
from tornet.data.constants import CHANNEL_MIN_MAX
from tensorflow.keras.layers import (
    Conv2D, GlobalMaxPool2D,MaxPooling2D, Dense, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D
)
input_vars = ALL_VARIABLES

# TF convention is B,L,W,H
inputs = {v:keras.Input(shape=(120,240,2),name=v) for v in input_vars}

# Normalize inputs
norm_layers = []
for v in input_vars:
    min_max = np.array(CHANNEL_MIN_MAX[v]) # [2,]

    # choose mean,var to get approximate [-1,1] scaling
    var=((min_max[1]-min_max[0])/2)**2 # scalar
    var=np.array(2*[var,])    # [n_sweeps,]
    offset=(min_max[0]+min_max[1])/2    # scalar
    offset=np.array(2*[offset,]) # [n_sweeps,]
    
    norm_layers.append(
        keras.layers.Normalization(mean=offset, variance=var,
                                   name='Normalized_%s' % v)
    )

# Concatenate normed inputs along channel dimension
x=keras.layers.Concatenate(axis=-1,name='Concatenate1')(
        [l(inputs[v]) for l,v in zip(norm_layers,input_vars)]
        )

# Replace background (nan) with -3
x=keras.layers.Lambda(lambda x: tf.where(tf.math.is_nan(x),-3.0,x),name='ReplaceNan')(x)

# **Improved Convolutional Feature Extraction**
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)  # Normalize activations
x = MaxPooling2D((2, 2))(x)  # Reduce spatial size
x = Dropout(0.3)(x)  # Regularization

x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Dropout(0.3)(x)


# **Final Prediction Layers**
x = GlobalAveragePooling2D()(x)  # Global feature aggregation
x = Dense(128, activation='relu')(x)  # Fully connected layer
x = Dropout(0.3)(x)
x = Dense(64, activation='relu')(x)
x = Dropout(0.3)(x)

y = Dense(1, activation='sigmoid', name='TornadoLikelihood')(x)

model = keras.Model(inputs=inputs,outputs=y,name='TornadoDetector')

model.summary()

In [7]:
# # Build a test set
# # Basic loader
# data_type='test'
# years = [2014]

# catalog_path = os.path.join(data_root,'catalog.csv')
# if not os.path.exists(catalog_path):
#     raise RuntimeError('Unable to find catalog.csv at '+data_root)
        
# catalog = pd.read_csv(catalog_path,parse_dates=['start_time','end_time'])
# catalog = catalog[catalog['type']==data_type]
# catalog = catalog[catalog.start_time.dt.year.isin(years)]
# catalog = catalog.sample(frac=1,random_state=1234)
# file_list = [os.path.join(data_root,f) for f in catalog.filename]

# ds_test = create_tf_dataset(file_list,variables=ALL_VARIABLES) 


# preprocess
# ds_test = ds_test.map(lambda d: pp.add_coordinates(d,include_az=False,backend=tf))
ds_test = ds_test.map(pp.remove_time_dim)
ds_test = ds_test.map(tfpp.split_x_y)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)   

history_df = pd.DataFrame(history.history)
# Start the plot at epoch 5. You can change this to get a different view.
history_df.loc[:,['Precision','Recall','pr_auc']].plot()

In [17]:
import optuna
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam, SGD,AdamW,Nadam,Adagrad
from tensorflow.keras.metrics import AUC, Precision, Recall
from tensorflow.keras.losses import BinaryCrossentropy,BinaryFocalCrossentropy
from tensorflow.keras.initializers import GlorotUniform

# Function to reset model weights
def reset_weights(model):
    """Reset model weights using GlorotUniform initialization."""
    for layer in model.layers:
        if hasattr(layer, 'kernel_initializer') and layer.weights:
            for weight in layer.weights:
                weight.assign(GlorotUniform()(weight.shape))

# Dice loss function
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
tf.random.set_seed(42)
# Define the objective function for Optuna
def objective(trial):
    # Sample hyperparameters
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
    #optimizer_name = trial.suggest_categorical('optimizer', ['adam', 'sgd','adamw','nadam','adagrad'])
    #loss_function = trial.suggest_categorical('loss_function', ['dice_loss', 'binary_crossentropy','binary_focal_crossentropy'])
    weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3,log=True)
    #alpha = trial.suggest_float('alpha', 0.25, 0.50, step=0.05)    
    #gamma = trial.suggest_float('gamma', 0, 5,step=0.5 )
    #class_weights = {0: 1, 1: trial.suggest_float('pos_weight', 1.0, 10.0, step=0.5)}
    
    optimizer = AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
    loss = BinaryFocalCrossentropy()
    # Select optimizer
    # if optimizer_name == 'adam':
    #     optimizer = Adam(learning_rate=learning_rate)
    # elif optimizer_name == 'sgd':
    #     optimizer = SGD(learning_rate=learning_rate, momentum=0.9)
    # elif optimizer_name == 'adamw':
    #     optimizer = AdamW(learning_rate=learning_rate, weight_decay=1e-4)
    # elif optimizer_name == 'nadam':
    #     optimizer = Nadam(learning_rate=learning_rate)
    # elif optimizer_name == 'adagrad':
    #     optimizer = Adagrad(learning_rate=learning_rate)

    # # Select loss function
    # if loss_function == 'dice_loss':
    #     loss = dice_loss
    # elif loss_function == 'binary_crossentropy':
    #     loss = BinaryCrossentropy()
    # elif loss_function == 'binary_focal_crossentropy':
    #     loss = BinaryFocalCrossentropy()
    
    


    # Reset model weights before training
    reset_weights(model)

    # Compile the model
    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=[
            AUC(curve='PR', name='pr_auc'),
            AUC(name='AUC'),
            Precision(name='Precision'),
            Recall(name='Recall')
        ]
    )

    # Early stopping
    early_stopping = EarlyStopping(
        monitor='pr_auc',
        patience=2,
        mode='max',
        restore_best_weights=True
    )

    # Train the model
    history = model.fit(
        ds,
        epochs=20,  # Limit epochs for tuning
        steps_per_epoch=10,
        callbacks=[early_stopping],
    )

    # Get the best PR AUC score
    best_pr_auc = max(history.history['pr_auc'])

    return best_pr_auc  # Optuna will maximize this

# Run Optuna optimization
#ds=ds.batch(32)
#study = optuna.create_study(direction='maximize')
#study.optimize(objective, n_trials=100)
model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=[
            AUC(curve='PR', name='pr_auc'),
            AUC(name='AUC'),
            Precision(name='Precision'),
            Recall(name='Recall')
        ]
    )

    # Early stopping
    early_stopping = EarlyStopping(
        monitor='pr_auc',
        patience=2,
        mode='max',
        restore_best_weights=True
    )




[I 2025-01-30 13:54:24,936] A new study created in memory with name: no-name-fdd19d78-33dd-4537-a5c9-402af605eeda


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 3s/step - AUC: 0.6459 - Precision: 0.0495 - Recall: 0.7117 - loss: 0.2158 - pr_auc: 0.0609
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 3s/step - AUC: 0.4557 - Precision: 0.1590 - Recall: 0.1691 - loss: 0.1346 - pr_auc: 0.1153
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 3s/step - AUC: 0.5053 - Precision: 0.1423 - Recall: 0.1915 - loss: 0.0911 - pr_auc: 0.0927
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 3s/step - AUC: 0.6214 - Precision: 0.0727 - Recall: 0.0130 - loss: 0.0760 - pr_auc: 0.0839     
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 2s/step - AUC: 0.6182 - Precision: 0.2404 - Recall: 0.1761 - loss: 0.0914 - pr_auc: 0.1440
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 3s/step - AUC: 0.6308 - Precision: 0.1590 - Recall: 0.0698 - loss: 0.0661 - pr_auc: 0.0804


[I 2025-01-30 13:57:05,833] Trial 0 finished with value: 0.14369924366474152 and parameters: {'learning_rate': 7.15907924760212e-05, 'weight_decay': 4.899194157588933e-05}. Best is trial 0 with value: 0.14369924366474152.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5959 - Precision: 0.0396 - Recall: 0.8005 - loss: 0.4466 - pr_auc: 0.0465
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 3s/step - AUC: 0.5941 - Precision: 0.1152 - Recall: 0.4060 - loss: 0.1731 - pr_auc: 0.1877
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 3s/step - AUC: 0.4955 - Precision: 0.0443 - Recall: 0.0925 - loss: 0.1256 - pr_auc: 0.0747
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 3s/step - AUC: 0.5920 - Precision: 0.0972 - Recall: 0.2077 - loss: 0.1072 - pr_auc: 0.0769


[I 2025-01-30 13:58:52,999] Trial 1 finished with value: 0.12156407535076141 and parameters: {'learning_rate': 0.00010884038631971741, 'weight_decay': 6.942824905610785e-05}. Best is trial 0 with value: 0.14369924366474152.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.3740 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1832 - pr_auc: 0.0281
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.4673 - Precision: 0.3901 - Recall: 0.1388 - loss: 0.1221 - pr_auc: 0.1727
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.6012 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0822 - pr_auc: 0.0730
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.6906 - Precision: 0.0606 - Recall: 0.0288 - loss: 0.0891 - pr_auc: 0.1078
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.4682 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1197 - pr_auc: 0.0750
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 4s/step - AUC: 0.6591 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0

[I 2025-01-30 14:01:56,951] Trial 2 finished with value: 0.11834824830293655 and parameters: {'learning_rate': 0.005753367831528095, 'weight_decay': 0.0008466108293832122}. Best is trial 0 with value: 0.14369924366474152.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 3s/step - AUC: 0.3512 - Precision: 0.0162 - Recall: 0.3502 - loss: 0.3133 - pr_auc: 0.0273
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5113 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1302 - pr_auc: 0.1019
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5695 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0791 - pr_auc: 0.0954
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.5529 - Precision: 0.0202 - Recall: 0.0065 - loss: 0.0855 - pr_auc: 0.0712     
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.6070 - Precision: 0.2290 - Recall: 0.0634 - loss: 0.0858 - pr_auc: 0.1304
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 3s/step - AUC: 0.6888 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.07

[I 2025-01-30 14:05:37,759] Trial 3 finished with value: 0.14099299907684326 and parameters: {'learning_rate': 0.00026767477221638646, 'weight_decay': 0.00020365838439429022}. Best is trial 0 with value: 0.14369924366474152.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 3s/step - AUC: 0.2586 - Precision: 0.0205 - Recall: 0.3502 - loss: 0.3544 - pr_auc: 0.0240
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.4523 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1524 - pr_auc: 0.0900
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.6166 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0879 - pr_auc: 0.0799
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 3s/step - AUC: 0.6041 - Precision: 0.0640 - Recall: 0.0418 - loss: 0.0941 - pr_auc: 0.0900
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.4936 - Precision: 0.0851 - Recall: 0.0751 - loss: 0.1270 - pr_auc: 0.0883
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.4171 - Precision: 0.0270 - Recall: 0.0292 - loss: 0.0996 - pr_auc: 

[I 2025-01-30 14:08:47,110] Trial 4 finished with value: 0.13527169823646545 and parameters: {'learning_rate': 0.0020074505111367965, 'weight_decay': 0.00017051354676053028}. Best is trial 0 with value: 0.14369924366474152.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 3s/step - AUC: 0.4057 - Precision: 0.0249 - Recall: 0.3502 - loss: 0.5558 - pr_auc: 0.0295
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.5482 - Precision: 0.0145 - Recall: 0.0184 - loss: 0.1335 - pr_auc: 0.1028
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5360 - Precision: 0.0297 - Recall: 0.0222 - loss: 0.1301 - pr_auc: 0.0687   
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.6220 - Precision: 0.1867 - Recall: 0.2754 - loss: 0.1194 - pr_auc: 0.1370
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.5780 - Precision: 0.0644 - Recall: 0.0745 - loss: 0.1214 - pr_auc: 0.1330
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.6190 - Precision: 0.0887 - Recall: 0.1118 - loss: 0.0751 - pr_auc: 0.0870


[I 2025-01-30 14:11:54,598] Trial 5 finished with value: 0.16851019859313965 and parameters: {'learning_rate': 0.005536548442418818, 'weight_decay': 1.5242765005950263e-05}. Best is trial 5 with value: 0.16851019859313965.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 4s/step - AUC: 0.5028 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0877 - pr_auc: 0.0364
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 4s/step - AUC: 0.5468 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1060 - pr_auc: 0.1030
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.4591 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0728 - pr_auc: 0.0504
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.5527 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0710 - pr_auc: 0.0807
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 4s/step - AUC: 0.7204 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0701 - pr_auc: 0.2009
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 4s/step - AUC: 0.6897 - Precision: 0.0000e+00 - Recall: 0.00

[I 2025-01-30 14:16:05,661] Trial 6 finished with value: 0.18140867352485657 and parameters: {'learning_rate': 0.0002281875239379196, 'weight_decay': 0.0008319142931225579}. Best is trial 6 with value: 0.18140867352485657.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 4s/step - AUC: 0.3498 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0811 - pr_auc: 0.0293
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.4742 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1239 - pr_auc: 0.0885
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 4s/step - AUC: 0.4547 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0813 - pr_auc: 0.0505
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 6s/step - AUC: 0.6521 - Precision: 0.1364 - Recall: 0.0104 - loss: 0.0717 - pr_auc: 0.1234    
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.6022 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0848 - pr_auc: 0.1178
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.7592 - Precision: 0.0000e+00 - Recall: 0.0000e+

[I 2025-01-30 14:20:01,254] Trial 7 finished with value: 0.14377830922603607 and parameters: {'learning_rate': 0.0009311990832162779, 'weight_decay': 2.6490328408202273e-05}. Best is trial 6 with value: 0.18140867352485657.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 3s/step - AUC: 0.3918 - Precision: 0.0387 - Recall: 1.0000 - loss: 0.9562 - pr_auc: 0.0289
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5783 - Precision: 0.1014 - Recall: 1.0000 - loss: 0.8524 - pr_auc: 0.1276
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.6093 - Precision: 0.0583 - Recall: 1.0000 - loss: 0.7831 - pr_auc: 0.0863
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 3s/step - AUC: 0.5683 - Precision: 0.0622 - Recall: 0.9896 - loss: 0.7206 - pr_auc: 0.0909
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5180 - Precision: 0.0786 - Recall: 0.9437 - loss: 0.7024 - pr_auc: 0.1140
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 3s/step - AUC: 0.4484 - Precision: 0.0512 - Recall: 1.0000 - loss: 0.6854 - pr_auc: 0.0501
Epoch 7/2

[I 2025-01-30 14:23:37,866] Trial 8 finished with value: 0.1434554159641266 and parameters: {'learning_rate': 1.0397431521607119e-05, 'weight_decay': 0.0003727795298546375}. Best is trial 6 with value: 0.18140867352485657.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 3s/step - AUC: 0.5132 - Precision: 0.0274 - Recall: 0.4095 - loss: 0.2264 - pr_auc: 0.0510
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5311 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1371 - pr_auc: 0.0975
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5664 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0810 - pr_auc: 0.0672
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.6853 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.0685 - pr_auc: 0.0976
Epoch 5/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 4s/step - AUC: 0.4774 - Precision: 0.0493 - Recall: 0.0319 - loss: 0.1127 - pr_auc: 0.0744
Epoch 6/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5661 - Precision: 0.0652 - Recall: 0.0357 - loss: 0.0670 - 

[I 2025-01-30 14:26:40,809] Trial 9 finished with value: 0.11065394431352615 and parameters: {'learning_rate': 0.0003911189219363936, 'weight_decay': 0.00011210640545586569}. Best is trial 6 with value: 0.18140867352485657.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.3481 - Precision: 0.0319 - Recall: 0.5202 - loss: 0.3751 - pr_auc: 0.0448
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.4927 - Precision: 0.0941 - Recall: 0.5728 - loss: 0.2668 - pr_auc: 0.1036
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 3s/step - AUC: 0.5405 - Precision: 0.0464 - Recall: 0.3914 - loss: 0.2194 - pr_auc: 0.0619
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 3s/step - AUC: 0.4455 - Precision: 0.0278 - Recall: 0.2166 - loss: 0.1805 - pr_auc: 0.0503


[I 2025-01-30 14:28:40,555] Trial 10 finished with value: 0.08534997701644897 and parameters: {'learning_rate': 2.266859176132713e-05, 'weight_decay': 0.0008815963139915432}. Best is trial 6 with value: 0.18140867352485657.


Epoch 1/20




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 3s/step - AUC: 0.5160 - Precision: 0.1250 - Recall: 0.3502 - loss: 0.1761 - pr_auc: 0.0487
Epoch 2/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5290 - Precision: 0.0000e+00 - Recall: 0.0000e+00 - loss: 0.1474 - pr_auc: 0.1343
Epoch 3/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 3s/step - AUC: 0.5321 - Precision: 0.0909 - Recall: 0.0083 - loss: 0.0730 - pr_auc: 0.1002     
Epoch 4/20
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 4s/step - AUC: 0.6117 - Precision: 0.1394 - Recall: 0.0714 - loss: 0.0825 - pr_auc: 0.0895
Epoch 5/20


In [38]:
# Evaluate
import tornet.metrics.keras.metrics as km
metrics = [keras.metrics.AUC(curve='PR',name='PRAUC'),
           km.BinaryAccuracy(from_logits=True,name='BinaryAccuracy'), 
           km.Precision(from_logits=True,name='Precision'),
           ]
model.compile(metrics=metrics)

# steps=10 for demo purposes
model.evaluate(ds_test,steps=10)




[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 1s/step - BinaryAccuracy: 0.1216 - PRAUC: 0.1009 - Precision: 0.1216 - loss: 0.8199


[0.8348420858383179,
 0.0908624678850174,
 0.10625000298023224,
 0.10625000298023224]