In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install transformers
%pip install wandb

In [None]:
import config
config = config.get_jarvis_config()

In [None]:
from wandb_helper import init_wandb
import wandb_helper
import wandb
wandb_helper.login(config)

In [None]:
from state import State

state = State(config)

In [None]:
state.load_train_nbs(10)

In [None]:
from dataclasses import dataclass
from common import get_markdown_cells, get_code_cells
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
import random
from tqdm import tqdm


@dataclass
class LearnSample:
    text:str
    relative_position:float

class MarkdownDataset(Dataset):
    def __init__(self, samples, max_len):
        super().__init__()
        self.samples = samples
        self.max_len = max_len
        # TODO: need lower case?
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) 

    def __getitem__(self, index):
        sample = self.samples[index]
        
        inputs = self.tokenizer.encode_plus(
            sample.text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            return_token_type_ids=True,
            truncation=True
        )
        ids = torch.LongTensor(inputs['input_ids'])
        mask = torch.LongTensor(inputs['attention_mask'])

        return ids, mask, torch.FloatTensor([sample.relative_position])

    def __len__(self):
        return len(self.samples)


def gen_learn_samples(state:State, seed=12345):
    random.seed(seed)
    print('Generating sample on train nbs.')
    samples = []
    df = state.cur_train_nbs
    nbs = df.index.get_level_values(0).unique()
    print('Total nbs:', len(nbs))

    for nb_id in tqdm(nbs):
        nb = df.loc[nb_id]
        correct_order = state.df_orders[nb_id]
        markdown_cells = get_markdown_cells(nb)
        for pos, cell_id in enumerate(correct_order):
            if cell_id in markdown_cells:
                relative_position = (pos + 0.5) / len(correct_order)
                samples.append(LearnSample(text=nb.loc[cell_id]['source'], relative_position=relative_position))

    random.shuffle(samples)

    print('samples len:', len(samples))
    return MarkdownDataset(samples, max_len=512)


dataset = gen_learn_samples(state)
display(dataset)

In [None]:

BS = 32
NW = 7

train_loader = DataLoader(dataset, batch_size=BS, shuffle=True, num_workers=NW,
                          pin_memory=False, drop_last=True)

In [None]:
import torch
import torch.nn as nn
import numpy as np 
from transformers import RobertaTokenizer, RobertaModel
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

class MarkdownModel(nn.Module):
    def __init__(self):
        super(MarkdownModel, self).__init__()
        self.roberta = RobertaModel.from_pretrained('roberta-base')
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        self.top = nn.Linear(768, 1)
        
    def forward(self, input_ids, attention_mask):
        x = self.roberta(input_ids=input_ids, attention_mask=attention_mask)[0]
        x = self.top(x[:, 0, :])
        return x


def train(state, model, train_loader):
    print('start training...')
    np.random.seed(123)
    learning_rate = 3e-5
    optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = len(train_loader))
    model.train()
    print('training... num batches:', len(train_loader))
    init_wandb(name='test-roberta-training')

    criterion = torch.nn.L1Loss()
    for data in tqdm(train_loader):
        ids, mask, target = data

        optimizer.zero_grad()
        pred = model(ids.to(state.device), mask.to(state.device))
        loss = criterion(pred, target.to(state.device))

        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step() 
        wandb.log({'roberta_loss':loss.item()})
    
    wandb.finish()



model = MarkdownModel()
model.to(state.device)

train(state, model, train_loader)
