In [None]:
import torch
import datasets
import datagen # prevent circular module import exception
from transformers import AutoTokenizer, AutoConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from modelling.collator import BetterDataCollatorForWholeWordMask
from modelling.model import SynsetClassificationModel
from modelling.trainer import BetterTrainer
from lit_nlp.api import types as lit_types
from lit_nlp.api.dataset import Dataset
from lit_nlp.api.model import Model
from pathlib import Path
import numpy as np

In [None]:
def load_dataset(path: str):
    ds = datagen.SemCorDataSet.unpickle(Path(path).with_suffix(".pickle"))
    hf_ds = datasets.Dataset.load_from_disk(Path(path).with_suffix(".hf"))
    
    hf_ds = hf_ds.add_column("sense-labels", hf_ds["labels"])
    relevant_columns = [
        column
        for column in hf_ds.column_names
        if column not in ds.sentence_level.columns
    ]
    relevant_columns.append("sense-labels")
    hf_ds.set_format(type="torch", columns=relevant_columns)
    
    return hf_ds

In [None]:
def load_model(path: str):
    if torch.cuda.is_available():
        device = "cuda:0"
        print(f"CUDA found; running on {device}")
    else:
        device = "cpu"
        print(f"CUDA not found; running on {device}")
    
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)
    
    model = SynsetClassificationModel.from_pretrained(
        path,
        config=AutoConfig.from_pretrained(path, local_files_only=True),
        local_files_only=True,
        model_name=path,
        num_classes=2584,
    ).to(device)
    
    trainer = BetterTrainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=DataCollatorForLanguageModeling(tokenizer),
        args=TrainingArguments(
            output_dir='./saliency',
            remove_unused_columns=False,
            label_names=["labels", "sense-labels"]
        )
    )
    
    return trainer

Refer to https://pair-code.github.io/lit/setup/ for implementation details

In [None]:
class LIT_Dataset(Dataset):
    def __init__(self, path: str):
        ds = load_dataset(path)
        
        self._examples = [{
            #'sentence_idx': row['sentence_idx'],
            #'sentence': row['sentence'],
            'input_ids': row['input_ids'].tolist(),
            'attention_mask': row['attention_mask'].tolist(),
            'labels': row['labels'].tolist(),
            'sense-labels': row['sense-labels'].tolist(),
        } for row in ds]
        
        # adding sentence will crash (can't create tensor, try padding, whatever)
        #for i in range(len(self._examples)):
            #self._examples[i]['sentence'] = ds['sentence'][i]
        
    # probably wrong types?!
    def spec(self):
        return {
            #'sentence_idx': lit_types.Scalar(),
            #'sentence': lit_types.TextSegment(),
            'input_ids': lit_types.Embeddings(),
            'attention_mask': lit_types.Embeddings(),
            'labels': lit_types.Embeddings(),
            'sense-labels': lit_types.Embeddings(),
        }

class LIT_Model(Model):
    def __init__(self, path: str):
        self._model = load_model(path)
        
    def input_spec(self):
        return {
            'input_ids': lit_types.Embeddings(),
            'attention_mask': lit_types.Embeddings(),
            'labels': lit_types.Embeddings(),
            #'sense-labels': lit_types.Embeddings(),
            #'sentence': lit_types.TextSegment(),
        }
    
    def output_spec(self):
        return {
            'prediction': lit_types.CategoryLabel(),
        }
        
    def predict(self, inputs):
        inputs = list(inputs)
        
        if len(inputs) == 0:
            return []
        
        pred = self._model.predict(inputs)
        logits = pred.predictions
        label_ids = pred.label_ids
        
        output = []
        
        n = logits.shape[0] // len(inputs)
        for i in range(len(inputs)):
            masked_labels = label_ids[0][i]
            sense_labels = label_ids[1][i]
            sense_labels[masked_labels == -100] = -100
            lossable = (sense_labels != -100)
            probas = logits[n*i:n*(i+1)]
            prediction = np.argmax(probas, axis=-1)[lossable.flatten()]
            if len(prediction) == 0:
                prediction = [1]
            token = self._model.tokenizer.decode(prediction[0])
            output.append({'prediction': token})

        return output
    
    def predict_minibatch(self, inputs):
        return self._model.predict(inputs)

In [None]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook

#datasets = {'sst_dev': glue.SST2Data('validation')}
lit_models = {'wsd': LIT_Model('./checkpoints-probing/roberta-probing+semcor/checkpoint-185900')}
lit_datasets = {'wsd': LIT_Dataset('./dataset/roberta+senseval2.pickle')}

widget = notebook.LitWidget(lit_models, lit_datasets, height=800)

In [None]:
# Render the widget
widget.render()

In [None]:
ds = load_dataset('./dataset/roberta+senseval2.pickle')

In [None]:
ds