In [6]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import (
    ResNet50,
    EfficientNetB0,
    resnet50,
    efficientnet
)
from tensorflow.keras.optimizers import Adam


# ---------------------------------------------------
# 1. Load the CIFAR-100 Dataset
# ---------------------------------------------------
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

num_classes = 10
nums = 10

# Convert to float32
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")

# Convert labels to one-hot
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)


In [11]:

# We'll define two separate pipelines:
#   - One for ResNet50
#   - One for EfficientNetB0

# ------------------------------
# 2. tf.data Pipeline for ResNet
# ------------------------------
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def preprocess_resnet(image, label):
    # Resize and apply ResNet-specific preprocessing
    image = tf.image.resize(image, (224, 224))
    image = resnet50.preprocess_input(image)
    return image, label

# Create training dataset
train_ds_resnet = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds_resnet = train_ds_resnet.shuffle(buffer_size=50000) \
    .map(preprocess_resnet, num_parallel_calls=AUTOTUNE) \
    .batch(batch_size) \
    .prefetch(AUTOTUNE)

# Create validation (test) dataset
val_ds_resnet = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_ds_resnet = val_ds_resnet.map(preprocess_resnet, num_parallel_calls=AUTOTUNE) \
    .batch(batch_size) \
    .prefetch(AUTOTUNE)


# ------------------------------------------
# 4. Define Baseline Models to Fine-Tune
# ------------------------------------------
def create_resnet50_finetune(input_shape, num_classes):
    base_model = ResNet50(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    # Freeze base model
    base_model.trainable = False

    x = layers.Flatten()(base_model.output)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs=base_model.input, outputs=outputs)
    return model

# ----------------------------------------------
# 5. Fine-Tune ResNet50 on CIFAR-100
# ----------------------------------------------
resnet_model = create_resnet50_finetune((224, 224, 3), num_classes)
resnet_model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("\n--- Fine-Tuning ResNet50 on CIFAR-100 ---")
history_resnet = resnet_model.fit(
    train_ds_resnet,
    epochs=3,
    validation_data=val_ds_resnet,
    verbose=1
)
resnet_loss, resnet_acc = resnet_model.evaluate(val_ds_resnet, verbose=0)
print(f"ResNet50 - CIFAR-100 Accuracy: {resnet_acc:.4f}")



--- Fine-Tuning ResNet50 on CIFAR-100 ---
Epoch 1/3
Epoch 2/3
Epoch 3/3
ResNet50 - CIFAR-100 Accuracy: 0.9067


In [12]:
# ------------------------------
# 3. tf.data Pipeline for EfficientNet
# ------------------------------
def preprocess_efficientnet(image, label):
    # Resize and apply EfficientNet-specific preprocessing
    image = tf.image.resize(image, (224, 224))
    image = efficientnet.preprocess_input(image)
    return image, label

train_ds_eff = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds_eff = train_ds_eff.shuffle(buffer_size=50000) \
    .map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE) \
    .batch(batch_size) \
    .prefetch(AUTOTUNE)

val_ds_eff = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_ds_eff = val_ds_eff.map(preprocess_efficientnet, num_parallel_calls=AUTOTUNE) \
    .batch(batch_size) \
    .prefetch(AUTOTUNE)



def create_efficientnet_finetune(input_shape, num_classes):
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    # Freeze base model
    base_model.trainable = False

    x = layers.Flatten()(base_model.output)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs=base_model.input, outputs=outputs)
    return model



# ----------------------------------------------
# 6. Fine-Tune EfficientNetB0 on CIFAR-100
# ----------------------------------------------
eff_model = create_efficientnet_finetune((224, 224, 3), num_classes)
eff_model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("\n--- Fine-Tuning EfficientNetB0 on CIFAR-100 ---")
history_eff = eff_model.fit(
    train_ds_eff,
    epochs=3,
    validation_data=val_ds_eff,
    verbose=1
)
eff_loss, eff_acc = eff_model.evaluate(val_ds_eff, verbose=0)
print(f"EfficientNetB0 - CIFAR-100 Accuracy: {eff_acc:.4f}")



--- Fine-Tuning EfficientNetB0 on CIFAR-100 ---
Epoch 1/3
Epoch 2/3
Epoch 3/3
EfficientNetB0 - CIFAR-100 Accuracy: 0.8962


In [None]:
################
################
#   GATING     #
################
################

In [13]:
import tensorflow as tf

class DynamicChannelGate(tf.keras.layers.Layer):
    """
    A learnable gating mechanism to dynamically prune (or re-expand) channels.
    Each channel has a gating parameter in [0, 1], learned via a sigmoid.
    """
    def __init__(self, num_channels, name=None):
        super().__init__(name=name)
        # We'll store one gating param per channel
        # Initialize them around 1.0 to start with minimal pruning.
        self.gate_params = tf.Variable(
            initial_value=tf.ones((num_channels,), dtype=tf.float32),
            trainable=True,
            name="gate_params"
        )

    def call(self, inputs, training=None):
        """
        Inputs shape: (batch, H, W, C)
        gate_params shape: (C,)
        Returns: inputs * gate, shape (batch, H, W, C)
        """
        # Sigmoid ensures gating stays between 0 and 1
        gate = tf.sigmoid(self.gate_params)
        # Reshape to broadcast across (batch, H, W, C)
        gate = tf.reshape(gate, (1, 1, 1, -1))
        return inputs * gate

    
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0

