## Training the MaxViTSmall Model for Lens Finding

* You will need to install tensorflow_addons and keras-cv the first time you run this nb, as it is not in the tensorflow-2.9.0 kernel

In [1]:
# !pip install tensorflow_addons
# !pip install keras-cv-attention-models>=1.3.4

In [2]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

import tensorflow as tf
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")
from tensorflow.keras import regularizers
from tensorflow.keras.applications import EfficientNetV2S
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, BatchNormalization as BatchNorm
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.utils import plot_model
from tensorflow.keras import layers, Model, Input
import keras_cv_attention_models

import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve
import tensorflow_addons as tfa
from datetime import date
import time

2025-07-07 13:43:01.892871: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-07 13:43:01.892914: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-07 13:43:01.937823: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-07 13:43:02.034734: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


In [3]:
!nvidia-smi 

# Run the following in a terminal to monitor VRAM during training
# watch -n 0.5 nvidia-smi

Mon Jul  7 13:43:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.163.01             Driver Version: 550.163.01     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:03:00.0 Off |                    0 |
| N/A   31C    P0             57W /  400W |       5MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00

In [4]:
clean_deluxe = '/global/cfs/projectdirs/cosmo/work/users/xhuang/dr10_1/Clean-Samples/TS40_deluxe_clean'
ethan_sim_jwst = '/global/cfs/projectdirs/deepsrch/jwst_sims/pristine_bright/'
data_path = clean_deluxe

 * clean deluxe is our highest quality sample, in the Clean-Samples dir you will find TS40 Baseline, which has more samples, but some positive and negative candidates may not be as clear, or have additional noise

In [9]:
path = ethan_sim_jwst

x0 = np.load(path+"images.npy")
y0 = np.load(path+"lensed.npy")
cap = np.percentile(x0, 99)

#Image pre-processing
for i in range(len(x0)):
    cap = np.percentile(x0[i],99)
    x0[i][x0[i]>cap]=cap
    x0[i] = (x0[i]-np.mean(x0[i])) / np.std(x0[i])
# for i in range(len(x0)):
#     x0[i] = x0[i] + add_poisson(x0[i], uniform.rvs(loc=150,scale=550))

#train-val split
np.random.seed(15)
indices = np.arange(len(x0))
np.random.shuffle(indices)
start = len(x0)//5 * 0 #0
end = len(x0)//5 * 1 #1
val_inds = indices[start:end]
train_inds = np.concatenate([indices[:start],indices[end:]])

xtrain = x0[train_inds]
xval = x0[val_inds]
ytrain = y0[train_inds]
yval = y0[val_inds]

xtrain = np.reshape(xtrain,(len(xtrain),125,125,1))
xval =  np.reshape(xval, (len(xval),125,125,1))
ytrain = np.reshape(ytrain, (len(ytrain),1))
yval =  np.reshape(yval, (len(yval),1))

xtrain = np.clip(xtrain, -1, 1)  
xval = np.clip(xval, -1, 1)

if xtrain.ndim == 3:
    xtrain = np.expand_dims(xtrain, axis=-1)
if xval.ndim == 3:
    xval = np.expand_dims(xval, axis=-1)

In [5]:
data_path = clean_deluxe

xtrain = np.load(f"{data_path}/train_x.npy")
ytrain = np.load(f"{data_path}/train_y.npy").reshape(-1, 1)

xval = np.load(f"{data_path}/val_x.npy")
yval = np.load(f"{data_path}/val_y.npy").reshape(-1, 1)

xtrain = np.clip(xtrain, -1, 1)  
xval = np.clip(xval, -1, 1)

if xtrain.ndim == 3:
    xtrain = np.expand_dims(xtrain, axis=-1)
if xval.ndim == 3:
    xval = np.expand_dims(xval, axis=-1)

print("xtrain type:", type(xtrain))
print("xtrain shape:", getattr(xtrain, 'shape', 'No shape'))
print("xtrain dtype:", getattr(xtrain, 'dtype', 'No dtype'))

print("ytrain type:", type(ytrain))
print("ytrain shape:", getattr(ytrain, 'shape', 'No shape'))
print("ytrain dtype:", getattr(ytrain, 'dtype', 'No dtype'))

# Check a single sample
try:
    print("Sample xtrain[0] shape:", xtrain[0].shape)
    print("Sample ytrain[0]:", ytrain[0])
