In [None]:
%pip install -q wandb keras_cv python-dotenv tensorflow-addons==0.22.0 # tensorboard_plugin_profile==2.8.0 

In [None]:
RERUN = False
if RERUN:
    WANDB_ID = ""
    MODEL_VERSION = ""  # v[0-9]

In [None]:
import math, re, os, pickle
from typing import Any, Dict, List, Optional
from datetime import datetime
import tensorflow as tf
from tensorflow.keras import layers, Model, Sequential
import keras_cv
import tensorflow_addons as tfa
import tensorflow_hub as hub
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weights
import wandb
from wandb.keras import WandbCallback, WandbMetricsLogger
from wandb_callback import WandbClfEvalCallback
from config import CFG
import gc
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import warnings
warnings.simplefilter('ignore')
print(f"Tensorflow version {tf.__version__}")
AUTO = tf.data.experimental.AUTOTUNE

CFG = CFG()

root = Path('kaggle')
inp = root / 'input'
out = root / working

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)
CFG.GCS_REPO = user_secrets.get_secret("GCS_REPO")
os.environ['WANDB_API_KEY'] = user_secrets.get_secret("WANDB_API_KEY")

In [None]:
np.random.seed(CFG.SEED)
tf.random.set_seed(CFG.SEED)
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ['PYTHONHASHSEED'] = str(CFG.SEED)

In [None]:
# Detect hardware
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:  # If TPU not found
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

REPLICAS = strategy.num_replicas_in_sync
print("Number of accelerators: ", strategy.num_replicas_in_sync)

CFG.BATCH_SIZE = CFG.BASE_BATCH_SIZE * REPLICAS  # CFG.BASE_BATCH_SIZE * REPLICAS

In [None]:
REPLICAS = strategy.num_replicas_in_sync
CFG.BATCH_SIZE = CFG.BASE_BATCH_SIZE * REPLICAS  # CFG.BASE_BATCH_SIZE * REPLICAS

## Visualization Utils

In [None]:
def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return class_dict[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(class_dict[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                class_dict[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    image = (image - image.min()) / (
        image.max() - image.min()
    )  # convert to [0, 1] for avoiding matplotlib warning
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    (images), (images, predictions), ((images, labels)), ((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]

    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows

    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else class_dict[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)

    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(class_dict)))
    ax.set_xticklabels(class_dict, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(class_dict)))
    ax.set_yticklabels(class_dict, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if titlestring != "":
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()

def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title(f'model {title}')
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
         for filename in filenames]
    return np.sum(n)

def decode_image(image_data, smallest_side, CFG):
    image = tf.image.decode_jpeg(image_data, channels=3)  # image format uint8 [0,255]
    image = tf.cast(image, tf.float32)
    image = tf.image.resize_with_crop_or_pad(image, smallest_side, smallest_side)
    image = tf.image.resize(image, size=CFG.PRECROP_SIZE, method="lanczos5")
    image = tf.image.random_crop(image, size=[*CFG.CROP_SIZE, 3])  #, method="lanczos5"
    return image

def read_labeled_tfrecord(example, CFG, return_id):
    feature_description = {
        "image/encoded": tf.io.FixedLenFeature([], tf.string),
        "image/id": tf.io.FixedLenFeature([], tf.string),
        "image/meta/width": tf.io.FixedLenFeature([], tf.int64),
        "image/meta/height": tf.io.FixedLenFeature([], tf.int64),
        "image/class/label": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)

    width = tf.cast(example['image/meta/width'], tf.int32)
    height = tf.cast(example['image/meta/height'], tf.int32)
    smallest_side = tf.minimum(width, height)

    image = decode_image(example["image/encoded"], smallest_side, CFG)
    label = tf.cast(example["image/class/label"], tf.int32)

    if not return_id:
        return image, label

    id = tf.cast(example['image/id'], tf.string)
    return image, label, id

def load_dataset(filenames, CFG, order=False, return_id=False):
    """Read from TFRecords. For optimal performance, read from multiple
    TFRecord files at once and set the option experimental_deterministic = False
    to allow order-altering optimizations."""
    if not order:
        option_no_order = tf.data.Options()
        option_no_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
        dataset = dataset.with_options(option_no_order)
    else:
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=None)

    dataset = dataset.map(lambda x: read_labeled_tfrecord(x, CFG, return_id), num_parallel_calls=AUTO)
    return dataset

def basic_augment(image, CFG=CFG):
    """Apply various augmentations to an image.

    Args:
        image: The input image.
        prob_hflip: Probability of applying horizontal flip.
        prob_vflip:
        prob_color_jitter: Probability of applying color jitter.
        prob_rotate: Probability of applying rotation.

    Returns:
        Augmented image.
    """
    num = tf.random.uniform([])
    # Horizontal Flip
    if num < CFG.FLIP:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)

    # Color Jitter
    elif num < CFG.JITTER:
        image = tf.image.random_brightness(image, max_delta=0.9)#CFG.JITTER_BRIGHT_LIM)
        image = tf.image.random_contrast(image, 0.6, 1.4)#*CFG.JITTER_CONTRAST_LIMS)
        image = tf.image.random_saturation(image, 0.6, 1.4)#*CFG.JITTER_SAT_LIMS)
        image = tf.image.random_hue(image, max_delta=0.15)#CFG.JITTER_HUE_LIM)

    # Rotation
    elif num < CFG.ROTATE:
        angles = tf.random.uniform([], *CFG.ROTATE_LIM)
        image = tfa.image.rotate(image, angles, fill_mode='reflect')

    else:
        image = random_masking(image)
    return image

