In [None]:
import numpy as np
import tensorflow as tf
import model


class MetaLearner:
    def __init__(self, build_model_fn, custom_loss_fn, num_classes, input_shape, meta_iters=100, meta_step_size=0.3, alpha=0.5):
        self.build_model_fn = build_model_fn
        self.custom_loss_fn = custom_loss_fn
        self.meta_iters = meta_iters
        self.meta_step_size = meta_step_size
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.alpha = alpha

        self.model = model.build_model(self.num_classes)
        self.model.compile(loss=lambda y_true, y_pred: self.custom_loss_fn(y_true, y_pred, self.alpha),
                           metrics=['accuracy'])

    def build_model(self):
        return self.build_model_fn(self.num_classes, self.input_shape)

    def calculate_prototypes(self, embeddings, labels):
        prototypes = []
        for class_idx in range(self.num_classes):
            class_embeddings = embeddings[labels == class_idx]
            prototype = np.mean(class_embeddings, axis=0)
            prototypes.append(prototype)
        return np.array(prototypes)


    def train(self, X_train, y_train, X_test, y_test, get_data_fn, K_shot):
        accuracies = []

        for meta_iter in range(self.meta_iters):
            frac_done = meta_iter / self.meta_iters
            cur_step_size = (1 - frac_done) * self.meta_step_size
            old_weights = self.model.get_weights()

            train_data, test_data, proto_X, proto_y = get_data_fn(X_train, y_train, X_test, y_test, K_shot, split=True)

            # Step 1: Get prototypes
            proto_embeddings = self.model.predict(proto_X, verbose=0)
            proto_labels = tf.argmax(proto_y, axis=1)
            prototypes = calculate_prototypes(proto_embeddings, proto_labels, self.num_classes)

            # Step 2: Set custom loss using current prototypes
            self.model.compile(
                loss=make_custom_loss(prototypes, alpha=self.alpha),
                metrics=['accuracy']
            )

            # Step 3: Train on one episode
            result = self.model.fit(train_data, epochs=100, validation_data=test_data, verbose=0)
            val_acc = np.max(result.history['val_accuracy'])
            accuracies.append(val_acc)

            print(f"Meta Iter {meta_iter+1}/{self.meta_iters} - Val Acc: {val_acc:.4f}, Max: {np.max(accuracies):.4f}")

            # Step 4: Meta-update
            new_weights = self.model.get_weights()
            updated_weights = [
                old + (new - old) * cur_step_size
                for old, new in zip(old_weights, new_weights)
            ]
            self.model.set_weights(updated_weights)

        return accuracies


In [None]:

def cosine_distance(a, b):
    a = tf.math.l2_normalize(a, axis=-1)
    b = tf.math.l2_normalize(b, axis=-1)
    return 1 - tf.reduce_sum(a * b, axis=-1)

def mahalanobis_distance(x, y, cov_inv):
    diff = x - y
    left = tf.matmul(diff, cov_inv)
    dist = tf.reduce_sum(left * diff, axis=-1)
    return dist



def calculate_prototypes(embeddings, labels, num_classes):
    """Calculate mean embedding (prototype) for each class."""
    prototypes = []
    for class_idx in range(num_classes):
        class_embeddings = embeddings[labels == class_idx]
        prototype = tf.reduce_mean(class_embeddings, axis=0)
        prototypes.append(prototype)
    return tf.stack(prototypes)

def prototype_loss(embedding, labels, prototypes):
    """Compute distance between embedding and true class prototype."""
    distances = tf.norm(embedding[:, None, :] - prototypes[None, :, :], axis=-1)  #L2 dist.
    #distances = tf.reduce_sum(tf.abs(embedding[:, None, :] - prototypes[None, :, :]), axis=-1)  #L1 dist.
    #distances = cosine_distance(embedding[:, None, :], prototypes[None, :, :]) #cosine sim.
    #prototype = tf.math.l2_normalize(tf.reduce_mean(class_embeddings, axis=0), axis=0) => add if dist. is cosine sim.

    class_indices = tf.argmax(labels, axis=-1)
    true_distances = tf.gather_nd(distances, tf.stack([tf.range(tf.shape(distances)[0]), class_indices], axis=1))
    return tf.reduce_mean(true_distances)

def make_custom_loss(prototypes, alpha=0.5):
    """
    Return a custom loss function combining cross-entropy and prototype loss.
    This is needed because prototype is fixed per task/episode.
    """
    def loss(y_true, y_pred):
        class_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        proto_loss = prototype_loss(y_pred, y_true, prototypes)
        return alpha * class_loss + (1 - alpha) * proto_loss
    return loss


In [None]:
meta = MetaLearner(
    build_model_fn=build_model,                  # Your model builder function
    custom_loss_fn=make_custom_loss,             # Prototype-aware loss generator
    num_classes=num_classes,                     # E.g. 2
    input_shape=X_train.shape[1:],               # E.g. (63,) or (3, 21, 1)
    meta_iters=100,                              # Number of meta-training loops
    meta_step_size=0.3,                          # Reptile meta step
    alpha=0.7                                     # Balancing classification vs prototype loss
)

In [None]:
accuracies = meta.train(
    X_train=X_train,             # full training set
    y_train=y_train,             # one-hot encoded
    X_test=X_test,               # meta-validation set
    y_test=y_test,
    get_data_fn=get_data,        # your K-shot episodic data function
    K_shot=5                     # few-shot setting (e.g., 5-shot)
)

In [None]:
import matplotlib.pyplot as plt
plt.plot(accuracies)
plt.xlabel("Meta Iteration")
plt.ylabel("Validation Accuracy")
plt.title("Reptile Meta-Learning Accuracy")
plt.grid(True)
plt.show()