In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_range(from_=15000, to_=20000)

In [None]:
from dataclasses import dataclass
from common import get_markdown_cells, get_code_cells, split_into_batches
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
import random
from tqdm import tqdm
import numpy as np
from state import State

from graph_model import Sample, SampleWithLabel

negative_sample_prob = 0.5

END_TOKEN = 'END'

def gen_learn_samples(state: State, seed=12345):
    random.seed(seed)
    print('Generating sample on train nbs.')
    samples = []
    df = state.cur_train_nbs
    nbs = df.index.get_level_values(0).unique()
    print('Total nbs:', len(nbs))

    for nb_id in tqdm(nbs):
        nb = df.loc[nb_id]
        correct_order = state.df_orders[nb_id]

        code_cells = get_code_cells(nb).tolist()
        code_cells.append(END_TOKEN)

        def conv(cell_id):
            if cell_id == END_TOKEN:
                return END_TOKEN
            else:
                return nb.loc[cell_id]['source']

        last_code_cell_text = END_TOKEN
        for cell_id in reversed(correct_order):
            if cell_id in code_cells:
                last_code_cell_text = nb.loc[cell_id]['source']
            else:
                sample_code_cell = last_code_cell_text if random.uniform(0, 1) > negative_sample_prob else conv(random.choice(code_cells))
                label = 1 if sample_code_cell == last_code_cell_text else 0
                markdown = nb.loc[cell_id]['source']
                samples.append(SampleWithLabel(sample=Sample(markdown=markdown, code=sample_code_cell), label=label))


    samples.sort(key=lambda x: len(x.sample.markdown) + len(x.sample.code))
    result = split_into_batches(samples, state.config.batch_size)
    random.shuffle(result)
    return result


dataset = gen_learn_samples(state)
print('Dataset created! Len:', len(dataset))


In [None]:
import graph_model
from graph_model import MyGraphModel


model = MyGraphModel(state, preload_state="graph-model-cur-final.bin")
# model = MyGraphModel(state)
model.to(state.device)


graph_model.train(state, model, dataset, save_to_wandb=True, optimizer_state="graph-model-cur-final.opt.bin")
# graph_model.train(state, model, dataset, save_to_wandb=True)
print('Model created')

In [None]:
# model.save("60k-bs")