In [69]:
import torch

import datagen
from datagen.dataset import SemCorDataSet
import datasets # prevent circular module import exception

from transformers import AutoTokenizer, AutoConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from modelling.collator import BetterDataCollatorForWholeWordMask
from modelling.model import SynsetClassificationModel
from modelling.trainer import BetterTrainer
from lit_nlp.api import types as lit_types
from lit_nlp.api.dataset import Dataset
from lit_nlp.api.model import Model
from pathlib import Path
import numpy as np
import pandas as pd

In [70]:
def load_dataset(path: str):
    ds = datagen.SemCorDataSet.unpickle(Path(path).with_suffix(".pickle"))
    hf_ds = datasets.Dataset.load_from_disk(Path(path).with_suffix(".hf"))
    
    hf_ds = hf_ds.add_column("sense-labels", hf_ds["labels"])
    relevant_columns = [
        column
        for column in hf_ds.column_names
        if column not in ds.sentence_level.columns
    ]
    relevant_columns.extend(["sense-labels"])
    hf_ds.set_format(type="torch", columns=relevant_columns)
    
    return ds, hf_ds

In [71]:
def load_model(path: str):
    if torch.cuda.is_available():
        device = "cuda:0"
        print(f"CUDA found; running on {device}")
    else:
        device = "cpu"
        print(f"CUDA not found; running on {device}")
    
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    
    model = SynsetClassificationModel.from_pretrained(
        path,
        config=AutoConfig.from_pretrained(path, local_files_only=True),
        local_files_only=True,
        model_name=path,
        num_classes=2584,
    ).to(device)
    
    trainer = BetterTrainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=DataCollatorForLanguageModeling(tokenizer),
        args=TrainingArguments(
            output_dir='./saliency',
            remove_unused_columns=False,
            label_names=["labels", "sense-labels"]
        )
    )
    
    return trainer

Refer to https://pair-code.github.io/lit/setup/ for implementation details

In [72]:
class LIT_Dataset(Dataset):
    def __init__(self, path: str, tokenizer):
        ds, hf_ds = load_dataset(path)    
        self._examples = [
            {
                "sentence": sentence,
                'input_ids': row["input_ids"].numpy(),
                'attention_mask': row['attention_mask'].numpy(),
                'labels': row['labels'].numpy(),
                "sense-labels": row["sense-labels"].numpy(),
            }
            for row, sentence in zip(hf_ds, ds.sentence_level["sentence"])
        ]
        
        #print(sentences[0], tok_sentences[0], sense_idxs[0])
        #print(self._examples)

    def spec(self):
        return {
            "sentence": lit_types.TextSegment(),
            #"input_ids": lit_types.Embeddings(),
            #'attention_mask': lit_types.Embeddings(),
            #'labels': row['labels'].numpy(),
            #"sense-labels": row["sense-labels"].numpy(),
        }


class LIT_Model(Model):
    def __init__(self, path: str, ds: SemCorDataSet):
        self.trainer = load_model(path)
        self.ds = ds

    def input_spec(self):
        return {
            "sentence": lit_types.TextSegment(),
            #"sense-labels":  lit_types.CategoryLabel(),
            #'attention_mask': lit_types.Embeddings(),
            #'sense-labels': lit_types.Embeddings(),
            #'sentence': lit_types.TextSegment(),
        }

    def output_spec(self):
        return {
            "prediction": lit_types.CategoryLabel()
        }

    def predict(self, inputs):
        outputs = list()
        for i in inputs:
            copyi = i.copy()
            _ = copyi.pop("sentence")
            
            tensor_i = {k: torch.tensor(v).unsqueeze(dim=1) for k, v in copyi.items()}
            #print(tensor_i)
            
            loss, logits, labels = self.trainer.prediction_step(
                self.trainer.model, 
                tensor_i, 
                prediction_loss_only=False
            )
            #print(pred)
            
            #print(pred)
            #print(logits)
            #print(label_ids)

            sense_labels = tensor_i["sense-labels"][:]
            sense_labels[tensor_i["labels"] == -100] = -100

            # Get IDs
            masks_idx = sense_labels != -100
            predictions = np.argmax(logits, axis=-1)[masks_idx.flatten()]
            reference = sense_labels[masks_idx]
            
            sense_keys = list()
            
            sks = self.ds.all_sense_keys
            #print(sks.shape)
            
            if len(reference):
                for ref in reference.numpy():
                    # print(ref)
                    r = sks[sks["sense-key-idx"] == ref]
                    if not r["sense-key1"].shape[0]:
                        sense_keys.append(None)
                    else:
                        ref_sk = r["sense-key1"].iloc[0]
                        sense_keys.append(ref_sk)
            else:
                sense_keys.append(None)
            
            #print(sense_keys)
            if any(s is not None for s in sense_keys):
                outputs.append({"prediction": next(filter(None, sense_keys))})
            else:
                outputs.append({"prediction": "unknown"})
            
        print(outputs)
        return outputs
            

    def predict_minibatch(self, inputs):
        return self.trainer.predict(inputs)