except Exception as e:
    print("Error accessing sample:", e)

xtrain type: <class 'numpy.ndarray'>
xtrain shape: (94887, 101, 101, 3)
xtrain dtype: float32
ytrain type: <class 'numpy.ndarray'>
ytrain shape: (94887, 1)
ytrain dtype: float64
Sample xtrain[0] shape: (101, 101, 3)
Sample ytrain[0]: [0.]


In [6]:
with tf.device('/CPU:0'):
    def ensure_rgb(x):
        if x.shape.rank == 3 and x.shape[-1] == 1:
            x = tf.image.grayscale_to_rgb(x)
        return x
    def preprocess(x, y):
        x = ensure_rgb(x)
        x = tf.image.resize(x, [224, 224])
        x = tf.image.random_flip_left_right(tf.image.random_flip_up_down(x))
        rg = tf.random.uniform(shape=[],minval=0, maxval=2 * np.pi, dtype=tf.float32)
        x = tfa.image.rotate(x, angles=rg, fill_mode = 'reflect')
        return x, y
    def preprocess_val(x, y):
        x = ensure_rgb(x)
        x = tf.image.resize(x, [224, 224])
        return x, y
    # image augmentation to help prevent overfitting in training. 

    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO
    
    batch_size = 512
    train = (tf.data.Dataset.from_tensor_slices((xtrain, ytrain))
            .map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .shuffle(len(ytrain), reshuffle_each_iteration=True, seed=42) 
            .repeat()
            .batch(batch_size)
            .prefetch(tf.data.experimental.AUTOTUNE)).with_options(options)
    
    validate = (tf.data.Dataset.from_tensor_slices((xval, yval))
            .map(preprocess_val, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .shuffle(len(yval))
            .repeat()
            .batch(batch_size)
            .prefetch(tf.data.experimental.AUTOTUNE)).with_options(options)

2025-07-07 13:43:48.437174: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:236] Using CUDA malloc Async allocator for GPU: 0
2025-07-07 13:43:48.439508: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38366 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:03:00.0, compute capability: 8.0
2025-07-07 13:43:48.440557: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:236] Using CUDA malloc Async allocator for GPU: 1
2025-07-07 13:43:48.442096: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 38366 MB memory:  -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:41:00.0, compute capability: 8.0
2025-07-07 13:43:48.442359: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:236] Using CUDA malloc Async allocator for GPU: 2
2025-07-07 13:43:48.443724: I tensorflow/core/common_runtime/gpu/gpu_d

In [6]:
### To resume training, set epoch to last save epoch and set LR to last known LR.
START_EPOCH = 0
lr_stopped_at = 0.0

run_name = "F1"
today = date.today()
d1 = today.strftime("%d_%m_%Y") 
# this WILL override multiple runs on same day because of line 17, 
# rename run name to distinguish between runs on the same day

parent_dir = "_Time_Trials"
save_dir = parent_dir + "/" + d1 + run_name

if START_EPOCH == 0:
    !mkdir {parent_dir}
    !rm -rf {save_dir}
    !mkdir {save_dir}
    print("CREATED DIRECTORY")

mkdir: cannot create directory ‘_Time_Trials’: File exists
CREATED DIRECTORY


In [7]:
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=8,
    verbose=1,
    min_lr=1e-7
)

metrics = tf.keras.metrics

In [8]:
strategy = tf.distribute.MirroredStrategy(cross_device_ops = tf.distribute.ReductionToOneDevice())

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


In [9]:
def create_maxvit(lr=1e-5):
    inputlayer = Input(shape = (224,224,3))
    base_model = keras_cv_attention_models.maxvit.MaxViT_Tiny(pretrained='imagenet', pretrained_base=True)
    base_model.trainable = True
    headless_output = Model(inputs=base_model.input, outputs=base_model.layers[-2].output) # required since head is 1000 logit classifier
    
    x = headless_output(inputlayer)
    x = Dense(64, activation='relu')(x)
    output = Dense(1, activation='sigmoid', dtype='float32')(x)

    model = tf.keras.models.Model(inputs=inputlayer, outputs=output)

    model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate = lr),
        loss = tf.keras.losses.BinaryCrossentropy(from_logits=False),
        metrics = [
            metrics.AUC(num_thresholds=1000), 
            metrics.Precision(0.9), 
            metrics.Recall(0.9),
        ],
    )
    return model

