In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_range(from_=100000, to_=105000)
# state.load_train_nbs_tail(1000)

In [None]:
import graph_model
from graph_model import MyGraphModel


graph3_model = MyGraphModel(state, preload_state="graph3-model-epoch1.bin")
graph3_model.to(state.device)
print('Model loaded')

In [None]:
import unixcoder

unixcoder_model = unixcoder.reload_model(state, "model-epoch1.5.bin")
print('Unixcoder model loaded')

In [None]:
from re import S
from cosine_train import end_token
import numpy as np
from common import sim, get_probs_by_embeddings, get_best_pos_by_probs
from dataclasses import dataclass
import math

@dataclass
class Sample:
    text:str
    graph3_pos:float
    unix_pos:float
    total_cells:int
    md_cells:int
    code_cells:int
    part_code_cells:float
    target_pos:float



def gen_nb_samples(nb, graph3_embeddings, unix_embeddings, correct_order):
    code_cells = nb[nb['cell_type'] == 'code'].reset_index(level='cell_id')
    markdown_cells = nb[nb['cell_type'] != 'code'].reset_index(level='cell_id')
        
    code_cell_ids = code_cells['cell_id'].values.tolist()
    code_cell_ids.append('END')

    samples = []

    md_cells = len(markdown_cells)
    code_cells = len(code_cells)
    total_cells = md_cells + code_cells
    part_code_cells = code_cells / total_cells
    
    for m_cell_id in markdown_cells['cell_id'].values:
        text = nb[m_cell_id]['source']
        graph_sims_probs = get_probs_by_embeddings(graph3_embeddings, m_cell_id, code_cell_ids, 25.0)
        unix_sims_probs = get_probs_by_embeddings(unix_embeddings, m_cell_id, code_cell_ids, 1000.0)
        
        graph3_pos = get_best_pos_by_probs(graph_sims_probs)
        unix_pos = get_best_pos_by_probs(graph_sims_probs)

        target_pos = correct_order.index(m_cell_id) / len(correct_order) * len(code_cells)
        samples.append(Sample(text=text, graph3_pos=graph3_pos, unix_pos=unix_pos, total_cells=total_cells, md_cells=md_cells, code_cells=code_cells, part_code_cells=part_code_cells, target_pos=target_pos))
        
    return samples

In [None]:
from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells, split_into_batches
from dataclasses import dataclass
import wandb
from graph_model import Sample
import numpy as np
import torch
import random

@torch.no_grad()
def gen_samples(state: State, nb, graph3_model: MyGraphModel, unixcoder_model, correct_order):
    graph3_embeddings = graph_model.get_nb_embeddings(state, graph3_model, nb)
    unix_embeddings = unixcoder.get_nb_embeddings(state, unixcoder_model, nb)    
    return gen_nb_samples(nb, graph3_embeddings, unix_embeddings, correct_order)

def gen_all_samples(state: State, graph3_model: MyGraphModel, unixcoder_model):
    graph3_model.eval()
    unixcoder_model.eval()
    print('Start generating sample points')
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()

    samples = []

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        samples += gen_nb_samples(state, nb, graph3_model, unixcoder_model, correct_order=state.df_orders.loc[nb_id])
        
    return samples




# TODO: do I need this?
# state.config.batch_size = 30
samples = gen_all_samples(state, graph3_model, unixcoder_model)




In [None]:

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

def train(state: State, samples, ensemble_model, save_to_wandb=True):
    ensemble_model.zero_grad()
    ensemble_model.train()
    random.seed(787788)
    print('Start training ensemble model:', graph3_model.name)
    if save_to_wandb:
        init_wandb(name='train-ensemble')

    random.shuffle(samples)
    batches = split_into_batches(samples, state.config.batch_size)

    optimizer = AdamW(ensemble_model.parameters(), lr=3e-5, eps=1e-8)
    scheduler = CosineAnnealingLR(optimizer, T_max=len(batches))

    criterion = torch.nn.L1Loss()
    for samples in tqdm(batches):
        texts = list(map(lambda s: s.text, samples))
        additional_features = list(map(lambda s: torch.FloatTensor(
            [s.graph3_pos, s.unix_pos, s.total_cells, s.md_cells, s.code_cells, s.part_code_cells]), samples))
        to_mul = list(map(lambda s: torch.FloatTensor([s.graph3_pos, s.unix_pos]), samples))
        pred = ensemble_model(texts, additional_features)
        pred = torch.einsum("ab,ab->a", pred, to_mul)

        target = list(map(lambda s: s.target_pos, samples))
        target = torch.tensor(target).to(state.device)

        loss = criterion(pred, target)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(ensemble_model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        if save_to_wandb:
            wandb.log({'ensemble-loss': loss.item()})

    if save_to_wandb:
        wandb.finish()


ensemble_model = unixcoder.EnsembleModel(
    unixcoder.reload_model(state, state_dict=None))
train(state, samples, save_to_wandb=True)
