# ResNet-50: Adversarial Training

In [1]:
import os
import sys
import glob
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import seaborn as sns

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
tf.get_logger().setLevel("ERROR")

In [2]:
# Constants
IMG_SIZE = 224
BATCH_SIZE = 300
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 5
INPUT_SHAPE=(224, 224, 3)

tf.random.set_seed(5)
dataset_dir = "../datasets"

# Change dataset_dir when run in google colab 
if 'google.colab' in sys.modules:
    from google.colab import drive

    drive.mount('/content/drive')
    dataset_dir = "/content/drive/Othercomputers/Big Mac/datasets"
    BATCH_SIZE = 430

physical_gpus = tf.config.list_physical_devices('GPU')
print("Using available GPUs: ", physical_gpus)

tf.keras.mixed_precision.set_global_policy('float32')

Using available GPUs:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
# Load ImageNet2012 dataset
def prepare_input_data(input):
    image = tf.cast(input['image'], tf.float32)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = preprocess_input(image)
    label = input['label']
    return image, label

def make_dataset(ds):
    return (
        ds.map(prepare_input_data, num_parallel_calls=AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(AUTOTUNE)
    )


(train, validation, test), info = tfds.load(
    'imagenet2012_subset/10pct',
    split=['train', 'validation[:50%]', 'validation[50%:]'],
    shuffle_files=False,
    with_info=True,
    data_dir=dataset_dir
)

num_classes = info.features['label'].num_classes
class_names = info.features['label'].names

print(f"Train count: {info.splits['train'].num_examples}")
print(f"Validation count: {info.splits['validation[:50%]'].num_examples}")
print(f"Test count: {info.splits['validation[50%:]'].num_examples}")

train_dataset = make_dataset(train)
validation_dataset = make_dataset(validation)
test_dataset = make_dataset(test)

Train count: 128116
Validation count: 25000
Test count: 25000


In [4]:
# Load adversarial datasets

def _parse_image(input):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(input, feature_description)
    image_f16 = tf.io.parse_tensor(parsed_features['image'], out_type=tf.float16)
    label = parsed_features['label']
    image_f32 = tf.cast(image_f16, tf.float32)
    image_f32.set_shape([IMG_SIZE, IMG_SIZE, 3])
    return image_f32, label

def create_tf_dataset(file_paths):
    raw_dataset = tf.data.TFRecordDataset(file_paths, compression_type='GZIP')
    tf_dataset = raw_dataset.map(_parse_image).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return tf_dataset

# Get all adversarial datasets for train, validation, and testing
train_file_paths = glob.glob(f'{dataset_dir}/adversaries/imagenet2012_subset/train-*.tfrec')
validation_file_paths = glob.glob(f'{dataset_dir}/adversaries/imagenet2012_subset/validation-*.tfrec')
test_file_paths = glob.glob(f'{dataset_dir}/adversaries/imagenet2012_subset/test-*.tfrec')

print(f"Loaded {len(train_file_paths)} TFrecord train files")
print(f"Loaded {len(validation_file_paths)} TFrecord validation files")
print(f"Loaded {len(test_file_paths)} TFrecord test files")

# Create a TFRecordDataset
adv_train_dataset = create_tf_dataset(train_file_paths)
adv_validation_dataset = create_tf_dataset(validation_file_paths)
adv_test_dataset = create_tf_dataset(test_file_paths)

Loaded 298 TFrecord train files
Loaded 59 TFrecord validation files
Loaded 59 TFrecord test files


In [5]:
# Merge Adversarial and Clean datasets for retraining
buffer_size = 8

merged_train_dataset = adv_train_dataset.concatenate(train_dataset)
shuffled_train_dataset = merged_train_dataset.shuffle(buffer_size=buffer_size)

merged_validation_dataset = adv_validation_dataset.concatenate(validation_dataset)
shuffled_validation_dataset = merged_validation_dataset.shuffle(buffer_size=buffer_size)

In [None]:
print("Training robust ResNet-50 model...\n")

base_model = ResNet50(
    include_top=False,
    weights='imagenet',
    input_shape=INPUT_SHAPE,
    classes=1000
)
base_model.trainable = False 

# Build a small classification head on top of the base model
inputs = tf.keras.Input(shape=INPUT_SHAPE)
x = base_model(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

robust_model = tf.keras.Model(inputs, outputs)

robust_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[
        'accuracy'
    ]
)

robust_model.fit(
    shuffled_train_dataset,
    verbose=1,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=shuffled_validation_dataset
)

print("Fine tuning with lower learning rate")
base_model.trainable = True

# Unfreeze last 10 layers of base model
for layer in base_model.layers[:-10]:
    layer.trainable = False

robust_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[
        'accuracy',
    ]
)

robust_model.fit(
    shuffled_train_dataset,
    verbose=1,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=shuffled_validation_dataset
)

robust_model.save(f"{dataset_dir}/models/robust_resnet50v2.keras")


Training robust ResNet-50 model...

Epoch 1/5
    856/Unknown [1m540s[0m 624ms/step - accuracy: 0.4505 - loss: 2.8553



[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m648s[0m 750ms/step - accuracy: 0.5885 - loss: 1.8751 - val_accuracy: 0.5601 - val_loss: 2.3930
Epoch 2/5
[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m682s[0m 793ms/step - accuracy: 0.7107 - loss: 1.2710 - val_accuracy: 0.5727 - val_loss: 2.8575
Epoch 3/5
[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m658s[0m 765ms/step - accuracy: 0.7562 - loss: 1.1121 - val_accuracy: 0.5763 - val_loss: 3.4391
Epoch 4/5
[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m639s[0m 742ms/step - accuracy: 0.7869 - loss: 1.0201 - val_accuracy: 0.5782 - val_loss: 4.0248
Epoch 5/5
[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m639s[0m 743ms/step - accuracy: 0.8159 - loss: 0.9000 - val_accuracy: 0.5781 - val_loss: 4.7225
Fine tuning with lower learning rate
Epoch 1/5
[1m856/856[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m696s[0m 804ms/step - accuracy: 0.9169 - loss: 0.3135 - val_accuracy: 0.6185 - 