In [15]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import EfficientNetB0

# Transformer Encoder Block
class TransformerEncoder(Model):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super(TransformerEncoder, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation='relu'),
            layers.Dense(embed_dim)
        ])
        self.add1 = layers.Add()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.add2 = layers.Add()
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        attn_output = self.attention(inputs, inputs)
        out1 = self.add1([inputs, attn_output])
        out1 = self.norm1(out1)

        ffn_output = self.ffn(out1)
        out2 = self.add2([out1, ffn_output])
        return self.norm2(out2)

# CNN-Transformer Interaction Module
class CTIModule(Model):
    def __init__(self, cnn_dim, transformer_dim):
        super(CTIModule, self).__init__()
        self.cnn_to_transformer = layers.Dense(transformer_dim)
        self.transformer_to_cnn = layers.Dense(cnn_dim)

    def call(self, cnn_features, transformer_features):
        cnn_flat = layers.GlobalAveragePooling2D()(cnn_features)
        cnn_transformed = self.cnn_to_transformer(cnn_flat)
        cnn_transformed = tf.expand_dims(cnn_transformed, axis=1)

        transformer_pooled = tf.reduce_mean(transformer_features, axis=1)
        transformer_transformed = self.transformer_to_cnn(transformer_pooled)

        fused_features = cnn_transformed + tf.expand_dims(transformer_transformed, axis=1)
        return fused_features

# Hybrid Model with split + CTIModule + Conv1D manual block
class HybridModelSplitCTI(Model):
    def __init__(self, num_classes):
        super(HybridModelSplitCTI, self).__init__()
        self.cnn = EfficientNetB0(include_top=False, input_shape=(224, 224, 3), weights="imagenet")

        # Projection to embed_dim
        self.project_to_embed = layers.Dense(128)

        self.transformer_encoder = TransformerEncoder(embed_dim=128, num_heads=8, ff_dim=256)

        # Manual efficient block with Conv1D
        self.manual_efficient_conv1d = layers.Conv1D(64, kernel_size=3, activation='relu', padding='same')
        self.manual_efficient_pool = layers.GlobalAveragePooling1D()
        self.manual_efficient_dense = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(128, activation='relu')
        ])

        # CTI Module integration
        self.cti_module = CTIModule(cnn_dim=128, transformer_dim=128)

        self.classifier = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(num_classes, activation='softmax')
        ])

    def call(self, inputs, return_all=False):
        cnn_features = self.cnn(inputs)  # (batch, H, W, 1280)

        # Flatten and project
        batch_size = tf.shape(cnn_features)[0]
        flatten_tokens = tf.reshape(cnn_features, [batch_size, -1, cnn_features.shape[-1]])  # (batch, tokens, 1280)
        flatten_tokens_proj = self.project_to_embed(flatten_tokens)  # (batch, tokens, 128)

        # Split tokens
        total_tokens = tf.shape(flatten_tokens_proj)[1]
        split_point = total_tokens // 2

        q1 = flatten_tokens_proj[:, :split_point, :]  # Q1 → Transformer
        q2 = flatten_tokens_proj[:, split_point:, :]  # Q2 → Manual block

        # Process Q1 through transformer
        q1_out = self.transformer_encoder(q1)

        # Process Q2 through manual efficient block with Conv1D
        q2_conv = self.manual_efficient_conv1d(q2)  # (batch, tokens, 64)
        q2_pooled = self.manual_efficient_pool(q2_conv)  # (batch, 64)
        q2_out = self.manual_efficient_dense(q2_pooled)  # (batch, 128)

        # CTI module fusion (cnn_features + transformer features)
        cti_out = self.cti_module(cnn_features, q1_out)
        cti_out_pooled = tf.reduce_mean(cti_out, axis=1)

        # Fuse CTI output with manual efficient output
        fused = tf.concat([cti_out_pooled, q2_out], axis=-1)

        output = self.classifier(fused)

        if return_all:
            return {
                'cnn_features': cnn_features,
                'flatten_tokens_proj': flatten_tokens_proj,
                'q1': q1,
                'q1_out': q1_out,
                'q2': q2,
                'q2_conv': q2_conv,
                'q2_pooled': q2_pooled,
                'q2_out': q2_out,
                'cti_out': cti_out,
                'cti_out_pooled': cti_out_pooled,
                'fused': fused,
                'final_output': output
            }
        else:
            return output

# === Instantiate and print summary ===
num_classes = 8  # Replace with your dataset class count
input_shape = (224, 224, 3)

inputs = tf.keras.Input(shape=input_shape)
model_cti = HybridModelSplitCTI(num_classes=num_classes)
outputs = model_cti(inputs)
final_model_cti = tf.keras.Model(inputs=inputs, outputs=outputs)

final_model_cti.summary()

# === Test intermediate outputs ===

batch = tf.random.normal((8, 224, 224, 3))
outputs_all = model_cti(batch, return_all=True)

for key, value in outputs_all.items():
    print(f"{key}: shape {value.shape}")


cnn_features: shape (8, 7, 7, 1280)
flatten_tokens_proj: shape (8, 49, 128)
q1: shape (8, 24, 128)
q1_out: shape (8, 24, 128)
q2: shape (8, 25, 128)
q2_conv: shape (8, 25, 64)
q2_pooled: shape (8, 64)
q2_out: shape (8, 128)
cti_out: shape (8, 1, 128)
cti_out_pooled: shape (8, 128)
fused: shape (8, 256)
final_output: shape (8, 8)
