Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8b8b9f1
commit 1d0c575
Showing
1 changed file
with
199 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,169 +1,200 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "VAMPIRE_AGNews.ipynb", | ||
"version": "0.3.2", | ||
"provenance": [], | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/allenai/vampire/blob/colab/colab/VAMPIRE_AGNews.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "wEFivVCCpEUn", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!git clone https://github.com/allenai/vampire\n" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "rMDq7NeCrHB3", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"%cd vampire" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "f52p11vJpJW8", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!pip install -r requirements.txt" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "JXC-bTeIptAF", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!python -m spacy download en" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "mTii7MynrbLn", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!SEED=42 python -m pytest -v --color=yes vampire" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "_pHcbntzqHuV", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!sh scripts/download_ag.sh" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "zSVpbMX1rQl7", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!curl -Lo ag.tar https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/vampire_preprocessed_example.tar\n", | ||
"!tar -xvf ag.tar -C examples/\n", | ||
"!rm ag.tar" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "d_oLBgaYtKsL", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!DATA_DIR=\"$(pwd)/examples/ag\" VOCAB_SIZE=30000 LAZY=1 python -m scripts.train \\\n", | ||
" --config training_config/vampire.jsonnet \\\n", | ||
" --serialization-dir model_logs/vampire \\\n", | ||
" --environment VAMPIRE \\\n", | ||
" --device 0 -o" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "D0T6rYRLtTU_", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"!DATA_DIR=\"$(pwd)/examples/ag\" VAMPIRE_DIR=\"$(pwd)/model_logs/vampire\" VAMPIRE_DIM=81 THROTTLE=200 EVALUATE_ON_TEST=0 python -m scripts.train \\\n", | ||
" --config training_config/classifier.jsonnet \\\n", | ||
" --serialization-dir model_logs/clf \\\n", | ||
" --environment CLASSIFIER \\\n", | ||
" --device 0" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "5H6K6A_S8fZs", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
} | ||
] | ||
} | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text", | ||
"id": "view-in-github" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/allenai/vampire/blob/colab/colab/VAMPIRE_AGNews.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### VAMPIRE Example: AG News Corpus" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In this notebook, we run through the example in the README. Since VAMPIRE is a low resource method, it can be run on the CPU or Colab GPU. Before starting, open this notebook in a Colab environment, and connect to a GPU instance." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Clone the repository and cd into working directory:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "wEFivVCCpEUn" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!git clone https://github.com/allenai/vampire\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "rMDq7NeCrHB3" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%cd vampire" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Install requirements:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "f52p11vJpJW8" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install -r requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "JXC-bTeIptAF" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!python -m spacy download en" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"All tests should pass:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "mTii7MynrbLn" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!SEED=42 python -m pytest -v --color=yes vampire" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Download preprocessed AG News corpus, ready to run with VAMPIRE:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "zSVpbMX1rQl7" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!curl -Lo ag.tar https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/vampire_preprocessed_example.tar\n", | ||
"!tar -xvf ag.tar -C examples/\n", | ||
"!rm ag.tar" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Run VAMPIRE:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "d_oLBgaYtKsL" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!DATA_DIR=\"$(pwd)/examples/ag\" VOCAB_SIZE=30000 LAZY=1 python -m scripts.train \\\n", | ||
" --config training_config/vampire.jsonnet \\\n", | ||
" --serialization-dir model_logs/vampire \\\n", | ||
" --environment VAMPIRE \\\n", | ||
" --device 0 -o" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"After VAMPIRE has trained, we can run a downstream classifier on the AG News corpus with just 200 examples. We report an average of 83.9% accuracy in the paper over five seeds under this setting:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab": {}, | ||
"colab_type": "code", | ||
"id": "D0T6rYRLtTU_" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!DATA_DIR=\"$(pwd)/examples/ag\" VAMPIRE_DIR=\"$(pwd)/model_logs/vampire\" VAMPIRE_DIM=81 THROTTLE=200 EVALUATE_ON_TEST=0 python -m scripts.train \\\n", | ||
" --config training_config/classifier.jsonnet \\\n", | ||
" --serialization-dir model_logs/clf \\\n", | ||
" --environment CLASSIFIER \\\n", | ||
" --device 0" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"include_colab_link": true, | ||
"name": "VAMPIRE_AGNews.ipynb", | ||
"provenance": [], | ||
"version": "0.3.2" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |