In [None]:
```json
{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# GEM-based Named Entity Recognition Demo\n",
        "\n",
        "This notebook demonstrates the usage of the GEM-based NER model for continual learning. We'll load sample data, train on multiple domains, evaluate performance, and visualize results."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "import sys\n",
        "import os\n",
        "sys.path.append(os.path.abspath('..'))  # Add project root to path\n",
        "\n",
        "from src.models.gem_ner import GEMNERAnalyzer, NERExample\n",
        "from src.utils.data_loader import load_ner_data\n",
        "from pathlib import Path\n",
        "import yaml\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load Configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "config_path = '../configs/gem_config.yaml'\n",
        "with open(config_path, 'r') as f:\n",
        "    config = yaml.safe_load(f)\n",
        "\n",
        "print('Configuration:', config)"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Initialize Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "model = GEMNERAnalyzer(\n",
        "    model_name=config['model_name'],\n",
        "    memory_size=config['memory_size'],\n",
        "    learning_rate=config['learning_rate'],\n",
        "    batch_size=config['batch_size'],\n",
        "    max_length=config['max_length'],\n",
        "    device=config['device'],\n",
        "    save_dir=config['save_dir'],\n",
        "    verbose=config['verbose']\n",
        ")"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load Sample Data\n",
        "\n",
        "We'll load sample NER data for the 'news' and 'medical' domains."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "data_dir = Path('../datasets/ner')\n",
        "domains = config['domains'][:2]  # Use first two domains for demo\n",
        "train_datasets = {}\n",
        "test_datasets = {}\n",
        "\n",
        "for task_id, domain in enumerate(domains):\n",
        "    train_file = data_dir / domain / 'train.json'\n",
        "    test_file = data_dir / domain / 'test.json'\n",
        "    if train_file.exists():\n",
        "        train_datasets[domain] = load_ner_data(str(train_file), task_id, domain)\n",
        "    if test_file.exists():\n",
        "        test_datasets[domain] = load_ner_data(str(test_file), task_id, domain)\n",
        "\n",
        "print('Loaded domains:', list(train_datasets.keys()))"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Train Model\n",
        "\n",
        "Train sequentially on each domain and save checkpoints."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "metrics_history = {}\n",
        "for task_id, domain in enumerate(domains):\n",
        "    print(f'Training on {domain} (Task {task_id})')\n",
        "    metrics = model.train_task(\n",
        "        train_data=train_datasets[domain],\n",
        "        task_id=task_id,\n",
        "        domain=domain,\n",
        "        epochs=config['epochs'],\n",
        "        validation_data=test_datasets.get(domain)\n",
        "    )\n",
        "    metrics_history[domain] = metrics\n",
        "    checkpoint_path = model.save_checkpoint(f'gem_ner_{domain}_task_{task_id}.pt')\n",
        "    print(f'Checkpoint saved: {checkpoint_path}')"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Evaluate Model\n",
        "\n",
        "Evaluate on all tasks to check for catastrophic forgetting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "results = model.evaluate_all_tasks(test_datasets)\n",
        "for task_id, metrics in results.items():\n",
        "    domain = model.task_id_to_domain.get(task_id, 'unknown')\n",
        "    print(f'Evaluation on {domain} (Task {task_id}):')\n",
        "    print(f'  Accuracy: {metrics.accuracy:.4f}')\n",
        "    print(f'  F1 Score: {metrics.f1_score:.4f}')\n",
        "    print(f'  Loss: {metrics.loss:.4f}')"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Analyze Forgetting\n",
        "\n",
        "Measure catastrophic forgetting across domains."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "forgetting_scores = model.analyze_forgetting(test_datasets)\n",
        "print('Forgetting Scores:')\n",
        "for domain, score in forgetting_scores.items():\n",
        "    print(f'  {domain}: {score:.4f}')"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Visualize Results\n",
        "\n",
        "Plot accuracy and F1 score for each domain over epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "plt.figure(figsize=(12, 5))\n",
        "\n",
        "# Plot accuracy\n",
        "plt.subplot(1, 2, 1)\n",
        "for domain, metrics in metrics_history.items():\n",
        "    accuracies = [m.accuracy for m in metrics]\n",
        "    epochs = [m.epoch + 1 for m in metrics]\n",
        "    plt.plot(epochs, accuracies, label=f'{domain} Accuracy')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.title('Training Accuracy per Domain')\n",
        "plt.legend()\n",
        "\n",
        "# Plot F1 score\n",
        "plt.subplot(1, 2, 2)\n",
        "for domain, metrics in metrics_history.items():\n",
        "    f1_scores = [m.f1_score for m in metrics]\n",
        "    epochs = [m.epoch + 1 for m in metrics]\n",
        "    plt.plot(epochs, f1_scores, label=f'{domain} F1 Score')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('F1 Score')\n",
        "plt.title('Training F1 Score per Domain')\n",
        "plt.legend()\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.savefig('../results/plots/gem_ner_training_metrics.png')\n",
        "plt.show()"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Predict on New Text\n",
        "\n",
        "Test the model on a sample text for each domain."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "sample_texts = {\n",
        "    'news': 'Apple Inc. announced a new product launch in New York.',\n",
        "    'medical': 'The patient was diagnosed with diabetes mellitus.'\n",
        "}\n",
        "\n",
        "for domain, text in sample_texts.items():\n",
        "    task_id = model.domain_to_task_id.get(domain, 0)\n",
        "    entities = model.predict_single(text, task_id, return_labels=True)\n",
        "    print(f'Predictions for {domain}:')\n",
        "    print(f'  Text: {text}')\n",
        "    print(f'  Entities: {entities}')"
      ],
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Summary\n",
        "\n",
        "Get a summary of performance across domains."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "source": [
        "summary = model.get_domain_summary()\n",
        "print('Domain Summary:')\n",
        "for domain, perf in summary['domain_performance'].items():\n",
        "    print(f'  {domain}:')\n",
        "    print(f'    Best Accuracy: {perf[\"best_accuracy\"]:.4f}')\n",
        "    print(f'    Best F1: {perf[\"best_f1\"]:.4f}')\n",
        "    print(f'    Task ID: {perf[\"task_id\"]}')\n",
        "    print(f'    Labels: {perf[\"labels\"]}')"
      ],
      "outputs": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}
```