In [None]:
import config
config = config.get_local_config()

In [None]:
!nvidia-smi

In [None]:
!pip install transformers

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import unixcoder

In [None]:
from common import reload_model

reload_model(config)

In [None]:
from state import State

state = State()
state.load_df_orders(config)
state.load_test_nbs(config)

In [None]:
!pip install wandb

In [None]:
from wandb_helper import init_wandb
import wandb_helper
wandb_helper.login(config)

In [None]:
max_batch_size = 60
minibatch_size = 8
default_mul = 1000
end_token = 'END'

from torch.nn import CrossEntropyLoss
from transformers import get_linear_schedule_with_warmup
from dataclasses import dataclass
from torch.optim import AdamW
import random
random.seed(787788)

@dataclass
class MiniBatch:
    markdowns:list
    code:list
    correct_idx:list # for each markdown store idx in code
    max_len_cache:int
        
    def append(self, cur_markdown, cur_code):
        self.markdowns.append(cur_markdown)
        if cur_code in self.code:
            self.correct_idx.append(self.code.index(cur_code))
        else:
            self.code.append(cur_code)
            self.correct_idx.append(len(self.code) - 1)
        
        
    def get_max_len(self):
        if self.max_len_cache == 0:
            texts_len = [len(t) for t in (self.markdowns + self.code)]
            self.max_len_cache = max(texts_len)
        return self.max_len_cache
    
    def cnt(self):
        return len(self.markdowns) + len(self.code)
    
@dataclass 
class Batch:
    mini:list
    sum_cnt:int
    
    def append(self, mini_batch):
        self.mini.append(mini_batch)
        self.sum_cnt += mini_batch.cnt()
        
    def get_all_tokens(self):
        all = []
        for mini in self.mini:
            all += mini.markdowns 
            all += mini.code
        return get_texts_tokens(all)
        
@dataclass
class Sample:
    markdown:str
    code:str
    


def gen_batches(all):
    minibatches = []
    for id, nb_id in enumerate(tqdm(all)):
        nb = df.loc[nb_id]
        correct_order = df_orders.loc[nb_id]
        correct_order.append(end_token)
        markdown_cell_ids = get_markdown_cells(nb)
        
        def get_code(cell_id):
            if cell_id == end_token:
                return end_token
            return nb.loc[cell_id]['source']
        
        samples = []
        for pos, cell_id in enumerate(correct_order):
            if cell_id in markdown_cell_ids:
                next_code_cell = None
                for next_cell in correct_order[pos:]:
                    if next_cell not in markdown_cell_ids:
                        next_code_cell = next_cell
                        break
                assert next_code_cell != None
                samples.append(Sample(markdown=nb.loc[cell_id]['source'], code=get_code(next_code_cell)))
        random.shuffle(samples)
        num_chunks = (len(samples) + minibatch_size - 1) // minibatch_size
        
        for batch_samples in np.array_split(samples, num_chunks):
            batch = MiniBatch(markdowns=[], code=[], correct_idx=[], max_len_cache=0)
            for sample in batch_samples:
                batch.append(sample.markdown, sample.code)
            minibatches.append(batch)
    print('Sorting minibatches')
    minibatches.sort(key=lambda x:x.get_max_len())
    print('Done sorting minibatches')
    
    batches = []
    for b in minibatches:
        if len(batches) == 0 or batches[-1].sum_cnt + b.cnt() > max_batch_size:
            batches.append(Batch(mini=[], sum_cnt=0))
        batches[-1].append(b) 
        
    random.shuffle(batches)        
    return batches

def train_on_batch(batch, model, optimizer, scheduler):
    tokens = batch.get_all_tokens()
    embeddings = model(tokens)
    
    
    markdown_vec = []
    code_vec = []
    expected_order = []
    
    shift = 0
    code_shift = 0
    
    for mini in batch.mini:
        markdown_vec += embeddings[shift:shift+len(mini.markdowns)]
        code_vec += embeddings[shift+len(mini.markdowns):shift+mini.cnt()]
        shift += mini.cnt()
        expected_order += [(x + code_shift) for x in mini.correct_idx]
        code_shift += len(mini.code)
        
    scores = torch.einsum("ab,cb->ac", torch.stack(markdown_vec), torch.stack(code_vec)) * default_mul

    expected_order = torch.tensor(expected_order).to(device)

    loss_fct = CrossEntropyLoss()
    loss = loss_fct(scores, expected_order)

    loss.backward() 
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step() 
    
    return loss.item()


