In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50, EfficientNetB0, resnet50, efficientnet
from tensorflow.keras.optimizers import Adam

# ---------------------------------------------------
# 1. Load the EuroSAT Dataset
# ---------------------------------------------------
# Create train/validation/test split
(dataset_train, dataset_test), info = tfds.load(
    'eurosat/rgb',
    split=['train[:90%]', 'train[90%:]'],
    as_supervised=True,
    with_info=True
)

# Number of classes in EuroSAT
num_classes = info.features['label'].num_classes

# ---------------------------------------------------
# 2. Preprocess Data for ResNet and EfficientNet
# ---------------------------------------------------
IMG_SIZE = 224
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_resnet(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = resnet50.preprocess_input(image)
    return image, tf.one_hot(label, num_classes)

def preprocess_efficientnet(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = efficientnet.preprocess_input(image)
    return image, tf.one_hot(label, num_classes)

# Create tf.data pipelines
train_ds_resnet = (dataset_train
                   .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
                   .batch(batch_size)
                   .prefetch(AUTOTUNE))

test_ds_resnet = (dataset_test
                  .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
                  .batch(batch_size)
                  .prefetch(AUTOTUNE))

train_ds_eff = (dataset_train
                .map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE)
                .batch(batch_size)
                .prefetch(AUTOTUNE))

test_ds_eff = (dataset_test
               .map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE)
               .batch(batch_size)
               .prefetch(AUTOTUNE))

# ---------------------------------------------------
# 3. Define Dynamic Channel Gate Layer
# ---------------------------------------------------
class DynamicChannelGate(tf.keras.layers.Layer):
    def __init__(self, num_channels, name=None):
        super().__init__(name=name)
        self.gate_params = tf.Variable(
            initial_value=tf.ones((num_channels,), dtype=tf.float32),
            trainable=True,
            name="gate_params"
        )

    def call(self, inputs, training=None):
        gate = tf.sigmoid(self.gate_params)
        gate = tf.reshape(gate, (1, 1, 1, -1))
        return inputs * gate

# ---------------------------------------------------
# 4. Define ResNet50 Model with Dynamic Gating
# ---------------------------------------------------
def create_resnet_dynamic_gating(input_shape, num_classes):
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)
    base_model.trainable = True

    x = base_model.output
    num_channels = x.shape[-1]
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs=base_model.input, outputs=outputs, name="ResNet_DynamicGating")
    return model

# ---------------------------------------------------
# 5. Define EfficientNetB0 Model with Dynamic Gating
# ---------------------------------------------------
def create_efficientnet_dynamic_gating(input_shape, num_classes):
    base_model = EfficientNetB0(weights="imagenet", include_top=False, input_shape=input_shape)
    base_model.trainable = True

    x = base_model.output
    num_channels = x.shape[-1]
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs=base_model.input, outputs=outputs, name="EfficientNet_DynamicGating")
    return model

# ---------------------------------------------------
# 6. Compile and Train the Models
# ---------------------------------------------------
# ResNet50 Model
resnet_model = create_resnet_dynamic_gating((IMG_SIZE, IMG_SIZE, 3), num_classes)
resnet_model.compile(optimizer=Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])

print("\n--- Fine-Tuning Dynamic-Gating ResNet50 on EuroSAT ---")
history_resnet = resnet_model.fit(train_ds_resnet, epochs=3, validation_data=test_ds_resnet, verbose=1)

# EfficientNetB0 Model
eff_model = create_efficientnet_dynamic_gating((IMG_SIZE, IMG_SIZE, 3), num_classes)
eff_model.compile(optimizer=Adam(learning_rate=1e-4), loss="categorical_crossentropy", metrics=["accuracy"])

print("\n--- Fine-Tuning Dynamic-Gating EfficientNetB0 on EuroSAT ---")
history_eff = eff_model.fit(train_ds_eff, epochs=3, validation_data=test_ds_eff, verbose=1)

# Evaluate ResNet50
loss_resnet, acc_resnet = resnet_model.evaluate(test_ds_resnet, verbose=0)
print(f"ResNet50 Final Accuracy: {acc_resnet:.4f}")

# Evaluate EfficientNetB0
loss_eff, acc_eff = eff_model.evaluate(test_ds_eff, verbose=0)
print(f"EfficientNetB0 Final Accuracy: {acc_eff:.4f}")


--- Fine-Tuning Dynamic-Gating ResNet50 on EuroSAT ---
Epoch 1/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 348ms/step - accuracy: 0.8802 - loss: 0.3599 - val_accuracy: 0.9385 - val_loss: 0.1944
Epoch 2/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m226s[0m 298ms/step - accuracy: 0.9751 - loss: 0.0795 - val_accuracy: 0.9189 - val_loss: 0.2858
Epoch 3/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m229s[0m 301ms/step - accuracy: 0.9854 - loss: 0.0480 - val_accuracy: 0.9630 - val_loss: 0.1111

--- Fine-Tuning Dynamic-Gating EfficientNetB0 on EuroSAT ---
Epoch 1/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m208s[0m 190ms/step - accuracy: 0.7946 - loss: 0.7250 - val_accuracy: 0.9726 - val_loss: 0.0764
Epoch 2/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 140ms/step - accuracy: 0.9702 - loss: 0.0976 - val_accuracy: 0.9793 - val_loss: 0.0552
Epoch 3/3
[1m760/760[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[