# Visualizing convnet filters

In [1]:
import functools
from keras import layers
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.utils.keras_utils import standardize_data_format


@keras_hub_export("keras_hub.models.XceptionBackbone")
class XceptionBackbone(Backbone):
    """Xception core network with hyperparameters.

    This class implements a Xception backbone as described in
    [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357).

    Most users will want the pretrained presets available with this model. If
    you are creating a custom backbone, this model provides customizability
    through the `stackwise_conv_filters` and `stackwise_pooling` arguments. This
    backbone assumes the same basic structure as the original Xception mode:
    * Residuals and pre-activation everywhere but the first and last block.
    * Conv layers for the first block only, separable conv layers elsewhere.

    Args:
        stackwise_conv_filters: list of list of ints. Each outermost list
            entry represents a block, and each innermost list entry a conv
            layer. The integer value specifies the number of filters for the
            conv layer.
        stackwise_pooling: list of bools. A list of booleans per block, where
            each entry is true if the block should includes a max pooling layer
            and false if it should not.
        image_shape: tuple. The input shape without the batch size.
            Defaults to `(None, None, 3)`.
        data_format: `None` or str. If specified, either `"channels_last"` or
            `"channels_first"`. If unspecified, the Keras default will be used.
        dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
            to use for the model's computations and weights.

    Examples:
    ```python
    input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3))

    # Pretrained Xception backbone.
    model = keras_hub.models.Backbone.from_preset("xception_41_imagenet")
    model(input_data)

    # Randomly initialized Xception backbone with a custom config.
    model = keras_hub.models.XceptionBackbone(
        stackwise_conv_filters=[[32, 64], [64, 128], [256, 256]],
        stackwise_pooling=[True, True, False],
    )
    model(input_data)
    ```
    """

    def __init__(
        self,
        stackwise_conv_filters,
        stackwise_pooling,
        image_shape=(None, None, 3),
        data_format=None,
        dtype=None,
        **kwargs,
    ):
        if len(stackwise_conv_filters) != len(stackwise_pooling):
            raise ValueError("All stackwise args should have the same length.")

        data_format = standardize_data_format(data_format)
        channel_axis = -1 if data_format == "channels_last" else 1
        num_blocks = len(stackwise_conv_filters)

        # Layer shorcuts with common args.
        norm = functools.partial(
            layers.BatchNormalization,
            axis=channel_axis,
            dtype=dtype,
        )
        act = functools.partial(
            layers.Activation,
            activation="relu",
            dtype=dtype,
        )
        conv = functools.partial(
            layers.Conv2D,
            kernel_size=(3, 3),
            use_bias=False,
            data_format=data_format,
            dtype=dtype,
        )
        sep_conv = functools.partial(
            layers.SeparableConv2D,
            kernel_size=(3, 3),
            padding="same",
            use_bias=False,
            data_format=data_format,
            dtype=dtype,
        )
        point_conv = functools.partial(
            layers.Conv2D,
            kernel_size=(1, 1),
            strides=(2, 2),
            padding="same",
            use_bias=False,
            data_format=data_format,
            dtype=dtype,
        )
        pool = functools.partial(
            layers.MaxPool2D,
            pool_size=(3, 3),
            strides=(2, 2),
            padding="same",
            data_format=data_format,
            dtype=dtype,
        )

        # === Functional Model ===
        image_input = layers.Input(shape=image_shape)
        x = image_input  # Intermediate result.

        # Iterate through the blocks.
        for block_i in range(num_blocks):
            first_block, last_block = block_i == 0, block_i == num_blocks - 1
            block_filters = stackwise_conv_filters[block_i]
            use_pooling = stackwise_pooling[block_i]

            # Save the block input as a residual.
            residual = x
            for conv_i, filters in enumerate(block_filters):
                # First block has post activation and strides on first conv.
                if first_block:
                    prefix = f"block{block_i + 1}_conv{conv_i + 1}"
                    strides = (2, 2) if conv_i == 0 else (1, 1)
                    x = conv(filters, strides=strides, name=prefix)(x)
                    x = norm(name=f"{prefix}_bn")(x)
                    x = act(name=f"{prefix}_act")(x)
                # Last block has post activation.
                elif last_block:
                    prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
                    x = sep_conv(filters, name=prefix)(x)
                    x = norm(name=f"{prefix}_bn")(x)
                    x = act(name=f"{prefix}_act")(x)
                else:
                    prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
                    # The first conv in second block has no activation.
                    if block_i != 1 or conv_i != 0:
                        x = act(name=f"{prefix}_act")(x)
                    x = sep_conv(filters, name=prefix)(x)
                    x = norm(name=f"{prefix}_bn")(x)

            # Optional block pooling.
            if use_pooling:
                x = pool(name=f"block{block_i + 1}_pool")(x)

            # Sum residual, first and last block do not have a residual.
            if not first_block and not last_block:
                prefix = f"block{block_i + 1}_residual"
                filters = x.shape[channel_axis]
                # Match filters with a pointwise conv if needed.
                if filters != residual.shape[channel_axis]:
                    residual = point_conv(filters, name=f"{prefix}_conv")(
                        residual
                    )
                    residual = norm(name=f"{prefix}_bn")(residual)
                x = layers.Add(name=f"{prefix}_add", dtype=dtype)([x, residual])

        super().__init__(
            inputs=image_input,
            outputs=x,
            dtype=dtype,
            **kwargs,
        )

        # === Config ===
        self.stackwise_conv_filters = stackwise_conv_filters
        self.stackwise_pooling = stackwise_pooling
        self.image_shape = image_shape
        self.data_format = data_format

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "stackwise_conv_filters": self.stackwise_conv_filters,
                "stackwise_pooling": self.stackwise_pooling,
                "image_shape": self.image_shape,
            }
        )
        return config


