In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_range(from_=100000, to_=105000)
# state.load_train_nbs_tail(1000)

In [None]:
import graph_model
from graph_model import MyGraphModel


graph3_model = MyGraphModel(state, preload_state="graph-model-100k-ncs=2.bin")
graph3_model.to(state.device)
print('Graph3 model loaded')

In [None]:
import unixcoder

unixcoder_model = unixcoder.reload_model(state, "model-epoch1.5.bin")
print('Unixcoder model loaded')

In [None]:
import unixcoder
import simple_ensemble

if state.config.use_simple_ensemble_model:
    ensemble_model = simple_ensemble.SimpleEnsembleModel(state, state_dict='simple-ensemble-model-1k-strange.bin')
else:
    ensemble_model = unixcoder.EnsembleModel(state, state_dict='ensemble-model-1k.bin')
print('Ensemble model loaded')

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells
from dataclasses import dataclass
import wandb
from graph_model import Sample
import numpy as np
import torch
import ensembles
from ensembles import gen_samples
from common import split_into_batches, OneCell

from cosine_train import end_token
from common import sim
import math

def get_probs_by_embeddings(embeddings, m_cell_id, code_cell_ids, coef_mul):
    markdown_emb = embeddings[m_cell_id]
    sims = [sim(markdown_emb, embeddings[c]) for c in code_cell_ids]
    max_sim = max(sims)
    sims_probs = list(map(lambda x:math.exp((x-max_sim) * coef_mul), sims))
    sum_probs = sum(sims_probs)
    sims_probs = list(map(lambda x:x/sum_probs, sims_probs))
    return sims_probs    


@torch.no_grad()
def predict_order(state: State, nb, graph3_model: MyGraphModel, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight):
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    
    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')
    
    cells = []
    for pos, cell_id in enumerate(get_code_cells(nb)):
        cells.append(OneCell(score=pos+0.5, cell_id=cell_id, cell_type="code"))


    samples = gen_samples(state, nb, graph3_model, unixcoder_model, correct_order=None)
    batches = split_into_batches(samples, state.config.batch_size)

    for samples in batches:
        for i in range(len(samples)):
            cell_id = samples[i].md_cell_id
            
            coef = graph_weight
            
            graph_sims_probs = get_probs_by_embeddings(graph3_embeddings, cell_id, code_cell_ids, graph3_model.coef_mul)
            unix_sims_probs = get_probs_by_embeddings(unix_embeddings, cell_id, code_cell_ids, 1000.0)
            sims_probs = [a*coef + b*(1 - coef) for (a, b) in zip(graph_sims_probs, unix_sims_probs)]
            scores = [0.0] * len(sims_probs)
            for i in range(len(sims_probs)):
                for j in range(len(sims_probs)):
                    scores[j] += abs(i - j) * sims_probs[i]
            best_pos = scores.index(min(scores))
            
            
            cells.append(OneCell(score=best_pos, cell_id=cell_id, cell_type="markdown"))

    cells.sort(key=lambda x:x.score)
    return list(map(lambda c:c.cell_id, cells))

@torch.no_grad()
def test(state: State, graph3_model: MyGraphModel, unixcoder_model, ensemble_model, save_to_wandb=True):
    graph3_model.eval()
    unixcoder_model.eval()
    ensemble_model.eval()
    print('Start testing model:', ensemble_model.name)
    if save_to_wandb:
        init_wandb(name='test-ensemble-graph3-unix')

    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()

    CNT_COEFS = 20
    def get_coef(id):
        return id / (CNT_COEFS)

    sum_scores = {}
    
    for id in range(CNT_COEFS + 1):
        sum_scores[id] = Score(0, 0)

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        graph3_embeddings = graph_model.get_nb_embeddings(state, graph3_model, nb)
        unix_embeddings = unixcoder.get_nb_embeddings(state, unixcoder_model, nb)
    
        for id in range(CNT_COEFS+1):
            my_orders = predict_order(state, nb, graph3_model, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight=get_coef(id))
        to_log = {}
        
        score = metric.calc_nb_score(
            my_order=my_orders, correct_order=state.df_orders.loc[nb_id])
        sum_scores = Score.merge(sum_scores, score)
        to_log['my'] = sum_scores[10].cur_score
        if save_to_wandb:
            wandb.log(to_log) 
            
        #if cnt >= 100:
        #    break

    for id in range(CNT_COEFS+1):
        print('score of ', get_coef(id), ':', sum_scores[id].cur_score)
            
    if save_to_wandb:
        wandb.finish()


test(state, graph3_model, unixcoder_model, ensemble_model, save_to_wandb=True)

