In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
state.load_train_nbs_range(from_=100000, to_=101000)
# state.load_train_nbs_tail(1000)

In [None]:
import graph_model
from graph_model import MyGraphModel


graph3_model = MyGraphModel(state, preload_state="graph3-model-epoch1.bin")
graph3_model.to(state.device)
print('Model loaded')

In [None]:
import unixcoder

unixcoder_model = unixcoder.reload_model(state, "model-epoch1.5.bin")
print('Unixcoder model loaded')

In [None]:
from re import S
from cosine_train import end_token
import numpy as np
from common import sim, get_probs_by_embeddings, get_best_pos_by_probs
from dataclasses import dataclass
import math

from ensembles import gen_samples

from tqdm import tqdm
import metric
from metric import Score
from common import get_code_cells, get_markdown_cells, split_into_batches
from dataclasses import dataclass
import wandb
import numpy as np
import torch
import random


def gen_all_samples(state: State, graph3_model: MyGraphModel, unixcoder_model):
    graph3_model.eval()
    unixcoder_model.eval()
    print('Start generating sample points')
    df = state.cur_train_nbs
    all = df.index.get_level_values(0).unique()

    samples = []

    for cnt, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        samples += gen_samples(state, nb, graph3_model, unixcoder_model, state.df_orders.loc[nb_id])
        
    return samples

samples = gen_all_samples(state, graph3_model, unixcoder_model)


In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import ensembles
import simple_ensemble

def train(state: State, samples, ensemble_model, save_to_wandb=True):
    ensemble_model.zero_grad()
    ensemble_model.train()
    random.seed(787788)
    print('Start training ensemble model:', graph3_model.name)
    if save_to_wandb:
        init_wandb(name='train-ensemble')

    random.shuffle(samples)
    batches = split_into_batches(samples, state.config.batch_size)

    optimizer = AdamW(ensemble_model.parameters(), lr=3e-5, eps=1e-8)
    scheduler = CosineAnnealingLR(optimizer, T_max=len(batches))

    criterion = torch.nn.L1Loss()
    for samples in tqdm(batches):
        pred = ensembles.predict(state, ensemble_model, samples)['preds']

        target = list(map(lambda s: s.target_pos, samples))
        target = torch.tensor(target).to(state.device)

        loss = criterion(pred, target)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(ensemble_model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        if save_to_wandb:
            wandb.log({'ensemble-loss': loss.item()})

    if save_to_wandb:
        wandb.finish()


ensemble_model = simple_ensemble.SimpleEnsembleModel(state) if state.config.use_simple_ensemble_model else unixcoder.EnsembleModel(state)
# train(state, samples, ensemble_model, save_to_wandb=True)
print('Done')


In [None]:
import ensembles

@torch.no_grad()
def calc_ensemble_dataset_score(state: State, samples, ensemble_model, save_to_wandb=True):
    ensemble_model.eval()
    if save_to_wandb:
        init_wandb(name='test-ensemble-dataset')

    batches = split_into_batches(samples, state.config.batch_size)

    sum_losses = 0.0
    
    criterion = torch.nn.L1Loss()
    for it, samples in enumerate(tqdm(batches)):
        pred = ensembles.predict(state, ensemble_model, samples)

        target = list(map(lambda s: s.target_pos, samples))
        target = torch.tensor(target).to(state.device)

        loss = criterion(pred['preds'], target)
        sum_losses += loss.item()

        if save_to_wandb:
            wandb.log({'test-ensemble-loss': sum_losses / (1 + it)})

    if save_to_wandb:
        wandb.finish()    
    print('sum losses:', sum_losses)
    print('av loss:', sum_losses / len(batches))
        
# calc_ensemble_dataset_score(state, samples, ensemble_model, save_to_wandb=True)

In [None]:
train(state, samples, ensemble_model, save_to_wandb=True)

In [None]:
calc_ensemble_dataset_score(state, samples, ensemble_model, save_to_wandb=True)

In [None]:
ensemble_model.save('1k')

In [None]:
ensemble_model = unixcoder.EnsembleModel(state, state_dict='ensemble-model-test.bin')

In [None]:
calc_ensemble_dataset_score(state, samples, ensemble_model, save_to_wandb=False)

In [None]:
for x in range(10):
    train(state, samples, ensemble_model, save_to_wandb=False)
    calc_ensemble_dataset_score(state, samples, ensemble_model, save_to_wandb=False)