In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
from wandb_helper import init_wandb
import wandb_helper
import wandb
from state import State

config = config.get_default_config()
wandb_helper.login(config)
state = State(config)

In [None]:
# state.load_train_nbs_range(from_=0, to_=100000)
state.load_additional_data()

In [None]:
max_batch_size = 60
minibatch_size = 8
default_mul = 1000
end_token = 'END'

from torch.nn import CrossEntropyLoss
from transformers import get_linear_schedule_with_warmup
from dataclasses import dataclass
from torch.optim import AdamW
from tqdm import tqdm
from common import get_markdown_cells
import unixcoder
import random
import torch
import numpy as np
random.seed(787788)

@dataclass
class MiniBatch:
    markdowns:list
    code:list
    correct_idx:list # for each markdown store idx in code
    max_len_cache:int
        
    def append(self, cur_markdown, cur_code):
        self.markdowns.append(cur_markdown)
        if cur_code in self.code:
            self.correct_idx.append(self.code.index(cur_code))
        else:
            self.code.append(cur_code)
            self.correct_idx.append(len(self.code) - 1)
        
        
    def get_max_len(self):
        if self.max_len_cache == 0:
            texts_len = [len(t) for t in (self.markdowns + self.code)]
            self.max_len_cache = max(texts_len)
        return self.max_len_cache
    
    def cnt(self):
        return len(self.markdowns) + len(self.code)
    
@dataclass 
class Batch:
    mini:list
    sum_cnt:int
    
    def append(self, mini_batch):
        self.mini.append(mini_batch)
        self.sum_cnt += mini_batch.cnt()
        
    def get_all_tokens(self, model, state):
        all = []
        for mini in self.mini:
            all += mini.markdowns 
            all += mini.code
        return model.get_texts_tokens(all, state)
        
@dataclass
class Sample:
    markdown:str
    code:str
    


def gen_batches(all, state: State):
    df = state.cur_train_nbs
    df_orders = state.df_orders

    minibatches = []
    for id, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        correct_order = df_orders.loc[nb_id]
        correct_order.append(end_token)
        markdown_cell_ids = get_markdown_cells(nb)
        
        def get_code(cell_id):
            if cell_id == end_token:
                return end_token
            return nb.loc[cell_id]['source']
        
        samples = []
        for pos, cell_id in enumerate(correct_order):
            if cell_id in markdown_cell_ids:
                next_code_cell = None
                for next_cell in correct_order[pos:]:
                    if next_cell not in markdown_cell_ids:
                        next_code_cell = next_cell
                        break
                assert next_code_cell != None
                samples.append(Sample(markdown=nb.loc[cell_id]['source'], code=get_code(next_code_cell)))
        random.shuffle(samples)

        if len(samples) == 0:
            continue

        num_chunks = (len(samples) + minibatch_size - 1) // minibatch_size
        
        for batch_samples in np.array_split(samples, num_chunks):
            batch = MiniBatch(markdowns=[], code=[], correct_idx=[], max_len_cache=0)
            for sample in batch_samples:
                batch.append(sample.markdown, sample.code)
            minibatches.append(batch)
    print('Sorting minibatches')
    minibatches.sort(key=lambda x:x.get_max_len())
    print('Done sorting minibatches')
    
    batches = []
    for b in minibatches:
        if len(batches) == 0 or batches[-1].sum_cnt + b.cnt() > max_batch_size:
            batches.append(Batch(mini=[], sum_cnt=0))
        batches[-1].append(b) 
        
    random.shuffle(batches)        
    return batches

def train_on_batch(batch, model, optimizer, scheduler, state: State):
    tokens = batch.get_all_tokens(model, state)
    embeddings = model(tokens)
    
    markdown_vec = []
    code_vec = []
    expected_order = []
    
    shift = 0
    code_shift = 0
    
    for mini in batch.mini:
        markdown_vec += embeddings[shift:shift+len(mini.markdowns)]
        code_vec += embeddings[shift+len(mini.markdowns):shift+mini.cnt()]
        shift += mini.cnt()
        expected_order += [(x + code_shift) for x in mini.correct_idx]
        code_shift += len(mini.code)
        
    scores = torch.einsum("ab,cb->ac", torch.stack(markdown_vec), torch.stack(code_vec)) * default_mul

    expected_order = torch.tensor(expected_order).to(state.device)

    loss_fct = CrossEntropyLoss()
    loss = loss_fct(scores, expected_order)

    loss.backward() 
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step() 
    
    return loss.item()


def run_train_all_new(state: State):
    print('Start training')
    all = state.cur_train_nbs.index.get_level_values(0).unique()
    
    unixcoder_model = unixcoder.reload_model(state, "model-epoch1.5.bin")
    model = unixcoder.Model(unixcoder_model)
    model.zero_grad()
    model.train()

    print('Start generating batches...')
    batches = gen_batches(all, state)
    print('Generated batches:', len(batches))
    

    learning_rate = 3e-5
    epochs = 1
    steps = len(batches)

    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0.05 * (steps * epochs), num_training_steps = steps * epochs)

    
    
    init_wandb(name="unix-train-additional-data")
    w_loss = 0.0
    
    for id, batch in enumerate(tqdm(batches)):
        cur_loss = train_on_batch(batch, model, optimizer, scheduler, state)
        
        w_loss = w_loss * 0.95 + cur_loss * 0.05
        wandb.log({'loss': w_loss, 'learning_rate': scheduler.get_last_lr()[0]})
            
    wandb.finish()
    model.save("cur-final")
  

run_train_all_new(state)    