# Data Preprocessing
While from imagenet might need little processing. To retrain Resnet-50 with adversarial inputs, they need to be generated. Training dataset has about 1M images, preprocessing, generates adversarial input for both training and testing datasets.
Corresponding adversarial tensorflow records are created for both datasets.

In [21]:
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input

# Constants
IMG_SIZE = 224
# BATCH_SIZE = 5000
BATCH_SIZE = 200
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 5
tf.random.set_seed(5)

physical_gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", physical_gpus)

try:
    tf.keras.mixed_precision.set_global_policy('float32')
    tf.config.experimental.set_virtual_device_configuration(
        physical_gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=56320)]  # Limit RAM to 55GB to avoid starving PC
    )
    print("Using GPU with 55GB of memory")
except Exception as e:
    print(e)


Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Using GPU with 55GB of memory


In [22]:
# Load ImageNet data

def prepare_input_data(input):
    image = tf.cast(input['image'], tf.float32) # ResNet-50 used this
    image = tf.image.resize(input['image'], (IMG_SIZE, IMG_SIZE))
    image = preprocess_input(image)
    label = input['label']
    return image, label

# Big dataset for real work
# dataset, info = tfds.load(
#     'imagenet2012',
#     shuffle_files=False,
#     with_info=True,
#     data_dir='../datasets'
# )

# Smaller dataset for testing
dataset, info = tfds.load(
    'imagenette',
    shuffle_files=False,
    with_info=True,
    data_dir='../datasets'
)

# Dataset stats
print(f'Train image count: {info.splits['train'].num_examples}')
print(f'Test image count: {info.splits['validation'].num_examples}')

