In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Setup

In [None]:
!git clone https://github.com/neuspell/neuspell
%cd neuspell

In [None]:
!pip install -e .

In [None]:
!pip install urllib3==1.25.4

In [None]:
!pip install folium==0.2.1

In [None]:
!pip install -r extras-requirements.txt

In [None]:
!pip install torch==1.6.0

In [None]:
!pip install transformers==4.1

In [None]:
import neuspell

In [None]:
%cd data/traintest
!python download_datafiles.py 

In [None]:
%cd /content/neuspell

# Library

In [None]:
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import numpy as np

import time

from neuspell.seq_modeling.subwordbert import load_model
from neuspell.seq_modeling.helpers import load_data, train_validation_split, batch_accuracy_func
from neuspell.seq_modeling.helpers import get_tokens, progressBar
from neuspell.seq_modeling.helpers import batch_iter, labelize, tokenize, bert_tokenize_for_valid_examples

from neuspell.seq_modeling.helpers import load_vocab_dict, save_vocab_dict

# Train

## Load dataset, vocab

In [None]:
train_data = load_data('/content/neuspell/data/traintest/','test.1blm','test.1blm.noise.prob')

train_data, valid_data = train_validation_split(train_data, 0.90, seed=1)

vocab_ref = {}

vocab = get_tokens([i[0] for i in train_data],
                           keep_simple=True,
                           min_max_freq=(2,float("inf")),
                           topk=100000,
                           intersect=vocab_ref,
                           load_char_tokens=True)

# save_vocab_dict('/content/drive/MyDrive/NLP/bert_vocab.pkl', vocab)

## Load model

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()
bert_pretrained_name_or_path = "distilbert-base-cased"
model = load_model(vocab,"distilbert-base-cased")
model = model.cuda()


VALID_BATCH_SIZE = 32

data_iter = batch_iter(train_data, batch_size=VALID_BATCH_SIZE, shuffle=False)

TRAIN_BATCH_SIZE = 32

DEVICE = 'cuda'


## freeze layers

In [None]:
# for layers in model.bert_model.encoder.layer[:9]:
#     for param in layers.parameters():
#         param.requires_grad = False

## Bert optimizer

In [None]:
START_EPOCH = 1
N_EPOCHS = 5

GRADIENT_ACC = 4
max_dev_acc, argmax_dev_acc = -1, -1

In [None]:
# from pytorch_pretrained_bert import BertAdam

# param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]
# t_total = int(len(train_data) / TRAIN_BATCH_SIZE / GRADIENT_ACC * N_EPOCHS)
# optimizer = BertAdam(optimizer_grouped_parameters,lr=5e-5,warmup=0.1,t_total=t_total)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



## Set epoch, start training

In [None]:
BERT_TOKENIZER = None
import transformers
from typing import List
BERT_MAX_SEQ_LEN = 512

def merge_subtokens(tokens: List):
    merged_tokens = []
    for token in tokens:
        if token.startswith("##"):
            merged_tokens[-1] = merged_tokens[-1] + token[2:]
        else:
            merged_tokens.append(token)
    text = " ".join(merged_tokens)
    return text

def _custom_bert_tokenize_sentence(text):
    tokens = BERT_TOKENIZER.tokenize(text)
    tokens = tokens[:BERT_MAX_SEQ_LEN - 2]  # 2 allowed for [CLS] and [SEP]
    idxs = np.array([idx for idx, token in enumerate(tokens) if not token.startswith("##")] + [len(tokens)])
    split_sizes = (idxs[1:] - idxs[0:-1]).tolist()
    # NOTE: BERT tokenizer does more than just splitting at whitespace and tokenizing. So be careful.
    # -----> assert len(split_sizes)==len(text.split()), print(len(tokens), len(split_sizes), len(text.split()), split_sizes, text)
    # -----> hence do the following:
    text = merge_subtokens(tokens)
    assert len(split_sizes) == len(text.split()), print(len(tokens), len(split_sizes), len(text.split()), split_sizes,
                                                        text)
    return text, tokens, split_sizes