def random_masking(image, CFG=CFG):
    original_shape = tf.shape(image)

    count = tf.random.uniform([], *CFG.COUNT)

    erase_size = tf.random.uniform([], *CFG.SIZE)
    erase_size = tf.math.divide(erase_size, tf.math.sqrt(count))
    erase_value = tf.random.uniform([], 0., 255.)

    mask1 = int(erase_size * float(original_shape[0])) - (int(erase_size * float(original_shape[0])) % 2)
    mask2 = int(erase_size * float(original_shape[1])) - (int(erase_size * float(original_shape[1])) % 2)

    image = tf.expand_dims(image, axis=0)
    for k in tf.range(count):
        image = tfa.image.random_cutout(
            image,
            mask_size=(mask1, mask2),
            constant_values=erase_value
        )
    image = tf.squeeze(image, 0)
    return image

def augment(images, labels, CFG):
    images = tf.map_fn(basic_augment, images)
    # images = tf.map_fn(random_masking, images)
    return images, labels

rand_augment = keras_cv.layers.RandAugment(
    value_range=(0, 255), augmentations_per_image=CFG.PER_IMAGE, magnitude=CFG.MAG
)

def get_batched_dataset(filenames, CFG, train=False, order=None):
    if order is None:
        order = not train
    dataset = load_dataset(filenames, CFG, order=order)
    # dataset = dataset.cache() # This dataset fits in RAM
    dataset = dataset.repeat()
    dataset = dataset.batch(CFG.BATCH_SIZE, drop_remainder=True)
    if train & CFG.AUGMENT:
#         dataset = dataset.map(lambda x, y: augment(x, y, CFG), num_parallel_calls=AUTO)
        dataset = dataset.map(
            lambda x, y: (
                rand_augment(tf.cast(x, tf.uint8)), y),
                num_parallel_calls=AUTO,
            )
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

In [None]:

# CFG.EPOCHS = 30

GCS_PATH_SELECT = {
    192: f'gs://{CFG.GCS_REPO}/tfrecords-jpeg-192x192',
    224: f'gs://{CFG.GCS_REPO}/tfrecords-jpeg-224x224',
    331: f'gs://{CFG.GCS_REPO}/tfrecords-jpeg-331x331',
    512: f'gs://{CFG.GCS_REPO}/tfrecords-jpeg-512x512',
    None: f'gs://{CFG.GCS_REPO}/tfrecords-jpeg-MO3'
}
GCS_PATH = GCS_PATH_SELECT[CFG.GCS_IMAGE_SIZE]

filenames = tf.io.gfile.glob(f"{GCS_PATH}/train*.tfrec")

if RERUN:
    training_filenames = set(filenames) - set(validation_filenames) - set(test_filenames)
    trining_filenames = list(training_filenames)
else:
    filenames, test_filenames = train_test_split(filenames, test_size=1, shuffle=True)
    training_filenames, validation_filenames = train_test_split(filenames, test_size=0.1, shuffle=True)

