# Training your own model

This notebook will walk you through training your own model using [seq2rel](https://github.com/JohnGiorgi/seq2rel).

## 🔧 Install the prerequisites

In [None]:
# The colab environment comes with py3.7, but several dependencies require py>=3.8 (like NumPy).
# This can be removed if Colab ever updates python to >=3.8 in its environment.
# For the solution, see: https://stackoverflow.com/q/60775160/6578628
# For the issue tracking Colab's python update, see: https://github.com/googlecolab/colabtools/issues/1880
!wget -O mini.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_4.8.2-Linux-x86_64.sh
!chmod +x mini.sh
!bash ./mini.sh -b -f -p /usr/local
!pip install ipykernel

In [None]:
!pip install git+https://github.com/JohnGiorgi/seq2rel.git
!pip install git+https://github.com/JohnGiorgi/seq2rel-ds.git

## 📖 Preparing a dataset

Datasets are tab-separated files, where each example is contained on its own line. The first column contains the text, and the second column contains the relations. Relations themselves must be serialized to strings.

Take the following example, which expresses a _gene-disease association_ (`"@GDA@"`) between _ESR1_ (`"@GENE@"`) and _schizophrenia_ (`"@DISEASE@`")

```
Variants in the estrogen receptor alpha (ESR1) gene and its mRNA contribute to risk for schizophrenia. estrogen receptor alpha ; ESR1 @GENE@ schizophrenia @DISEASE@ @GDA@
```

For convenience, we provide a second package, [seq2rel-ds](https://github.com/JohnGiorgi/seq2rel-ds), which makes it easy to generate data in this format for various popular corpora. In this tutorial, we will preprocess and train on the [GDA corpus](https://www.researchgate.net/publication/332411712_RENET_A_Deep_Learning_Approach_for_Extracting_Gene-Disease_Associations_from_Literature).

> See [our paper](https://aclanthology.org/2022.bionlp-1.2/) for more details on serializing relations.

In [3]:
preprocessed_datadir = "gda"

# Here, we hold out only 1% of training data so that validation is quick.
# In the paper, we hold out 20%, which is the default value for --valid-size.
!seq2rel-ds gda main "$preprocessed_datadir" --valid-size 0.01

[1m
[2K[38;5;2m✔ Downloaded the corpus.[0m
[2K[38;5;2m✔ Preprocessed the training data.[0m
[2K[38;5;2m✔ Preprocessed the test data.[0m
[38;5;4mℹ Holding out 1.00% of the training data as a validation set.[0m
[38;5;2m✔ Preprocessed data saved to /content/gda.[0m


Lets confirm that our dataset looks as expected.

In [4]:
!ls "$preprocessed_datadir"  # This directory should contain three files, train.tsv, valid.tsv, and test.tsv

test.tsv  train.tsv  valid.tsv


In [5]:
!wc -l "$preprocessed_datadir/train.tsv"  # This file should contain 28899 lines

28899 gda/train.tsv


In [6]:
!head -n 1 "$preprocessed_datadir/train.tsv"  # This should be a single tab-seperated example

The fractalkine receptor CX3CR1 is involved in liver fibrosis due to chronic hepatitis C infection. BACKGROUND/AIMS: The chemokine receptor CX3CR1 and its specific ligand fractalkine (CX3CL1) are known to modulate inflammatory and fibroproliferative diseases. Here we investigate the role of CX3CR1/fractalkine in HCV-induced liver fibrosis. METHODS: A genotype analysis of CX3CR1 variants was performed in 211 HCV-infected patients. Hepatic expression of CX3CR1 was studied in HCV-infected livers and isolated liver cell populations by RT-PCR and immunohistochemistry. The effects of fractalkine on mRNA expression of profibrogenic genes were determined in isolated hepatic stellate cells (HSC) and CX3CR1 genotypes were related to intrahepatic TIMP-1 mRNA levels. RESULTS: The intrahepatic mRNA expression of CX3CR1 correlates with the stage of HCV-induced liver fibrosis (P=0.03). The CX3CR1 coding variant V249I is associated with advanced liver fibrosis, independent of the T280M variant (P=0.00

## 🏃 Training the model

Once you have collected the dataset, you can initiate a training session with the [`allennlp train`](https://docs.allennlp.org/main/api/commands/train/) command. An experiment is configured using a [Jsonnet](https://jsonnet.org/) config file. Lets take a look at the config used to train the model on the [GDA corpus](https://link.springer.com/chapter/10.1007/978-3-030-17083-7_17) used in [our paper](https://aclanthology.org/2022.bionlp-1.2/):

In [7]:
config_filepath = "gda.jsonnet"
!wget -nc https://raw.githubusercontent.com/JohnGiorgi/seq2rel/main/training_config/gda.jsonnet -O {config_filepath}
with open(config_filepath, "r") as f:
    print(f.read())

--2022-04-15 16:10:29--  https://raw.githubusercontent.com/JohnGiorgi/seq2rel/main/training_config/gda.jsonnet
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7940 (7.8K) [text/plain]
Saving to: ‘gda.jsonnet’


2022-04-15 16:10:29 (83.1 MB/s) - ‘gda.jsonnet’ saved [7940/7940]


// The pretrained model to use as encoder. This is a reasonable default for biomedical text.
// Should be a registered name in the Transformers library (see https://huggingface.co/models) 
// OR a path on disk to a serialized transformer model.
local model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext";

// These are reasonable defaults.
local max_length = 512;       // Max length of input text
local max_steps = 96;         // Max number of decoding steps

loca

The only additional information we need to provide is the path to the train and validation sets and the size of the dataset, which is needed to set the learning rate warmup ratio. We can do this with the enviornment variables `train_data_path`, `valid_data_path` and `dataset_size`

In [8]:
from pathlib import Path

train_data_path = Path(preprocessed_datadir) / "train.tsv"
valid_data_path = Path(preprocessed_datadir) / "valid.tsv"
# Get the number of examples in the train set
dataset_size = len(Path("gda/train.tsv").read_text().strip().split("\n"))

Because training the model on the entire dataset takes ~12hrs, we will train on a fraction of the dataset. We can modify the parameters in the config file via the `--overrides` argument (but you can also modify them in your config file directly, if you prefer):

> Note: This model will not converge and obtains an F1-score close to zero. To actually train a model to good performance in a reasonable amount of time, you will likely need more powerful hardware than Colab provides (we used a V100-32GB in our paper) and _at least_ a few hundred training examples.

In [11]:
import json

overrides = {
    # validate after the last epoch only
    "trainer.callbacks.0.validation_start": 29,
    # load only a fraction of the train set
    "dataset_reader.max_instances": 16
}
# Necessary to pass to the allennlp train command without errors
overrides = json.dumps(overrides).replace('"', "'")

In [12]:
# This should train and evaluate in ~3min.
!train_data_path="$train_data_path" \
valid_data_path="$valid_data_path" \
dataset_size="$dataset_size" \
allennlp train "$config_filepath" \
    --serialization-dir "output" \
    --overrides "$overrides" \
    --include-package "seq2rel" \
    -f

2022-04-15 16:10:54,398 - INFO - allennlp.common.plugins - Plugin allennlp_models available
2022-04-15 16:10:55,611 - INFO - allennlp.common.params - evaluation = None
2022-04-15 16:10:55,611 - INFO - allennlp.common.params - include_in_archive = None
2022-04-15 16:10:55,611 - INFO - allennlp.common.params - random_seed = 13370
2022-04-15 16:10:55,611 - INFO - allennlp.common.params - numpy_seed = 1337
2022-04-15 16:10:55,612 - INFO - allennlp.common.params - pytorch_seed = 133
2022-04-15 16:10:55,612 - INFO - allennlp.common.checks - Pytorch version: 1.11.0+cu102
2022-04-15 16:10:55,612 - INFO - allennlp.common.params - type = default
2022-04-15 16:10:55,613 - INFO - allennlp.common.params - dataset_reader.type = seq2rel
2022-04-15 16:10:55,613 - INFO - allennlp.common.params - dataset_reader.max_instances = 16
2022-04-15 16:10:55,613 - INFO - allennlp.common.params - dataset_reader.manual_distributed_sharding = False
2022-04-15 16:10:55,613 - INFO - allennlp.common.params - dataset_r

The best model checkpoint (measured by micro-F1 score on the validation set), vocabulary, configuration, and log files will be saved to `--serialization-dir`. This can be changed to any directory you like.


## ♻️ Conclusion

That's it! In this notebook, we covered how to collect data for training the model. We then briefly covered configuring and running a training session. Please see [our paper](https://aclanthology.org/2022.bionlp-1.2/) and [repo](https://github.com/JohnGiorgi/seq2rel) for more details, and don't hesitate to open an issue if you have any trouble!