# Preprocess data
train_dataset = dataset['train'].map(prepare_input_data, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
test_dataset = dataset['validation'].map(prepare_input_data, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

Train image count: 9469
Test image count: 3925


In [23]:
# Load ResNet50 model
from tensorflow.keras.applications import ResNet50

# base_model = ResNet50(
#     include_top=True,
#     weights='imagenet',
#     input_shape=(IMG_SIZE, IMG_SIZE, 3),
#     pooling=None,
#     classes=1000,
#     classifier_activation='softmax'
# )
base_model = ResNet50(
    include_top=True,
    weights=None,  # Don't load pretrained weights since we're changing the output
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    pooling=None,
    classes=10,  # Imagenette has 10 classes
    classifier_activation='softmax'
)

In [24]:
# save_processed_dataset('../datasets/adversaries/test_dataset.tfrecord', adversarial_test_dataset)
# %reset_selective -f adversarial_test_dataset test_dataset
# import gc;gc.collect()

In [25]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # get value from EagerTensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _create_adversary_with_pgd(model, images, labels, eps, eps_iter, nb_iter):
    """
    This generates adversarial images by iteratively applying a small
    perturbation in the direction of the gradient of the loss, and then
    projecting the result back into the epsilon-ball of the original image.

    Args:
        model (tf.keras.Model): The model to attack.
        images (tf.Tensor): The original, clean input images.
        labels (tf.Tensor): The true labels for the images.
        eps (float): The maximum perturbation (L-infinity norm).
        eps_iter (float): The step size for each attack iteration.
        nb_iter (int): The number of PGD iterations to perform.

    Returns:
        tf.Tensor: The generated adversarial images.
    """
    x_adv = tf.identity(images)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

    for _ in range(nb_iter):
        with tf.GradientTape() as tape:
            tape.watch(x_adv)
            prediction = model(x_adv, training=False)
            loss = loss_object(labels, prediction)

        gradients = tape.gradient(loss, x_adv)
        signed_grad = tf.sign(gradients)
        x_adv = x_adv + eps_iter * signed_grad
        perturbation = tf.clip_by_value(x_adv - images, -eps, eps)
        x_adv = images + perturbation

    return x_adv

def generate_adversarial_dataset(filename,dataset, model, eps, pgd_steps, pgd_step_size):
    """
    Generates adversarial examples and saves them to a TFRecord file
    by serializing the raw float32 tensors.
    """
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    num = 0
    with tf.io.TFRecordWriter(filename, options=options) as writer:
        for i, (images, labels) in enumerate(dataset):
            print(f"Batch {i+1}")
            # Generate the adversarial images (these are already preprocessed)
            adv_images = _create_adversary_with_pgd(
                model=model,
                images=images,
                labels=labels,
                eps=eps,
                eps_iter=pgd_step_size,
                nb_iter=pgd_steps
            )

            # Iterate through the batch to save each image/label pair
            for i in range(len(adv_images)):
                image_tensor = adv_images[i]
                label = labels[i]

                # 1. Cast the tensor to float16 to halve its size
                image_tensor_f16 = tf.cast(image_tensor, tf.float16)

                # 2. Serialize the smaller tensor
                image_bytes = tf.io.serialize_tensor(image_tensor_f16)
                # 2. Create the feature and write to the TFRecord file
                feature = {
                    'image': _bytes_feature(image_bytes), # Save the raw serialized tensor
                    'label': _int64_feature(label.numpy())
                }
                num += 1
                serialized_example = tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
                writer.write(serialized_example)

    print(f"Processed and saved: {num} images")


In [26]:
## Create adversarial dataset
EPSILON = 0.03
PGD_STEPS = 2
PGD_STEP_SIZE = 0.007
adversarial_test_file = '../datasets/adversaries/small_test_dataset.tfrec'
# generate_adversarial_dataset(
#     filename=adversarial_test_file,
#     dataset=test_dataset,
#     model=base_model,
#     eps=EPSILON,
#     pgd_steps=PGD_STEPS,
#     pgd_step_size=PGD_STEP_SIZE
# )
adversarial_train_file = '../datasets/adversaries/small_train_dataset.tfrec'
# generate_adversarial_dataset(
#     filename=adversarial_train_file,
#     dataset=train_dataset,
#     model=base_model,
#     eps=EPSILON,
#     pgd_steps=PGD_STEPS,
#     pgd_step_size=PGD_STEP_SIZE
# )

In [27]:
## Load data from file
def _parse_function(proto):
    """
    Parses a single example proto by deserializing the float16 tensor
    and casting it back to float32.
    """
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(proto, feature_description)

    # 1. Deserialize the byte string back into a float16 tensor
    image_f16 = tf.io.parse_tensor(parsed_features['image'], out_type=tf.float16)
    label = parsed_features['label']

    # 2. Cast the image back to float32 for the model
    image_f32 = tf.cast(image_f16, tf.float32)

    # 3. Set the shape on the final float32 tensor
    image_f32.set_shape([IMG_SIZE, IMG_SIZE, 3])

    return image_f32, label

# Load the TFRecord file back into a dataset
loaded_test_dataset = tf.data.TFRecordDataset(adversarial_test_file, compression_type='GZIP')

# Map the parsing function across the dataset
parsed_test_dataset = loaded_test_dataset.map(_parse_function).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
# Compile model
# First, update the model compilation to include more metrics
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=[
        'accuracy',
        # tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy'),
        tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy'),
        tf.keras.metrics.SparseCategoricalAccuracy(name='sparse_categorical_accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        # tf.keras.metrics.Precision(average='macro', name='precision'),
        # tf.keras.metrics.Recall(average='macro', name='recall')
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall')
    ]
)


print("Training the model on clean data...\n")
history = base_model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=5,
    batch_size=BATCH_SIZE,
    verbose=1,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=2,
            restore_best_weights=True
        )
    ]
)

In [None]:
# Evaluate model on all datasets
print("Computing baseline metrics...\n")

train_metrics = base_model.evaluate(train_dataset, verbose=1)
test_metrics = base_model.evaluate(test_dataset, verbose=1)
noisy_metrics = base_model.evaluate(parsed_test_dataset, verbose=1)

# Extract metrics
# train_loss, train_acc, train_top5, train_auc, train_prec, train_rec = train_metrics
# test_loss, test_acc, test_top5, test_auc, test_prec, test_rec = test_metrics
# noisy_loss, noisy_acc, noisy_top5, noisy_auc, noisy_prec, noisy_rec = noisy_metrics

try:
    # Extract metrics with error handling
    metric_names = ['loss', 'accuracy', 'top_5_accuracy', 'sparse_categorical_accuracy', 'auc', 'precision', 'recall']
    metrics_dict = {
        'train': dict(zip(metric_names, train_metrics)),
        'test': dict(zip(metric_names, test_metrics)),
        'noisy': dict(zip(metric_names, noisy_metrics))
    }

    # Unpack metrics safely
    train_loss, train_acc = metrics_dict['train']['loss'], metrics_dict['train']['accuracy']
    train_top5, train_auc = metrics_dict['train']['top_5_accuracy'], metrics_dict['train']['auc']
    train_prec, train_rec = metrics_dict['train']['precision'], metrics_dict['train']['recall']

    test_loss, test_acc = metrics_dict['test']['loss'], metrics_dict['test']['accuracy']
    test_top5, test_auc = metrics_dict['test']['top_5_accuracy'], metrics_dict['test']['auc']
    test_prec, test_rec = metrics_dict['test']['precision'], metrics_dict['test']['recall']

    noisy_loss, noisy_acc = metrics_dict['noisy']['loss'], metrics_dict['noisy']['accuracy']
    noisy_top5, noisy_auc = metrics_dict['noisy']['top_5_accuracy'], metrics_dict['noisy']['auc']
    noisy_prec, noisy_rec = metrics_dict['noisy']['precision'], metrics_dict['noisy']['recall']

except Exception as e:
    print(f"Error extracting metrics: {e}")
    print("Available metrics:", base_model.metrics_names)
    raise

print("\n" + "="*50)
print("       Model Performance Analysis")
print("="*50 + "\n")

print("## Base Performance Metrics 📊")
print("**Evaluated on clean test dataset**\n")
print(f"* **Top-1 Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Top-5 Accuracy**: `{test_top5*100:.2f}%`")
print(f"* **AUC Score**: `{test_auc:.3f}`")
print(f"* **Precision**: `{test_prec:.3f}`")
print(f"* **Recall**: `{test_rec:.3f}`")
print(f"* **Loss**: `{test_loss:.3f}`")
print("\n" + "---")

print("\n## Generalization Analysis 🧠")
print("**Comparing training vs test performance**\n")
print(f"* **Training Accuracy**: `{train_acc*100:.2f}%`")
print(f"* **Test Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Generalization Gap**: `{(train_acc-test_acc)*100:.2f}%`")
print(f"* **Training Loss**: `{train_loss:.3f}`")
print(f"* **Test Loss**: `{test_loss:.3f}`")
print("> *A smaller gap indicates better generalization*")
print("\n" + "---")

print("\n## Adversarial Robustness 🛡️")
print("**Evaluating model stability against adversarial inputs**\n")
print(f"* **Clean Data Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Adversarial Data Accuracy**: `{noisy_acc*100:.2f}%`")
print(f"* **Robustness Gap**: `{(test_acc-noisy_acc)*100:.2f}%`")
print(f"* **Clean AUC**: `{test_auc:.3f}`")
print(f"* **Adversarial AUC**: `{noisy_auc:.3f}`")
print("> *Smaller gaps between clean and adversarial metrics indicate better robustness*")

In [None]:
# save_processed_dataset('../datasets/adversaries/train_dataset.tfrecord', adversarial_train_dataset)
# %reset_selective -f adversarial_train_dataset train_dataset
# import gc;gc.collect()

In [None]:
## Adversarial training
# Load the TFRecord file back into a dataset
loaded_train_dataset = tf.data.TFRecordDataset(adversarial_train_file, compression_type='GZIP')

# Map the parsing function across the dataset
parsed_train_dataset = loaded_train_dataset.map(_parse_function).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

history = base_model.fit(
    parsed_train_dataset,
    epochs=5, # Use a suitable number of epochs for your task
    validation_data=parsed_test_dataset,
    verbose=0
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# First compile model with comprehensive metrics
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall')
    ]
)

# Evaluate model on all datasets
print("Computing comprehensive metrics...\n")

train_metrics = base_model.evaluate(train_dataset, verbose=0)
test_metrics = base_model.evaluate(test_dataset, verbose=0)
noisy_metrics = base_model.evaluate(parsed_test_dataset, verbose=0)

