In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_probability as tfp

# === Part 1: Second Meta-Learner Blueprints ===

def build_gnn_meta_learner(num_models, num_classes, hidden_dim=128, num_layers=2):
    """
    Graph-based meta learner: treats each base model output as a node,
    learns interactions via message passing.
    Input shape: (batch, num_models, num_classes)
    Output: (batch, num_classes)
    """
    inputs = layers.Input(shape=(num_models, num_classes), name="gnn_inputs")
    x = inputs
    for i in range(num_layers):
        # node-wise transform
        node_feats = layers.Dense(hidden_dim, activation='relu')(x)
        # compute adjacency (learned) via similarity
        # Here: simple self-attention as adjacency
        attn_scores = tf.matmul(node_feats, node_feats, transpose_b=True)
        attn_weights = tf.nn.softmax(attn_scores, axis=-1)
        # message passing
        x = tf.matmul(attn_weights, node_feats)
    # readout: mean over nodes
    readout = tf.reduce_mean(x, axis=1)
    outputs = layers.Dense(num_classes, activation='softmax', name="gnn_meta_output")(readout)
    return Model(inputs, outputs, name="GNN_MetaLearner")


def build_tabnet_meta_learner(num_models, num_classes, feature_dim=128, num_steps=3, relaxation=1.5):
    """
    Simplified TabNet: sequential feature masking + shared feature transformer.
    Input flattened: (batch, num_models * num_classes)
    """
    from tensorflow_addons.layers import FeatureTransformer  # hypothetical
    inputs = layers.Input(shape=(num_models * num_classes,), name="tabnet_inputs")
    masks = []
    aggregated = 0
    shared_block = layers.Dense(feature_dim, activation='relu')
    for step in range(num_steps):
        # compute mask
        mask = layers.Dense(num_models * num_classes, activation='softmax', name=f"mask_{step}")(inputs)
        # relax mask for exploration
        mask = tf.pow(mask, 1.0 / relaxation)
        masks.append(mask)
        # apply mask
        masked_x = layers.Multiply()([inputs, mask])
        # feature transformer
        transformed = shared_block(masked_x)
        aggregated = aggregated + transformed
    outputs = layers.Dense(num_classes, activation='softmax', name="tabnet_meta_output")(aggregated)
    return Model(inputs, outputs, name="TabNet_MetaLearner")


def build_moe_meta_learner(num_models, num_classes, expert_units=128, num_experts=4, k=2):
    """
    Mixture of Experts meta learner with k-sparse gating.
    Input: (batch, num_models * num_classes)
    """
    inputs = layers.Input(shape=(num_models * num_classes,), name="moe_inputs")
    # gating network
    gate_logits = layers.Dense(num_experts, name="gate_logits")(inputs)
    # k-sparse selection
    top_k = tf.math.top_k(gate_logits, k=k, sorted=False)
    mask = tf.reduce_sum(tf.one_hot(top_k.indices, depth=num_experts), axis=1)
    gate_weights = tf.nn.softmax(gate_logits) * mask
    gate_weights = gate_weights / tf.reduce_sum(gate_weights, axis=-1, keepdims=True)
    # experts
    expert_outputs = []
    for i in range(num_experts):
        expert = layers.Dense(expert_units, activation='relu', name=f"expert_{i}")(inputs)
        expert = layers.Dense(num_classes, name=f"expert_out_{i}")(expert)
        expert_outputs.append(expert)
    stack = tf.stack(expert_outputs, axis=-1)  # shape: (batch, num_classes, num_experts)
    weighted = tf.matmul(stack, tf.expand_dims(gate_weights, -1))  # (batch, num_classes, 1)
    outputs = tf.squeeze(weighted, axis=-1, name="moe_meta_output")
    outputs = layers.Activation('softmax')(outputs)
    return Model(inputs, outputs, name="MoE_MetaLearner")


# === Part 2: Meta-Meta Learner with Gumbel-Softmax Routing ===

def build_meta_meta_learner(meta_models, num_classes, temperature=0.5):
    """
    Combines multiple meta learners via Gumbel-Softmax routing.
    meta_models: list of Keras models that output (batch, num_classes)
    """
    # aggregate inputs: assume all meta_models share same input
    meta_inputs = layers.Input(shape=meta_models[0].input_shape[1:], name="meta_meta_input")
    # get meta outputs
    meta_outputs = [m(meta_inputs) for m in meta_models]
    # stack outputs: shape (batch, num_classes, n_meta)
    stack = tf.stack(meta_outputs, axis=-1)
    # routing logits
    routing_logits = layers.Dense(len(meta_models), name="routing_logits")(layers.Flatten()(meta_inputs))
    # Gumbel-Softmax sampling
    gumbel = tfp.distributions.RelaxedOneHotCategorical(temperature, logits=routing_logits)
    weights = gumbel.sample()
    # apply weights: (batch, num_meta)
    weights_exp = tf.expand_dims(weights, axis=1)  # (batch,1,n_meta)
    mixed = tf.matmul(stack, weights_exp)  # (batch,num_classes,1)
    final_out = tf.squeeze(mixed, axis=-1)
    final_out = layers.Activation('softmax', name="final_output")(final_out)
    return Model(meta_inputs, final_out, name="MetaMetaLearner")


# === Part 3: Training Loop with Dynamic Meta-Learner Weighting ===

def train_meta_stack(base_data, labels, base_models, meta_models, meta_meta_model,
                     epochs=10, batch_size=32, lr=1e-3):
    """
    base_data: raw inputs to feed base models
    meta_models: [transformer_meta, second_meta]
    meta_meta_model: final arbiter
    """
    opt = tf.keras.optimizers.Adam(lr)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

    # metrics
    train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for step, (x_batch, y_batch) in enumerate(tf.data.Dataset.from_tensor_slices((base_data, labels)).batch(batch_size)):
            with tf.GradientTape() as tape:
                # gather base predictions
                base_preds = [m(x_batch, training=False) for m in base_models]
                # stack base preds for meta input
                meta_input = tf.concat(base_preds, axis=-1)
                # meta outputs
                meta_outs = [mm(meta_input, training=True) for mm in meta_models]
                # meta-meta output
                final_pred = meta_meta_model(meta_input, training=True)
                # compute loss
                loss = loss_fn(y_batch, final_pred)
            # compute grads only for meta and meta-meta
            trainable_vars = []
            for mm in meta_models + [meta_meta_model]:
                trainable_vars += mm.trainable_variables
            grads = tape.gradient(loss, trainable_vars)
            opt.apply_gradients(zip(grads, trainable_vars))

            # update metrics
            train_acc.update_state(y_batch, final_pred)
            if step % 100 == 0:
                print(f"Step {step}: loss = {loss.numpy():.4f}, acc = {train_acc.result().numpy():.4f}")
        print(f"Epoch {epoch+1} Accuracy: {train_acc.result().numpy():.4f}")
        train_acc.reset_states()
