In [14]:
import torch

import datagen
from datagen.dataset import SemCorDataSet
import datasets # prevent circular module import exception

from transformers import AutoTokenizer, AutoConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from modelling.collator import BetterDataCollatorForWholeWordMask
from modelling.model import SynsetClassificationModel
from modelling.trainer import BetterTrainer
from lit_nlp.api import types as lit_types
from lit_nlp.api.dataset import Dataset
from lit_nlp.api.model import Model
from pathlib import Path
import numpy as np
import pandas as pd

In [15]:
def load_dataset(path: str):
    ds = datagen.SemCorDataSet.unpickle(Path(path).with_suffix(".pickle"))
    hf_ds = datasets.Dataset.load_from_disk(Path(path).with_suffix(".hf"))
    
    hf_ds = hf_ds.add_column("sense-labels", hf_ds["labels"])
    relevant_columns = [
        column
        for column in hf_ds.column_names
        if column not in ds.sentence_level.columns
    ]
    relevant_columns.extend(["sense-labels"])
    hf_ds.set_format(type="torch", columns=relevant_columns)
    
    return ds, hf_ds

In [16]:
def load_model(path: str):
    if torch.cuda.is_available():
        device = "cuda:0"
        print(f"CUDA found; running on {device}")
    else:
        device = "cpu"
        print(f"CUDA not found; running on {device}")
    
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    
    model = SynsetClassificationModel.from_pretrained(
        path,
        config=AutoConfig.from_pretrained(path, local_files_only=True),
        local_files_only=True,
        model_name=path,
        num_classes=2584,
    ).to(device)
    
    trainer = BetterTrainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=DataCollatorForLanguageModeling(tokenizer),
        args=TrainingArguments(
            output_dir='./saliency',
            remove_unused_columns=False,
            label_names=["labels", "sense-labels"]
        )
    )
    
    return trainer

In [17]:
ds = SemCorDataSet.unpickle("dataset/roberta+senseval2.pickle")
VOCAB = ds.all_sense_keys['sense-key1'].tolist()

Refer to https://pair-code.github.io/lit/setup/ for implementation details

In [27]:
class LIT_Dataset(Dataset):
    def __init__(self, path: str, tokenizer):
        ds, hf_ds = load_dataset(path)    
        self._examples = [
            {
                "sentence": sentence,
                'input_ids': row["input_ids"].numpy(),
                'attention_mask': row['attention_mask'].numpy(),
                'labels': row['labels'].numpy(),
                "sense-labels": row["sense-labels"].numpy(),
            }
            for row, sentence in zip(hf_ds, ds.sentence_level["sentence"])
        ]
        
        #print(sentences[0], tok_sentences[0], sense_idxs[0])
        #print(self._examples)

    def spec(self):
        return {
            "sentence": lit_types.TextSegment(),
            "input_ids": lit_types.Embeddings(),
            'attention_mask': lit_types.Embeddings(),
            'labels': lit_types.CategoryLabel(),
            "sense-labels": lit_types.CategoryLabel(),
        }


class LIT_Model(Model):
    def __init__(self, path: str, ds: SemCorDataSet):
        self.trainer = load_model(path)
        self.ds = ds

    def input_spec(self):
        return {
            #
            # "input_ids": lit_types.TextSegment(),
            "sense-labels":  lit_types.CategoryLabel(),
            'attention_mask': lit_types.Embeddings(),
            'labels': lit_types.CategoryLabel()
        }

    def output_spec(self):
        return {
            "probs": lit_types.MulticlassPreds(vocab=VOCAB, parent='sense-labels', null_idx=-100),
            # "tokens": lit_types.Tokens(parent="sentence")
        }

    def predict_minibatch(self, inputs):
        outputs = list()
        for i in inputs:
            copyi = i.copy()
            _ = copyi.pop("sentence")
            
            tensor_i = {k: torch.tensor(v).unsqueeze(dim=1) for k, v in copyi.items()}
            #print(tensor_i)
            
            loss, logits, labels = self.trainer.prediction_step(
                self.trainer.model, 
                tensor_i, 
                prediction_loss_only=False
            )
            
            probas = torch.nn.functional.softmax(logits, dim=-1).squeeze().cpu().numpy()
            
            outputs.append({'probs': probas,
                            'tokens': i["input_ids"]})
        #print(outputs)
        return outputs
#             #print(pred)
            
#             #print(pred)
#             #print(logits)
#             #print(label_ids)

#             sense_labels = tensor_i["sense-labels"][:]
#             sense_labels[tensor_i["labels"] == -100] = -100

#             # Get IDs
#             masks_idx = sense_labels != -100
#             predictions = np.argmax(logits.cpu().numpy(), axis=-1)[masks_idx.flatten()]
#             reference = sense_labels[masks_idx]
            