# Extract all metrics
train_loss, train_acc, train_top5, train_auc, train_prec, train_rec = train_metrics
test_loss, test_acc, test_top5, test_auc, test_prec, test_rec = test_metrics
noisy_loss, noisy_acc, noisy_top5, noisy_auc, noisy_prec, noisy_rec = noisy_metrics

# Create figure for multiple plots
plt.style.use('seaborn')
fig = plt.figure(figsize=(15, 10))

# 1. Accuracy Comparison
plt.subplot(2, 2, 1)
metrics = ['Top-1 Acc', 'Top-5 Acc', 'AUC']
clean_scores = [test_acc*100, test_top5*100, test_auc*100]
noisy_scores = [noisy_acc*100, noisy_top5*100, noisy_auc*100]

x = range(len(metrics))
width = 0.35

plt.bar([i - width/2 for i in x], clean_scores, width, label='Clean Data', color='skyblue')
plt.bar([i + width/2 for i in x], noisy_scores, width, label='Adversarial Data', color='lightcoral')
plt.ylabel('Percentage (%)')
plt.title('Performance Metrics Comparison')
plt.xticks(x, metrics)
plt.legend()

# 2. Precision-Recall Plot
plt.subplot(2, 2, 2)
metrics = ['Precision', 'Recall']
clean_scores = [test_prec, test_rec]
noisy_scores = [noisy_prec, noisy_rec]

x = range(len(metrics))
plt.bar([i - width/2 for i in x], clean_scores, width, label='Clean Data', color='skyblue')
plt.bar([i + width/2 for i in x], noisy_scores, width, label='Adversarial Data', color='lightcoral')
plt.ylabel('Score')
plt.title('Precision-Recall Comparison')
plt.xticks(x, metrics)
plt.legend()

# 3. Loss Comparison
plt.subplot(2, 2, 3)
plt.bar(['Training', 'Testing', 'Adversarial'], 
        [train_loss, test_loss, noisy_loss],
        color=['green', 'skyblue', 'lightcoral'])
plt.ylabel('Loss')
plt.title('Loss Comparison Across Datasets')

# 4. Robustness Gap
plt.subplot(2, 2, 4)
plt.bar(['Generalization Gap', 'Robustness Gap'], 
        [(train_acc - test_acc)*100, (test_acc - noisy_acc)*100],
        color=['skyblue', 'lightcoral'])
plt.ylabel('Gap Percentage (%)')
plt.title('Model Gaps Analysis')

plt.tight_layout()
plt.show()

# Print detailed metrics report
print("\n" + "="*50)
print("       Comprehensive Model Evaluation Results")
print("="*50 + "\n")

print("## Base Performance Metrics 📊")
print("**Evaluated on clean test dataset**\n")
print(f"* **Top-1 Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Top-5 Accuracy**: `{test_top5*100:.2f}%`")
print(f"* **AUC Score**: `{test_auc:.3f}`")
print(f"* **Precision**: `{test_prec:.3f}`")
print(f"* **Recall**: `{test_rec:.3f}`")
print(f"* **Loss**: `{test_loss:.3f}`")
print("\n" + "---")

print("\n## Generalization Analysis 🧠")
print("**Comparing training vs test performance**\n")
print(f"* **Training Accuracy**: `{train_acc*100:.2f}%`")
print(f"* **Test Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Generalization Gap**: `{(train_acc-test_acc)*100:.2f}%`")
print(f"* **Training Loss**: `{train_loss:.3f}`")
print(f"* **Test Loss**: `{test_loss:.3f}`")
print("> *A smaller gap indicates better generalization*")
print("\n" + "---")

print("\n## Adversarial Robustness 🛡️")
print("**Evaluating model stability against adversarial inputs**\n")
print(f"* **Clean Data Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Adversarial Data Accuracy**: `{noisy_acc*100:.2f}%`")
print(f"* **Robustness Gap**: `{(test_acc-noisy_acc)*100:.2f}%`")
print(f"* **Clean AUC**: `{test_auc:.3f}`")
print(f"* **Adversarial AUC**: `{noisy_auc:.3f}`")
print("> *Smaller gaps between clean and adversarial metrics indicate better robustness*")