@keras_hub_export("keras_hub.layers.XceptionImageConverter")
class XceptionImageConverter(ImageConverter):
    """Image converter for Xception models that handles legacy parameters."""
    
    backbone_cls = XceptionBackbone
    
    def __init__(self, antialias=None, **kwargs):
        """Initialize XceptionImageConverter.
        
        Args:
            antialias: Legacy parameter that is ignored for compatibility.
            **kwargs: Arguments passed to the parent ImageConverter.
        """
        # Remove antialias from kwargs if it exists (for compatibility)
        kwargs.pop('antialias', None)
        super().__init__(**kwargs)
    
    @classmethod
    def from_config(cls, config):
        """Create layer from config, handling legacy antialias parameter."""
        # Remove antialias from config if it exists
        config = config.copy()
        config.pop('antialias', None)
        return super().from_config(config)
    
    def get_config(self):
        """Get configuration dict, excluding antialias for compatibility."""
        config = super().get_config()
        # Remove antialias if it exists in the config
        config.pop('antialias', None)
        return config


# SOLUTION 1: Direct base class patching (most reliable)
def patch_image_converter_for_antialias():
    """Permanently patch ImageConverter to handle antialias parameter."""
    
    # Store original methods
    if not hasattr(ImageConverter, '_original_init'):
        ImageConverter._original_init = ImageConverter.__init__
        ImageConverter._original_from_config = ImageConverter.from_config
    
    def patched_init(self, antialias=None, **kwargs):
        """Patched init that ignores antialias parameter."""
        kwargs.pop('antialias', None)  # Remove antialias if present
        ImageConverter._original_init(self, **kwargs)
    
    @classmethod
    def patched_from_config(cls, config):
        """Patched from_config that ignores antialias parameter."""
        config = config.copy()
        config.pop('antialias', None)  # Remove antialias if present
        return ImageConverter._original_from_config(config)
    
    # Apply patches
    ImageConverter.__init__ = patched_init
    ImageConverter.from_config = patched_from_config
    print("✓ Successfully patched ImageConverter to handle antialias parameter")