#             sense_keys = list()
            
#             sks = self.ds.all_sense_keys
#             #print(sks.shape)
            
#             if len(reference):
#                 for ref in reference.numpy():
#                     # print(ref)
#                     r = sks[sks["sense-key-idx"] == ref]
#                     if not r["sense-key1"].shape[0]:
#                         sense_keys.append(None)
#                     else:
#                         ref_sk = r["sense-key1"].iloc[0]
#                         sense_keys.append(ref_sk)
#             else:
#                 sense_keys.append(None)
            
#             #print(sense_keys)
#             if any(s is not None for s in sense_keys):
#                 outputs.append({"prediction": next(filter(None, sense_keys))})
#             else:
#                 outputs.append({"prediction": "unknown"})
            
#         # print(outputs)
#         return outputs


In [28]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook
tr = SemCorDataSet.unpickle("dataset/roberta+senseval2.pickle")
lm = LIT_Model('./out/checkpoints/roberta-probing+semcor/checkpoint-185900', tr)
lit_models = {'wsd': lm}
#datasets = {'sst_dev': glue.SST2Data('validation')}


Didn't find file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/added_tokens.json. We won't load it.
loading file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/vocab.json
loading file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/merges.txt
loading file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/tokenizer.json
loading file None
loading file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/special_tokens_map.json
loading file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/tokenizer_config.json
loading configuration file ./out/checkpoints/roberta-probing+semcor/checkpoint-185900/config.json
Model config RobertaConfig {
  "_name_or_path": "./out/checkpoints/roberta-probing+semcor/checkpoint-185900",
  "architectures": [
    "SynsetClassificationModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 

CUDA found; running on cuda:0


Some weights of the model checkpoint at ./out/checkpoints/roberta-probing+semcor/checkpoint-185900 were not used when initializing RobertaModel: ['mlmodel.encoder.layer.11.attention.output.dense.bias', 'mlmodel.encoder.layer.2.attention.self.value.bias', 'mlmodel.embeddings.word_embeddings.weight', 'mlmodel.encoder.layer.0.attention.output.dense.bias', 'mlmodel.encoder.layer.6.attention.self.value.bias', 'mlmodel.pooler.dense.weight', 'mlmodel.encoder.layer.10.intermediate.dense.weight', 'mlmodel.encoder.layer.7.intermediate.dense.weight', 'mlmodel.encoder.layer.3.attention.self.value.weight', 'mlmodel.encoder.layer.0.output.dense.bias', 'mlmodel.encoder.layer.0.attention.output.dense.weight', 'classifier.1.weight', 'mlmodel.encoder.layer.10.output.dense.weight', 'mlmodel.encoder.layer.3.attention.self.key.bias', 'mlmodel.encoder.layer.7.output.dense.bias', 'mlmodel.encoder.layer.2.attention.output.LayerNorm.bias', 'mlmodel.pooler.dense.bias', 'mlmodel.encoder.layer.8.attention.output.

Some weights of RobertaModel were not initialized from the model checkpoint at ./out/checkpoints/roberta-probing+semcor/checkpoint-185900 and are newly initialized: ['encoder.layer.1.attention.self.value.bias', 'encoder.layer.6.output.dense.bias', 'embeddings.token_type_embeddings.weight', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.self.key.weight', 'embeddings.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'embeddings.word_embeddings.weight

All model checkpoint weights were used when initializing SynsetClassificationModel.

All the weights of SynsetClassificationModel were initialized from the model checkpoint at ./out/checkpoints/roberta-probing+semcor/checkpoint-185900.
If your task is similar to the task the model of the checkpoint was trained on, you can already use SynsetClassificationModel for predictions without further training.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [29]:
lit_datasets = {'wsd': LIT_Dataset('dataset/roberta+senseval2.pickle', lm.trainer.tokenizer)}

In [30]:
# Render the widget
widget = notebook.LitWidget(lit_models, lit_datasets, height=1600, width=900)
widget.render()

127.0.0.1 - - [29/Mar/2022 22:49:40] "GET /? HTTP/1.1" 200 1406
127.0.0.1 - - [29/Mar/2022 22:49:40] "GET /main.js HTTP/1.1" 200 1809942
127.0.0.1 - - [29/Mar/2022 22:49:40] "POST /get_info? HTTP/1.1" 200 42940
127.0.0.1 - - [29/Mar/2022 22:49:40] "GET /static/favicon.png HTTP/1.1" 200 13257
127.0.0.1 - - [29/Mar/2022 22:49:40] "POST /get_dataset?dataset_name=wsd HTTP/1.1" 200 433457
127.0.0.1 - - [29/Mar/2022 22:50:10] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 1194695655
  pred_spec.vocab.index(label) if label in pred_spec.vocab else -1
ERROR:absl:Uncaught error: operands could not be broadcast together with shapes (82,2584) (1182,)  

 Traceback (most recent call last):
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 191, in __call__
    return self._ServeCustomHandler(request, clean_path, environ)(
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/be

127.0.0.1 - - [29/Mar/2022 22:50:41] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 500 2357
127.0.0.1 - - [29/Mar/2022 22:50:41] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936852
127.0.0.1 - - [29/Mar/2022 22:50:41] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936852
ERROR:absl:Uncaught error: operands could not be broadcast together with shapes (82,2584) (1182,)  

 Traceback (most recent call last):
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 191, in __call__
    return self._ServeCustomHandler(request, clean_path, environ)(
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 176, in _ServeCustomHandler
    return self._handlers[clean_path](self, request, environ)
  File "/home/tkriege

127.0.0.1 - - [29/Mar/2022 22:50:42] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 500 2357
127.0.0.1 - - [29/Mar/2022 22:50:42] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936694
127.0.0.1 - - [29/Mar/2022 22:50:43] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936694
ERROR:absl:Uncaught error: operands could not be broadcast together with shapes (82,2584) (1182,)  

 Traceback (most recent call last):
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 191, in __call__
    return self._ServeCustomHandler(request, clean_path, environ)(
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 176, in _ServeCustomHandler
    return self._handlers[clean_path](self, request, environ)
  File "/home/tkriege

127.0.0.1 - - [29/Mar/2022 22:50:43] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 500 2357
127.0.0.1 - - [29/Mar/2022 22:50:43] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936707
127.0.0.1 - - [29/Mar/2022 22:50:43] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936707
ERROR:absl:Uncaught error: operands could not be broadcast together with shapes (82,2584) (1182,)  

 Traceback (most recent call last):
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 191, in __call__
    return self._ServeCustomHandler(request, clean_path, environ)(
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 176, in _ServeCustomHandler
    return self._handlers[clean_path](self, request, environ)
  File "/home/tkriege

127.0.0.1 - - [29/Mar/2022 22:50:54] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 500 2357
127.0.0.1 - - [29/Mar/2022 22:50:55] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936881
127.0.0.1 - - [29/Mar/2022 22:50:55] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936881
ERROR:absl:Uncaught error: operands could not be broadcast together with shapes (82,2584) (1182,)  

 Traceback (most recent call last):
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 191, in __call__
    return self._ServeCustomHandler(request, clean_path, environ)(
  File "/home/tkrieger/.cache/pypoetry/virtualenvs/bert-wsd-uula_XMy-py3.8/lib/python3.8/site-packages/lit_nlp/lib/wsgi_app.py", line 176, in _ServeCustomHandler
    return self._handlers[clean_path](self, request, environ)
  File "/home/tkriege

127.0.0.1 - - [29/Mar/2022 22:52:10] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 500 2357
127.0.0.1 - - [29/Mar/2022 22:52:10] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936573
127.0.0.1 - - [29/Mar/2022 22:52:10] "POST /get_preds?model=wsd&dataset_name=wsd&requested_types=MulticlassPreds HTTP/1.1" 200 4936573


In [None]:
#ds, _ = load_dataset('dataset/semeval2007-roberta.pickle')
#ds.token_level[ds.token_level["sense-key-idx1"] == 8]

In [None]:
tr = SemCorDataSet.unpickle("dataset/semcor4roberta.pickle")
tr.all_sense_keys["sense-key1"].unique().shape
#sens2, _ = load_dataset('dataset/senseval2-roberta.pickle')
#sens3, _ = load_dataset('dataset/senseval3-roberta.pickle')

In [None]:
semeval2007 = SemCorDataSet.unpickle('dataset/semeval2007-roberta.pickle')
print(semeval2007.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2007.all_sense_keys, on="sense-key1")

In [None]:
semeval2013 = SemCorDataSet.unpickle('dataset/semeval2013-roberta.pickle')
print(semeval2013.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2013.all_sense_keys, on="sense-key1")

In [None]:
semeval2015 = SemCorDataSet.unpickle('dataset/semeval2015-roberta.pickle')
print(semeval2015.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2015.all_sense_keys, on="sense-key1")

In [None]:
senseval2 = SemCorDataSet.unpickle('dataset/senseval2-roberta.pickle')
print(senseval2.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, senseval2.all_sense_keys, on="sense-key1")

In [None]:
senseval3 = SemCorDataSet.unpickle('dataset/senseval3-roberta.pickle')
print(senseval3.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, senseval3.all_sense_keys, on="sense-key1")