In [6]:
import datagen # prevent circular module import exception
from transformers import AutoTokenizer, AutoConfig, Trainer
from modelling.collator import BetterDataCollatorForWholeWordMask
from modelling.model import SynsetClassificationModel
from modelling.trainer import BetterTrainer
from lit_nlp.api.dataset import Dataset
from lit_nlp.api.model import Model

def load_model(path: str):
    tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True)

    model = SynsetClassificationModel.from_pretrained(
        model_name,
        config=AutoConfig.from_pretrained(path, local_files_only=True),
        local_files_only=True,
        model_name=base_model_name,
        num_classes=2584,
    ).to(device)
    
    trainer = BetterTrainer(
        model=model,
        data_collator=DataCollatorForLanguageModeling(tokenizer),
        args=TrainingArguments(
            remove_unused_columns=False,
            label_names=["labels", "sense-labels"]
        )
    )
    
    return trainer

class LIT_Dataset(Dataset):
    pass

class LIT_Model(Model):
    def __init__(self, path: str):
        self._model = load_model(path)
        
    def predict(self, inputs):
        return self._model.predict(inputs)

In [None]:
# Create the LIT widget with the model and dataset to analyze.
from lit_nlp import notebook
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.models import glue_models

datasets = {'sst_dev': glue.SST2Data('validation')}
models = {'sst_tiny': glue_models.SST2Model('./')}

widget = notebook.LitWidget(models, datasets, height=800)

In [None]:
# Render the widget
widget.render()