Computing comprehensive metrics...



ValueError: Shape must be rank 1 but is rank 0 for '{{node in_top_k/InTopKV2}} = InTopKV2[T=DT_INT32](resnet50_1/predictions_1/Softmax, ArgMax_1, in_top_k/InTopKV2/k)' with input shapes: [?,1000], [], [].

# bbbb

In [None]:
import os
# 0 = all logs, 1 = filter INFO, 2 = filter INFO & WARNING, 3 = filter INFO, WARNING & ERROR
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # use "3" to hide even ERROR logs

import tensorflow as tf

# Silence TensorFlow's Python logger as well
tf.get_logger().setLevel("ERROR")

# Silence absl logs that TF uses
# import absl.logging
# absl.logging.set_verbosity(absl.logging.ERROR)

# # If on TF 1.x APIs:
# try:
#     tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# except Exception:
#     pass

import tensorflow_datasets as tfds
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import seaborn as sns

# Constants
IMG_SIZE = 224
BATCH_SIZE = 200
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 5 # Changed to a lower number for demonstration if retraining is needed
tf.random.set_seed(5)

physical_gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", physical_gpus)

try:
    tf.keras.mixed_precision.set_global_policy('float32') # Ensured float32 policy as per the original notebook
    if physical_gpus: # Check if GPUs are available before setting virtual device
        tf.config.experimental.set_virtual_device_configuration(
            physical_gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=56320)]  # Limit RAM to 55GB to avoid starving PC
        )
        print("Using GPU with 55GB of memory")
except Exception as e:
    print(e)

# Load ImageNet data
def prepare_input_data(input):
    image = tf.cast(input['image'], tf.float32)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) # Corrected to use the casted image
    image = preprocess_input(image)
    label = input['label']
    return image, label

dataset, info = tfds.load(
    'imagenette',
    shuffle_files=False,
    with_info=True,
    data_dir='../datasets'
)
# dataset, info = tfds.load(
#     'imagenet2012',
#     shuffle_files=False,
#     with_info=True,
#     data_dir='../datasets'
# )

print(f'Train image count: {info.splits["train"].num_examples}')
print(f'Test image count: {info.splits["validation"].num_examples}')

train_dataset = dataset['train'].map(prepare_input_data, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
test_dataset = dataset['validation'].map(prepare_input_data, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

# Load ResNet50 model
base_model = ResNet50(
    include_top=True,
    weights=None,
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    pooling=None,
    classes=10,
    classifier_activation='softmax'
)

# base_model = ResNet50(
#     include_top=True,
#     weights='imagenet',
#     input_shape=(IMG_SIZE, IMG_SIZE, 3),
#     pooling=None,
#     classes=1000,
#     classifier_activation='softmax'
# )

# Functions for adversarial data generation and loading
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _create_adversary_with_pgd(model, images, labels, eps, eps_iter, nb_iter):
    x_adv = tf.identity(images)
    # Use from_logits=False because classifier_activation='softmax' means model outputs probabilities
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    for _ in range(nb_iter):
        with tf.GradientTape() as tape:
            tape.watch(x_adv)
            prediction = model(x_adv, training=False)
            loss = loss_object(labels, prediction)

        gradients = tape.gradient(loss, x_adv)
        signed_grad = tf.sign(gradients)
        x_adv = x_adv + eps_iter * signed_grad
        perturbation = tf.clip_by_value(x_adv - images, -eps, eps)
        x_adv = images + perturbation

    return x_adv

def generate_adversarial_dataset(filename, dataset, model, eps, pgd_steps, pgd_step_size):
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    num = 0
    with tf.io.TFRecordWriter(filename, options=options) as writer:
        for i, (images, labels) in enumerate(dataset):
            print(f"Batch {i+1}")
            adv_images = _create_adversary_with_pgd(
                model=model,
                images=images,
                labels=labels,
                eps=eps,
                eps_iter=pgd_step_size,
                nb_iter=pgd_steps
            )

            for i in range(len(adv_images)):
                image_tensor = adv_images[i]
                label = labels[i]
                image_tensor_f16 = tf.cast(image_tensor, tf.float16)
                image_bytes = tf.io.serialize_tensor(image_tensor_f16)
                feature = {
                    'image': _bytes_feature(image_bytes),
                    'label': _int64_feature(label.numpy())
                }
                num += 1
                serialized_example = tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
                writer.write(serialized_example)
    print(f"Processed and saved: {num} images")

def _parse_function(proto):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(proto, feature_description)
    image_f16 = tf.io.parse_tensor(parsed_features['image'], out_type=tf.float16)
    label = parsed_features['label']
    image_f32 = tf.cast(image_f16, tf.float32)
    image_f32.set_shape([IMG_SIZE, IMG_SIZE, 3])
    return image_f32, label

# Create adversarial dataset (uncomment to run generation)
EPSILON = 0.03
PGD_STEPS = 2
PGD_STEP_SIZE = 0.007
adversarial_test_file = '../datasets/adversaries/small_test_dataset.tfrec'
adversarial_train_file = '../datasets/adversaries/small_train_dataset.tfrec'

loaded_test_dataset = tf.data.TFRecordDataset(adversarial_test_file, compression_type='GZIP')
parsed_test_dataset = loaded_test_dataset.map(_parse_function).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top_5_accuracy'),
    ]
)

# print("Training baseline model...\n")
# base_model.fit(train_dataset, verbose=1)

print("Computing baseline metrics...\n")

train_metrics = base_model.evaluate(train_dataset, verbose=1)
test_metrics = base_model.evaluate(test_dataset, verbose=1)
noisy_metrics = base_model.evaluate(parsed_test_dataset, verbose=1)

# Extract metrics
try:
    metric_names = ['loss', 'accuracy', 'top_5_accuracy']
    metrics_dict = {
        'train': dict(zip(metric_names, train_metrics)),
        'test': dict(zip(metric_names, test_metrics)),
        'noisy': dict(zip(metric_names, noisy_metrics))
    }

    train_loss, train_acc, train_top5 = metrics_dict['train'].values()
    test_loss, test_acc, test_top5 = metrics_dict['test'].values()
    noisy_loss, noisy_acc, noisy_top5 = metrics_dict['noisy'].values()

except Exception as e:
    print(f"Error extracting metrics: {e}")

print("## Base Performance Metrics 📊")
print("**Evaluated on clean test dataset**\n")
print(f"* **Top-1 Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Top-5 Accuracy**: `{test_top5*100:.2f}%`")
print(f"* **Loss**: `{test_loss:.3f}`")
print("\n" + "---")

print("\n## Generalization Analysis 🧠")
print("> *A smaller gap indicates better generalization*")
print("\n" + "---")

print("\n## Adversarial Robustness 🛡️")
print("**Evaluating model stability against adversarial inputs**\n")
print(f"* **Clean Data Accuracy**: `{test_acc*100:.2f}%`")
print(f"* **Adversarial Data Accuracy**: `{noisy_acc*100:.2f}%`")
print(f"* **Robustness Gap**: `{(test_acc-noisy_acc)*100:.2f}%`")
print("> *Smaller gaps between clean and adversarial metrics indicate better robustness*")

# Adversarial training
loaded_train_dataset = tf.data.TFRecordDataset(adversarial_train_file, compression_type='GZIP')
parsed_train_dataset = loaded_train_dataset.map(_parse_function).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print("Training robust model...\n")
base_model.fit(parsed_train_dataset, verbose=1)

train_metrics_adv = base_model.evaluate(parsed_train_dataset, verbose=1)
test_metrics_adv = base_model.evaluate(test_dataset, verbose=1)
noisy_metrics_adv = base_model.evaluate(parsed_test_dataset, verbose=1)

