In [None]:
#############################################################
# MAML Implementation
#############################################################

class MAMLTransformer:
    def __init__(
        self,
        input_shape,
        n_way=5,
        k_shot=5,
        inner_lr=0.01,
        meta_lr=0.001,
        meta_batch_size=32,
        num_inner_updates=5,
        embed_dim=256,
        num_heads=8,
        ff_dim=512,
        num_transformer_blocks=4,
        mlp_units=[128, 64],
        dropout=0.1,
        name="maml_transformer"
    ):
        self.input_shape = input_shape
        self.n_way = n_way
        self.k_shot = k_shot
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.meta_batch_size = meta_batch_size
        self.num_inner_updates = num_inner_updates
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.num_transformer_blocks = num_transformer_blocks
        self.mlp_units = mlp_units
        self.dropout = dropout
        self.name = name

        # Build the model
        self.model = self._build_model()

        # Initialize optimizer
        self.meta_optimizer = keras.optimizers.Adam(learning_rate=self.meta_lr)

        # Store the original weights
        self.meta_weights = self.model.get_weights()

    def _build_model(self):
        """Build the transformer-based model for few-shot learning"""
        inputs = layers.Input(shape=self.input_shape)

        # Embedding layer to transform input features to embedding space
        x = layers.Dense(self.embed_dim)(inputs)

        # Reshape for transformer if input is not sequential
        # For sequential data, we would reshape to (batch_size, seq_length, embed_dim)
        # For tabular data, we treat each feature as a "token"
        x = tf.expand_dims(x, axis=1)  # Add sequence dimension of length 1

        # Apply positional encoding
        x = PositionalEncoding(position=50, d_model=self.embed_dim)(x)

        # Transformer blocks
        for _ in range(self.num_transformer_blocks):
            x = TransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                ff_dim=self.ff_dim,
                rate=self.dropout
            )(x)

        # Global average pooling
        x = layers.GlobalAveragePooling1D()(x)

        # MLP for classification
        for dim in self.mlp_units:
            x = layers.Dense(dim, activation="relu")(x)
            x = layers.Dropout(self.dropout)(x)

        # Output layer
        outputs = layers.Dense(self.n_way, activation="softmax")(x)

        return keras.Model(inputs, outputs, name=self.name)

    def train_on_batch(self, batch_of_tasks):
        """
        Train on a batch of tasks using MAML

        Args:
            batch_of_tasks: list of tasks, each containing support and query sets
        """
        meta_batch_size = min(len(batch_of_tasks), self.meta_batch_size)

        with tf.GradientTape() as meta_tape:
            total_meta_loss = tf.constant(0.0, dtype=tf.float32)

            for i in range(meta_batch_size):
                task = batch_of_tasks[i]
                support_X = task['support_X']
                support_y = task['support_y']
                query_X = task['query_X']
                query_y = task['query_y']

                # Convert to one-hot
                support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)
                query_y_one_hot = tf.one_hot(query_y, depth=self.n_way)

                # Store the original weights
                original_weights = self.model.get_weights()

                # Inner loop - adapt to the task
                for _ in range(self.num_inner_updates):
                    with tf.GradientTape() as inner_tape:
                        support_logits = self.model(support_X, training=True)
                        support_loss = keras.losses.categorical_crossentropy(
                            support_y_one_hot, support_logits
                        )
                        support_loss = tf.reduce_mean(support_loss)

                    # Compute gradients and update model weights
                    gradients = inner_tape.gradient(support_loss, self.model.trainable_variables)

                    # Manual weight update for inner loop
                    updated_weights = []
                    for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                        if grad is not None:
                            updated_weights.append(weight - self.inner_lr * grad)
                        else:
                            updated_weights.append(weight)

                    # Set the updated weights
                    self.model.set_weights(updated_weights)

                # Evaluate on query set with the adapted model
                query_logits = self.model(query_X, training=True)
                query_loss = keras.losses.categorical_crossentropy(
                    query_y_one_hot, query_logits
                )
                query_loss = tf.reduce_mean(query_loss)

                # Add to meta loss
                total_meta_loss += query_loss

                # Restore original weights
                self.model.set_weights(original_weights)

            # Average meta loss
            total_meta_loss /= meta_batch_size

        # Compute gradients of meta loss with respect to meta weights
        meta_gradients = meta_tape.gradient(total_meta_loss, self.model.trainable_variables)

        # Apply meta gradients
        self.meta_optimizer.apply_gradients(zip(meta_gradients, self.model.trainable_variables))

        # Update stored meta weights
        self.meta_weights = self.model.get_weights()

        return total_meta_loss.numpy()

    def evaluate(self, tasks, num_inner_updates=None):
        """
        Evaluate the meta-model on a list of tasks

        Args:
            tasks: list of tasks to evaluate on
            num_inner_updates: number of inner updates to perform (if None, use self.num_inner_updates)

        Returns:
            mean accuracy across tasks
        """
        if num_inner_updates is None:
            num_inner_updates = self.num_inner_updates

        accuracies = []
        losses = []

        # Restore meta weights
        self.model.set_weights(self.meta_weights)

        for task in tasks:
            support_X = task['support_X']
            support_y = task['support_y']
            query_X = task['query_X']
            query_y = task['query_y']

            # Convert to one-hot
            support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)

            # Store the original weights
            original_weights = self.model.get_weights()

            # Inner loop adaptation
            for _ in range(num_inner_updates):
                with tf.GradientTape() as tape:
                    support_logits = self.model(support_X, training=True)
                    support_loss = keras.losses.categorical_crossentropy(
                        support_y_one_hot, support_logits
                    )
                    support_loss = tf.reduce_mean(support_loss)

                # Compute gradients and update model weights
                gradients = tape.gradient(support_loss, self.model.trainable_variables)

                # Manual weight update
                updated_weights = []
                for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                    if grad is not None:
                        updated_weights.append(weight - self.inner_lr * grad)
                    else:
                        updated_weights.append(weight)

                # Set the updated weights
                self.model.set_weights(updated_weights)

            # Evaluate on query set
            query_logits = self.model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()
            accuracy = np.mean(pred_y == query_y)

            # Calculate loss
            query_y_one_hot = tf.one_hot(query_y, depth=self.n_way)
            loss = keras.losses.categorical_crossentropy(query_y_one_hot, query_logits)
            loss = tf.reduce_mean(loss).numpy()

            accuracies.append(accuracy)
            losses.append(loss)

            # Restore original weights
            self.model.set_weights(original_weights)

        # Return mean accuracy and loss
        return np.mean(accuracies), np.mean(losses)

    def adapt_to_task(self, support_X, support_y, num_inner_updates=None):
        """
        Adapt the model to a new task using the support set

        Args:
            support_X: support set features
            support_y: support set labels
            num_inner_updates: number of inner updates to perform (if None, use self.num_inner_updates)

        Returns:
            Adapted model
        """
        if num_inner_updates is None:
            num_inner_updates = self.num_inner_updates

        # Convert class indices to one-hot
        support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)

        # Reset to the meta weights
        self.model.set_weights(self.meta_weights)

        # Inner loop adaptation
        for _ in range(num_inner_updates):
            with tf.GradientTape() as tape:
                support_logits = self.model(support_X, training=True)
                support_loss = keras.losses.categorical_crossentropy(
                    support_y_one_hot, support_logits
                )
                support_loss = tf.reduce_mean(support_loss)

            # Compute gradients
            gradients = tape.gradient(support_loss, self.model.trainable_variables)

            # Manual weight update
            updated_weights = []
            for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                if grad is not None:
                    updated_weights.append(weight - self.inner_lr * grad)
                else:
                    updated_weights.append(weight)

            # Set the updated weights
            self.model.set_weights(updated_weights)

        return self.model

    def save_meta_model(self, filepath):
        """Save the meta-model weights"""
        # Set to meta weights before saving
        self.model.set_weights(self.meta_weights)
        self.model.save_weights(filepath)

    def load_meta_model(self, filepath):
        """Load the meta-model weights"""
        self.model.load_weights(filepath)
        self.meta_weights = self.model.get_weights()