def _custom_bert_tokenize_sentences(list_of_texts):
    out = [_custom_bert_tokenize_sentence(text) for text in list_of_texts]
    texts, tokens, split_sizes = list(zip(*out))
    return [*texts], [*tokens], [*split_sizes]

def _simple_bert_tokenize_sentences(list_of_texts):
    return [merge_subtokens(BERT_TOKENIZER.tokenize(text)[:BERT_MAX_SEQ_LEN - 2]) for text in list_of_texts]

def bert_tokenize_for_valid_examples(batch_orginal_sentences, batch_noisy_sentences, bert_pretrained_name_or_path=bert_pretrained_name_or_path):
    """
    inputs:
        batch_noisy_sentences: List[str]
            a list of textual sentences to tokenized
        batch_orginal_sentences: List[str]
            a list of texts to make sure lengths of input and output are same in the seq-modeling task
        bert_pretrained_name_or_path:
            a huggingface path for loading a custom bert model
    outputs (only of batch_noisy_sentences):
        batch_attention_masks, batch_input_ids, batch_token_type_ids
            2d tensors of shape (bs,max_len)
        batch_splits: List[List[Int]]
            specifies #sub-tokens for each word in each textual string after sub-word tokenization
    """
    global BERT_TOKENIZER

    if BERT_TOKENIZER is None:  # gets initialized during the first call to this method
        if bert_pretrained_name_or_path:
            BERT_TOKENIZER = transformers.BertTokenizer.from_pretrained(bert_pretrained_name_or_path)
            BERT_TOKENIZER.do_basic_tokenize = True
            BERT_TOKENIZER.tokenize_chinese_chars = False
        else:
            BERT_TOKENIZER = transformers.BertTokenizer.from_pretrained(bert_pretrained_name_or_path)
            BERT_TOKENIZER.do_basic_tokenize = True
            BERT_TOKENIZER.tokenize_chinese_chars = False

    _batch_orginal_sentences = _simple_bert_tokenize_sentences(batch_orginal_sentences)
    _batch_noisy_sentences, _batch_tokens, _batch_splits = _custom_bert_tokenize_sentences(batch_noisy_sentences)

    valid_idxs = [idx for idx, (a, b) in enumerate(zip(_batch_orginal_sentences, _batch_noisy_sentences)) if
                  len(a.split()) == len(b.split())]
    batch_orginal_sentences = [line for idx, line in enumerate(_batch_orginal_sentences) if idx in valid_idxs]
    batch_noisy_sentences = [line for idx, line in enumerate(_batch_noisy_sentences) if idx in valid_idxs]
    batch_tokens = [line for idx, line in enumerate(_batch_tokens) if idx in valid_idxs]
    batch_splits = [line for idx, line in enumerate(_batch_splits) if idx in valid_idxs]

    batch_bert_dict = {
        "attention_mask": [],
        "input_ids": [],
        # "token_type_ids": []
    }
    if len(valid_idxs) > 0:
        batch_encoded_dicts = [BERT_TOKENIZER.encode_plus(tokens) for tokens in batch_tokens]
        batch_attention_masks = pad_sequence(
            [torch.tensor(encoded_dict["attention_mask"]) for encoded_dict in batch_encoded_dicts], batch_first=True,
            padding_value=0)
        batch_input_ids = pad_sequence(
            [torch.tensor(encoded_dict["input_ids"]) for encoded_dict in batch_encoded_dicts], batch_first=True,
            padding_value=0)
        # batch_token_type_ids = pad_sequence(
        #     [torch.tensor(encoded_dict["token_type_ids"]) for encoded_dict in batch_encoded_dicts], batch_first=True,
        #     padding_value=0)
        batch_bert_dict = {"attention_mask": batch_attention_masks,
                           "input_ids": batch_input_ids,
                           # "token_type_ids": batch_token_type_ids
                           }

    return batch_orginal_sentences, batch_noisy_sentences, batch_bert_dict, batch_splits

