In [None]:
#############################################################
# Training and Evaluation
#############################################################

class MAMLTrainer:
    def __init__(
        self,
        maml_model,
        train_tasks,
        val_tasks,
        test_tasks=None,
        meta_epochs=10000,
        meta_batch_size=16,
        eval_interval=100,
        early_stopping_patience=10,
        log_dir='logs'
    ):
        self.maml_model = maml_model
        self.train_tasks = train_tasks
        self.val_tasks = val_tasks
        self.test_tasks = test_tasks
        self.meta_epochs = meta_epochs
        self.meta_batch_size = meta_batch_size
        self.eval_interval = eval_interval
        self.early_stopping_patience = early_stopping_patience
        self.log_dir = log_dir

        # Create log directory if it doesn't exist
        os.makedirs(log_dir, exist_ok=True)

        # Initialize training history
        self.history = {
            'train_loss': [],
            'val_accuracy': [],
            'val_loss': []
        }

        # Initialize early stopping variables
        self.best_val_accuracy = 0
        self.patience_counter = 0
        self.best_weights = None

    def train(self):
        """Train the MAML model"""
        print("Starting meta-training...")

        start_time = time.time()

        for epoch in range(self.meta_epochs):
            # Sample batch of tasks
            batch_indices = np.random.choice(
                len(self.train_tasks),
                min(self.meta_batch_size, len(self.train_tasks)),
                replace=False
            )
            batch_of_tasks = [self.train_tasks[i] for i in batch_indices]

            # Train on batch of tasks
            loss = self.maml_model.train_on_batch(batch_of_tasks)
            self.history['train_loss'].append(loss)

            # Evaluate periodically
            if (epoch + 1) % self.eval_interval == 0:
                val_accuracy, val_loss = self.maml_model.evaluate(self.val_tasks)
                self.history['val_accuracy'].append(val_accuracy)
                self.history['val_loss'].append(val_loss)

                elapsed_time = time.time() - start_time
                print(f"Epoch {epoch+1}/{self.meta_epochs} - "
                      f"Loss: {loss:.4f} - "
                      f"Val Accuracy: {val_accuracy:.4f} - "
                      f"Val Loss: {val_loss:.4f} - "
                      f"Time: {elapsed_time:.2f}s")

                # Check for early stopping
                if val_accuracy > self.best_val_accuracy:
                    self.best_val_accuracy = val_accuracy
                    self.patience_counter = 0
                    self.best_weights = self.maml_model.model.get_weights()
                    # Save best model
                    self.maml_model.save_meta_model(os.path.join(self.log_dir, 'best_model.h5'))
                else:
                    self.patience_counter += 1

                if self.patience_counter >= self.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        # Restore best weights
        if self.best_weights is not None:
            self.maml_model.model.set_weights(self.best_weights)
            self.maml_model.meta_weights = self.best_weights

        print(f"Meta-training completed in {time.time() - start_time:.2f}s")

        # Final evaluation on test set if available
        if self.test_tasks is not None:
            test_accuracy, test_loss = self.maml_model.evaluate(self.test_tasks)
            print(f"Test Accuracy: {test_accuracy:.4f} - Test Loss: {test_loss:.4f}")

        return self.history

    def visualize_training(self):
        """Visualize the training history"""
        plt.figure(figsize=(15, 5))

        # Plot training loss
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Training Loss')
        plt.plot(np.arange(0, len(self.history['train_loss']), self.eval_interval)[:-1],
                 self.history['val_loss'], 'r-', label='Validation Loss')
        plt.title('Meta-Learning Loss')
        plt.xlabel('Meta-Iterations')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        # Plot validation accuracy
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(0, len(self.history['train_loss']), self.eval_interval)[:-1],
                 self.history['val_accuracy'], 'g-', label='Validation Accuracy')
        plt.title('Few-Shot Classification Accuracy')
        plt.xlabel('Meta-Iterations')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'training_history.png'), dpi=300)
        plt.show()

    def plot_adaptation_curve(self, task, updates_range=[0, 1, 2, 3, 5, 10]):
        """Plot the adaptation curve for a specific task"""
        support_X = task['support_X']
        support_y = task['support_y']
        query_X = task['query_X']
        query_y = task['query_y']

        accuracies = []

        for updates in updates_range:
            # Adapt model to task with different number of gradient updates
            adapted_model = self.maml_model.adapt_to_task(support_X, support_y, num_inner_updates=updates)

            # Evaluate on query set
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()
            accuracy = np.mean(pred_y == query_y)
            accuracies.append(accuracy)

        plt.figure(figsize=(10, 6))
        plt.plot(updates_range, accuracies, 'o-', linewidth=2)
        plt.title('Adaptation Performance vs Gradient Steps')
        plt.xlabel('Number of Gradient Updates')
        plt.ylabel('Query Set Accuracy')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'adaptation_curve.png'), dpi=300)
        plt.show()

        return accuracies