def create_xception_preprocessor_with_preset():
    """Create Xception preprocessor from preset after patching."""
    import keras_hub
    
    # Apply the patch first
    patch_image_converter_for_antialias()
    
    # Now load the preset
    preprocessor = keras_hub.layers.ImageConverter.from_preset(
        "hf://keras/xception_41_imagenet",
        image_size=(180, 180),
    )
    print("✓ Successfully loaded preprocessor from preset!")
    return preprocessor


# SOLUTION 2: Manual creation with exact preset settings
def create_xception_preprocessor_manual():
    """Create Xception preprocessor manually with exact preset settings."""
    # These are the exact values from the preset config in the error message
    preprocessor = ImageConverter(
        image_size=(180, 180),
        scale=0.00784313725490196,  # Exact value from preset
        offset=-1.0,
        interpolation="bilinear",
        crop_to_aspect_ratio=True,
        pad_to_aspect_ratio=False,
        bounding_box_format="yxyx",
        name="image_converter"
    )
    print("✓ Created ImageConverter manually with exact preset settings!")
    return preprocessor


# SOLUTION 3: Simple Keras preprocessing pipeline
def create_simple_xception_preprocessor():
    """Create a simple preprocessing pipeline for Xception."""
    import keras
    
    preprocessor = keras.Sequential([
        keras.layers.Resizing(180, 180, interpolation="bilinear", crop_to_aspect_ratio=True),
        keras.layers.Rescaling(scale=0.00784313725490196, offset=-1.0)  # Exact preset values
    ], name="xception_preprocessor")
    
    print("✓ Created simple Keras preprocessing pipeline!")
    return preprocessor


# SOLUTION 4: Complete Xception model loading workaround
def load_complete_xception_model():
    """Load the complete Xception model and extract preprocessor settings."""
    import keras_hub
    
    try:
        # Try to load just the backbone first
        backbone = keras_hub.models.XceptionBackbone.from_preset("hf://keras/xception_41_imagenet")
        print("✓ Successfully loaded Xception backbone!")
        
        # Create compatible preprocessor
        preprocessor = create_simple_xception_preprocessor()
        
        return backbone, preprocessor
        
    except Exception as e:
        print(f"✗ Could not load backbone: {e}")
        return None, create_simple_xception_preprocessor()


# Auto-select best solution
def get_xception_preprocessor():
    """Automatically get a working Xception preprocessor."""
    
    print("Attempting to resolve Xception preprocessor...\n")
    
    # Try Solution 1: Preset with patch
    try:
        preprocessor = create_xception_preprocessor_with_preset()
        return preprocessor, "preset_with_patch"
    except Exception as e:
        print(f"✗ Preset loading failed: {str(e)[:100]}...\n")
    
    # Try Solution 2: Manual creation
    try:
        preprocessor = create_xception_preprocessor_manual()
        return preprocessor, "manual_imageconverter"
    except Exception as e:
        print(f"✗ Manual ImageConverter failed: {str(e)[:100]}...\n")
    
    # Fallback to Solution 3: Simple pipeline
    try:
        preprocessor = create_simple_xception_preprocessor()
        return preprocessor, "simple_pipeline"
    except Exception as e:
        print(f"✗ Simple pipeline failed: {str(e)[:100]}...\n")
        raise Exception("All preprocessing solutions failed!")


# Usage
print("=== Xception Preprocessor Setup ===\n")

try:
    preprocessor, method = get_xception_preprocessor()
    print(f"\n✓ Success using method: {method}")
    print(f"Preprocessor type: {type(preprocessor)}")
    
    # Test the preprocessor
    import numpy as np
    test_image = np.random.randint(0, 255, (1, 224, 224, 3), dtype=np.uint8)
    
    processed = preprocessor(test_image)
    print(f"\n✓ Preprocessing test successful!")
    print(f"  Input shape: {test_image.shape}")
    print(f"  Output shape: {processed.shape}")
    print(f"  Output range: [{np.min(processed):.4f}, {np.max(processed):.4f}]")
    print(f"  Expected range for Xception: [-1.0, 1.0]")
    
    # Check if output is in correct range for Xception
    if -1.1 <= np.min(processed) <= -0.9 and 0.9 <= np.max(processed) <= 1.1:
        print("  ✓ Output range looks correct for Xception!")
    else:
        print("  ⚠ Output range may not be optimal for Xception")
        
