In [None]:
%pip install bertviz
%pip install jupyterlab
%pip install ipywidgets

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_tail(100)

In [None]:
import roberta_model_wc
from roberta_model_wc import MyRobertaModel


model = MyRobertaModel(state, preload_state="roberta-wc-model-30k-20cs-5e-5.bin")
# model = MyRobertaModel()
model.to(state.device)
print('Model loaded')

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells
from dataclasses import dataclass
import wandb
from roberta_model_wc import LearnSample
from roberta_model_wc import MyRobertaModel
import torch
from bertviz import model_view, head_view

@dataclass
class OneCell:
    score: float
    cell_id: str
    cell_type: str


def predict_order(state: State, nb, model: MyRobertaModel):
    code_cells = get_code_cells(nb)
    markdown_cells = get_markdown_cells(nb)

    code_texts = []
    for cell_id in code_cells:
        code_texts.append(nb.loc[cell_id]['source'])
                
    samples = list(map(lambda x: LearnSample(text=nb.loc[x]['source'], relative_position=0.0, code_texts=code_texts), markdown_cells))
    scores = model.predict_two_step(state, samples)

    num_code_cells = len(code_cells)
    cells = [OneCell(cell_id=cell_id, score=score*num_code_cells, cell_type='markdown') for (cell_id, score) in zip(markdown_cells, scores)] \
        + [OneCell(cell_id=cell_id, score=pos + 0.5, cell_type='code')
           for (pos, cell_id) in enumerate(code_cells)]
    cells.sort(key=lambda x: x.score)

    return list(map(lambda x: x.cell_id, cells))


def test(state: State):
    print('Start testing')
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()
    
    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]

        code_cells = get_code_cells(nb)
        markdown_cells = get_markdown_cells(nb)
        
        if len(code_cells) == 5:
            print('nb_id:', nb_id)
            print('cnt code:', len(code_cells))
            print('cnt markdown:', len(markdown_cells))
            display(nb)

def test2(state: State, nb_id, markdown_id):
    df = state.cur_train_nbs
    nb = df.loc[nb_id]

    code_cells = get_code_cells(nb)
    markdown_cells = get_markdown_cells(nb)
    
    display(nb)

    code_texts = []
    for cell_id in code_cells:
        code_texts.append(nb.loc[cell_id]['source'])
                
    sample = LearnSample(text=nb.loc[markdown_id]['source'], relative_position=0.0, code_texts=code_texts)
    scores = model.predict(state, [sample])

    encoded = model.encode_sample(sample, cnt_codes=10)
    tokens = model.tokenizer.convert_ids_to_tokens(encoded['input_ids'])
    print('tokens:', tokens)
    attention_mask = encoded['attention_mask']
    # display(encoded)

    res = model.roberta(input_ids=torch.LongTensor([encoded['input_ids']]), output_attentions=True)
    attention = res['attentions']

    return model_view(attention, tokens, include_layers=[4])



# test(state)
test2(state, nb_id='df29fbc69f9fd7', markdown_id='789c0cd3')

In [None]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view
utils.logging.set_verbosity_error()  # Suppress standard warnings

model_name = "microsoft/xtremedistil-l12-h384-uncased"  # Find popular HuggingFace models here: https://huggingface.co/models
input_text = "The cat sat on the mat"  
model = AutoModel.from_pretrained(model_name, output_attentions=True)  # Configure model to return attention values
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view