In [10]:
checkpoint = ModelCheckpoint(
    f"{save_dir}/chkpt.h5", 
    monitor = f'val_auc', 
    save_best_only = True, 
    mode = 'max', 
    verbose = 1, 
    save_weights_only = True,
)

csv_logger = CSVLogger(
    f"{save_dir}/training_history.csv", 
    separator = ',', 
    append = True,
)

callbacks = [
    checkpoint, 
    csv_logger,
    reduce_lr,
]

In [None]:
with strategy.scope():
    
    train_dist = strategy.experimental_distribute_dataset(train)
    val_dist = strategy.experimental_distribute_dataset(validate)

    ### Distributed Training ###
    
    model = create_maxvit()

    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
    start = time.time()
    print(f'Start: {start}')
    
    model.fit(
        train_dist, 
        validation_data = val_dist, 
        epochs = 160, 
        steps_per_epoch = (len(ytrain) // batch_size), 
        callbacks = callbacks, 
        verbose = 1, 
        batch_size = batch_size, 
        validation_steps = (len(yval) // batch_size),
    )
    
    end = time.time()
    print(f'Total time running: {end-start}')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

2025-06-18 21:20:24.404471: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:422] fused(ShuffleDatasetV3:2,RepeatDataset:3): Filling up shuffle buffer (this may take a while): 70900 of 94887
2025-06-18 21:20:26.892476: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] Shuffle buffer filled.
2025-06-18 21:20:28.062270: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8903
2025-06-18 21:20:28.062371: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8903
2025-06-18 21:20:28.075060: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8903
2025-06-18 21:20:28.086737: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8903


Epoch 1: val_auc improved from -inf to 0.51548, saving model to _Time_Trials/18_06_2025F1/chkpt.h5
Epoch 2/160
Epoch 2: val_auc did not improve from 0.51548
Epoch 3/160
Epoch 3: val_auc did not improve from 0.51548
Epoch 4/160
Epoch 4: val_auc did not improve from 0.51548
Epoch 5/160
Epoch 5: val_auc did not improve from 0.51548
Epoch 6/160
Epoch 6: val_auc did not improve from 0.51548
Epoch 7/160
Epoch 7: val_auc did not improve from 0.51548
Epoch 8/160
Epoch 8: val_auc did not improve from 0.51548
Epoch 9/160
Epoch 9: val_auc did not improve from 0.51548
Epoch 10/160
Epoch 10: val_auc did not improve from 0.51548
Epoch 11/160
Epoch 11: val_auc did not improve from 0.51548
Epoch 12/160
Epoch 12: val_auc did not improve from 0.51548
Epoch 13/160
Epoch 13: val_auc did not improve from 0.51548
Epoch 14/160
Epoch 14: val_auc did not improve from 0.51548
Epoch 15/160
Epoch 15: val_auc did not improve from 0.51548
Epoch 16/160
Epoch 16: val_auc did not improve from 0.51548
Epoch 17/160
 34/

In [None]:
model.save_weights(f'{save_dir}/endrun.h5')

## Visualize Results

In [None]:
path_results = f"{save_dir}/training_history.csv"

In [None]:
metric = ""
df = pd.read_csv(path_results)
val_auc_max = max(df["val_auc" + metric])
auc_max = max(df["auc" + metric])
largest = [max(df["val_auc" + metric][:i+1]) for i in range(len(df))]
plt.plot(df["auc" + metric],label='Peak train = {:.5f}'.format(auc_max))
plt.plot(df["val_auc" + metric], label='Peak val = {:.5f}'.format(val_auc_max))
plt.plot(largest, label='max')
plt.legend()
plt.ylim(0.95,1)
plt.show()


# print(f"Max train AUC: {auc_max:.5f}, Max val AUC: {val_auc_max:.5f}")

In [None]:
val_loss = min(df["val_loss"])
train_loss = min(df["loss"])
plt.plot(df["loss"],label=f'Train loss: {train_loss:.4f}')
plt.plot(df["val_loss"], label=f'Val loss: {val_loss:.4f}')
plt.ylim(top=0.1, bottom=0)
plt.legend()
plt.show()

In [None]:
endrun_preds = model.predict(xval)
model.load_weights(f'{save_dir}/chkpt.h5')
best_preds = model.predict(xval)

In [None]:
fpr, tpr, thresholds = roc_curve(yval, endrun_preds)
auc_roc = roc_auc_score(yval, endrun_preds)