except Exception as e:
    print(f"✗ All solutions failed: {e}")
    print("\nTry updating your packages:")
    print("pip install --upgrade keras-hub tensorflow")

  from .autonotebook import tqdm as notebook_tqdm


=== Xception Preprocessor Setup ===

Attempting to resolve Xception preprocessor...

✓ Successfully patched ImageConverter to handle antialias parameter
✓ Successfully loaded preprocessor from preset!

✓ Success using method: preset_with_patch
Preprocessor type: <class 'keras_hub.src.layers.preprocessing.image_converter.ImageConverter'>
✗ All solutions failed: Exception encountered when calling ImageConverter.call().

[1mcannot compute Mul as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Mul] name: [0m

Arguments received by ImageConverter.call():
  • inputs=array([[[[149, 218,  66],
         [165,   0, 151],
         [ 12, 184, 232],
         ...,
         [ 19, 161, 226],
         [ 30, 163, 191],
         [ 29,   3, 132]],

        [[ 87,  37, 110],
         [162,   7, 134],
         [ 32, 112, 159],
         ...,
         [242, 178, 227],
         [ 75, 105,  23],
         [162, 108,  95]],

        [[225, 253,  14],
         [ 50, 152, 163],
  

Could not follow the tutorial, it seems like xception_41_imagenet and it's backbone no longer exista in newer version of Keras. I ended up copying the definition of XceptionBackbone straight from source: https://github.com/keras-team/keras-hub/blob/v0.21.1/keras_hub/src/models/xception/xception_backbone.py#L10
and model straight from hugging_face: https://huggingface.co/keras/xception_41_imagenet

In [2]:
model = XceptionBackbone.from_preset("hf://keras/xception_41_imagenet")

In [3]:
preprocessor = XceptionImageConverter.from_preset(
    "hf://keras/xception_41_imagenet",
    image_size=(180, 180),
)

In [4]:
import keras 
for layer in model.layers:
    if isinstance(layer, (keras.layers.Conv2D, keras.layers.SeparableConv2D)):
        print(layer.name)

block1_conv1
block1_conv2
block2_sepconv1
block2_sepconv2
block2_residual_conv
block3_sepconv1
block3_sepconv2
block3_residual_conv
block4_sepconv1
block4_sepconv2
block4_residual_conv
block5_sepconv1
block5_sepconv2
block5_sepconv3
block6_sepconv1
block6_sepconv2
block6_sepconv3
block7_sepconv1
block7_sepconv2
block7_sepconv3
block8_sepconv1
block8_sepconv2
block8_sepconv3
block9_sepconv1
block9_sepconv2
block9_sepconv3
block10_sepconv1
block10_sepconv2
block10_sepconv3
block11_sepconv1
block11_sepconv2
block11_sepconv3
block12_sepconv1
block12_sepconv2
block12_sepconv3
block13_sepconv1
block13_sepconv2
block13_residual_conv
block14_sepconv1
block14_sepconv2


In [5]:
layer_name = "block3_sepconv1"
layer = model.get_layer(name=layer_name)
feature_extractor = keras.Model(inputs=model.input, outputs=layer.output)

In [6]:
img_path = keras.utils.get_file(
    fname="cat.jpg", origin="https://img-datasets.s3.amazonaws.com/cat.jpg"
)

def get_img_array(img_path, target_size):
    img = keras.utils.load_img(img_path, target_size=target_size)
    array = keras.utils.img_to_array(img)
    array = np.expand_dims(array, axis=0)
    return array

img_tensor = get_img_array(img_path, target_size=(180, 180))

In [7]:
activation = feature_extractor(preprocessor(img_tensor))

In [8]:
from keras import ops

def compute_loss(image, filter_index):
    activation = feature_extractor(image)
    filter_activation = activation[:, 2:-2, 2:-2, filter_index]
    return ops.mean(filter_activation)