In [None]:
from pathlib import Path

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

from src.dataloader import DataLoader
from src.model import EnokeeConfig, EnokeeEncoder
from src.tokenizer import LUKETokenizer
from src.optimizer import MultiOptimizer
from src.utils import (get_num_param_and_model_size, classification_metrics,
                       load_checkpoint, save_checkpoint)

torch.manual_seed(0)
torch.cuda.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16

In [None]:
def train(output_dir, dataloader, model, optimizer, epochs, 
          scheduler=None, logger=None, previous_state=(0, 0), 
          clip_val=5, save_every=1000, log_every=10, accum_iter=16):
    
    print("INFO: Using device {}".format(str(device)))
    print("INFO: Starting training, press CTRL+C to stop")
    
    # setup model
    model.to(device, dtype)
    model.train()
    get_num_param_and_model_size(model)
    
    # setup
    step, epoch = previous_state
    # softmax = torch.nn.functional.log_softmax
    # criterion = torch.nn.NLLLoss(ignore_index=-1)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, label_smoothing=0.15)
    tokenizer = LUKETokenizer()
    metrics = []
    
    # train loop
    while epoch < epochs:
        pbar = tqdm(dataloader, desc=f"[EPOCH {epoch+1}/{epochs}]")
        for sentences, spans, targets in pbar:
            # reshape targets
            targets = targets.flatten().to(device)
            
            # forward pass
            try:
                inputs = tokenizer(sentences, spans).to(device)
            except IndexError:
                print(f"ERROR: Index error in batch {step}, skipping batch")
                pass
            outputs = model(**inputs)
            outputs = outputs.view(-1, model.config.num_entities)
            
            # loss
            # outputs = softmax(outputs, dim=1)
            # loss = criterion(outputs, targets)
            loss = criterion(outputs, targets)
            loss.backward()
            
            # gradient accumulation and backward pass
            if (step + 1) % accum_iter == 0:
                clip_grad_norm_(model.parameters(), clip_val)
                optimizer.step()
                optimizer.zero_grad()
                # update learning rate
                if scheduler is not None:
                    scheduler.step()    
            
            # log metrics
            mask = inputs['entity_attention_mask'].bool().flatten()
            y_true = torch.masked_select(targets, mask)
            y_pred = torch.masked_select(outputs.argmax(1), mask)
            accuracy = torch.sum(y_true == y_pred)/len(y_true)
            metrics.append([loss.item(), accuracy.item()])
            if logger is not None and (step + 1) % log_every == 0:
                loss, acc = torch.mean(torch.Tensor(metrics), dim=0)
                logger.add_scalar("Loss/Train", loss.item(), step)
                logger.add_scalar("Accuracy/Train", acc.item(), step)
                metrics = []
            
            # save checkpoints
            if (step + 1) % save_every == 0:
                save_checkpoint(output_dir, step, epoch, dataloader, model, optimizer, scheduler)
            
            # update progress bar
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(dataloader.batch_size)
            
            # global step
            step += 1
            
        # global epoch
        epoch += 1

In [None]:
def main(output_dir, dataset_path, default_output_dir, batch_size, epochs, dataset_len=0, accum_iter=32):
    # initialise dataloader, model, optimizer and (optionally schedular)
    step = 0
    epoch = 0
    dataloader = None
    config = EnokeeConfig(num_entities=14065, d_ff=1024, n_layers=1, 
                          finetune=True, base_model_id="distilroberta-base")
    model = EnokeeEncoder(config).to(device=device)
    optimizer = MultiOptimizer(
        torch.optim.Adam(model.base_model.parameters(), lr=1e-5),
        torch.optim.Adam(list(model.attention.parameters()) + 
                         list(model.layers.parameters()) +  
                         list(model.classifier.parameters()), lr=5e-4)
    )
    scheduler = None

    # load checkpoints if exist
    if output_dir is not None:
        print("INFO: Loading checkpoints")
        output_dir = Path(output_dir)
        (step, 
         epoch, 
         dataloader_state_dict, 
         model_state_dict, 
         optimizer_state_dict,
         scheduler_state_dict,) = load_checkpoint(output_dir, device)
        # load dataloader_state_dict
        dataloader = DataLoader.from_state_dict(dataloader_state_dict)
        dataloader.batch_size = batch_size
        # load model_state_dict
        model.load_state_dict(model_state_dict, strict=False)
        # load optimizer_state_dict
        optimizer.load_state_dict(optimizer_state_dict)
        # load scheduler_state_dict
        if scheduler is not None and scheduler_state_dict is not None:
            scheduler.load_state_dict(scheduler_state_dict)

    elif dataset_path is not None:
        print("INFO: Loading dataset")
        dataset_path = Path(dataset_path)
        output_dir = Path(default_output_dir)
        output_dir.mkdir(exist_ok=True)
        # create dataloader
        if dataset_path.exists():
            dataloader = DataLoader(dataset_path, batch_size=batch_size, nrows=dataset_len)
        else:
            raise FileNotFoundError("Dataset does not exist at the provided path")
        # remove previous tensorboard files
        print("INFO: existing checkpoints found, removing...")
        if output_dir.exists():
            for file in output_dir.iterdir():
                file.unlink()
            output_dir.rmdir()
    else:
        raise ValueError(
            "No arguments provided, run `python train.py --help to list arguments"
        )

    # initialise summary writer
    logger = SummaryWriter(output_dir)
    # train
    try:
        train(output_dir, dataloader, model, optimizer, epochs, 
              scheduler=scheduler, logger=logger, previous_state=(step, epoch),
              save_every=1000, log_every=100, accum_iter=accum_iter)
    except KeyboardInterrupt:
        if logger is not None:
            logger.close()
        print("Stopped training.")

In [None]:
main(output_dir=None,
     dataset_path="./data/zelda.jsonl.bz2",
     dataset_len=22535558,
     default_output_dir="./output",
     batch_size=64,
     accum_iter=8,
     epochs=2)