From a159f6fbed90f7de28ea4535869ed8f63ec056cb Mon Sep 17 00:00:00 2001 From: LeonardoEmili Date: Sun, 21 Nov 2021 17:18:24 +0100 Subject: [PATCH] Update `WSDModel`, add `MFS` baseline, add unit testing --- conf/data/default_data.yaml | 65 ++++++-- conf/logging/wandb_logging.yaml | 10 ++ conf/model/default_model.yaml | 24 +++ conf/root.yaml | 8 +- conf/test/default_test.yaml | 3 + conf/train/default_train.yaml | 22 ++- pyproject.toml | 4 + requirements.txt | 7 +- setup.sh | 6 +- src/colab/settings.json | 14 ++ src/colab/setup.sh | 18 +++ src/colab/setup_colab.ipynb | 102 +++++++++++++ src/dataset.py | 212 +++++++++++++++++++------ src/layers/word_encoder.py | 77 ++++++++++ src/models/mfs.py | 26 ++++ src/models/wsd_model.py | 34 +++++ src/pl_data_modules.py | 100 +++++++++--- src/pl_modules.py | 81 +++++++--- src/readers/omsti_reader.py | 87 ----------- src/readers/raganato_reader.py | 190 +++++++++++++++++++++++ src/readers/semcor_reader.py | 68 --------- src/readers/wordnet_reader.py | 64 ++++++++ src/scripts/get-wsd-data.sh | 41 +++++ src/scripts/github_downloader.sh | 42 +++++ src/scripts/wordnet_extractor.py | 110 +++++++++++++ src/test.py | 44 ++++++ src/train.py | 44 ++++-- src/utils/torch_utilities.py | 238 +++++++++++++++++++++++++++++ src/utils/utilities.py | 196 ++++++++++++++++++++++++ tests/unit/datamodule_test.py | 35 +++++ tests/unit/dataset_test.py | 36 +++++ tests/unit/model_test.py | 95 ++++++++++++ tests/unit/raganato_reader_test.py | 55 +++++++ tests/unit/root.yaml | 123 +++++++++++++++ tests/unit/test_case.py | 9 ++ 35 files changed, 1999 insertions(+), 291 deletions(-) create mode 100644 conf/logging/wandb_logging.yaml create mode 100644 conf/test/default_test.yaml create mode 100644 pyproject.toml mode change 100755 => 100644 setup.sh create mode 100644 src/colab/settings.json create mode 100644 src/colab/setup.sh create mode 100644 src/colab/setup_colab.ipynb create mode 100644 src/layers/word_encoder.py create mode 100644 src/models/mfs.py create mode 100644 src/models/wsd_model.py delete mode 100644 src/readers/omsti_reader.py create mode 100644 src/readers/raganato_reader.py delete mode 100644 src/readers/semcor_reader.py create mode 100644 src/readers/wordnet_reader.py create mode 100755 src/scripts/get-wsd-data.sh create mode 100755 src/scripts/github_downloader.sh create mode 100644 src/scripts/wordnet_extractor.py create mode 100644 src/test.py create mode 100644 src/utils/torch_utilities.py create mode 100644 src/utils/utilities.py create mode 100644 tests/unit/datamodule_test.py create mode 100644 tests/unit/dataset_test.py create mode 100644 tests/unit/model_test.py create mode 100644 tests/unit/raganato_reader_test.py create mode 100644 tests/unit/root.yaml create mode 100644 tests/unit/test_case.py diff --git a/conf/data/default_data.yaml b/conf/data/default_data.yaml index b637615..0f48d3d 100644 --- a/conf/data/default_data.yaml +++ b/conf/data/default_data.yaml @@ -1,12 +1,53 @@ -train_path: 'data/train.tsv' -validation_path: 'data/validation.tsv' -test_path: 'data/test.tsv' - -train_ds: 'semcor' -semcor_data_path: 'data/WSD_Training_Corpora/SemCor/semcor.data.xml' -semcor_key_path: 'data/WSD_Training_Corpora/SemCor/semcor.gold.key.txt' -semcor_omsti_data_path: 'data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.data.xml' -semcor_omsti_key_path: 'data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.gold.key.txt' - -batch_size: 16 -num_workers: 1 +train_path: "data/train.tsv" +validation_path: "data/validation.tsv" +test_path: "data/test.tsv" + +train_ds: "semcor" +val_ds: "semeval2007" +test_ds: "semeval2015" + +preprocessed_dir: "data/preprocessed/" +force_preprocessing: False +dump_preprocessed: True +use_synset_vocab: True + +wordnet: + glosses: "data/wordnet/means/glosses.json" + lemma_means: "data/wordnet/means/lemma_means.json" + lexeme_means: "data/wordnet/means/lexeme_means.json" + sense_means: "data/wordnet/means/sense_means.json" + +corpora: + semcor: + data_path: "data/WSD_Training_Corpora/SemCor/semcor.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor/semcor.gold.key.txt" + semcor+omsti: + data_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.gold.key.txt" + omsti: + data_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.gold.key.txt" + semeval_all: + data_path: "data/WSD_Unified_Evaluation_Datasets/ALL/ALL.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/ALL/ALL.gold.key.txt" + semeval2007: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2007/semeval2007.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt" + semeval2013: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2013/semeval2013.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2013/semeval2013.gold.key.txt" + semeval2015: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2015/semeval2015.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2015/semeval2015.gold.key.txt" + senseval2: + data_path: "data/WSD_Unified_Evaluation_Datasets/senseval2/senseval2.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/senseval2/senseval2.gold.key.txt" + senseval3: + data_path: "data/WSD_Unified_Evaluation_Datasets/senseval3/senseval3.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/senseval3/senseval3.gold.key.txt" + +batch_size: 32 +num_workers: 0 + +min_freq_senses: 1 +allow_multiple_senses: False diff --git a/conf/logging/wandb_logging.yaml b/conf/logging/wandb_logging.yaml new file mode 100644 index 0000000..39b7659 --- /dev/null +++ b/conf/logging/wandb_logging.yaml @@ -0,0 +1,10 @@ +log: False + +wandb_logger: + _target_: pytorch_lightning.loggers.WandbLogger + entity: LeonardoEmili + project: neural-wsd + +watch: + log: 'all' + log_freq: 100 diff --git a/conf/model/default_model.yaml b/conf/model/default_model.yaml index 205ed1a..26af091 100644 --- a/conf/model/default_model.yaml +++ b/conf/model/default_model.yaml @@ -1 +1,25 @@ tokenizer: 'bert-base-cased' +model_name: 'bert-base-cased' +learning_rate: 1e-3 +min_learning_rate: 1e-4 +language_model_learning_rate: 1e-5 +language_model_min_learning_rate: 1e-6 +language_model_weight_decay: 1e-4 +use_lemma_mask: False +use_lexeme_mask: False + +word_encoder: + _target_: src.layers.word_encoder.WordEncoder + fine_tune: False + word_dropout: 0.2 + model_name: ${model.model_name} + +sequence_encoder: lstm +lstm_encoder: + _target_: torch.nn.LSTM + input_size: 512 + hidden_size: 256 + bidirectional: True + batch_first: True + num_layers: 2 + dropout: 0.40 \ No newline at end of file diff --git a/conf/root.yaml b/conf/root.yaml index e82565f..bedb0bb 100644 --- a/conf/root.yaml +++ b/conf/root.yaml @@ -1,9 +1,15 @@ # Required to make the "experiments" dir the default one for the output of the models hydra: run: - dir: ./experiments/${train.model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + dir: ./experiments/${model.model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} + +# Debug mode +debug: False +max_samples: 1000 defaults: - train: default_train - model: default_model - data: default_data + - logging: wandb_logging + - test: default_test diff --git a/conf/test/default_test.yaml b/conf/test/default_test.yaml new file mode 100644 index 0000000..da7dd74 --- /dev/null +++ b/conf/test/default_test.yaml @@ -0,0 +1,3 @@ +checkpoint_path: +latest_checkpoint_path: experiments/bert-base-cased/2021-11-16/23-06-26/default_name/epoch=2-step=3485.ckpt +use_latest: false diff --git a/conf/train/default_train.yaml b/conf/train/default_train.yaml index db36533..9c6ec2c 100644 --- a/conf/train/default_train.yaml +++ b/conf/train/default_train.yaml @@ -1,35 +1,33 @@ # reproducibility seed: 42 -# model name -model_name: default_name # used to name the directory in which model's checkpoints will be stored (experiments/model_name/...) +# experiment name +experiment_name: default_name # pl_trainer pl_trainer: _target_: pytorch_lightning.Trainer gpus: 1 - accumulate_grad_batches: 4 - gradient_clip_val: 10.0 - val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps - max_steps: 100_000 - # uncomment the lines below for training with mixed precision + max_epochs: 20 + fast_dev_run: False # precision: 16 # amp_level: O2 + # early stopping callback # "early_stopping_callback: null" will disable early stopping early_stopping_callback: _target_: pytorch_lightning.callbacks.EarlyStopping - monitor: val_loss - mode: min + monitor: val_f1_micro + mode: max patience: 50 # model_checkpoint_callback # "model_checkpoint_callback: null" will disable model checkpointing model_checkpoint_callback: _target_: pytorch_lightning.callbacks.ModelCheckpoint - monitor: val_loss - mode: min + monitor: val_f1_micro + mode: max verbose: True save_top_k: 5 - dirpath: experiments/${train.model_name} + dirpath: ${train.experiment_name}/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3b67022 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.black] +line-length = 120 +target-version = ['py36', 'py37', 'py38'] +include = '\.pyi?$' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 39c269e..828aa2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,8 @@ pytorch-lightning==1.2.5 torch==1.8.1 nltk==3.4.5 hydra-core==1.1.0.dev5 -wandb==0.10.31 -transformers==4.9.1 \ No newline at end of file +wandb==0.12.6 +transformers==4.12.3 +torchtext==0.9.1 +black==21.9b0 +python-dotenv==0.19.1 diff --git a/setup.sh b/setup.sh old mode 100755 new mode 100644 index b0677e1..5d4be5c --- a/setup.sh +++ b/setup.sh @@ -5,16 +5,16 @@ source ~/miniconda3/etc/profile.d/conda.sh # create conda env read -rp "Enter environment name: " env_name -read -rp "Enter python version (e.g. 3.7) " python_version +read -rp "Enter python version (e.g. 3.9.7) " python_version conda create -yn "$env_name" python="$python_version" conda activate "$env_name" # install torch read -rp "Enter cuda version (e.g. '10.1' or 'none' to avoid installing cuda support): " cuda_version if [ "$cuda_version" == "none" ]; then - conda install -y pytorch torchvision cpuonly -c pytorch + conda install -y pytorch cpuonly -c pytorch else - conda install -y pytorch torchvision cudatoolkit=$cuda_version -c pytorch + conda install -y pytorch cudatoolkit=$cuda_version -c pytorch fi # install python requirements diff --git a/src/colab/settings.json b/src/colab/settings.json new file mode 100644 index 0000000..cbc7c8a --- /dev/null +++ b/src/colab/settings.json @@ -0,0 +1,14 @@ +{ + "python.defaultInterpreterPath": "/root/miniconda3/envs/neural-wsd/bin/python", + "python.formatting.provider": "black", + "python.formatting.blackArgs": [ + "--line-length", + "120" + ], + "files.exclude": { + "**/.classpath": true, + "**/.project": true, + "**/.settings": true, + "**/.factorypath": true + } +} \ No newline at end of file diff --git a/src/colab/setup.sh b/src/colab/setup.sh new file mode 100644 index 0000000..3633d45 --- /dev/null +++ b/src/colab/setup.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# Downloads miniconda +wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +sh Miniconda3-latest-Linux-x86_64.sh -b +export PATH="/root/miniconda3/bin:${PATH}" +conda init + +# Creates the environment +echo "Creating the environment" +source ~/miniconda3/etc/profile.d/conda.sh +conda create -qyn neural-wsd python=3.9.7 +conda activate neural-wsd +pip install -r /content/neural-wsd/requirements.txt + +# Configure vscode and overwrite default settings +code --install-extension ms-python.python +cp /content/neural-wsd/src/colab/settings.json /root/.vscode-server/data/Machine/settings.json \ No newline at end of file diff --git a/src/colab/setup_colab.ipynb b/src/colab/setup_colab.ipynb new file mode 100644 index 0000000..aa3c831 --- /dev/null +++ b/src/colab/setup_colab.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Google Colab + VSCode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yxChFURdSJfQ", + "outputId": "09dde9a4-528c-4557-e1b2-2f4a7d0a0578" + }, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 381 + }, + "id": "aqXACjLhFX1C", + "outputId": "ffce36c3-93e9-41b1-90f7-4211ffbfd122" + }, + "outputs": [], + "source": [ + "!pip install -q colab_ssh python-dotenv --upgrade\n", + "\n", + "copy_env_from_gdrive = False\n", + "if copy_env_from_gdrive:\n", + " from google.colab import drive\n", + "\n", + " drive.mount(\"/content/drive\")\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()\n", + "\n", + "from colab_ssh import launch_ssh_cloudflared, init_git_cloudflared\n", + "\n", + "launch_ssh_cloudflared(password=os.getenv(\"CLOUDFLARED_PASSWORD\"))\n", + "\n", + "init_git_cloudflared(\n", + " repository_url=os.getenv(\"GITHUB_REPO_URL\"),\n", + " personal_token=os.getenv(\"GITHUB_PERSONAL_ACCESS_TOKEN\"),\n", + " branch=os.getenv(\"GITHUB_BRANCH\"),\n", + " email=os.getenv(\"GITHUB_EMAIL\"),\n", + " username=os.getenv(\"GITHUB_USERNAME\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-tTM1EMGJH0z", + "outputId": "8839b443-f0a6-4cc8-fea9-180b0b34002c" + }, + "outputs": [], + "source": [ + "# Install dependecies and configure bash\n", + "%%bash\n", + "source neural-wsd/src/colab/setup.sh\n", + "echo \"cd /content/neural-wsd/\" >> ~/.bashrc**\n", + "echo \"source ~/miniconda3/etc/profile.d/conda.sh\" >> ~/.bashrc**\n", + "echo \"conda activate neural-wsd\" >> ~/.bashrc**" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "setup_colab.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/src/dataset.py b/src/dataset.py index cb6ce29..6407763 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,17 +1,35 @@ +from collections import Counter, defaultdict +from operator import itemgetter from typing import * -from torch.utils.data import Dataset +import logging +import json +import os + +import hydra +import torch from omegaconf import DictConfig +from torch.utils.data import Dataset +from torchtext.vocab import Vocab +from tqdm import tqdm from transformers import AutoTokenizer -from src.readers.semcor_reader import SemCorReader + +from src.readers.wordnet_reader import WordNetReader +from src.readers.raganato_reader import * +from src.utils.utilities import * + +# A logger for this file +logger = logging.getLogger(__name__) class SenseAnnotatedDataset(Dataset): def __init__( self, conf: DictConfig, - tokenizer: AutoTokenizer, + tokenizer: Optional[AutoTokenizer] = None, + senses_vocab: Optional[Vocab] = None, name: str = "semcor", cached: Optional[Any] = None, + split: str = "train", ) -> None: """ Args: @@ -20,59 +38,159 @@ def __init__( name: the name of dataset cached: optional, initialize a dataset from pre-computed data """ - assert name in ["semcor", "semcor+omsti"] - self.conf = conf + self.conf: DictConfig = conf self.name = name - self.tokenizer = tokenizer - if name == "semcor": - self.data = SemCorReader.read( - data_path=conf.data.semcor_data_path, - key_path=conf.data.semcor_key_path, - cached=cached, - ) + self.split = split + self.tokenizer: AutoTokenizer = tokenizer + + self.load_dataset(cached) + self.tokenize() + + self.sense_vocabulary = senses_vocab + self.create_senses_means() - def dump_data(self, path: str) -> None: - if path.endswith(".pth"): - torch.save(self.data, output_path) + @property + def sense_vocabulary(self) -> Vocab: + return self._sense_vocabulary + + @sense_vocabulary.setter + def sense_vocabulary(self, vocabulary: Vocab): + self._sense_vocabulary = vocabulary or self.compute_sense_vocabulary() + + def __len__(self) -> int: + return len(self.preprocessed_data) + + def __getitem__(self, idx: int): + return self.preprocessed_data[idx] + + def load_dataset(self, cached: Optional[Any] = None) -> None: + kwargs = {"conf": self.conf, "cached": cached, "split": self.split, "merge_with_semcor": "semcor" in self.name} + if self.name == "semcor": + self.data = SemCorReader.read(**kwargs, **self.conf.data.corpora[self.name]) + elif "omsti" in self.name: + self.data = OMSTIReader.read(**kwargs, **self.conf.data.corpora[self.name]) + elif "semeval" in self.name or "senseval" in self.name: + self.data = SemEvalReader.read(**kwargs, **self.conf.data.corpora["semeval_all"], filter_ds=self.name) else: - result = {entry["sentence_id"]: entry for entry in self.data} - with open(path, "w") as writer: - json.dump(result, writer, indent=4, sort_keys=True) + raise ValueError(f"{self.name} Dataset not supported (e.g. try with semcor, semeval2007, ...)") + + @property + def features(self) -> List[str]: + return self.preprocessed_data[0].keys() + + def compute_sense_vocabulary(self) -> Vocab: + counter = Counter([instance_sense for sample in self.data for instance_sense in sample["instance_senses"]]) + return Vocab(counter, specials=["", ""]) - def load_data(self, path: str) -> Dict: + def lookup_indices(self, tokens: List[str]) -> List[int]: + return vocab_lookup_indices(vocabulary=self.sense_vocabulary, tokens=tokens) + + def tokenize(self) -> None: + """Preprocesses the entire dataset.""" + self.preprocessed_data = [] + for sample in tqdm(self.data, desc=f"Tokenizing {self.name}", total=len(self.data)): + sentence = sample["sentence"] + lemmas, pos_tags = [], [] + if len(sample["instance_lexemes"]) > 0: + lemmas, pos_tags = zip(*[lexeme.split("#") for lexeme in sample["instance_lexemes"]]) + sentence_tokenized, indexes = sentence_tokenizer(sentence, self.tokenizer) + self.preprocessed_data.append( + { + **sentence_tokenized, + "lengths": len(sentence), + "word_pieces_indexes": indexes, + "sentence_id": sample["sentence_id"], + "senses_indices": sample["instance_indices"], + "lexemes": sample["instance_lexemes"], + "lemmas": lemmas, + } + ) + assert len(self.data) == len(self.preprocessed_data), "Size mismatch in dataset.tokenize" + + def preprocess_labels(self) -> None: + """Apply vectorized labels using the correct output vocabulary.""" + msg = f"Creating labels for {self.name}" + # Show progress bar only when debugging (it's usually really quick) + for i, sample in tqdm(enumerate(self.data), desc=msg, total=len(self), disable=not self.conf.debug): + instance_senses = ( + [WordNetReader.sense_means()[sense] for sense in sample["instance_senses"]] + if self.conf.data.use_synset_vocab + else sample["instance_senses"] + ) + self.preprocessed_data[i]["senses"] = self.lookup_indices(instance_senses) + + @staticmethod + def load_data(path: str) -> Dict: with open(path, "r") as reader: return json.load(reader) - @classmethod - def from_cached(cls, conf: DictConfig, path): - *path, name, ext = re.split("/|\.", path) - if path.endswith(".pth"): - data: List = torch.load(path) - return cls(conf=conf, name=name, cached=data) + @staticmethod + def from_cached( + conf: DictConfig, + tokenizer: AutoTokenizer, + name: str = "semcor", + split: str = "train", + train_vocab: Optional[Vocab] = None, + ): + """Fetches dataset and vocabulary from file, if not available creates them.""" + base_path = os.path.join(hydra.utils.to_absolute_path("."), conf.data.preprocessed_dir, name) + ds_path = os.path.join(base_path, "dataset.pth") + vocab_path = os.path.join(base_path, "vocab.pth") + + # Loads pre-tokenized dataset and senses vocab + vocab = torch.load(vocab_path) if os.path.exists(vocab_path) and not conf.data.force_preprocessing else None + dataset = ( + torch.load(ds_path) + if os.path.exists(ds_path) and not conf.data.force_preprocessing + else SenseAnnotatedDataset(conf, name=name, tokenizer=tokenizer, split=split, senses_vocab=vocab) + ) + if vocab is None and (os.path.exists(ds_path) and not conf.data.force_preprocessing): + # The dataset (hence its vocabulary) is retrieved from file, but the vocab is not + logger.warning(f"Cannot load vocabulary from {vocab_path}, computing a new one for {split} split.") - assert not path.endswith(".json"), "Extension not supported" - data: dict = load_data() - return cls(conf=conf, name=name, cached=list(data.values())) + # Persists objects to files + if conf.data.dump_preprocessed: + os.makedirs(base_path, exist_ok=True) + torch.save(dataset, ds_path) + torch.save(dataset.sense_vocabulary, vocab_path) - def __len__(self) -> int: - return len(self.data) + # We can only use training vocabulary at inference time + dataset.sense_vocabulary = train_vocab - def indices_word_pieces(self, sentence: List[str]) -> List[int]: - indices = [] - for idx_word, word in enumerate(sentence): - word_tokenized = self.tokenizer.tokenize(word) - for _ in range(len(word_tokenized)): - indices.append(idx_word) - return indices + # Vectorize labels using the given training sense vocabulary + dataset.preprocess_labels() - def bert_tokenizer(self, sentence: List[str]): - sentence_tokenized = self.tokenizer(" ".join(sentence), return_tensors="pt") - indexes: List[int] = self.indices_word_pieces(sentence) - return sentence_tokenized, indexes + return dataset - def __getitem__(self, idx: int): - sentence = self.data[idx]["sentence"] - sentence_tokenized, indexes = self.bert_tokenizer(sentence) - self.data[idx]["bert_sentence"] = sentence_tokenized - self.data[idx]["word_pieces_indexes"] = indexes - return self.data[idx] + @property + def lemma_senses_means(self) -> defaultdict: + return defaultdict(list, self._lemma_senses_means) + + @property + def lexeme_senses_means(self) -> defaultdict: + return defaultdict(int, self._lexeme_senses_means) + + @property + def mfs_lexeme_sense_means(self) -> defaultdict: + means = {k: Counter(v).most_common(1)[0][0] for k, v in self._lexeme_senses_means.items()} + return defaultdict(int, means) + + def create_senses_means(self, keys: tuple[str] = ("instance_lexemes", "instance_senses")) -> None: + """Computes the mapping lexeme->sense.""" + self._lexeme_senses_means = dict() + for sample in self.data: + for lexeme, sense in zip(*itemgetter(*keys)(sample)): + if lexeme not in self._lexeme_senses_means: + self._lexeme_senses_means[lexeme] = [self.sense_vocabulary[sense]] + else: + self._lexeme_senses_means[lexeme].append(self.sense_vocabulary[sense]) + + self._lemma_senses_means = dict() + for lexeme, senses in self._lexeme_senses_means.items(): + lemma, pos_tag = lexeme.split("#") + if lemma not in self._lemma_senses_means: + self._lemma_senses_means[lemma] = set() + self._lemma_senses_means[lemma].update(senses) + + # map lemmas to candidate senses (i.e. useful for indexing tensors) + self._lemma_senses_means = {k: list(v) for k, v in self._lemma_senses_means.items()} diff --git a/src/layers/word_encoder.py b/src/layers/word_encoder.py new file mode 100644 index 0000000..38919de --- /dev/null +++ b/src/layers/word_encoder.py @@ -0,0 +1,77 @@ +from transformers import AutoModel, AutoConfig, logging +import torch.nn as nn +import torch + +from src.utils.torch_utilities import scatter_mean + + +class WordEncoder(nn.Module): + def __init__( + self, + word_dropout: float = 0.1, + word_projection_size: int = 512, + fine_tune: bool = False, + model_name: str = "bert-base-cased", + ): + super(WordEncoder, self).__init__() + + self.word_embedding = BertEmbedding(model_name=model_name, fine_tune=fine_tune) + if "base" in model_name: + word_embedding_size = 4 * 768 + else: + word_embedding_size = 4 * 1024 + + self.batch_normalization = nn.BatchNorm1d(word_embedding_size) + self.output = nn.Linear(word_embedding_size, word_projection_size) + self.word_dropout = nn.Dropout(word_dropout) + + # output size + self.word_embedding_size = word_projection_size + + def forward(self, word_ids, subword_indices=None, sequence_lengths=None): + word_embeddings = self.word_embedding(word_ids, sequence_lengths=sequence_lengths) + + # permute twice since batchnorm expects the temporal index on the last axis + word_embeddings = word_embeddings.transpose(1, 2) + word_embeddings = self.batch_normalization(word_embeddings) + word_embeddings = word_embeddings.transpose(1, 2) + + word_embeddings = self.output(word_embeddings) + word_embeddings = torch.sigmoid(word_embeddings) + word_embeddings = self.word_dropout(word_embeddings) + + # get word-level embeddings + word_embeddings = scatter_mean(word_embeddings, subword_indices, dim=1) + + return word_embeddings + + +class BertEmbedding(nn.Module): + """Wrapper of transformer's AutoModel class representing BERT word embedder.""" + + def __init__(self, model_name="bert-base-cased", fine_tune=False): + super(BertEmbedding, self).__init__() + self.fine_tune = fine_tune + config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) + logging.set_verbosity_error() + self.bert = AutoModel.from_pretrained(model_name, config=config) + logging.set_verbosity_warning() + if not fine_tune: + self.bert.eval() + + def forward(self, word_ids, sequence_lengths=None): + timesteps = word_ids.shape[1] + device = "cuda" if word_ids.get_device() == 0 else "cpu" + # mask to avoid performing attention on padding token indices + attention_mask = torch.arange(timesteps, device=device).unsqueeze(0) < sequence_lengths.unsqueeze(1) + + if not self.fine_tune: + with torch.no_grad(): + # freeze bert's weights + word_embeddings = self.bert(input_ids=word_ids, attention_mask=attention_mask) + else: + word_embeddings = self.bert(input_ids=word_ids, attention_mask=attention_mask) + + # concatenate the last four layers of BERT + word_embeddings = torch.cat(word_embeddings[2][-4:], dim=-1) + return word_embeddings diff --git a/src/models/mfs.py b/src/models/mfs.py new file mode 100644 index 0000000..98d5891 --- /dev/null +++ b/src/models/mfs.py @@ -0,0 +1,26 @@ +from collections import defaultdict + +from pytorch_lightning.metrics.functional import f1 as f1_score +from torchtext.vocab import Vocab +import pytorch_lightning as pl +import torch + +from src.dataset import SenseAnnotatedDataset + + +class MFS(pl.LightningModule): + """Implementation of the simple Most Frequent Sense (MFS) baseline""" + + def __init__(self, n_classes: int, mfs_means: defaultdict) -> None: + super().__init__() + self.n_classes = n_classes + self.mfs_means = mfs_means + + def forward(self, x: dict) -> torch.Tensor: + return torch.tensor([self.mfs_means[l] for lexemes in x["lexemes"] for l in lexemes], device=self.device) + + def test_step(self, x: dict, batch_idx: int) -> dict: + annotation = self.forward(x) + labels = x["senses"][x["senses"] != 0] + f1_micro = f1_score(annotation, labels, num_classes=self.n_classes, average="micro") + return {"test_f1_micro": f1_micro} diff --git a/src/models/wsd_model.py b/src/models/wsd_model.py new file mode 100644 index 0000000..4a243a8 --- /dev/null +++ b/src/models/wsd_model.py @@ -0,0 +1,34 @@ +from hydra.utils import instantiate +from omegaconf import DictConfig +import torch.nn as nn +import torch + + +class WSDModel(nn.Module): + def __init__(self, conf: DictConfig, n_classes: int): + super().__init__() + self.conf = conf + self.n_classes = n_classes + self.word_encoder = instantiate(conf.model.word_encoder) + + if conf.model.sequence_encoder == "lstm": + self.sequence_encoder = instantiate(conf.model.lstm_encoder) + self.hidden_size = conf.model.lstm_encoder.hidden_size * conf.model.lstm_encoder.num_layers + else: + self.sequence_encoder = IdentityLayer() + self.hidden_size = self.word_encoder.word_embedding_size + + self.output_layer = torch.nn.Linear(self.hidden_size, self.n_classes) + + def forward(self, x: dict) -> torch.Tensor: + word_ids, subword_indices, lengths = x["input_ids"], x["word_pieces_indexes"], x["lengths"] + result = self.word_encoder(word_ids, subword_indices=subword_indices, sequence_lengths=lengths) + sequence_out, _ = self.sequence_encoder(result) # batch, seq_len, hidden state + return self.output_layer(sequence_out) + + +class IdentityLayer(nn.Module): + """Syntactic sugar to simplify the sequence encoder""" + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x, None diff --git a/src/pl_data_modules.py b/src/pl_data_modules.py index 3f6653d..bcfbf7c 100644 --- a/src/pl_data_modules.py +++ b/src/pl_data_modules.py @@ -1,14 +1,19 @@ -from typing import Any, Union, List, Optional +import os +import subprocess +from collections import defaultdict +from functools import partial +from typing import * +import hydra +import pytorch_lightning as pl from omegaconf import DictConfig - -import torch from torch.utils.data import DataLoader -import pytorch_lightning as pl +from torchtext.vocab import Vocab +from transformers import AutoTokenizer +from src.readers.wordnet_reader import WordNetReader from src.dataset import SenseAnnotatedDataset - -from transformers import AutoTokenizer +from src.utils.utilities import * class BasePLDataModule(pl.LightningDataModule): @@ -56,18 +61,73 @@ def test_dataloader(self): def __init__(self, conf: DictConfig): super().__init__() self.conf = conf - - def prepare_data(self, *args, **kwargs): - print(os.getcwd()) - # os.system("bash download_dataset.sh") - - def setup(self, stage: Optional[str] = None): - self.tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) - self.train_dataset = SenseAnnotatedDataset( - self.conf, name="semcor", tokenizer=self.tokenizer + self.prepare_data() + self.setup() + + def prepare_data(self, *args, **kwargs) -> None: + base_path = hydra.utils.to_absolute_path(".") + if not os.path.exists(os.path.join(base_path + "/data/", "WSD_Training_Corpora/")): + subprocess.run(f"bash src/scripts/get-wsd-data.sh", shell=True, check=True, cwd=base_path) + + def setup(self, stage: Optional[str] = None) -> None: + self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + self.train_dataset = SenseAnnotatedDataset.from_cached( + self.conf, + tokenizer=self.tokenizer, + name=self.conf.data.train_ds, + split="train", + train_vocab=WordNetReader.vocabulary(self.conf) if self.conf.data.use_synset_vocab else None, + ) + self.valid_dataset = SenseAnnotatedDataset.from_cached( + self.conf, + tokenizer=self.tokenizer, + name=self.conf.data.val_ds, + split="validation", + train_vocab=self.train_dataset.sense_vocabulary, ) - # self.valid_dataset = SenseAnnotatedDataset(self.conf, name='semeval2007', tokenizer=self.tokenizer) - # self.test_dataset = SenseAnnotatedDataset(self.conf, name='semevalALL', tokenizer=self.tokenizer) + self.test_dataset = SenseAnnotatedDataset.from_cached( + self.conf, + tokenizer=self.tokenizer, + name=self.conf.data.test_ds, + split="test", + train_vocab=self.train_dataset.sense_vocabulary, + ) + + @property + def train_features(self) -> Tuple[str]: + return self.train_dataset.features + + @property + def sense_vocabulary(self) -> Vocab: + """Returns the output vocabulary to encode labels (i.e. training or WordNet).""" + return self.train_dataset.sense_vocabulary + + @property + def mfs_lexeme_means(self) -> defaultdict: + if self.conf.data.use_synset_vocab: + return WordNetReader.mfs_lexeme_means() + return self.train_dataset.mfs_lexeme_sense_means + + @property + def lexeme_means(self) -> defaultdict: + if self.conf.data.use_synset_vocab: + return WordNetReader.lexeme_means() + return self.train_dataset.lexeme_senses_means + + @property + def lemma_means(self) -> defaultdict: + if self.conf.data.use_synset_vocab: + return WordNetReader.lemma_means() + return self.train_dataset.lemma_senses_means + + @property + def collate_kwargs(self) -> dict[str, any]: + return { + "batch_keys": self.train_features, + "lemma_means": self.lemma_means if self.conf.model.use_lemma_mask else None, + "lexeme_means": self.lexeme_means if self.conf.model.use_lexeme_mask else None, + "output_dim": len(self.sense_vocabulary), + } def train_dataloader(self, *args, **kwargs) -> DataLoader: return DataLoader( @@ -75,6 +135,7 @@ def train_dataloader(self, *args, **kwargs) -> DataLoader: num_workers=self.conf.data.num_workers, batch_size=self.conf.data.batch_size, shuffle=True, + collate_fn=partial(collate_fn, **self.collate_kwargs), ) def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: @@ -83,6 +144,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] num_workers=self.conf.data.num_workers, batch_size=self.conf.data.batch_size, shuffle=False, + collate_fn=partial(collate_fn, **self.collate_kwargs), ) def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: @@ -91,7 +153,5 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader] num_workers=self.conf.data.num_workers, batch_size=self.conf.data.batch_size, shuffle=False, + collate_fn=partial(collate_fn, **self.collate_kwargs), ) - - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: - raise NotImplementedError diff --git a/src/pl_modules.py b/src/pl_modules.py index 2306ae0..82507ac 100644 --- a/src/pl_modules.py +++ b/src/pl_modules.py @@ -1,38 +1,67 @@ -from typing import Any +from collections import defaultdict +from typing import Optional +from pytorch_lightning.metrics.functional import f1 as f1_score +from omegaconf import DictConfig +from torch import nn +import torch.nn.functional as F import pytorch_lightning as pl import torch +from src.layers.word_encoder import WordEncoder +from src.utils.torch_utilities import RAdam +from src.models.wsd_model import WSDModel + class BasePLModule(pl.LightningModule): - def __init__(self, conf, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.save_hyperparameters(conf) + def __init__( + self, + conf: DictConfig, + n_classes: int, + *args, + **kwargs, + ) -> None: + super().__init__() + self.conf = conf + self.n_classes = n_classes + self.loss_function = nn.CrossEntropyLoss() + self.save_hyperparameters({**dict(conf), "n_classes": n_classes}) + self.model = WSDModel(conf, n_classes=n_classes) - def forward(self, **kwargs) -> dict: - """ - Method for the forward pass. - 'training_step', 'validation_step' and 'test_step' should call - this method in order to compute the output predictions and the loss. + def _evaluate(self, x: dict[str, torch.Tensor], logits_: torch.Tensor, labels: torch.Tensor): + mask = labels != 0 + logits, labels = logits_[mask], labels[mask] + loss = F.cross_entropy(logits, labels) - Returns: - output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. + if "sense_mask" in x: + logits_.masked_fill_(~x["sense_mask"], float("-inf")) - """ - output_dict = {} - return output_dict + annotation = torch.argmax(logits_[mask], dim=-1) + f1_micro = f1_score(annotation, labels, num_classes=self.n_classes, average="micro") + return f1_micro, loss - def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: - forward_output = self.forward(**batch) - self.log("loss", forward_output["loss"]) - return forward_output["loss"] + def _shared_step(self, x: dict[str, torch.Tensor]): + logits = self.model(x) + f1_micro, loss = self._evaluate(x, logits, labels=x["senses"]) + return f1_micro, loss - def validation_step(self, batch: dict, batch_idx: int) -> None: - forward_output = self.forward(**batch) - self.log("val_loss", forward_output["loss"]) + def training_step(self, x: dict, batch_idx: int) -> dict: + f1_micro, loss = self._shared_step(x) + metrics = {"f1_micro": f1_micro, "loss": loss} + self.log_dict(metrics, on_step=False, on_epoch=True) + return metrics - def test_step(self, batch: dict, batch_idx: int) -> Any: - raise NotImplementedError + def validation_step(self, x: dict, batch_idx: int) -> dict: + f1_micro, loss = self._shared_step(x) + metrics = {"val_f1_micro": f1_micro, "val_loss": loss} + self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) + return metrics + + def test_step(self, x: dict, batch_idx: int) -> dict: + f1_micro, loss = self._shared_step(x) + metrics = {"test_f1_micro": f1_micro, "test_loss": loss} + self.log_dict(metrics, on_step=False, on_epoch=True) + return metrics def configure_optimizers(self): """ @@ -51,5 +80,9 @@ def configure_optimizers(self): key whose value is a single LR scheduler or lr_dict. - Tuple of dictionaries as described, with an optional 'frequency' key. - None - Fit will run without any optimizer. + loss avg 3.650 - f1 avg 00255 """ - raise NotImplementedError + + # return RAdam(self.parameters()) + optimizer = torch.optim.Adam(self.parameters(), lr=self.conf.model.learning_rate) + return optimizer diff --git a/src/readers/omsti_reader.py b/src/readers/omsti_reader.py deleted file mode 100644 index 77c63e0..0000000 --- a/src/readers/omsti_reader.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import * -import xml.etree.ElementTree as et -from omegaconf import DictConfig -from tqdm import tqdm -import hydra - - -class SemCorReader(object): - data = [] - inst2wn = dict() - - @classmethod - def read(cls, data_path: str, key_path: str, cached: Optional[Any] = None): - if cached and isinstance(cached, list): - cls.data = cached - - if not len(cls.data): - cls._load_key(path=key_path) - cls._load_xml(path=data_path) - return cls.data - - @classmethod - def _preprocess_omsti(cls, path: str): - with open(path, "r") as f: - omsti_corpus = f.readlines() - - xml_header = omsti_corpus[0] - idxs = [j for j, line in enumerate(omsti_corpus) if line.startswith(" bool: + """Validates the cache data.""" + return cached and isinstance(cached, list) + + @classmethod + def read( + cls, + data_path: str, + key_path: str, + conf: DictConfig, + cached: Optional[Any] = None, + split: str = "train", + **kwargs, + ) -> List[Dict[str, Any]]: + cls.data = [] + cls.inst2wn = dict() + + if RaganatoReader.is_valid_cache(cached): + cls.data = cached + return cls.data + + cls._load_key(path=key_path, conf=conf) + cls._load_xml(path=data_path, conf=conf, split=split, **kwargs) + return cls.data + + @classmethod + def _load_xml( + cls, + path: str, + conf: DictConfig, + pattern: str = "./text/sentence", + split: str = "train", + **kwargs, + ) -> List[Dict[str, Any]]: + """Parse the XML file and extract it into a data-driven format.""" + path = hydra.utils.to_absolute_path(path) + filter_ds = kwargs.get("filter_ds", None) + if filter_ds: + filter_ds = filter_ds.split("+") + root = et.parse(path).getroot() + *_, name = path.rsplit("/", 1) + iterator = root.findall(pattern) + total = len(iterator) if not conf.debug else min(len(iterator), conf.max_samples) + for id_sentence, sentence in enumerate(tqdm(iterator, desc=f"Loading {split} dataset {name}", total=total)): + corpus_id, _ = sentence.attrib["id"].split(".", 1) + if conf.debug and len(cls.data) >= conf.max_samples: + break + if filter_ds and corpus_id not in filter_ds: + # Filter out sentences from subcorpus (i.e. SemEval-ALL) + continue + words = [] + instance_idxs = [] + instance_ids = [] + instance_senses = [] + instance_lemma_pos = [] + for j, word in enumerate(sentence): + words.append(word.text) + if word.tag == "instance": + # Store id, lemma, POS tags for instances + _id, _lemma, _pos = ( + word.attrib["id"], + word.attrib["lemma"], + word.attrib["pos"], + ) + instance_idxs.append(j) + instance_ids.append(_id) + instance_senses.append(cls.inst2wn[_id]) + instance_lemma_pos.append(_lemma + "#" + _pos) + + cls.data.append( + { + "sentence_id": id_sentence, + "sentence": words, + "instance_indices": instance_idxs, # indexes of senses + "instance_ids": instance_ids, # "d000.s000.t000" + "instance_senses": instance_senses, + "instance_lexemes": instance_lemma_pos, + } + ) + + @classmethod + def _load_key(cls, path: str, conf: DictConfig): + """Extract the mapping to gold labels (i.e. WordNet ids).""" + path = hydra.utils.to_absolute_path(path) + with open(path, "r") as lines: + for line in lines: + instance_id, *wn_ids = line.rstrip().split() + # an instance may have more than one gold annotation + cls.inst2wn[instance_id] = wn_ids if conf.data.allow_multiple_senses else wn_ids[0] + + +class SemCorReader(RaganatoReader): + """RaganatoReader already implements all SemCorReader functionalities.""" + + pass + + +class SemEvalReader(RaganatoReader): + """RaganatoReader allows parsing the SemEval/Senseval datasets.""" + + @classmethod + def read( + cls, + data_path: str, + key_path: str, + conf: DictConfig, + cached: Optional[Any] = None, + split: str = "train", + **kwargs, + ) -> List[Dict[str, Any]]: + """SemEvalReader read method requires a value for parameter [filter_ds].""" + filter_ds = kwargs.get("filter_ds", None) + assert filter_ds is not None, "Called SemEvalReader.read without passing a value for [filter_ds]" + if "+" not in filter_ds: + assert ( + filter_ds in conf.data.corpora + ), f"{name} Dataset not supported (e.g. try with semcor, semeval2007, ...)" + # Override data_path and key_path and filter the semeval_all dataset + data_path = conf.data.corpora[filter_ds]["data_path"] + key_path = conf.data.corpora[filter_ds]["key_path"] + # No need to filter when using individual SemEval/Senseval datasets + kwargs["filter_ds"] = None + + return super(SemEvalReader, cls).read(data_path, key_path, conf=conf, cached=cached, split=split, **kwargs) + + +class OMSTIReader(RaganatoReader): + """OMSTIReader allows parsing the OMSTI dataset w/o performing the merge with SemCor.""" + + @classmethod + def read( + cls, + data_path: str, + key_path: str, + conf: DictConfig, + cached: Optional[Any] = None, + split: str = "train", + **kwargs, + ) -> List[Dict[str, Any]]: + if RaganatoReader.is_valid_cache(cached): + return super(OMSTIReader, cls).read(data_path, key_path, conf, cached, split, **kwargs) + + path_dict: Dict = OMSTIReader.preprocess_omsti(path=data_path) + omsti_data = super(OMSTIReader, cls).read(path_dict["omsti"], key_path, split=split) + if kwargs.get("merge_with_semcor", False): + # Test with SemCor + OMSTI + semcor_data = super(OMSTIReader, cls).read(path_dict["semcor"], key_path, split=split) + for k, sample in enumerate(semcor_data): + sample["sentence_id"] = len(omsti_data) + k + omsti_data += semcor_data + return omsti_data + + @staticmethod + def preprocess_omsti(path: str) -> Dict[str, str]: + """Splits the SemCor+OMSTI datasets into separate XML files.""" + path = hydra.utils.to_absolute_path(path) + with open(path, "r") as f: + omsti_corpus = f.readlines() + + # Split the original file looking at lines with tag + xml_header = omsti_corpus[0] + idxs = [j for j, line in enumerate(omsti_corpus) if line.startswith(" 'omsti' + # Extract corpus name from tag + sources = [re.findall(r"source=\"(.+?)\"", omsti_corpus[i])[0] for i in idxs] + path, ext = path.rsplit(".", 1) + output_paths = {sources2path[src]: f"{path}_{sources2path[src]}.{ext}" for src in sources} + # Python ranges expect [i, j), adding the last index to consider the full doc + if len(omsti_corpus) not in idxs: + idxs.append(len(omsti_corpus)) + + corpora = [xml_header + "".join(omsti_corpus[i:j]) for i, j in zip(idxs, idxs[1:])] + for output_path, corpus in zip(output_paths.values(), corpora): + with open(output_path, "w") as f: + f.write(corpus) + + return output_paths diff --git a/src/readers/semcor_reader.py b/src/readers/semcor_reader.py deleted file mode 100644 index 45eb20d..0000000 --- a/src/readers/semcor_reader.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import * -import xml.etree.ElementTree as et -from omegaconf import DictConfig -from tqdm import tqdm -import hydra - - -class SemCorReader(object): - data = [] - inst2wn = dict() - - @classmethod - def read(cls, data_path: str, key_path: str, cached: Optional[Any] = None): - if cached and isinstance(cached, list): - cls.data = cached - - if not len(cls.data): - cls._load_key(path=key_path) - cls._load_xml(path=data_path) - return cls.data - - @classmethod - def _load_xml(cls, path: str, pattern: str = "./text/sentence"): - """Parse the XML file and extract it into a data-driven format.""" - path = hydra.utils.to_absolute_path(path) - root = et.parse(path).getroot() - *_, name = path.rsplit("/", 1) - iterator = tqdm(root.findall(pattern), desc=f"Reading XML dataset {name}") - for id_sentence, sentence in enumerate(iterator): - words = [] - instance_idxs = [] - instance_ids = [] - instance_senses = [] - instance_lemma_pos = [] - for j, word in enumerate(sentence): - words.append(word.text) - if word.tag == "instance": - # Store id, lemma, POS tags for instances - _id, _lemma, _pos = ( - word.attrib["id"], - word.attrib["lemma"], - word.attrib["pos"], - ) - instance_idxs.append(j) - instance_ids.append(_id) - instance_senses.append(cls.inst2wn[_id]) - instance_lemma_pos.append(_lemma + "#" + _pos) - - cls.data.append( - { - "sentence_id": id_sentence, - "sentence": words, - "index": instance_idxs, - "instance_id": instance_ids, - "sense": instance_senses, - "lexeme": instance_lemma_pos, - } - ) - - @classmethod - def _load_key(cls, path: str): - """Extract the mapping to gold labels (i.e. WordNet ids).""" - path = hydra.utils.to_absolute_path(path) - with open(path, "r") as lines: - for line in lines: - instance_id, *wn_ids = line.rstrip().split() - # an instance may have more than one gold annotation - cls.inst2wn[instance_id] = wn_ids diff --git a/src/readers/wordnet_reader.py b/src/readers/wordnet_reader.py new file mode 100644 index 0000000..4ca7a95 --- /dev/null +++ b/src/readers/wordnet_reader.py @@ -0,0 +1,64 @@ +from collections import Counter, defaultdict +from operator import itemgetter +from typing import * + +from torchtext.vocab import Vocab +from omegaconf import DictConfig +import hydra + +from src.utils.utilities import vocab_lookup_indices, read_json_hydra +from src.readers.raganato_reader import RaganatoReader + + +class WordNetReader(object): + """A WordNet reader class that implements the Singleton pattern.""" + + _conf = None + _vocab = None + _glosses = None + _lexeme_means = None + _lemma_means = None + _sense_means = None + + @classmethod + def vocabulary(cls, conf: Optional[DictConfig] = None) -> Vocab: + cls._conf = cls._conf or conf + if cls._glosses is None: + cls._glosses = read_json_hydra(path=cls._conf.data.wordnet.glosses) + cls._vocab = cls._vocab or Vocab(Counter(cls._glosses.keys()), specials=["", ""]) + return cls._vocab + + @classmethod + def sense_means(cls, conf: Optional[DictConfig] = None) -> defaultdict: + cls._conf = cls._conf or conf + if cls._sense_means is None: + cls._sense_means = defaultdict(int, read_json_hydra(path=cls._conf.data.wordnet.sense_means)) + return cls._sense_means + + @classmethod + def mfs_lexeme_means(cls, conf: Optional[DictConfig] = None) -> defaultdict: + if cls._lexeme_means is None: + cls.lexeme_means(conf) + means = {k: Counter(v).most_common(1)[0][0] for k, v in cls._lexeme_means.items()} + cls._wn_mfs_lexeme_means = defaultdict(int, means) + return cls._wn_mfs_lexeme_means + + @classmethod + def lexeme_means(cls, conf: Optional[DictConfig] = None, vocab: Optional[Vocab] = None) -> defaultdict: + cls._conf = cls._conf or conf + cls._vocab = cls._vocab or vocab + if cls._lexeme_means is None: + lexeme_means = read_json_hydra(path=cls._conf.data.wordnet.lexeme_means) + lexeme_means = {k: vocab_lookup_indices(cls._vocab, v) for k, v in lexeme_means.items()} + cls._lexeme_means = defaultdict(int, lexeme_means) + return cls._lexeme_means + + @classmethod + def lemma_means(cls, conf: Optional[DictConfig] = None, vocab: Optional[Vocab] = None) -> defaultdict: + cls._conf = cls._conf or conf + cls._vocab = cls._vocab or vocab + if cls._lemma_means is None: + lemma_means = read_json_hydra(path=cls._conf.data.wordnet.lemma_means) + lemma_means = {k: vocab_lookup_indices(cls._vocab, v) for k, v in lemma_means.items()} + cls._lemma_means = defaultdict(int, lemma_means) + return cls._lemma_means diff --git a/src/scripts/get-wsd-data.sh b/src/scripts/get-wsd-data.sh new file mode 100755 index 0000000..2dbae1d --- /dev/null +++ b/src/scripts/get-wsd-data.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +DATASET_DIRECTORY="data/" +envfile="../.env" +mkdir -p $DATASET_DIRECTORY + +# Retrieve the train data for WSD (Raganato et al., 2017) +WSD_TRAIN_DATA_NAME="WSD_Training_Corpora" +WSD_TRAIN_URL="https://github.com/andreabac3/neural-wsd/releases/download/1.0/WSD_Training_Corpora.zip" +WSD_TRAIN_DATA_ZIP=$DATASET_DIRECTORY+$WSD_TRAIN_DATA_NAME".zip" +WSD_TRAIN_DATA=$DATASET_DIRECTORY+$WSD_TRAIN_DATA_NAME + +if [ -f "$envfile" ]; then + # Using GitHub API to fetch the WSD datasets (usually faster) + bash src/scripts/github_downloader.sh $WSD_TRAIN_DATA_NAME".zip" $WSD_TRAIN_DATA_ZIP +else + # Using plain download from origin servers whenever credentials are missing + WSD_TRAIN_URL="http://lcl.uniroma1.it/wsdeval/data/WSD_Training_Corpora.zip" + curl $WSD_TRAIN_URL -o $WSD_TRAIN_DATA_ZIP +fi + +unzip $WSD_TRAIN_DATA_ZIP -d $DATASET_DIRECTORY +rm $WSD_TRAIN_DATA_ZIP + +# Retrieve the evaluation data for WSD (Raganato et al., 2017) +WSD_EVAL_DATA_NAME="WSD_Unified_Evaluation_Datasets" +WSD_EVAL_URL="https://github.com/andreabac3/neural-wsd/releases/download/1.0/WSD_Unified_Evaluation_Datasets.zip" +WSD_EVAL_DATA_ZIP=$DATASET_DIRECTORY+$WSD_EVAL_DATA_NAME".zip" +WSD_EVAL_DATA=$DATASET_DIRECTORY+$WSD_EVAL_DATA_NAME + +if [ -f "$envfile" ]; then + # Using GitHub API to fetch the WSD datasets (usually faster) + bash src/scripts/github_downloader.sh $WSD_EVAL_DATA_NAME".zip" $WSD_EVAL_DATA_ZIP +else + # Using plain download from origin servers whenever credentials are missing + WSD_EVAL_URL="http://lcl.uniroma1.it/wsdeval/data/WSD_Unified_Evaluation_Datasets.zip" + curl $WSD_EVAL_URL -o $WSD_EVAL_DATA_ZIP +fi + +unzip $WSD_EVAL_DATA_ZIP -d $DATASET_DIRECTORY +rm $WSD_EVAL_DATA_ZIP \ No newline at end of file diff --git a/src/scripts/github_downloader.sh b/src/scripts/github_downloader.sh new file mode 100755 index 0000000..0125466 --- /dev/null +++ b/src/scripts/github_downloader.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Script to download asset file from tag release using GitHub API v3. +# This code is mainly an adaptation to our env of the following stackoverflow thread. +# See: http://stackoverflow.com/a/35688093/55075 + +# Validate settings. +GITHUB_API_TOKEN=`grep 'GITHUB_PERSONAL_ACCESS_TOKEN' ../.env | tr '=' '\n' | grep -v 'GITHUB_PERSONAL_ACCESS_TOKEN'` +repo=`grep 'GITHUB_REPO_NAME' ../.env | tr '=' '\n' | grep -v 'GITHUB_REPO_NAME'` +owner=`grep 'GITHUB_REPO_OWNER' ../.env | tr '=' '\n' | grep -v 'GITHUB_REPO_OWNER'` +tag=`grep 'GITHUB_REPO_RELEASE_TAG' ../.env | tr '=' '\n' | grep -v 'GITHUB_REPO_RELEASE_TAG'` + +[ !"$GITHUB_API_TOKEN" ] || { echo "Error: Please define GITHUB_API_TOKEN variable." >&2; exit 1; } +[ !"$repo" ] || { echo "Error: Please define GITHUB_REPO_NAME variable." >&2; exit 1; } +[ !"$owner" ] || { echo "Error: Please define GITHUB_REPO_OWNER variable." >&2; exit 1; } +[ !"$tag" ] || { echo "Error: Please define GITHUB_REPO_RELEASE_TAG variable." >&2; exit 1; } + +# Define variables. +GH_API="https://api.github.com" +GH_REPO="$GH_API/repos/$owner/$repo" +GH_TAGS="$GH_REPO/releases/tags/$tag" +AUTH="Authorization: token $GITHUB_API_TOKEN" +WGET_ARGS="--content-disposition --auth-no-challenge --no-cookie" +CURL_ARGS="-LJO#" + +# Define asset name and output path +name=$1 +output_path=$2 + +# Validate token. +curl -o /dev/null -sH "$AUTH" $GH_REPO || { echo "Error: Invalid repo, token or network issue!"; exit 1; } + +# Read asset tags. +response=$(curl -sH "$AUTH" $GH_TAGS) +# Get ID of the asset based on given name. +eval $(echo "$response" | grep -C3 "name.:.\+$name" | grep -w id | tr : = | tr -cd '[[:alnum:]]=') +[ "$id" ] || { echo "Error: Failed to get asset id, response: $response" | awk 'length($0)<100' >&2; exit 1; } +GH_ASSET="$GH_REPO/releases/assets/$id" + +# Download asset file. +echo "Downloading asset..." >&2 +curl -o "$output_path" $CURL_ARGS -H "Authorization: token $GITHUB_API_TOKEN" -H 'Accept: application/octet-stream' "$GH_ASSET" +echo "$0 done." >&2 \ No newline at end of file diff --git a/src/scripts/wordnet_extractor.py b/src/scripts/wordnet_extractor.py new file mode 100644 index 0000000..42310bd --- /dev/null +++ b/src/scripts/wordnet_extractor.py @@ -0,0 +1,110 @@ +""" +A simple script to scrape data from WordNet 3.1. database. +""" + +import json +import os +from pathlib import Path + +POS_TAGS = ["adj", "adv", "noun", "verb"] +WN_POS_TAGS = {"adj": "ADJ", "adv": "ADV", "noun": "NOUN", "verb": "VERB"} + +WORDNET_URL = "https://wordnetcode.princeton.edu/wn3.1.dict.tar.gz" +WORDNET_FILE = "wn3.1.dict.tar.gz" + + +def is_synset_id(token: str, synset_list: list[str]) -> bool: + return token in synset_list + + +def write_json(dictionary: dict, path: str) -> None: + with open(path, "w") as writer: + json.dump(dictionary, writer, indent=4, sort_keys=True) + + +def is_header_line(line: str) -> bool: + return line.startswith(" ") + + +def parse_glosses(path: str, output_dir: str) -> set: + """Parses WordNet glosses and returns the set of synset ids.""" + synset_glosses = dict() + for pos in POS_TAGS: + with open(f"{path}data.{pos}", "r") as f_in: + for line in f_in: + if is_header_line(line): + continue + left, gloss = line.strip().split("|", 1) + synset_id, *_ = left.split(" ", 1) + synset_glosses[synset_id] = gloss + + write_json(synset_glosses, path=f"{output_dir}/glosses.json") + return set(synset_glosses.keys()) + + +def parse_lexeme_means(path: str, synset_ids: set[str], output_dir: str) -> dict[str, list[str]]: + """Builds the mapping lexeme -> possible synsets.""" + lexeme_means = dict() + for pos in POS_TAGS: + with open(f"{path}index.{pos}", "r") as f_in: + for line in f_in: + if is_header_line(line): + continue + lemma, *line = line.split(" ") + synsets = [token for token in line if is_synset_id(token, synset_ids)] + lexeme_means[f"{lemma}#{WN_POS_TAGS[pos]}"] = synsets + + write_json(lexeme_means, path=f"{output_dir}/lexeme_means.json") + return lexeme_means + + +def parse_lemma_means(lexeme_means: dict[str, list[str]], output_dir: str) -> None: + """Builds the mapping lemma -> possible synsets.""" + lemma_means = dict() + for lexeme, synsets in lexeme_means.items(): + lemma, pos = lexeme.split("#") + if lemma not in lemma_means: + lemma_means[lemma] = set() + lemma_means[lemma].update(synsets) + + lemma_means = {k: list(v) for k, v in lemma_means.items()} + write_json(lemma_means, path=f"{output_dir}/lemma_means.json") + + +def parse_sense_means(path: str, output_dir: str) -> None: + sense_means = dict() + with open(f"{path}index.sense", "r") as f_in: + for line in f_in: + sense, synset, *_ = line.strip().split(" ") + sense_means[sense] = synset + + write_json(sense_means, path=f"{output_dir}/sense_means.json") + + +def main(path: str = "data/wordnet/", output_dir: str = "data/wordnet/means") -> None: + + # download wordnet's database from origin source + if not os.path.exists(path): + Path(path).mkdir(parents=True) + os.system(f"curl {WORDNET_URL} -o {path}{WORDNET_FILE}") + os.system(f"tar -xf {path}{WORDNET_FILE} -C {path}") + os.system(f"rm {path}{WORDNET_FILE}") + os.makedirs(f"{path}/means/", exist_ok=True) + + path += "dict/" + + # ... fetch glosses and the collection of available synsets + synset_ids = parse_glosses(path, output_dir) + + # ... fetch the lexemes that are used by WordNet for indexing + lexeme_means = parse_lexeme_means(path, synset_ids, output_dir) + + # ... extend the indexing to simple lemmas (i.e. dropping the POS tag) + parse_lemma_means(lexeme_means, output_dir) + + # ... retrieve the mapping to go from senses to their synsets + parse_sense_means(path, output_dir) + + +if __name__ == "__main__": + main() diff --git a/src/test.py b/src/test.py new file mode 100644 index 0000000..e4bf868 --- /dev/null +++ b/src/test.py @@ -0,0 +1,44 @@ +import sys + +import hydra +import omegaconf +import pytorch_lightning as pl +from pytorch_lightning import Trainer + +from src.pl_data_modules import BasePLDataModule +from src.pl_modules import BasePLModule +from src.models.mfs import MFS +from src.utils.utilities import * + + +def test(conf: omegaconf.DictConfig) -> None: + if conf.debug: + print("Running in DEBUG mode.", file=sys.stderr) + + # reproducibility + pl.seed_everything(conf.train.seed) + + # data module declaration + pl_data_module = BasePLDataModule(conf) + + # main module declaration + output_classes = len(pl_data_module.sense_vocabulary) + checkpoint_path = get_checkpoint_path(conf) + pl_baseline = MFS(n_classes=output_classes, mfs_means=pl_data_module.mfs_lexeme_means) + pl_module = BasePLModule.load_from_checkpoint(checkpoint_path, conf=conf, n_classes=output_classes) + + # trainer + trainer: Trainer = hydra.utils.instantiate(conf.train.pl_trainer, gpus=gpus(conf)) + + # module test + trainer.test(pl_baseline, datamodule=pl_data_module) + trainer.test(pl_module, datamodule=pl_data_module) + + +@hydra.main(config_path="../conf", config_name="root") +def main(conf: omegaconf.DictConfig): + test(conf) + + +if __name__ == "__main__": + main() diff --git a/src/train.py b/src/train.py index 0fc72ba..ef4d132 100644 --- a/src/train.py +++ b/src/train.py @@ -1,55 +1,69 @@ -import omegaconf -import hydra +import sys -import pytorch_lightning as pl +from transformers import logging from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +import pytorch_lightning as pl +import omegaconf +import hydra from src.pl_data_modules import BasePLDataModule from src.pl_modules import BasePLModule +from src.utils.utilities import * def train(conf: omegaconf.DictConfig) -> None: + if conf.debug: + print("Running in DEBUG mode.", file=sys.stderr) # reproducibility pl.seed_everything(conf.train.seed) # data module declaration pl_data_module = BasePLDataModule(conf) - pl_data_module.setup() - return # crash - dl = next(iter(pl_data_module.train_dataloader())) - return # main module declaration - pl_module = BasePLModule(conf) + output_classes = len(pl_data_module.sense_vocabulary) + pl_module = BasePLModule(conf, n_classes=output_classes) # callbacks declaration callbacks_store = [] if conf.train.early_stopping_callback is not None: - early_stopping_callback: EarlyStopping = hydra.utils.instantiate( - conf.train.early_stopping_callback - ) + early_stopping_callback: EarlyStopping = hydra.utils.instantiate(conf.train.early_stopping_callback) callbacks_store.append(early_stopping_callback) if conf.train.model_checkpoint_callback is not None: - model_checkpoint_callback: ModelCheckpoint = hydra.utils.instantiate( - conf.train.early_stopping_callback - ) + model_checkpoint_callback: ModelCheckpoint = hydra.utils.instantiate(conf.train.model_checkpoint_callback) callbacks_store.append(model_checkpoint_callback) + # logger + logger: WandbLogger = None + if conf.logging.log and not conf.debug: + wandb_login() + logger: WandbLogger = hydra.utils.instantiate(conf.logging.wandb_logger) + hydra.utils.log.info(f"W&B is now watching <{conf.logging.watch.log}>!") + logger.watch(pl_module, log=conf.logging.watch.log, log_freq=conf.logging.watch.log_freq) + # trainer trainer: Trainer = hydra.utils.instantiate( - conf.train.pl_trainer, callbacks=callbacks_store + conf.train.pl_trainer, callbacks=callbacks_store, gpus=gpus(conf), logger=logger ) # module fit trainer.fit(pl_module, datamodule=pl_data_module) + # store best model path + if conf.train.model_checkpoint_callback is not None: + update_latest_checkpoint_path(model_path=model_checkpoint_callback.best_model_path) + # module test trainer.test(pl_module, datamodule=pl_data_module) + if logger is not None: + logger.experiment.finish() + @hydra.main(config_path="../conf", config_name="root") def main(conf: omegaconf.DictConfig): diff --git a/src/utils/torch_utilities.py b/src/utils/torch_utilities.py new file mode 100644 index 0000000..b053e6b --- /dev/null +++ b/src/utils/torch_utilities.py @@ -0,0 +1,238 @@ +import math + +from torch.optim.optimizer import Optimizer +from typing import Optional +import torch + +# ==================================== +# Implementation of RAdam in torch before major release +# Source: https://github.com/SapienzaNLP/consec/blob/main/src/utils/optimizers.py + + +class RAdam(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0, + degenerated_to_sgd=True, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + buffer=[[None, None, None] for _ in range(10)], + ) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + p_data_fp32.add_(-step_size * group["lr"], exp_avg) + p.data.copy_(p_data_fp32) + + return loss + + +# ==================================== +# Minimal version of ``scatter_mean`` +# Source: https://github.com/rusty1s/pytorch_scatter/ + + +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_add( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + return scatter_sum(src, index, dim, out, dim_size) + + +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0): + r""" + | + + .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ + master/docs/source/_figures/mean.svg?sanitize=true + :align: center + :width: 400px + + | + + Averages all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`.If multiple indices reference the same location, their + **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = \mathrm{out}_i + \frac{1}{N_i} \cdot + \sum_j \mathrm{src}_j + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. :math:`N_i` indicates the number of indices + referencing :math:`i`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. (default: :obj:`0`) + + :rtype: :class:`Tensor` + + .. testsetup:: + + import torch + + .. testcode:: + + from torch_scatter import scatter_mean + + src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) + index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) + out = src.new_zeros((2, 6)) + + out = scatter_mean(src, index, out=out) + + print(out) + + .. testoutput:: + + tensor([[0.0000, 0.0000, 4.0000, 3.0000, 1.5000, 0.0000], + [1.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000]]) + """ + out = scatter_add(src=src, index=index, dim=dim, out=out, dim_size=dim_size) + count = scatter_add(src=torch.ones_like(src), index=index, dim=dim, out=None, dim_size=out.size(dim)) + return out / count.clamp(min=1) diff --git a/src/utils/utilities.py b/src/utils/utilities.py new file mode 100644 index 0000000..3af1c18 --- /dev/null +++ b/src/utils/utilities.py @@ -0,0 +1,196 @@ +from collections import Counter, defaultdict +from operator import itemgetter +from dotenv import load_dotenv +from typing import * +import wandb +import os + +from omegaconf import DictConfig, OmegaConf, open_dict +from transformers.tokenization_utils_base import BatchEncoding +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer +from torch.optim import Optimizer +from torchtext.vocab import Vocab +import torch.nn.functional as F +import torch +import hydra +import yaml +import json + + +def read_json_hydra(path: str) -> any: + base_path = hydra.utils.to_absolute_path(".") + path = os.path.join(base_path, path) + with open(path, "r") as f_in: + result = json.load(f_in) + return result + + +def manual_training_step( + model: torch.nn.Module, + batch: BatchEncoding, + labels: torch.Tensor, + optimizer: Optimizer = torch.optim.Adam, + loss_fn: F.cross_entropy = F.cross_entropy, +) -> None: + """Performs a training step of the input `model` using a custom training configuration.""" + optim = optimizer(model.parameters()) + optim.zero_grad() + logits = model(batch) + loss = loss_fn(logits, labels) + loss.backward() + optim.step() + return {"logits": logits, "predictions": logits.argmax(-1)} + + +def get_batch_wordpiece_indices( + batch: Union[list[list[str]], list[str]], + tokenizer: AutoTokenizer, + padding: int = 0, + return_tensors: Optional[str] = None, +) -> list[Union[torch.tensor, list[int]]]: + transform_fn = lambda x: torch.tensor(x) if return_tensors == "pt" else x + return [transform_fn(get_wordpiece_indices(sentence, tokenizer)) for sentence in batch] + + +def get_wordpiece_indices( + sentence: list[str], + tokenizer: AutoTokenizer, +) -> Union[torch.tensor, list[int]]: + indices = [] + for idx_word, word in enumerate(sentence): + word_tokenized = tokenizer.tokenize(word) + for _ in range(len(word_tokenized)): + indices.append(idx_word) + return indices + + +def batch_tokenizer( + batch: Union[list[list[str]], list[str]], + tokenizer: AutoTokenizer, + padding: int = 0, +) -> BatchEncoding: + assert len(batch) > 0 + if isinstance(batch[0], str): + # The input batch should be already tokenized into words + batch = [sentence.split(" ") for sentence in batch] + lengths = torch.tensor([len(sentence) for sentence in batch]) + word_pieces_indexes = get_batch_wordpiece_indices(batch, tokenizer, padding=padding, return_tensors="pt") + batch = tokenizer(batch, return_tensors="pt", is_split_into_words=True, padding=True, add_special_tokens=False) + batch["lengths"] = lengths + batch["word_pieces_indexes"] = pad_sequence(word_pieces_indexes, batch_first=True, padding_value=padding) + return batch + + +def sentence_tokenizer(sentence: list[str], tokenizer: AutoTokenizer) -> tuple[dict[str, torch.Tensor], list]: + """Returns the input sentence after applying tokenization.""" + sentence_tokenized = tokenizer(" ".join(sentence), return_tensors="pt", add_special_tokens=False) + indexes: list[int] = get_wordpiece_indices(sentence, tokenizer=tokenizer) + return {k: v[0] for k, v in sentence_tokenized.items()}, indexes + + +def wandb_login() -> None: + """Weights and Biases login using environmental key.""" + load_dotenv() + wandb.login(key=os.getenv("WANDB_KEY")) + + +def get_checkpoint_path(conf: DictConfig) -> str: + """Returns the model checkpoint path from the config file.""" + checkpoint_path = conf.test.checkpoint_path + if conf.test.use_latest and "latest_checkpoint_path" in conf.test: + checkpoint_path = conf.test.latest_checkpoint_path + return hydra.utils.to_absolute_path(checkpoint_path) + + +def update_latest_checkpoint_path(model_path: str, config_path: str = "conf/test/default_test.yaml") -> None: + """Useful function that updates the field ``best_model_path`` when the training is complete.""" + base_path = hydra.utils.to_absolute_path(".") + config_path = os.path.join(base_path, config_path) + with open(config_path, "r") as stream: + try: + yaml_dict = yaml.safe_load(stream) + yaml_dict["latest_checkpoint_path"] = os.path.relpath(model_path, start=base_path) + except Exception as e: + print(e) + + with open(config_path, "w") as stream: + yaml.dump(yaml_dict, stream) + + +def gpus(conf: DictConfig) -> int: + """Utility to determine the number of GPUs to use.""" + return conf.train.pl_trainer.gpus if torch.cuda.is_available() else 0 + + +def add_configuration_field(conf: DictConfig, field: str, value: Any) -> None: + """ + Adds a new struct flag. + Docs: https://omegaconf.readthedocs.io/en/2.0_branch/usage.html#struct-flag + """ + OmegaConf.set_struct(conf, True) + with open_dict(conf): + conf[field] = value + + +def vocab_lookup_indices(vocabulary: Union[Vocab, defaultdict], tokens: List[str]) -> List[int]: + """Replacement for Vocab's method ``lookup_indices`` introduced in latest version of torchtext.""" + return [vocabulary[token] for token in tokens] + + +def collate_fn( + batch: List[Tuple[Dict[str, Any]]], + batch_keys: Tuple[str], + lemma_means: Optional[defaultdict] = None, + lexeme_means: Optional[defaultdict] = None, + output_dim: Optional[int] = None, + padding_value: int = 0, +) -> Dict[str, torch.Tensor]: + """ + A simple collate function used to provide batched input to our models. + :param batch: a zipped list of tuples containing dataset values + :param batch_keys: a tuple of keys describing the values in the batch + :param padding_value: the value to use to pad input sequences + :return a batch of preprocessed input sentences + """ + # Unroll the list of tuples into a more useful dictionary with batched features + batch = dict(zip(batch_keys, zip(*[itemgetter(*batch_keys)(x) for x in batch]))) + word_pieces_indexes: List[torch.Tensor] = [torch.tensor(elem) for elem in batch["word_pieces_indexes"]] + # Create the output labels with sense identifiers at position senses_indices + senses_ = torch.zeros(len(batch["senses_indices"]), max(batch["lengths"]), dtype=torch.long) + for batch_idx, (_indices, _senses) in enumerate(zip(batch["senses_indices"], batch["senses"])): + senses_[batch_idx, _indices] = torch.tensor(_senses, dtype=torch.long) + + assert not (lemma_means and lexeme_means), "Specify either one among use_lemma_mask OR use_lexeme_mask" + + if lemma_means or lexeme_means: + assert output_dim, "Output dimension is required to create the ``sense_mask``" + if lemma_means: + batch["sense_mask"] = create_senses_mask(batch, "lemmas", lemma_means, senses_.shape, output_dim) + else: + batch["sense_mask"] = create_senses_mask(batch, "lexemes", lexeme_means, senses_.shape, output_dim) + + batch["input_ids"] = pad_sequence(batch["input_ids"], batch_first=True, padding_value=padding_value) + batch["attention_mask"] = pad_sequence(batch["attention_mask"], batch_first=True, padding_value=padding_value) + batch["word_pieces_indexes"] = pad_sequence(word_pieces_indexes, batch_first=True, padding_value=padding_value) + batch["senses"] = senses_ + batch["lengths"] = torch.tensor(batch["lengths"]) + + return batch + + +def create_senses_mask( + batch: list, + field: str, + means: defaultdict, + senses_dim: tuple[int, int], + output_dim: int, +) -> torch.Tensor: + """Computes the sense mask from the provided ``means`` to limit predictions to candidate senses only.""" + batch_sense_mask_idxs = [vocab_lookup_indices(means, tokens) for tokens in batch[field]] + batch_sense_mask = torch.ones((*senses_dim, output_dim), dtype=torch.bool) + for batch_idx, (sense_idxs, sense_mask_idxs) in enumerate(zip(batch["senses_indices"], batch_sense_mask_idxs)): + batch_sense_mask[batch_idx, sense_idxs] = False + for sense_idx, _sense_mask_idxs in zip(sense_idxs, sense_mask_idxs): + batch_sense_mask[batch_idx, sense_idx, _sense_mask_idxs] = True + return batch_sense_mask diff --git a/tests/unit/datamodule_test.py b/tests/unit/datamodule_test.py new file mode 100644 index 0000000..f187800 --- /dev/null +++ b/tests/unit/datamodule_test.py @@ -0,0 +1,35 @@ +from src.pl_data_modules import BasePLDataModule +from tests.unit.test_case import TestCase + + +class DataModuleTest(TestCase): + def test_initialization(self): + pl_data_module = BasePLDataModule(self.conf) + self.assertIsNotNone(pl_data_module) + + def test_batches_num(self): + pl_data_module = BasePLDataModule(self.conf) + batch_size = self.conf.data.batch_size + + self.assertEqual( + (len(pl_data_module.train_dataset) + batch_size - 1) // batch_size, + len( + pl_data_module.train_dataloader(), + ), + ) + self.assertEqual( + (len(pl_data_module.valid_dataset) + batch_size - 1) // batch_size, + len( + pl_data_module.val_dataloader(), + ), + ) + self.assertEqual( + (len(pl_data_module.test_dataset) + batch_size - 1) // batch_size, + len( + pl_data_module.test_dataloader(), + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_test.py b/tests/unit/dataset_test.py new file mode 100644 index 0000000..15cce0a --- /dev/null +++ b/tests/unit/dataset_test.py @@ -0,0 +1,36 @@ +from transformers import AutoTokenizer + +from src.dataset import SenseAnnotatedDataset +from tests.unit.test_case import TestCase + + +class DatasetTest(TestCase): + def test_initialization(self): + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + dataset = SenseAnnotatedDataset(self.conf, tokenizer=tokenizer) + self.assertIsNotNone(dataset) + + def test_loading_from_cache(self): + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + dataset = SenseAnnotatedDataset.from_cached(self.conf, tokenizer=tokenizer) + self.assertIsNotNone(dataset) + + def test_cache_equivalence(self): + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + dataset = SenseAnnotatedDataset(self.conf, tokenizer=tokenizer) + cached_dataset = SenseAnnotatedDataset.from_cached(self.conf, tokenizer=tokenizer) + self.assertGreaterEqual(len(dataset.data), 0) + self.assertEqual(len(dataset), len(cached_dataset)) + + def test_bert_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + dataset = SenseAnnotatedDataset.from_cached(self.conf, tokenizer=tokenizer) + for sample, preprocessed_sample in zip(dataset.data, dataset.preprocessed_data): + tokens = tokenizer(" ".join(sample["sentence"]), return_tensors="pt") + self.assertIn("input_ids", tokens) + self.assertIn("token_type_ids", tokens) + self.assertIn("attention_mask", tokens) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py new file mode 100644 index 0000000..4d4d474 --- /dev/null +++ b/tests/unit/model_test.py @@ -0,0 +1,95 @@ +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer +from torch import nn +import torch + +from src.utils.utilities import get_wordpiece_indices, batch_tokenizer, manual_training_step +from tests.unit.test_case import TestCase +from src.models.wsd_model import WSDModel + + +class ModelTest(TestCase): + mock_samples = [ + ( + "BERT was created in 2018 by Jacob Devlin and his colleagues from Google.", + torch.randint(0, 10, (10,)), # dummy labels using 10 classes only + ), + ( + "Google Search consists of a series of localized websites.", + torch.randint(0, 10, (10,)), # dummy labels using 10 classes only + ), + ( + "Natural-language understanding is considered an AI-hard problem.", + torch.randint(0, 10, (10,)), # dummy labels using 10 classes only + ), + ] + + def test_parameters_change(self): + n_classes = 10 + sentences, labels = zip(*self.mock_samples) + labels = pad_sequence(labels, batch_first=True, padding_value=0) + + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + batch = batch_tokenizer(batch=sentences, tokenizer=tokenizer) + + # Define the model and store initial parameters + model = WSDModel(self.conf, n_classes=n_classes) + params = [param for param in model.parameters() if param.requires_grad] + initial_params = [param.clone() for param in params] + + # Set the model in `training` mode and update the weights + manual_training_step(model=model, batch=batch, labels=labels) + + # Check if the weights are actually updated + for p0, p1 in zip(initial_params, params): + # using the more stable torch builtin function to check tensor equality + assert not torch.equal(p0, p1) + + def test_output_range(self): + n_classes = 10 + sentences, labels = zip(*self.mock_samples) + labels = pad_sequence(labels, batch_first=True, padding_value=0) + + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + batch = batch_tokenizer(batch=sentences, tokenizer=tokenizer) + + # Define the model and store initial parameters + model = WSDModel(self.conf, n_classes=n_classes) + batch_out = manual_training_step(model=model, batch=batch, labels=labels) + + self.assertGreaterEqual(torch.min(batch_out["predictions"]), torch.tensor(0)) + self.assertLess(torch.max(batch_out["predictions"]), torch.tensor(n_classes)) + + def test_nan_output(self): + n_classes = 10 + sentences, labels = zip(*self.mock_samples) + labels = pad_sequence(labels, batch_first=True, padding_value=0) + + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + batch = batch_tokenizer(batch=sentences, tokenizer=tokenizer) + + # Define the model and store initial parameters + model = WSDModel(self.conf, n_classes=n_classes) + batch_out = manual_training_step(model=model, batch=batch, labels=labels) + + self.assertFalse(batch_out["logits"].isnan().any()) + self.assertFalse(batch_out["predictions"].isnan().any()) + + def test_inf_output(self): + n_classes = 10 + sentences, labels = zip(*self.mock_samples) + labels = pad_sequence(labels, batch_first=True, padding_value=0) + + tokenizer = AutoTokenizer.from_pretrained(self.conf.model.tokenizer) + batch = batch_tokenizer(batch=sentences, tokenizer=tokenizer) + + # Define the model and store initial parameters + model = WSDModel(self.conf, n_classes=n_classes) + batch_out = manual_training_step(model=model, batch=batch, labels=labels) + + self.assertTrue(batch_out["logits"].isfinite().all()) + self.assertTrue(batch_out["predictions"].isfinite().all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/raganato_reader_test.py b/tests/unit/raganato_reader_test.py new file mode 100644 index 0000000..490c3e3 --- /dev/null +++ b/tests/unit/raganato_reader_test.py @@ -0,0 +1,55 @@ +from omegaconf import OmegaConf + +from src.readers.raganato_reader import SemCorReader, OMSTIReader, SemEvalReader +from tests.unit.test_case import TestCase + + +class RaganatoReaderTest(TestCase): + def test_omsti(self): + # samples = OMSTIReader.read(conf=self.conf, **self.conf.data.corpora["omsti"]) + # self.assertEqual(len(samples), 37176) + pass + + def test_semcor_omsti(self): + # samples = OMSTIReader.read(conf=self.conf, **self.conf.data.corpora["semcor+omsti"]) + # self.assertEqual(len(samples), 37176) + pass + + def test_semcor(self): + samples = SemCorReader.read(conf=self.conf, **self.conf.data.corpora["semcor"]) + self.assertEqual(len(samples), 37176) + self.assertEqual(type(samples), list) + + def test_semeval_all(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="semeval_all") + self.assertEqual(len(samples), 1173) + self.assertEqual(type(samples), list) + + def test_semeval2007(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="semeval2007") + self.assertEqual(len(samples), 135) + self.assertEqual(type(samples), list) + + def test_semeval2013(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="semeval2013") + self.assertEqual(len(samples), 306) + self.assertEqual(type(samples), list) + + def test_semeval2015(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="semeval2015") + self.assertEqual(len(samples), 138) + self.assertEqual(type(samples), list) + + def test_senseval2(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="senseval2") + self.assertEqual(len(samples), 242) + self.assertEqual(type(samples), list) + + def test_senseval3(self): + samples = SemEvalReader.read(conf=self.conf, **self.conf.data.corpora["semeval_all"], filter_ds="senseval3") + self.assertEqual(len(samples), 352) + self.assertEqual(type(samples), list) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/root.yaml b/tests/unit/root.yaml new file mode 100644 index 0000000..b411d89 --- /dev/null +++ b/tests/unit/root.yaml @@ -0,0 +1,123 @@ +# Debug mode +debug: False +max_samples: 1000 + +model: + tokenizer: 'bert-base-cased' # 'bert-base-cased' + model_name: 'bert-base-cased' + learning_rate: 1e-3 # 5e-4 + min_learning_rate: 1e-4 + language_model_learning_rate: 1e-5 + language_model_min_learning_rate: 1e-6 + language_model_weight_decay: 1e-4 + use_lemma_mask: False + use_lexeme_mask: False + + word_encoder: + _target_: src.layers.word_encoder.WordEncoder + fine_tune: False + word_dropout: 0.2 + model_name: ${model.model_name} + + sequence_encoder: lstm + lstm_encoder: + _target_: torch.nn.LSTM + input_size: 512 + hidden_size: 256 + bidirectional: True + batch_first: True + num_layers: 2 + dropout: 0.40 + +test: + checkpoint_path: + latest_checkpoint_path: experiments/bert-base-cased/2021-11-13/16-39-09/default_name/epoch=0-step=580.ckpt + use_latest: false + +train: + # reproducibility + seed: 42 + + # experiment name + experiment_name: default_name + + # pl_trainer + pl_trainer: + _target_: pytorch_lightning.Trainer + gpus: 1 + #accumulate_grad_batches: 1 # 8 + #gradient_clip_val: 10.0 + #val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps + max_epochs: 20 + fast_dev_run: False + #max_steps: 100_000 + # uncomment the lines below for training with mixed precision + #precision: 16 + #amp_level: O2 + + # early stopping callback + # "early_stopping_callback: null" will disable early stopping + early_stopping_callback: + _target_: pytorch_lightning.callbacks.EarlyStopping + monitor: val_f1_micro + mode: max + patience: 50 + + # model_checkpoint_callback + # "model_checkpoint_callback: null" will disable model checkpointing + model_checkpoint_callback: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: val_f1_micro + mode: max + verbose: True + save_top_k: 5 + dirpath: ${train.experiment_name}/ + +data: + train_path: "data/train.tsv" + validation_path: "data/validation.tsv" + test_path: "data/test.tsv" + + train_ds: "semcor" + val_ds: "semeval2007" + test_ds: "semeval2015" + + preprocessed_dir: "data/preprocessed/" + force_preprocessing: False + dump_preprocessed: True + + corpora: + semcor: + data_path: "data/WSD_Training_Corpora/SemCor/semcor.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor/semcor.gold.key.txt" + semcor+omsti: + data_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.gold.key.txt" + omsti: + data_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.data.xml" + key_path: "data/WSD_Training_Corpora/SemCor+OMSTI/semcor+omsti.gold.key.txt" + semeval_all: + data_path: "data/WSD_Unified_Evaluation_Datasets/ALL/ALL.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/ALL/ALL.gold.key.txt" + semeval2007: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2007/semeval2007.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt" + semeval2013: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2013/semeval2013.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2013/semeval2013.gold.key.txt" + semeval2015: + data_path: "data/WSD_Unified_Evaluation_Datasets/semeval2015/semeval2015.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/semeval2015/semeval2015.gold.key.txt" + senseval2: + data_path: "data/WSD_Unified_Evaluation_Datasets/senseval2/senseval2.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/senseval2/senseval2.gold.key.txt" + senseval3: + data_path: "data/WSD_Unified_Evaluation_Datasets/senseval3/senseval3.data.xml" + key_path: "data/WSD_Unified_Evaluation_Datasets/senseval3/senseval3.gold.key.txt" + + batch_size: 64 + num_workers: 0 + + min_freq_senses: 1 + allow_multiple_senses: False + diff --git a/tests/unit/test_case.py b/tests/unit/test_case.py new file mode 100644 index 0000000..d51bffc --- /dev/null +++ b/tests/unit/test_case.py @@ -0,0 +1,9 @@ +import unittest + +from omegaconf import OmegaConf, DictConfig + + +class TestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conf: DictConfig = OmegaConf.load("tests/unit/root.yaml")