In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50, EfficientNetB0, resnet50, efficientnet
from tensorflow.keras.optimizers import Adam

# ---------------------------------------------------
# 1. Load the TF-Flowers Dataset
# ---------------------------------------------------
# Create train/validation split
(dataset_train, dataset_test), info = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:]'],
    as_supervised=True,
    with_info=True
)

# Number of classes in TF-Flowers
num_classes = info.features['label'].num_classes

# ---------------------------------------------------
# 2. Preprocess Data for ResNet and EfficientNet
# ---------------------------------------------------
IMG_SIZE = 224
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_resnet(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = resnet50.preprocess_input(image)
    return image, tf.one_hot(label, num_classes)

def preprocess_efficientnet(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = efficientnet.preprocess_input(image)
    return image, tf.one_hot(label, num_classes)

# Create tf.data pipelines
train_ds_resnet = (dataset_train
                   .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
                   .batch(batch_size)
                   .prefetch(AUTOTUNE))

test_ds_resnet = (dataset_test
                  .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
                  .batch(batch_size)
                  .prefetch(AUTOTUNE))

train_ds_eff = (dataset_train
                .map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE)
                .batch(batch_size)
                .prefetch(AUTOTUNE))

test_ds_eff = (dataset_test
               .map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE)
               .batch(batch_size)
               .prefetch(AUTOTUNE))

# ---------------------------------------------------
# 3. Define Dynamic Channel Gate Layer
# ---------------------------------------------------
class DynamicChannelGate(tf.keras.layers.Layer):
    def __init__(self, num_channels, name=None):
        super().__init__(name=name)
        self.gate_params = tf.Variable(
            initial_value=tf.ones((num_channels,), dtype=tf.float32),
            trainable=True,
            name="gate_params"
        )

    def call(self, inputs, training=None):
        gate = tf.sigmoid(self.gate_params)
        gate = tf.reshape(gate, (1, 1, 1, -1))
        return inputs * gate

# ---------------------------------------------------
# 4. Define ResNet50 Model with Dynamic Gating
# ---------------------------------------------------
def create_resnet_dynamic_gating(input_shape, num_classes):
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)
    base_model.trainable = True

    x = base_model.output
    num_channels = x.shape[-1]
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs=base_model.input, outputs=outputs, name="ResNet_DynamicGating")
    return model

# ---------------------------------------------------
# 5. Define EfficientNetB0 Model with Dynamic Gating
# ---------------------------------------------------
def create_efficientnet_dynamic_gating(input_shape, num_classes):
    base_model = EfficientNetB0(weights="imagenet", include_top=False, input_shape=input_shape)
    base_model.trainable = True

    x = base_model.output
    num_channels = x.shape[-1]
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs=base_model.input, outputs=outputs, name="EfficientNet_DynamicGating")
    return model

# ---------------------------------------------------
# 6. Compile and Train the Models
# ---------------------------------------------------
# ResNet50 Model
resnet_model = create_resnet_dynamic_gating((IMG_SIZE, IMG_SIZE, 3), num_classes)
resnet_model.compile(optimizer=Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])

print("\n--- Fine-Tuning Dynamic-Gating ResNet50 on TF-Flowers ---")
history_resnet = resnet_model.fit(train_ds_resnet, epochs=50, validation_data=test_ds_resnet, verbose=1)

# EfficientNetB0 Model
eff_model = create_efficientnet_dynamic_gating((IMG_SIZE, IMG_SIZE, 3), num_classes)
eff_model.compile(optimizer=Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])

print("\n--- Fine-Tuning Dynamic-Gating EfficientNetB0 on TF-Flowers ---")
history_eff = eff_model.fit(train_ds_eff, epochs=50, validation_data=test_ds_eff, verbose=1)

# Evaluate ResNet50
loss_resnet, acc_resnet = resnet_model.evaluate(test_ds_resnet, verbose=0)
print(f"ResNet50 Final Accuracy: {acc_resnet:.4f}")

# Evaluate EfficientNetB0
loss_eff, acc_eff = eff_model.evaluate(test_ds_eff, verbose=0)
print(f"EfficientNetB0 Final Accuracy: {acc_eff:.4f}")

Downloading and preparing dataset 218.21 MiB (download: 218.21 MiB, generated: 221.83 MiB, total: 440.05 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step

--- Fine-Tuning Dynamic-Gating ResNet50 on TF-Flowers ---
Epoch 1/50
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 738ms/step - accuracy: 0.7159 - loss: 0.7164 - val_accuracy: 0.8910 - val_loss: 0.4920
Epoch 2/50
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 308ms/step - accuracy: 0.9791 - loss: 0.0809 - val_accuracy: 0.9455 - val_loss: 0.1574
Epoch 3/50
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 314ms/step - accuracy: 0.9980 - loss: 0.0170 - val_accuracy: 0.9550 - val_loss: 0.1556
Epoch 4/50
[1m92/92[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 319ms/step - accur