num_train = count_data_items(training_filenames)
num_val = count_data_items(validation_filenames)
num_test = count_data_items(test_filenames)

validation_steps = num_val / CFG.BATCH_SIZE // REPLICAS
steps_per_epoch = num_train / CFG.BATCH_SIZE // REPLICAS
TOTAL_STEPS = int(steps_per_epoch * (CFG.EPOCHS - 1))

class_dict = pickle.load(open(inp / 'class-dict/class_dict_NEW3.pkl', 'rb'))
train_df = pd.read_csv(inp / 'train-csv/train_with_MO3.csv')

In [None]:
if CFG.DEBUG:
    # Peek at training data
    training_dataset = get_batched_dataset(training_filenames, CFG, train=True)
    training_dataset = training_dataset.unbatch().batch(20)
    train_batch = iter(training_dataset)

In [None]:
if CFG.DEBUG:
    # run this cell again for next set of images
    display_batch_of_images(next(train_batch))

In [None]:
if CFG.DEBUG:
    print("Training data shapes:")
    for image, label in get_batched_dataset(training_filenames, CFG, train=True).take(3):
        print(image.numpy().shape, label.numpy().shape)
    print("Training data label examples:", label.numpy())
    print("Validation data shapes:")
    for image, label in get_batched_dataset(validation_filenames, CFG).take(3):
        print(image.numpy().shape, label.numpy().shape)
    print("Validation data label examples:", label.numpy())

In [None]:
match CFG.MODEL:
    case 'swin-224':
        model_path = '/kaggle/input/swin/tensorflow2/large-patch4-window7-224-fe/1'
    case 'swin-384':
        model_path = '/kaggle/input/swin/tensorflow2/large-patch4-window12-384-fe/1'
    case 'convnext-large-21k-224':
        model_path = '/kaggle/input/convnext/tensorflow2/large-21k-1k-224-fe/1'
    case 'convnext-large-1k-224':
        model_path = '/kaggle/input/convnext/tensorflow2/large-1k-224-fe/1'
    case 'convnext-xlarge-21k-224':
        model_path = '/kaggle/input/convnext/tensorflow2/xlarge-21k-1k-224-fe/1'
    case 'convnext-xlarge-21k-384':
        model_path = '/kaggle/input/convnext/tensorflow2/xlarge-21k-1k-384-fe/1'
    case 'vit-l16-fe':
        model_path = '/kaggle/input/vision-transformer/tensorflow2/vit-l16-fe/1'
    case 'vit-b8-fe':
        model_path = '/kaggle/input/vision-transformer/tensorflow2/vit-b8-fe/1'
    case 'bit':
        model_path = '/kaggle/input/bit/tensorflow2/m-r152x4/1'

In [None]:
if RERUN:
    try:
        with strategy.scope():
            options = tf.saved_model.LoadOptions(
                    experimental_io_device="/job:localhost"
                )
            final_model = tf.keras.models.load_model(f'/content/drive/MyDrive/Mushroom-Classifier/artifacts/run_{WANDB_ID}_model:{MODEL_VERSION}', options=options)

    except FileNotFoundError:
        artifact = run.use_artifact(f'g-broughton/Mushroom-Classifier/run_{WANDB_ID}_model:{MODEL_VERSION}', type='model')
        artifact_dir = artifact.download()

        with strategy.scope():
            options = tf.saved_model.LoadOptions(
                    experimental_io_device="/job:localhost"
                )
            final_model = tf.keras.models.load_model(artifact_dir, options=options)


else:
    with strategy.scope():
        img_adjust_layer = layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*CFG.CROP_SIZE, 3])
        pretrained_model = tf.keras.models.load_model(model_path)

        model = Sequential([
            img_adjust_layer,
            pretrained_model,
        ])

        # Assuming input_image is the input layer for model
        input_image = layers.Input(shape=model.input_shape[1:])
        model_output = model(input_image)
        output = layers.Dense(len(class_dict), activation='softmax')(model_output)

        # Create the final model
        final_model = Model(inputs=input_image, outputs=output)

        top3 = tf.keras.metrics.SparseTopKCategoricalAccuracy(3, name='top-3-accuracy')

    class_weights = compute_class_weight('balanced', classes=np.unique(train_df['class_id']), y=train_df['class_id'])
    final_model.compile(
        optimizer=tf.keras.optimizers.AdamW(
            learning_rate=.001,
            epsilon=CFG.EPSILON,
            weight_decay=CFG.DECAY,
            beta_1=CFG.BETA1,
            beta_2=CFG.BETA2
        ),
        loss = tf.keras.losses.CategoricalFocalCrossentropy(alpha=class_weights, label_smoothing=0.1)  # 'sparse_categorical_crossentropy',
        metrics=['accuracy', top3]
    )
    final_model.summary()

