In [3]:
import torch
import datasets
import datagen # prevent circular module import exception
from transformers import AutoTokenizer, AutoConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from modelling.collator import BetterDataCollatorForWholeWordMask
from modelling.model import SynsetClassificationModel
from modelling.trainer import BetterTrainer
from lit_nlp.api import types as lit_types
from lit_nlp.api.dataset import Dataset
from lit_nlp.api.model import Model
from pathlib import Path

In [4]:
def load_dataset(path: str):
    ds = datagen.SemCorDataSet.unpickle(Path(path).with_suffix(".pickle"))
    hf_ds = datasets.Dataset.load_from_disk(Path(path).with_suffix(".hf"))
    
    hf_ds = hf_ds.add_column("sense-labels", hf_ds["labels"])
    relevant_columns = [
        column
        for column in hf_ds.column_names
        if column not in ds.sentence_level.columns
    ]
    relevant_columns.append("sense-labels")
    hf_ds.set_format(type="torch", columns=relevant_columns)
    
    return hf_ds

In [5]:
def load_model(path: str):
    if torch.cuda.is_available():
        device = "cuda:0"
        print(f"CUDA found; running on {device}")
    else:
        device = "cpu"
        print(f"CUDA not found; running on {device}")
    
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    
    model = SynsetClassificationModel.from_pretrained(
        path,
        config=AutoConfig.from_pretrained(path, local_files_only=True),
        local_files_only=True,
        model_name=path,
        num_classes=2584,
    ).to(device)
    
    trainer = BetterTrainer(
        model=model,
        data_collator=DataCollatorForLanguageModeling(tokenizer),
        args=TrainingArguments(
            output_dir='./saliency',
            remove_unused_columns=False,
            label_names=["labels", "sense-labels"]
        )
    )
    
    return trainer

Refer to https://pair-code.github.io/lit/setup/ for implementation details

In [20]:
class LIT_Dataset(Dataset):
    def __init__(self, path: str):
        ds = load_dataset(path)
        
        self._examples = [{
            #'sentence_idx': row['sentence_idx'],
            #'sentence': row['sentence'],
            'input_ids': row['input_ids'],
            'attention_mask': row['attention_mask'].tolist(),
            'labels': row['labels'].tolist(),
            'sense-labels': row['sense-labels'].tolist(),
        } for row in ds]
        
        #self._examples = ds
        

    def spec(self):
        return {
            #'sentence_idx': lit_types.Scalar(),
            #'sentence': lit_types.TextSegment(),
            'input_ids': lit_types.TokenEmbeddings(),
            'attention_mask': lit_types.TokenEmbeddings(),
            'labels': lit_types.TokenEmbeddings(),
            'sense-labels': lit_types.TokenEmbeddings(),
        }

class LIT_Model(Model):
    def __init__(self, path: str):
        self._model = load_model(path)
        
    def input_spec(self):
        return {
            'input_ids': lit_types.TokenEmbeddings(),
            'attention_mask': lit_types.TokenEmbeddings(),
            'sense-labels': lit_types.TokenEmbeddings(),
        }
    
    def output_spec(self):
        return {
            'labels': lit_types.TokenEmbeddings(),
        }
        
    def predict(self, inputs):
        return self._model.predict(inputs)
    
    def predict_minibatch(self, inputs):
        return self._model.predict(inputs)

In [21]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook

#datasets = {'sst_dev': glue.SST2Data('validation')}
lit_models = {'wsd': LIT_Model('./checkpoints-probing/roberta-probing+semcor/checkpoint-185900')}
lit_datasets = {'wsd': LIT_Dataset('./dataset/roberta+senseval2.pickle')}

widget = notebook.LitWidget(lit_models, lit_datasets, height=800)

Didn't find file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/added_tokens.json. We won't load it.
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/vocab.json
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/merges.txt
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/tokenizer.json
loading file None
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/special_tokens_map.json
loading file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/tokenizer_config.json
loading configuration file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/config.json
Model config RobertaConfig {
  "_name_or_path": "./checkpoints-probing/roberta-probing+semcor/checkpoint-185900",
  "architectures": [
    "SynsetClassificationModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "

CUDA not found; running on cpu


loading configuration file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/config.json
Model config RobertaConfig {
  "_name_or_path": "./checkpoints-probing/roberta-probing+semcor/checkpoint-185900",
  "architectures": [
    "SynsetClassificationModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "torch_dtype": "float32",
  "transformers_version": "4.17.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

loading weights file ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900/pytorch_model.bin
Some weights of the model checkpoint at

Some weights of RobertaModel were not initialized from the model checkpoint at ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900 and are newly initialized: ['encoder.layer.5.output.dense.bias', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.2.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.8.output.dense.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.4.output.dense.bias', 

All model checkpoint weights were used when initializing SynsetClassificationModel.

All the weights of SynsetClassificationModel were initialized from the model checkpoint at ./checkpoints-probing/roberta-probing+semcor/checkpoint-185900.
If your task is similar to the task the model of the checkpoint was trained on, you can already use SynsetClassificationModel for predictions without further training.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


TypeError: tensor([    0,   133,  1808,     9,   464,    12,  4506,   154,    16, 28178,
            7,     5,  2370,  2156,     8,  2156,   101,   144,  2370, 28178,
         2192,  2156, 45467, 37448,  4748,     7,     5,  1079,     9,     5,
          232,   479,     2,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1]) is not JSON serializable.

In [None]:
# Render the widget
widget.render()

In [13]:
ds = load_dataset('./dataset/roberta+senseval2.pickle')

In [17]:
ds[0]['attention_mask'].tolist()

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]