In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_tail(1000)

In [None]:
import graph_model
from graph_model import MyGraphModel


model = MyGraphModel(state, preload_state="graph-model-15k.bin")
model.to(state.device)
print('Model loaded')

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells
from dataclasses import dataclass
import wandb
from graph_model import Sample
import numpy as np

@dataclass
class OneCell:
    score: float
    cell_id: str
    cell_type: str
    
MULT = 5.0
    
# TODO: this doesn't really work. Check what is better option?
def predict_pos_by_scores2(pred):
    softmax = np.exp((pred-np.mean(pred)) * MULT)/np.sum(np.exp((pred-np.mean(pred)) *MULT))
    res = 0.0
    for i in range(len(pred)):
        res += i * softmax[i]
    return res
    

def predict_pos_by_scores(scores):
    best_score = -1.0
    best_pos = 0.0
    for i in range(len(scores)):
        if scores[i] > best_score:
            best_score = scores[i]
            best_pos = i
    return best_pos

def predict_order(state: State, nb, model: MyGraphModel):
    code_cells = get_code_cells(nb)
    markdown_cells = get_markdown_cells(nb)

    code_texts = [nb.loc[cell_id]['source'] for cell_id in code_cells] + ['END']
    samples = []
    for markdown_cell in markdown_cells:
        for code in code_texts:
            samples.append(Sample(code=code, markdown=nb.loc[markdown_cell]['source']))

    # print('num code cells:', len(code_texts), ', num markdowns:', len(markdown_cells), ', samples:', len(samples))
    scores = model.predict(state, samples)
    positions = []
    mul = len(code_texts)
    for i in range(len(markdown_cells)):
        positions.append(predict_pos_by_scores(scores[mul * i:mul * (i + 1)]))

    num_code_cells = len(code_cells)
    cells = [OneCell(cell_id=cell_id, score=p, cell_type='markdown') for (cell_id, p) in zip(markdown_cells, positions)] \
        + [OneCell(cell_id=cell_id, score=pos + 0.5, cell_type='code')
           for (pos, cell_id) in enumerate(code_cells)]
    cells.sort(key=lambda x: x.score)

    return list(map(lambda x: x.cell_id, cells))


def test(state: State, model: MyGraphModel, save_to_wandb=True):
    print('Start testing model:', model.name)
    if save_to_wandb:
        init_wandb(name='test-graph-model-'+model.name)
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()
    sum_scores = Score(0, 0)

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        my_order = predict_order(state, nb, model)
        score = metric.calc_nb_score(
            my_order=my_order, correct_order=state.df_orders.loc[nb_id])

        sum_scores = Score.merge(sum_scores, score)

        if save_to_wandb:
            wandb.log({'my': sum_scores.cur_score})
            
        if cnt >= 10:
            break
    print(sum_scores)
            
    if save_to_wandb:
        wandb.finish()


test(state, model, save_to_wandb=False)