In [None]:
if RERUN:
    run = wandb.init(
        project="Mushroom-Classifier",
        tags=[CFG.OPT, CFG.LR_SCHED, str(CFG.CROP_SIZE[0])],
        resume='allow',
        id=WANDB_ID,
        dir="../",
    )

    validation_filenames = run.config['VAL_FILENAMES']
    test_filenames = run.config['TEST_FILENAMES']
else:
    config = wandb.helper.parse_config(
        CFG, include=(
            'SEED', 'BATCH_SIZE', 'EPOCHS', 'PRECROP_SIZE', 'LR_START', 'with_MO',
            'LR_MAX', 'LR_RAMPEP', 'LR_SUSEP', 'LR_MIN', 'EPSILON', 'DECAY', 'BETA1', 'BETA2',
            # 'JITTER_BRIGHT_LIM', 'JITTER_CONTRAST_LIMS', 'JITTER_SAT_LIMS', 'JITTER_HUE_LIM', 'ROTATE_LIM', 'SIZE', 'COUNT',
            'LR_SCHED', 'MODEL', 'MAG', 'PER_IMAGE' 'LR_DECAY',
        )
    )
    wandb.init(
        # reinit=True,
        project="Mushroom-Classifier",
        tags=[CFG.OPT, CFG.LR_SCHED, str(CFG.CROP_SIZE[0])],
        config=config,
        dir="../",
    )
    wandb.config.update({
        "STEPS_PER_EPOCH": steps_per_epoch,
        "TOTAL_STEPS": TOTAL_STEPS,
        "TEST_FILENAMES": test_filenames,
        "VAL_FILENAMES": validation_filenames
        'LOSS': 'focal'
        'LABEL_SMOOTHING': 0.1
        "ALPHA": "sklearn compute class weights"
        "GAMMA": 2
    })

In [None]:
# CFG.LR_SUSEP = 0
# CFG.LR_DECAY = 0.85
# CFG.LR_MIN = 0.000001
# CFG.LR_MAX = 0.000001

def get_lr_callback(batch_size=8, plot=False):
    lr_start   = CFG.LR_START  # 0.000001
    lr_max     = CFG.LR_MAX * batch_size #CFG.LR_MAX * batch_size
    lr_min     = CFG.LR_MIN  # 0.0000001
    lr_ramp_ep = CFG.LR_RAMPEP  # 5
    lr_sus_ep  = CFG.LR_SUSEP  # 0
    lr_decay   = CFG.LR_DECAY  # 0.9
    # CFG.LR_MAX = lr_max

    def lrfn(epoch):
        if epoch <= lr_ramp_ep:
            lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start

        elif epoch < lr_ramp_ep + lr_sus_ep:
            lr = lr_max

        elif CFG.LR_SCHED=='ExponentialWarmup':
            lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min

        elif CFG.LR_SCHED=='CosineWarmup':
            decay_total_epochs = CFG.EPOCHS - lr_ramp_ep - lr_sus_ep
            decay_epoch_index = epoch - lr_ramp_ep - lr_sus_ep
            phase = math.pi * decay_epoch_index / decay_total_epochs
            cosine_decay = 0.4 * (1 + math.cos(phase))
            lr = (lr_max - lr_min) * cosine_decay + lr_min

        return lr

    if plot:
        plt.figure(figsize=(10,5))
        plt.plot(np.arange(CFG.EPOCHS), [lrfn(epoch) for epoch in np.arange(CFG.EPOCHS)], marker='o')
        plt.xlabel('epoch'); plt.ylabel('learning rate')
        plt.title('Learning Rate Scheduler')
        plt.show()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    return lr_callback

_=get_lr_callback(CFG.BATCH_SIZE, plot=True )

In [None]:
options = tf.saved_model.SaveOptions(
        experimental_io_device="/job:localhost"
    )
