In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_tail(50)

In [None]:
import graph_model
from graph_model import MyGraphModel


model = MyGraphModel(state, preload_state="graph3-model-epoch1.bin")
model.to(state.device)
model.eval()
print('Model loaded')

In [None]:
print(model.coef_mul)
print(model.next_code_cells)

In [None]:
from cosine_train import end_token
import numpy as np
from common import sim
from dataclasses import dataclass
import math



def find_best_cell_order3_with_emb(nb, embeddings, coef_mul):    
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    markdown_cells = nb[nb['cell_type'] != 'code'].reset_index(level='cell_id')
        
    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')
    
    order = code_cell_ids.copy()    
    
    for m_cell_id in markdown_cells['cell_id'].values:
        markdown_emb = embeddings[m_cell_id]
        sims = [sim(markdown_emb, embeddings[c]) for c in code_cell_ids]
        max_sim = max(sims)
        sims_probs = list(map(lambda x:math.exp((x-max_sim) * coef_mul), sims))
        sum_probs = sum(sims_probs)
        sims_probs = list(map(lambda x:x/sum_probs, sims_probs))
        scores = [0.0] * len(sims_probs)
        for i in range(len(sims)):
            for j in range(len(sims)):
                scores[j] += abs(i - j) * sims_probs[i]
        best_pos = scores.index(min(scores))
        
        order.insert(order.index(code_cell_ids[best_pos]), m_cell_id)
    
    assert order[-1] == 'END'
    order.pop()
    return order

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells
from dataclasses import dataclass
import wandb
from graph_model import Sample
import numpy as np
from cosine_train import get_nb_embeddings

def predict_order(state: State, nb, model: MyGraphModel, use_sigmoid):
    embeddings = get_nb_embeddings(state, model, nb)
    return find_best_cell_order3_with_emb(nb, embeddings, coef_mul=model.coef_mul)

def test(state: State, model: MyGraphModel, use_sigmoid, save_to_wandb=True):
    print('Start testing model:', model.name)
    if save_to_wandb:
        init_wandb(name='test-graph3-model-'+model.name)
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()
    sum_scores = Score(0, 0)

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        my_order = predict_order(state, nb, model, use_sigmoid=use_sigmoid)
        score = metric.calc_nb_score(
            my_order=my_order, correct_order=state.df_orders.loc[nb_id])

        sum_scores = Score.merge(sum_scores, score)

        if save_to_wandb:
            wandb.log({'my': sum_scores.cur_score})

    print(sum_scores)
            
    if save_to_wandb:
        wandb.finish()


test(state, model, use_sigmoid=False, save_to_wandb=True)