In [73]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook
tr = SemCorDataSet.unpickle("dataset/semcor4roberta.pickle")
lm = LIT_Model('./checkpoints-probing/roberta-probing+semcor/checkpoint-185900', tr)
lit_models = {'wsd': lm}
#datasets = {'sst_dev': glue.SST2Data('validation')}


Didn't find file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/added_tokens.json. We won't load it.
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/vocab.json
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/merges.txt
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/tokenizer.json
loading file None
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/special_tokens_map.json
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/tokenizer_config.json
loading configuration file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/config.json
Model config RobertaConfig {
  "_name_or_path": "./checkpoints-probing/roberta-probing+semcor/checkpoint-185900",
  "architectures": [
    "SynsetClassificationModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "

CUDA not found; running on cpu


loading configuration file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/config.json
Model config RobertaConfig {
  "_name_or_path": "./checkpoints-probing/roberta-probing+semcor/checkpoint-185900",
  "architectures": [
    "SynsetClassificationModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.17.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/pytorch_model.bin
Some weights of the model checkpoint at

Some weights of RobertaModel were not initialized from the model checkpoint at ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900 and are newly initialized: ['encoder.layer.5.attention.self.query.bias', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.8.intermediate.dense.weight', 'pooler.dense.bias', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.la

All model checkpoint weights were used when initializing SynsetClassificationModel.

All the weights of SynsetClassificationModel were initialized from the model checkpoint at ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900.
If your task is similar to the task the model of the checkpoint was trained on, you can already use SynsetClassificationModel for predictions without further training.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [74]:
lit_datasets = {'wsd': LIT_Dataset('dataset/semeval2007-roberta.pickle', lm.trainer.tokenizer)}

In [75]:
# Render the widget
widget = notebook.LitWidget(lit_models, lit_datasets, height=1600, width=900)
widget.render()

[{'prediction': 'long%3:00:02::'}, {'prediction': 'benefit%1:21:00::'}, {'prediction': 'have%2:42:00::'}, {'prediction': 'employee%1:18:00::'}, {'prediction': 'become%2:42:01::'}, {'prediction': 'increased%3:00:00::'}, {'prediction': 'make%2:41:00::'}, {'prediction': 'commercial_enterprise%1:04:00::'}, {'prediction': 'measure%2:31:01::'}, {'prediction': 'reach%2:38:06::'}, {'prediction': 'method%1:09:00::'}, {'prediction': 'other%3:00:00::'}, {'prediction': 'technique%1:09:00::'}, {'prediction': 'unknown'}, {'prediction': 'government%1:14:00::'}, {'prediction': 'government%1:14:00::'}, {'prediction': 'bigger%5:00:00:large:00'}, {'prediction': 'unknown'}, {'prediction': 'try%2:41:00::'}, {'prediction': 'have%2:40:00::'}, {'prediction': 'accident%1:11:01::'}, {'prediction': 'new%3:00:00::'}, {'prediction': 'unknown'}, {'prediction': 'unknown'}, {'prediction': 'hour%1:28:00::'}, {'prediction': 'sign%2:41:00::'}, {'prediction': 'be%2:42:03::'}, {'prediction': 'unknown'}, {'prediction': 'en

In [76]:
#ds, _ = load_dataset('dataset/semeval2007-roberta.pickle')
#ds.token_level[ds.token_level["sense-key-idx1"] == 8]

In [77]:
tr = SemCorDataSet.unpickle("dataset/semcor4roberta.pickle")
tr.all_sense_keys["sense-key1"].unique().shape
#sens2, _ = load_dataset('dataset/senseval2-roberta.pickle')
#sens3, _ = load_dataset('dataset/senseval3-roberta.pickle')

127.0.0.1 - - [29/Mar/2022 03:03:16] "GET /? HTTP/1.1" 200 1406
127.0.0.1 - - [29/Mar/2022 03:03:16] "GET /main.js HTTP/1.1" 200 1809942
127.0.0.1 - - [29/Mar/2022 03:03:16] "POST /get_info? HTTP/1.1" 200 15343
127.0.0.1 - - [29/Mar/2022 03:03:16] "GET /static/favicon.png HTTP/1.1" 200 13257
127.0.0.1 - - [29/Mar/2022 03:03:16] "POST /get_dataset?dataset_name=wsd HTTP/1.1" 200 287089


[]


127.0.0.1 - - [29/Mar/2022 03:03:16] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 200 62


(2584,)

In [78]:
semeval2007 = SemCorDataSet.unpickle('dataset/semeval2007-roberta.pickle')
print(semeval2007.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2007.all_sense_keys, on="sense-key1")

(354, 2)


Unnamed: 0,sense-key1,sense-key-idx_x,sense-key-idx_y
0,program%1:09:01::,4,126
1,become%2:42:01::,6,120
2,have%2:42:00::,8,188
3,attempt%1:04:00::,13,224
4,make%2:41:00::,14,150
...,...,...,...
193,lean%2:38:00::,2295,218
194,necessity%1:17:00::,2337,22
195,resemble%2:42:00::,2348,253
196,laugh%2:29:00::,2404,316


In [79]:
semeval2013 = SemCorDataSet.unpickle('dataset/semeval2013-roberta.pickle')
print(semeval2013.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2013.all_sense_keys, on="sense-key1")

(731, 2)


Unnamed: 0,sense-key1,sense-key-idx_x,sense-key-idx_y
0,result%1:11:00::,16,564
1,commercial_enterprise%1:04:00::,19,263
2,publication%1:10:00::,22,454
3,paper%1:27:00::,30,116
4,communication%1:10:01::,33,358
...,...,...,...
262,editor%1:18:00::,2506,240
263,punishment%1:04:00::,2514,423
264,play%1:04:05::,2559,165
265,planet%1:17:00::,2565,81


In [80]:
semeval2015 = SemCorDataSet.unpickle('dataset/semeval2015-roberta.pickle')
print(semeval2015.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, semeval2015.all_sense_keys, on="sense-key1")

(509, 2)


Unnamed: 0,sense-key1,sense-key-idx_x,sense-key-idx_y
0,long%3:00:02::,0,161
1,let%2:41:00::,5,248
2,attempt%1:04:00::,13,320
3,measure%2:31:01::,15,12
4,result%1:11:00::,16,255
...,...,...,...
240,milligram%1:23:00::,2451,83
241,patient%1:18:00::,2512,67
242,cell%1:03:00::,2517,72
243,substance%1:03:00::,2527,446


In [81]:
senseval2 = SemCorDataSet.unpickle('dataset/senseval2-roberta.pickle')
print(senseval2.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, senseval2.all_sense_keys, on="sense-key1")

(1182, 2)


Unnamed: 0,sense-key1,sense-key-idx_x,sense-key-idx_y
0,program%1:09:01::,4,1024
1,let%2:41:00::,5,896
2,become%2:42:01::,6,254
3,rather%4:02:02::,7,998
4,increased%3:00:00::,12,1077
...,...,...,...
558,trip%2:36:00::,2486,489
559,test%1:04:00::,2494,473
560,patient%1:18:00::,2512,812
561,cell%1:03:00::,2517,453


In [82]:
senseval3 = SemCorDataSet.unpickle('dataset/senseval3-roberta.pickle')
print(senseval3.all_sense_keys.shape)
pd.merge(tr.all_sense_keys, senseval3.all_sense_keys, on="sense-key1")

(1072, 2)


Unnamed: 0,sense-key1,sense-key-idx_x,sense-key-idx_y
0,long%3:00:02::,0,796
1,benefit%1:21:00::,3,603
2,let%2:41:00::,5,988
3,have%2:42:00::,8,162
4,attempt%1:04:00::,13,687
...,...,...,...
506,photograph%1:06:00::,2472,1006
507,earthquake%1:11:00::,2491,798
508,prefer%2:41:00::,2511,525
509,rub%2:35:00::,2550,172


[]


127.0.0.1 - - [29/Mar/2022 03:04:20] "POST /get_interpretations?model=wsd&dataset_name=wsd&interpreter=metrics HTTP/1.1" 200 62
