In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb
%pip install langdetect
%pip install sentencepiece
%pip install easynmt

In [None]:
from easynmt import EasyNMT
model = EasyNMT('opus-mt')

#Translate a single sentence to German
print(model.translate('This is a sentence we want to translate to German', target_lang='de'))

#Translate several sentences to German
sentences = ['You can define a list with sentences.',
             'All sentences are translated to your target language.',
             'Note, you could also mix the languages of the sentences.']
print(model.translate(sentences, target_lang='de'))

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
#state.load_train_nbs_range(from_=104000, to_=105000)
#state.load_train_nbs_tail(1000)

In [None]:
import graph_model
from graph_model import MyGraphModel


graph3_model = MyGraphModel(state, preload_state="graph-model-epoch3.bin")
graph3_model.to(state.device)
print('Graph3 model loaded')

In [None]:
import unixcoder

unixcoder_model = unixcoder.reload_model(state, "model-epoch1.5.bin")
print('Unixcoder model loaded')

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells
from dataclasses import dataclass
import wandb
from graph_model import Sample
import numpy as np
import torch
import ensembles
from ensembles import gen_samples
from common import split_into_batches, OneCell

from cosine_train import end_token
from common import sim
import math

def get_probs_by_embeddings(embeddings, m_cell_id, code_cell_ids, coef_mul):
    markdown_emb = embeddings[m_cell_id]
    sims = [sim(markdown_emb, embeddings[c]) for c in code_cell_ids]
    max_sim = max(sims)
    sims_probs = list(map(lambda x:math.exp((x-max_sim) * coef_mul), sims))
    sum_probs = sum(sims_probs)
    sims_probs = list(map(lambda x:x/sum_probs, sims_probs))
    return sims_probs    


@torch.no_grad()
def predict_order(state: State, nb, graph3_model: MyGraphModel, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight):
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    
    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')
    
    cells = []
    for pos, cell_id in enumerate(get_code_cells(nb)):
        cells.append(OneCell(score=pos+0.5, cell_id=cell_id, cell_type="code"))

    markdown_cells = get_markdown_cells(nb)

    for cell_id in markdown_cells:            
        coef = graph_weight

        graph_sims_probs = get_probs_by_embeddings(graph3_embeddings, cell_id, code_cell_ids, graph3_model.coef_mul)
        unix_sims_probs = get_probs_by_embeddings(unix_embeddings, cell_id, code_cell_ids, 1000.0)
        sims_probs = [a*coef + b*(1 - coef) for (a, b) in zip(graph_sims_probs, unix_sims_probs)]
        scores = [0.0] * len(sims_probs)
        for i in range(len(sims_probs)):
            for j in range(len(sims_probs)):
                scores[j] += abs(i - j) * sims_probs[i]
        best_pos = scores.index(min(scores))


        cells.append(OneCell(score=best_pos, cell_id=cell_id, cell_type="markdown"))

    cells.sort(key=lambda x:x.score)
    return list(map(lambda c:c.cell_id, cells))


@torch.no_grad()
def test(state: State, graph3_model: MyGraphModel, unixcoder_model):
    graph3_model.eval()
    unixcoder_model.eval()
    print('Start testing model:')
    
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()

    wandb_data = []
    
    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        graph3_embeddings = graph_model.get_nb_embeddings(state, graph3_model, nb)
        unix_embeddings = unixcoder.get_nb_embeddings(state, unixcoder_model, nb)
    
        my_orders = predict_order(state, nb, graph3_model, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight=0.5)
        score = metric.calc_nb_score(
            my_order=my_orders, correct_order=state.df_orders.loc[nb_id])
        if score.cur_score < 0.8:
            wandb_data.append([nb_id, score.cur_score, len(state.df_orders.loc[nb_id]), len(get_markdown_cells(nb)), len(get_code_cells(nb)), ",".join(my_orders)])
            
    init_wandb(name='save-bad-nbs')
    my_table = wandb.Table(columns=["nb_id", "score", "cnt cells", "cnt markdowns", "cnt code", "order"], data=wandb_data)
    wandb.log({"bad runs": my_table})
    wandb.finish()

#test(state, graph3_model, unixcoder_model)



In [None]:
state.easymnt = EasyNMT('opus-mt')

In [None]:
import time
    
@torch.no_grad()
def test(state: State, graph3_model: MyGraphModel, unixcoder_model, nb_id):
    graph3_model.eval()
    unixcoder_model.eval()
    print('Start testing model:')
    
    nb = state.load_one_nb(nb_id)
   

    start = time.time()
    graph3_embeddings = graph_model.get_nb_embeddings(state, graph3_model, nb)
    end = time.time()
    print('time:', end - start)
    unix_embeddings = unixcoder.get_nb_embeddings(state, unixcoder_model, nb)

    my_orders = predict_order(state, nb, graph3_model, unixcoder_model, graph3_embeddings, unix_embeddings, graph_weight=1.0)
    score = metric.calc_nb_score(
        my_order=my_orders, correct_order=state.df_orders.loc[nb_id])
    print(score)
    
#state.config.clean_html = False

test(state, graph3_model, unixcoder_model, nb_id='d1ef1808a5d61a')

In [None]:
from language import detect_nb_lang
nb_ids = ['3a1da227ab17a0', '961b63b04a9637', 'dbfe8b7e5b2226', 'da8f893b2ec723', 
    'd1ef1808a5d61a', '2a829655be3fab', '241f611e38f2b2', '6207afb54d4159', '97db9116552332', '2fa216ddc5c8cd', '5ca09e62dd4dd0']
# nb_ids = ['97db9116552332']
for nb_id in nb_ids:
    nb = state.load_one_nb(nb_id)
    lang = detect_nb_lang(nb)
    print('nb_id:', nb_id, 'lang:', lang)

In [None]:
from language import Translator, get_translator

marian_ru_en = get_translator('ru', 'en')
marian_ru_en.translate(
    ['что слишком сознавать — это болезнь, настоящая, полная болезнь.'])