try:
    metrics_dict_adv = {
        'train': dict(zip(metric_names, train_metrics_adv)),
        'test': dict(zip(metric_names, test_metrics_adv)),
        'noisy': dict(zip(metric_names, noisy_metrics_adv))
    }

    train_loss_adv, train_acc_adv, train_top5_adv = metrics_dict_adv['train'].values()
    test_loss_adv, test_acc_adv, test_top5_adv = metrics_dict_adv['test'].values()
    noisy_loss_adv, noisy_acc_adv, noisy_top5_adv = metrics_dict_adv['noisy'].values()

except Exception as e:
    print(f"Error extracting metrics after adversarial training: {e}")
    print("Available metrics:", base_model.metrics_names)
    raise

# Create figure for multiple plots - After adversarial training
plt.style.use('seaborn-v0_8-darkgrid') # Updated style for better visuals
fig = plt.figure(figsize=(18, 6)) # Larger figure size

# 1. Accuracy Comparison (Post-Adversarial Training)
plt.subplot(1, 3, 1)
metrics = ['Top-1 Acc', 'Top-5 Acc']
clean_scores_post_adv = [test_acc_adv*100, test_top5_adv*100]
noisy_scores_post_adv = [noisy_acc_adv*100, noisy_top5_adv*100]

x = range(len(metrics))
width = 0.35

plt.bar([i - width/2 for i in x], clean_scores_post_adv, width, label='Clean Data', color='mediumseagreen')
plt.bar([i + width/2 for i in x], noisy_scores_post_adv, width, label='Adversarial Data', color='salmon')
plt.ylabel('Percentage (%)')
plt.title('Accuracy Comparison (Post-Adversarial Training)')
plt.xticks(x, metrics)
plt.legend()
plt.ylim(0, 100) # Set y-limit for better comparison

# 2. Loss Comparison (Post-Adversarial Training)
plt.subplot(1, 3, 2)
plt.bar(['Training', 'Testing', 'Adversarial'],
        [train_loss_adv, test_loss_adv, noisy_loss_adv],
        color=['steelblue', 'mediumseagreen', 'salmon'])
plt.ylabel('Loss')
plt.title('Loss Comparison Across Datasets (Post-Adversarial Training)')

# 3. Robustness Gap (Post-Adversarial Training)
plt.subplot(1, 3, 3)
plt.bar(['Generalization Gap', 'Robustness Gap'],
        [(train_acc_adv - test_acc_adv)*100, (test_acc_adv - noisy_acc_adv)*100],
        color=['steelblue', 'salmon'])
plt.ylabel('Gap Percentage (%)')
plt.title('Model Gaps Analysis (Post-Adversarial Training)')

plt.tight_layout()
plt.show()

# Print detailed metrics report (Post-Adversarial Training)
print("\n" + "="*50)
print("       Comprehensive Model Evaluation Results (Post-Adversarial Training)")
print("="*50 + "\n")

print("## Base Performance Metrics 📊")
print("**Evaluated on clean test dataset**\n")
print(f"* **Top-1 Accuracy**: `{test_acc_adv*100:.2f}%`")
print(f"* **Top-5 Accuracy**: `{test_top5_adv*100:.2f}%`")
print(f"* **Loss**: `{test_loss_adv:.3f}`")
print("\n" + "---")

print("\n## Generalization Analysis 🧠")
print("**Comparing training vs test performance**\n")
print(f"* **Training Accuracy**: `{train_acc_adv*100:.2f}%`")
print(f"* **Test Accuracy**: `{test_acc_adv*100:.2f}%`")
print(f"* **Generalization Gap**: `{(train_acc_adv-test_acc_adv)*100:.2f}%`")
print(f"* **Training Loss**: `{train_loss_adv:.3f}`")
print(f"* **Test Loss**: `{test_loss_adv:.3f}`")
print("> *A smaller gap indicates better generalization*")
print("\n" + "---")

print("\n## Adversarial Robustness 🛡️")
print("**Evaluating model stability against adversarial inputs**\n")
print(f"* **Clean Data Accuracy**: `{test_acc_adv*100:.2f}%`")
print(f"* **Adversarial Data Accuracy**: `{noisy_acc_adv*100:.2f}%`")
print(f"* **Robustness Gap**: `{(test_acc_adv-noisy_acc_adv)*100:.2f}%`")
print("> *Smaller gaps between clean and adversarial metrics")