fig, axs = plt.subplots(2, 2, figsize=(12, 12))

axs[0,0].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc))
axs[0,0].set_xlabel('False positive rate')
axs[0,0].set_ylabel('True positive rate')
axs[0,0].set_title('Endrun ROC curve')
axs[0,0].legend(loc='best')

axs[0,1].set_xlim(0, 0.2)
axs[0,1].set_ylim(0.8, 1)
axs[0,1].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc))
axs[0,1].set_xlabel('False positive rate')
axs[0,1].set_ylabel('True positive rate')
axs[0,1].set_title('Endrun ROC curve (zoomed in at top left)')
axs[0,1].legend(loc='best')

fpr, tpr, thresholds = roc_curve(yval, best_preds)
auc_roc2 = roc_auc_score(yval, best_preds)

axs[1,0].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc2))
axs[1,0].set_xlabel('False positive rate')
axs[1,0].set_ylabel('True positive rate')
axs[1,0].set_title('Best val_auc ROC curve')
axs[1,0].legend(loc='best')

axs[1,1].set_xlim(0, 0.2)
axs[1,1].set_ylim(0.8, 1)
axs[1,1].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc2))
axs[1,1].set_xlabel('False positive rate')
axs[1,1].set_ylabel('True positive rate')
axs[1,1].set_title('Best val_auc ROC curve (zoomed in at top left)')
axs[1,1].legend(loc='best')

plt.show()

In [None]:
save_dir = '/global/homes/b/bkauf/Clean_Training/0.999_f_deluxe'

In [None]:
path_results = f"{save_dir}/training_history.csv"

In [None]:
metric = ""
df = pd.read_csv(path_results)
val_auc_max = max(df["val_auc" + metric])
auc_max = max(df["auc" + metric])
largest = [max(df["val_auc" + metric][:i+1]) for i in range(len(df))]
plt.plot(df["auc" + metric],label='Peak train = {:.5f}'.format(auc_max))
plt.plot(df["val_auc" + metric], label='Peak val = {:.5f}'.format(val_auc_max))
plt.plot(largest, label='max')
plt.legend()
plt.ylim(0.95,1)
plt.show()


# print(f"Max train AUC: {auc_max:.5f}, Max val AUC: {val_auc_max:.5f}")

In [None]:
val_loss = min(df["val_loss"])
train_loss = min(df["loss"])
plt.plot(df["loss"],label=f'Train loss: {train_loss:.4f}')
plt.plot(df["val_loss"], label=f'Val loss: {val_loss:.4f}')
plt.ylim(top=0.1, bottom=0)
plt.legend()
plt.show()

In [None]:
endrun_preds = model.predict(xval)
model.load_weights(f'{save_dir}/chkpt.h5')
best_preds = model.predict(xval)

In [None]:
fpr, tpr, thresholds = roc_curve(yval, endrun_preds)
auc_roc = roc_auc_score(yval, endrun_preds)

fig, axs = plt.subplots(2, 2, figsize=(12, 12))

axs[0,0].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc))
axs[0,0].set_xlabel('False positive rate')
axs[0,0].set_ylabel('True positive rate')
axs[0,0].set_title('Endrun ROC curve')
axs[0,0].legend(loc='best')

axs[0,1].set_xlim(0, 0.2)
axs[0,1].set_ylim(0.8, 1)
axs[0,1].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc))
axs[0,1].set_xlabel('False positive rate')
axs[0,1].set_ylabel('True positive rate')
axs[0,1].set_title('Endrun ROC curve (zoomed in at top left)')
axs[0,1].legend(loc='best')

fpr, tpr, thresholds = roc_curve(yval, best_preds)
auc_roc2 = roc_auc_score(yval, best_preds)

axs[1,0].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc2))
axs[1,0].set_xlabel('False positive rate')
axs[1,0].set_ylabel('True positive rate')
axs[1,0].set_title('Best val_auc ROC curve')
axs[1,0].legend(loc='best')

axs[1,1].set_xlim(0, 0.2)
axs[1,1].set_ylim(0.8, 1)
axs[1,1].plot(fpr, tpr, label='Effnet (area = {:.5f})'.format(auc_roc2))
axs[1,1].set_xlabel('False positive rate')
axs[1,1].set_ylabel('True positive rate')
axs[1,1].set_title('Best val_auc ROC curve (zoomed in at top left)')
axs[1,1].legend(loc='best')

plt.show()