val_data = load_dataset(validation_filenames, CFG, order=True).take(51200).batch(1024).prefetch(AUTO)

modeldir = out / 'model'
modeldir.mkdir(exist_ok=True, parents=True)

callbacks = [
    tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=CFG.ES_PATIENCE,
        verbose=1,
        mode="max",
        restore_best_weights=True,
    ),
    wandb.keras.WandbMetricsLogger(log_freq="epoch"),
#     wandb.keras.WandbModelCheckpoint(
#         str(modeldir),#CFG.CKPT_DIR,  # .h5 for weights, dir for whole model
#         monitor="val_accuracy",
#         mode="max",
#         verbose=1,
#         save_best_only=True,
#         save_weights_only=False,
#         options=options,
#         initial_value_threshold=0.8,
#         save_freq=TOTAL_STEPS,
#     ),
    get_lr_callback(CFG.BATCH_SIZE),
    if CFG.RERUN:
        WandbClfEvalCallback(
            validation_data=val_data,
            data_table_columns=["idx", "image", "species"],
            pred_table_columns=["image", "species", "predicted species", "score"],
            num_batches=50,
            class_dict=class_dict
        ),

    # tf.keras.callbacks.TensorBoard(
    #     log_dir = logs,
    #     histogram_freq = 1,
    #     # profile_batch = '500,520'
    # )
]

In [None]:
if RERUN:
    tfiles = training_filenames[0]
    history = final_model.fit(
        get_batched_dataset(tfiles, CFG, train=True).unbatch().take(32).batch(32),
        epochs=CFG.EPOCHS,
        validation_data=get_batched_dataset(validation_filenames, CFG, train=False),
        validation_steps=validation_steps,
        callbacks=callbacks,
        initial_epoch=CFG.EPOCHS,
    )
else:
    history = final_model.fit(
        get_batched_dataset(training_filenames, CFG, train=True),
        steps_per_epoch=steps_per_epoch,
        epochs=CFG.EPOCHS,
        validation_data=get_batched_dataset(validation_filenames, CFG, train=False),
        validation_steps=validation_steps,
        callbacks=callbacks,
    )
# wandb.finish()

In [None]:
final_model.save(modeldir)

model_checkpoint_artifact = wandb.Artifact(
    f"run_{wandb.run.id}_model", type="model"
)
if os.path.isdir('/kaggle/working/model'):
    model_checkpoint_artifact.add_dir(modeldir)
else:
    raise FileNotFoundError(f"No such file or directory {modeldir}")
wandb.log_artifact(model_checkpoint_artifact, aliases='latest')

# wandb.finish()

## Get image ID, Label and Model Predictions

In [None]:
def process_and_save_predictions(model, dataset, batch_size, num_samples):
    predictions = []
    actual_labels = []
    score_list = []
    id_list = []

    for image_data, labels, ids in tqdm(dataset, total=(num_samples // batch_size)):
        # Predict
        preds = model.predict_on_batch(image_data)
        pred_idx = tf.math.argmax(preds, axis=-1)
        scores = tf.math.reduce_max(preds, axis=-1)

        # Append predictions and labels
        predictions.extend(pred_idx.numpy().tolist())
        actual_labels.extend(labels.numpy().tolist())
        score_list.extend(scores.numpy().tolist())
        id_list.extend(ids.numpy().tolist())

    # Create DataFrame
    results_df = pd.DataFrame({
        'Prediction': predictions,
        'Actual Label': actual_labels,
        'Score': score_list,
        'ids': id_list,
    })

    results_df['Species'] = results_df['Actual Label'].map(class_dict)
    results_df['Species Pred'] = results_df['Prediction'].map(class_dict)
    return results_df

In [None]:
for filenames, num, csvname, batch in zip([test_filenames, training_filenames, validation_filenames], [num_test, num_train, num_val], ["test_df.csv", "train_df.csv", "val_df.csv"], [256, 1024, 1024]):
    ds = load_dataset(filenames, CFG, return_id=True).batch(batch).prefetch(AUTO)
    df = process_and_save_predictions(final_model, ds, batch, num)
    table = wandb.Table(dataframe=df)
    wandb.log({csvname: table})
    df.to_csv(out / csvname, index=False)

wandb.finish()