def run_train_all_new():
    print('Start training')
    all = df.index.get_level_values(0).unique()
    print('Start generating batches...')
    batches = gen_batches(all)
    print('Generated batches:', len(batches))
    
    reload_model(preload="model-100k-mul1000-all-bs60-8.bin")
    model = Model(unixcoder_model)
    model.zero_grad()
    model.train()

    

    learning_rate = 3e-5
    epochs = 1
    steps = len(batches)

    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = steps * epochs)

    
    start_time = time.time()
    last_saved_time = start_time
    save_every_s = 40 * 60
    max_run_s = 12 * 3600
    
    init_wandb(name=(str(len(all) / 1000) + "k,lr=3e-5,mul-" + str(default_mul)+",all-bs="+str(max_batch_size)+"-"+str(minibatch_size)))
    w_loss = 0.0
    
    for id, batch in enumerate(tqdm(batches)):
        cur_loss = train_on_batch(batch, model, optimizer, scheduler)
        
        w_loss = w_loss * 0.95 + cur_loss * 0.05
        wandb.log({'loss': w_loss})
        
        cur_time = time.time()
        if cur_time - last_saved_time > save_every_s:
            last_saved_time = cur_time
            save_model(model, id)
        
        if cur_time - start_time > max_run_s:
            print('Finishing early because of timeout')
            break
            
    wandb.finish()
    save_model(model, "final")
  
    
run_train_all_new()

In [None]:
load_train_nbs(50)

In [None]:
def num_test_inputs():
    return len(test_df.index.get_level_values(0).unique())

# if num_test_inputs() != 4:
if not is_interactive_mode() or True:
    print('Going to generate model...')
    # 2000 - half an hour
    load_train_nbs(10_000)
    run_train_all_new()

# save_results()

In [None]:
!mv model-final.bin model-epoch2-10k.bin

In [None]:
!/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-epoch2-10k.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-54321.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-43403.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-37926.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-32433.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-26960.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-21506.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-16067.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-15334.bin
# !/home/gdrive upload -p 12wE_l-hW_ScKnP9l-cpWWRTtuC5fZD7c model-10691.bin

In [None]:
# !cp model-20312.bin drive/MyDrive/ai4-code/

In [None]:
!pip install -U sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer, util, InputExample, evaluation, losses
from torch.utils.data import DataLoader

In [None]:
from dataclasses import dataclass
from typing import List
import random

@dataclass
class Part:
  is_code: bool
  ids: List[str]

def split_parts(nb_id):
  parts = []
  correct_order = df_orders[nb_id]
  nb = df.loc[nb_id]
  i = 0
  while i != len(correct_order):
    j = i
    cur_cell_type = nb.loc[correct_order[i]]['cell_type']
    ids = []
    while j != len(correct_order) and nb.loc[correct_order[j]]['cell_type'] == cur_cell_type:
      ids.append(correct_order[j])
      j = j + 1
    parts.append(Part(cur_cell_type=='code', ids))
    i = j
  return parts

def only_code_parts(parts):
  return list(filter(lambda p: p.is_code, parts))

def only_markdown_parts(parts):
  return list(filter(lambda p: not p.is_code, parts))  

@dataclass
class MmDataset:
  same_group: List[List[str]]
  diff_group: List[List[str]]