In [None]:
# train and eval
for epoch_id in range(START_EPOCH,N_EPOCHS+1):

    print(f"In epoch: {epoch_id}")

    # train loss and backprop
    train_loss = 0.
    train_acc = 0.
    train_acc_count = 0.
    print("train_data size: {}".format(len(train_data)))
    
    train_data_iter = batch_iter(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
    nbatches = int(np.ceil(len(train_data)/TRAIN_BATCH_SIZE))
    optimizer.zero_grad()

    for batch_id, (batch_labels,batch_sentences) in enumerate(train_data_iter):
        optimizer.zero_grad()
        st_time = time.time()

        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences,bert_pretrained_name_or_path=bert_pretrained_name_or_path)                
        if len(batch_labels_)==0:
            print("################")
            print("Not training the following lines due to pre-processing mismatch: \n")
            print([(a,b) for a,b in zip(batch_labels,batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_

        batch_bert_inp = {k:v.to(DEVICE) for k,v in batch_bert_inp.items()}

        # set batch data for others
        batch_labels, batch_lengths = labelize(batch_labels, vocab)
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels = batch_labels.to(DEVICE)

        # forward
        model.train()
        
        loss = model(batch_bert_inp, batch_bert_splits, targets=batch_labels)
        
        batch_loss = loss.cpu().detach().numpy()
        train_loss += batch_loss

        # backward
        # if GRADIENT_ACC > 1:
        #     loss = loss / GRADIENT_ACC
        loss.backward()
        # step
        # if (batch_id + 1) % GRADIENT_ACC == 0 or batch_id >= nbatches - 1:
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
            # scheduler.step()
        # optimizer.zero_grad()

        # compute accuracy in numpy
        if batch_id%10000==0:

            train_acc_count += 1

            model.eval()
            with torch.no_grad():
                _, batch_predictions = model(batch_bert_inp, batch_bert_splits, targets=batch_labels)

            model.train()

            batch_labels = batch_labels.cpu().detach().numpy()
            batch_lengths = batch_lengths.cpu().detach().numpy()
            ncorr,ntotal = batch_accuracy_func(batch_predictions,batch_labels,batch_lengths)
            batch_acc = ncorr/ntotal
            train_acc += batch_acc     

        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(train_data) / TRAIN_BATCH_SIZE)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"],
                    [time.time()-st_time,batch_loss,train_loss/(batch_id+1),batch_acc,train_acc/train_acc_count]) 
    
    print(f"\nEpoch {epoch_id} train_loss: {train_loss/(batch_id+1)}")

    # save model every epoch
    model_name = "bert_epoch_" + str(epoch_id) + '.pt'
    # torch.save(model.state_dict(), 
    #         '/content/drive/MyDrive/NLP/'+model_name)

    # valid loss
    valid_loss = 0.
    valid_acc = 0.
    print("valid_data size: {}".format(len(valid_data)))

    valid_data_iter = batch_iter(valid_data, batch_size=VALID_BATCH_SIZE, shuffle=False)

    for batch_id, (batch_labels,batch_sentences) in enumerate(valid_data_iter):

        st_time = time.time()
        # set batch data for bert
        # batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences)

        batch_labels, batch_sentences, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(batch_labels,batch_sentences)
        """
        if len(batch_labels_)==0:
            print("################")
            print("Not validating the following lines due to pre-processing mismatch: \n")
            print([(a,b) for a,b in zip(batch_labels,batch_sentences)])
            print("################")
            continue
        else:
        
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        """

        batch_bert_inp = {k:v.to(DEVICE) for k,v in batch_bert_inp.items()}


        # set batch data for others
        batch_labels, batch_lengths = labelize(batch_labels, vocab)
        batch_lengths = batch_lengths.to(DEVICE)
        batch_labels = batch_labels.to(DEVICE)

        # forward
        model.eval()
        with torch.no_grad():
            batch_loss, batch_predictions = model(batch_bert_inp, batch_bert_splits, targets=batch_labels)
        model.train()        
        valid_loss += batch_loss
        # compute accuracy in numpy
        batch_labels = batch_labels.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        ncorr,ntotal = batch_accuracy_func(batch_predictions,batch_labels,batch_lengths)
        batch_acc = ncorr/ntotal
        valid_acc += batch_acc
        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(valid_data) / VALID_BATCH_SIZE)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 
                    [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),batch_acc,valid_acc/(batch_id+1)])

    print(f"\nEpoch {epoch_id} valid_loss: {valid_loss/(batch_id+1)}")
    torch.save({
            'epoch': epoch_id,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': valid_loss}, f'/content/gdrive/MyDrive/mdistilepoch:{epoch_id}valid_acc{valid_acc/(batch_id+1)}.hdf5')

  

## Evaluation

In [None]:
from tqdm import tqdm
from neuspell.seq_modeling.evals import get_metrics

def untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_clean_sentences, backoff="pass-through"):
    assert backoff in ["neutral", "pass-through"], print(f"selected backoff strategy not implemented: {backoff}")
    idx2token = vocab["idx2token"]
    unktoken = vocab["token2idx"][vocab["unk_token"]]
    assert len(batch_predictions) == len(batch_lengths) == len(batch_clean_sentences)
    batch_clean_sentences = [sent.split() for sent in batch_clean_sentences]
    if backoff == "pass-through":
        batch_predictions = \
            [" ".join([idx2token[idx] if idx != unktoken else clean_[i] for i, idx in enumerate(pred_[:len_])]) \
             for pred_, len_, clean_ in zip(batch_predictions, batch_lengths, batch_clean_sentences)]
    elif backoff == "neutral":
        batch_predictions = \
            [" ".join([idx2token[idx] if idx != unktoken else "a" for i, idx in enumerate(pred_[:len_])]) \
             for pred_, len_, clean_ in zip(batch_predictions, batch_lengths, batch_clean_sentences)]
    return batch_predictions

def batch_iter(data, batch_size, shuffle):
    """
    each data item is a tuple of lables and text
    """
    n_batches = int(np.ceil(len(data) / batch_size))
    indices = list(range(len(data)))
    if shuffle:  np.random.shuffle(indices)

    for i in range(n_batches):
        batch_indices = indices[i * batch_size: (i + 1) * batch_size]
        batch_labels = [data[idx][0] for idx in batch_indices]
        batch_sentences = [data[idx][1] for idx in batch_indices]

        yield (batch_labels, batch_sentences)

def model_inference(model, data, topk, device, batch_size=16, vocab_=None):
    """
    model: an instance of SubwordBert
    data: list of tuples, with each tuple consisting of correct and incorrect 
            sentence string (would be split at whitespaces)
    topk: how many of the topk softmax predictions are considered for metrics calculations
    """
    if vocab_ is not None:
        vocab = vocab_
    print("###############################################")
    inference_st_time = time.time()
    _corr2corr, _corr2incorr, _incorr2corr, _incorr2incorr = 0, 0, 0, 0
    _mistakes = []
    VALID_batch_size = batch_size
    valid_loss = 0.
    valid_acc = 0.
    print("data size: {}".format(len(data)))
    data_iter = batch_iter(data, batch_size=VALID_batch_size, shuffle=False)
    model.eval()
    model.to(device)
    for batch_id, (batch_labels, batch_sentences) in tqdm(enumerate(data_iter)):
        torch.cuda.empty_cache()
        st_time = time.time()
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = bert_tokenize_for_valid_examples(
            batch_labels, batch_sentences, bert_pretrained_name_or_path=bert_pretrained_name_or_path)
        if len(batch_labels_) == 0:
            print("################")
            print("Not predicting the following lines due to pre-processing mismatch: \n")
            print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
            print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k: v.to(device) for k, v in batch_bert_inp.items()}
        # set batch data for others
        batch_labels_ids, batch_lengths = labelize(batch_labels, vocab)
        # batch_lengths = batch_lengths.to(device)
        batch_labels_ids = batch_labels_ids.to(device)
        # forward
        try:
            with torch.no_grad():
                """
                NEW: batch_predictions can now be of shape (batch_size,batch_max_seq_len,topk) if topk>1, else (batch_size,batch_max_seq_len)
                """
                batch_loss, batch_predictions = model(batch_bert_inp, batch_bert_splits, targets=batch_labels_ids,
                                                      topk=topk)
                print(batch_predictions)
        except RuntimeError:
            print(f"batch_bert_inp:{len(batch_bert_inp.keys())},batch_labels_ids:{batch_labels_ids.shape}")
            raise Exception("")
        valid_loss += batch_loss
        # compute accuracy in numpy
        batch_labels_ids = batch_labels_ids.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        # based on topk, obtain either strings of batch_predictions or list of tokens
        if topk == 1:
            batch_predictions = untokenize_without_unks(batch_predictions, batch_lengths, vocab, batch_sentences)
        else:
            batch_predictions = untokenize_without_unks2(batch_predictions, batch_lengths, vocab, batch_sentences,
                                                         topk=None)
        # corr2corr, corr2incorr, incorr2corr, incorr2incorr, mistakes = \
        #    get_metrics(batch_labels,batch_sentences,batch_predictions,check_until_topk=topk,return_mistakes=True)
        # _mistakes.extend(mistakes)
        # batch_labels = [line.lower() for line in batch_labels]
        # batch_sentences = [line.lower() for line in batch_sentences]
        # batch_predictions = [line.lower() for line in batch_predictions]
        print(batch_predictions)
        corr2corr, corr2incorr, incorr2corr, incorr2incorr = \
            get_metrics(batch_labels, batch_sentences, batch_predictions, check_until_topk=topk, return_mistakes=False)
        _corr2corr += corr2corr
        _corr2incorr += corr2incorr
        _incorr2corr += incorr2corr
        _incorr2incorr += incorr2incorr

        # delete
        del batch_loss
        del batch_predictions
        del batch_labels, batch_lengths, batch_bert_inp
        torch.cuda.empty_cache()

        '''
        # update progress
        progressBar(batch_id+1,
                    int(np.ceil(len(data) / VALID_batch_size)), 
                    ["batch_time","batch_loss","avg_batch_loss","batch_acc","avg_batch_acc"], 
                    [time.time()-st_time,batch_loss,valid_loss/(batch_id+1),None,None])
        '''
    print(f"\nEpoch {None} valid_loss: {valid_loss / (batch_id + 1)}")
    print("total inference time for this data is: {:4f} secs".format(time.time() - inference_st_time))
    print("###############################################")
    print("")
    # for mistake in _mistakes:
    #    print(mistake)
    print("")
    print("total token count: {}".format(_corr2corr + _corr2incorr + _incorr2corr + _incorr2incorr))
    print(
        f"_corr2corr:{_corr2corr}, _corr2incorr:{_corr2incorr}, _incorr2corr:{_incorr2corr}, _incorr2incorr:{_incorr2incorr}")
    print(f"accuracy is {(_corr2corr + _incorr2corr) / (_corr2corr + _corr2incorr + _incorr2corr + _incorr2incorr)}")
    print(f"word correction rate is {(_incorr2corr) / (_incorr2corr + _incorr2incorr)}")
    print("###############################################")
    return

In [None]:
# from neuspell.seq_modeling.subwordbert import model_inference

test_data = load_data('/content/neuspell/data/traintest/','test.bea60k','test.bea60k.noise')

predicted_result = model_inference(model, test_data, 1, 'cuda', 16, vocab)

In [None]:

test_data = load_data('/content/neuspell/data/traintest/','test.jfleg','test.jfleg.noise')

predicted_result = model_inference(model, test_data, 1, 'cuda', 16, vocab)