def create_efficientnet_dynamic_gating(input_shape, num_classes):
    """
    Loads EfficientNetB0 with ImageNet weights, then appends a custom
    gating layer + classification head for dynamic channel pruning/re-expansion.
    """
    # 1. Load base EfficientNet (feature extractor)
    #    include_top=False => no final classification layers
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )

    # Optionally unfreeze part or all of the base model for fine-tuning
    # For demonstration, we'll unfreeze everything.
    # (You could freeze some initial layers if you wish.)
    base_model.trainable = True

    # 2. Get the final feature map
    x = base_model.output  # shape: (batch, 7, 7, channels) for EFN-B0

    # 3. Insert the DynamicChannelGate
    #    We'll figure out how many channels the base model outputs:
    num_channels = x.shape[-1]
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)

    # 4. Classification Head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)  # you can tune dropout
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    # 5. Wrap into a Model
    model = models.Model(inputs=base_model.input, outputs=outputs, name="EFN_DynamicGating")
    return model


import tensorflow as tf
from tensorflow.keras.optimizers import Adam

# Create the dynamic-gating EfficientNet
model = create_efficientnet_dynamic_gating((224, 224, 3), num_classes=nums)

# Compile the model (tweak hyperparameters as needed)
model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print("\n--- Fine-Tuning Dynamic-Gating EfficientNetB0 on CIFAR-100 ---")
history = model.fit(
    train_ds_eff,             # from your existing pipeline
    epochs=3,                 # increase for better results
    validation_data=val_ds_eff,
    verbose=1
)

# Evaluate
loss, acc = model.evaluate(val_ds_eff, verbose=0)
print(f"Final Accuracy with Dynamic Gating: {acc:.4f}")



--- Fine-Tuning Dynamic-Gating EfficientNetB0 on CIFAR-100 ---
Epoch 1/3
Epoch 2/3
Epoch 3/3
Final Accuracy with Dynamic Gating: 0.9648


In [14]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50, resnet50
from tensorflow.keras.optimizers import Adam

###############################################################################
# 1. Data: CIFAR-10 loading & preprocessing for ResNet50
###############################################################################
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
num_classes = 10

x_train = x_train.astype("float32")
x_test  = x_test.astype("float32")
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test  = tf.keras.utils.to_categorical(y_test,  num_classes)

def preprocess_resnet(image, label):
    image = tf.image.resize(image, (224, 224))
    # Use ResNet50's preprocessing
    image = resnet50.preprocess_input(image)
    return image, label

AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32

train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(50000)
    .batch(batch_size)
    .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size)
    .map(preprocess_resnet, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

###############################################################################
# 2. DynamicChannelGate Layer
###############################################################################
class DynamicChannelGate(tf.keras.layers.Layer):
    """
    A learnable gating mechanism to dynamically prune (or re-expand) channels.
    Each channel has a gating parameter in [0, 1], learned via a sigmoid.
    """
    def __init__(self, num_channels, name=None):
        super().__init__(name=name)
        self.gate_params = tf.Variable(
            initial_value=tf.ones((num_channels,), dtype=tf.float32),
            trainable=True,
            name="gate_params"
        )

    def call(self, inputs, training=None):
        gate = tf.sigmoid(self.gate_params)          # shape (C,)
        gate = tf.reshape(gate, (1, 1, 1, -1))       # broadcast to (1, 1, 1, C)
        return inputs * gate

###############################################################################
# 3. Create a ResNet50 model with Dynamic Gating
###############################################################################
def create_resnet_dynamic_gating(input_shape, num_classes):
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)
    # Optionally unfreeze the base for deeper fine-tuning
    base_model.trainable = True

    x = base_model.output
    num_channels = x.shape[-1]  # Typically 2048 for ResNet50
    gating_layer = DynamicChannelGate(num_channels=num_channels, name="dynamic_gating")
    x = gating_layer(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = models.Model(inputs=base_model.input, outputs=outputs, name="ResNet_DynamicGating")
    return model

###############################################################################
# 4. Build, Train, and Evaluate the Model
###############################################################################
model = create_resnet_dynamic_gating((224, 224, 3), num_classes)
model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

print("\n--- Fine-Tuning Dynamic-Gating ResNet50 on CIFAR-10 ---")
history = model.fit(
    train_ds,
    epochs=3,
    validation_data=test_ds,
    verbose=1
)

loss, acc = model.evaluate(test_ds, verbose=0)
print(f"\nFinal Accuracy with Dynamic Gating: {acc:.4f}")



--- Fine-Tuning Dynamic-Gating ResNet50 on CIFAR-10 ---
Epoch 1/3
Epoch 2/3
Epoch 3/3

Final Accuracy with Dynamic Gating: 0.9304