def generate_mm_dataset():
  print('Generating markdown-markdown dataset')

  all = df.index.get_level_values(0).unique()

  same_group = []
  diff_group = []

  for nb_id in tqdm(all):
    # print('nb_id:', nb_id)
    nb = df.loc[nb_id]
    def get_text(id):
      return nb.loc[id]['source']
    parts = split_parts(nb_id)
    markdown_parts = only_markdown_parts(parts)
    for i in range(len(markdown_parts)):
      part = markdown_parts[i]
      if len(part.ids) > 1 and random.getrandbits(1):
        c1,c2 = random.sample(part.ids, 2)
        same_group.append([get_text(c1), get_text(c2)])
      else:
        j = random.randint(0, len(markdown_parts) - 1)
        if j != i:
          c1 = random.choice(part.ids)
          c2 = random.choice(markdown_parts[j].ids)
          diff_group.append([get_text(c1), get_text(c2)])

  return MmDataset(same_group, diff_group)
  


# dataloader = generate_mm_dataset()



In [None]:
def mm_dataloader():
  mm_dataset = generate_mm_dataset()
  train_examples = [InputExample(texts=texts, label=1.0) for texts in mm_dataset.same_group] + [InputExample(texts=texts, label=0.0) for texts in mm_dataset.diff_group]
  train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)
  return train_dataloader

In [None]:
# model = SentenceTransformer('nq-distilbert-base-v1') # TODO: change model name

In [None]:
from sklearn.metrics.pairwise import paired_cosine_distances
import seaborn as sns

def test():
  mm_dataset = generate_mm_dataset()

  first  = [x[0] for x in mm_dataset.same_group] + [x[0] for x in mm_dataset.diff_group]
  second = [x[1] for x in mm_dataset.same_group] + [x[1] for x in mm_dataset.diff_group]

  embeddings1 = model.encode(first, batch_size= 8, show_progress_bar=True, convert_to_numpy=True)
  embeddings2 = model.encode(second, batch_size= 8, show_progress_bar=True, convert_to_numpy=True)

  labels = [1]*len(mm_dataset.same_group) + [0]*len(mm_dataset.diff_group)

  cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))

  sz1 = len(mm_dataset.same_group)
  print(cosine_scores[:sz1].mean())
  print(cosine_scores[sz1:].mean())


  # sns.distplot(cosine_scores[:sz1], label = "1")
  # sns.distplot(cosine_scores[sz1:], label = "0")

  tmp_df = pd.DataFrame({"s1":first,"s2":second, "lbl":labels, "cos": cosine_scores}).assign(delta = lambda x: np.abs(x.lbl - x.cos)).sort_values("delta", ascending=True)

  # getting correct incorrect match
  scores = []
  for x in range(10,90,1):
      scores.append( ( x/100, tmp_df.assign(pred = lambda xx: xx.cos > x/100 )\
            .assign(correct = lambda xxx: xxx.pred == xxx.lbl).correct.mean() ) )
      
  best_score = sorted(scores, key = lambda x: x[1], reverse = True)[0]
  print("Best accuracy of {} using threshold {}".format(best_score[1], best_score[0])) 

  print("\nBad predictions - ")
  return tmp_df.head(50)
            


# test()
# print('End')

In [None]:
# train_loss = losses.CosineSimilarityLoss(model)

# train_dataloader = mm_dataloader()

# #Tune the model
# model.fit(train_objectives=[(train_dataloader, train_loss)], 
#           epochs=1, 
#           warmup_steps= 0 )

In [None]:
# test()

In [None]:
!pip install langdetect

In [None]:
from torch.nn import CrossEntropyLoss
from transformers import get_linear_schedule_with_warmup
from langdetect import detect

def detect_lang(x):
  try:
    return detect(x)
  except:
    return 'other'

def run_train_all():
    all = df.index.get_level_values(0).unique()
    texts = []
    tokenized = []
    langs = []
    for id, nb_id in enumerate(tqdm(all)):
      nb = get_nb_by_id(nb_id)
      markdown_cell_ids = get_markdown_cells(nb)
      cell_id = random.choice(markdown_cell_ids)
      t = nb.loc[cell_id]['source'].lower()
      texts.append(t)
      langs.append(detect_lang(t))
      tokenized.append(unixcoder_model.tokenizer.tokenize(t))

    return pd.DataFrame(data={'text':texts, 'lang' : langs, 'tokenized':tokenized})

 
# run_train_all()