In [None]:
from transformers import *
import torch
import time
import math
import os
import numpy as np
from sklearn.metrics import classification_report
from sklearn_crfsuite.metrics import flat_classification_report

In [None]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)

In [None]:
def data_loader(x, y, batch_size=32, shuffle=True, drop_last=True, device=torch.device('cpu')):
    dataset = torch.utils.data.TensorDataset(x.to(device), y.to(device))
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)

In [None]:
def get_tensor_data(path, tokenizer, all_labels, max_len=512):
    label2idx = {ele:idx for idx, ele in enumerate(all_labels)}
    
    x = []
    y = []
    
    with open(path, 'r', encoding='utf8') as f:
        all_lines = [line.strip().split('\t') for line in f]        
        
    for (text, label, aid, sid) in all_lines:
        padding_len = max_len - 2 - len(label.split())
        curr_x = tokenizer.encode(text, pad_to_max_length=True, truncation=True, max_length=max_len)
        curr_y = [label2idx['X']] + [label2idx[l] for l in label.split()] + [label2idx['X']] + [label2idx['X']]*padding_len
        x.append(curr_x)
        y.append(curr_y)
        
    return torch.tensor(x), torch.tensor(y)

In [None]:
def evaluation(model, te_x, te_y, batch_size=32):
    warnings.filterwarnings(action='once')
    test_dataloader = data_loader(te_x, te_y, batch_size=batch_size, shuffle=False, drop_last=False, device=avail_device)
    all_pred = []
    all_real = []
    model.eval()

    for batch_iter, (_x, _) in enumerate(test_dataloader):
        with torch.no_grad():
            outputs = model(_x, labels=None)
            curr_pred = torch.argmax(torch.softmax(outputs.logits.cpu().detach(), dim=-1), dim=-1).detach().numpy().tolist()
            all_pred.extend(curr_pred)
            all_real.extend(_.detach().cpu().numpy().tolist())
        torch.cuda.empty_cache()
    
    print(flat_classification_report(all_real, all_pred, labels=[i for i in range(2, len(all_labels))], target_names=all_labels[2:]))

    return all_pred

### Initialize from pretrained

In [None]:
#pretrained = 'voidful/albert_chinese_xxlarge'
#pretrained = 'bert-base-chinese'
#pretrained = 'clue/roberta_chinese_base'
#tokenizer = BertTokenizer.from_pretrained(pretrained)

#pretrained = 'distilbert-base-multilingual-cased'
#pretrained = 'hfl/chinese-xlnet-base'
#pretrained = 'hfl/chinese-electra-base-discriminator'
pretrained = 'hfl/chinese-electra-large-discriminator'
#tokenizer = AutoTokenizer.from_pretrained(pretrained)

In [None]:
ElectraModel.from_pretrained(pretrained)

In [None]:
all_labels = ['X', 'O', 'B-time', 'I-time', 'B-med_exam', 'I-med_exam', 'B-name', 'I-name', 'B-location', 'I-location', 
              'B-family', 'I-family', 'B-ID', 'I-ID', 'B-clinical_event', 'I-clinical_event', 'B-profession', 'I-profession', 
              'B-education', 'I-education', 'B-money', 'I-money', 'B-contact', 'I-contact', 'B-organization', 'I-organization']

In [None]:
#config = AlbertConfig.from_pretrained(
#            pretrained,
#            architectures=["AlbertForTokenClassification"],
#            id2label={idx:ele for idx, ele in enumerate(all_labels)},
#            label2id={ele:idx for idx, ele in enumerate(all_labels)},
#            num_labels=len(all_labels),
#            return_dict=True
#)

#config = BertConfig.from_pretrained(
#            pretrained,
#            architectures=["bertForTokenClassification"],
#            id2label={idx:ele for idx, ele in enumerate(all_labels)},
#            label2id={ele:idx for idx, ele in enumerate(all_labels)},
#            num_labels=len(all_labels),
#            return_dict=True
#)

#config = RobertaConfig.from_pretrained(
#            pretrained,
#            architectures=["RobertaForTokenClassification"],
#            id2label={idx:ele for idx, ele in enumerate(all_labels)},
#            label2id={ele:idx for idx, ele in enumerate(all_labels)},
#            num_labels=len(all_labels),
#            return_dict=True
#)

#config = DistilBertConfig.from_pretrained(
#            pretrained,
#            architectures=["distilbertForTokenClassification"],
#            id2label={idx:ele for idx, ele in enumerate(all_labels)},
#            label2id={ele:idx for idx, ele in enumerate(all_labels)},
#            num_labels=len(all_labels),
#            return_dict=True
#)

#config = XLNetConfig.from_pretrained(
#            pretrained,
#            architectures=["XLNetForTokenClassification"],
#            id2label={idx:ele for idx, ele in enumerate(all_labels)},
#            label2id={ele:idx for idx, ele in enumerate(all_labels)},
#            num_labels=len(all_labels),
#            return_dict=True,
#            mem_len=1024
#)

config = ElectraConfig.from_pretrained(
            pretrained,
            architectures=["ElectraForTokenClassification"],
            id2label={idx:ele for idx, ele in enumerate(all_labels)},
            label2id={ele:idx for idx, ele in enumerate(all_labels)},
            num_labels=len(all_labels),
            return_dict=True
)

In [None]:
avail_device = torch.device("cuda:{}".format(get_free_gpu()) if torch.cuda.is_available() else "cpu")
model = AutoModelForTokenClassification.from_pretrained(pretrained, config=config).to(avail_device)

In [None]:
#optimizer_grouped_parameters = [
#    {'params': [p for n, p in model.named_parameters() if 'classifier' not in n], 'lr': 3e-5},
#    {'params': [p for n, p in model.named_parameters() if 'classifier' in n], 'lr': 1e-3}
#]
#optimizer = AdamW(optimizer_grouped_parameters)

optimizer = AdamW(model.parameters(), lr=1e-4)
scheduler = None#torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.002 * epoch))

### Finetuning

In [None]:
max_len = 512
batch_size = 8
accumulation_steps = 1

x, y = get_tensor_data('TEA/train_data_char_sub_510.txt', tokenizer, all_labels, max_len=max_len)
train_dataloader = data_loader(x, y, batch_size=batch_size, shuffle=True, drop_last=True, device=avail_device)

In [None]:
epochs = 10
print_percent = 20

for i in range(epochs):
    print('===== epoch {} ====='.format(i+1))
    model.train()
    
    start = time.time()
    n_iters = len(train_dataloader)
    print_every = n_iters*print_percent/100
    print_every = 1 if print_every < 1 else int(print_every)
    print_loss_total = 0
    
    for batch_iter, (_x, _y) in enumerate(train_dataloader):
        outputs = model(_x, labels=_y)
        outputs.loss = outputs.loss / accumulation_steps
        print_loss_total += outputs.loss.item()
        outputs.loss.backward()
        
        if (batch_iter+1) % accumulation_steps == 0:
            optimizer.step()
            model.zero_grad()
            if scheduler is not None:
                scheduler.step()
                print(f'curr_lr: {optimizer.state_dict()["param_groups"][0]["lr"]}')
    
        if (batch_iter+1) % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, (batch_iter+1)/n_iters),
                                         (batch_iter+1), (batch_iter+1)/n_iters * 100, print_loss_avg))
            
        torch.cuda.empty_cache()
        
    evaluation(model, x, y, batch_size=64)

### Evaluation

In [None]:
te_x, te_y = get_tensor_data('TEA/dev_data_char_sub_510.txt', tokenizer, all_labels, max_len=max_len)
test_dataloader = data_loader(te_x, te_y, batch_size=16, shuffle=False, drop_last=False, device=avail_device)

In [None]:
all_pred = []
model.eval()

for batch_iter, (_x, _) in enumerate(test_dataloader):
    with torch.no_grad():
        outputs = model(_x, labels=None)
        curr_pred = torch.argmax(torch.softmax(outputs.logits.cpu().detach(), dim=-1), dim=-1).detach().numpy().tolist()
        all_pred.extend(curr_pred)
    torch.cuda.empty_cache()

### Generate result tsv

In [None]:
with open('TEA/dev_data_char_sub_510.txt', 'r', encoding='utf8') as f:
    test_lines = [line.strip().split('\t') for line in f]

In [None]:
combine_test_lines = []
combine_test_pred = []

curr_line = []
curr_pred = []
curr_aid = 0

for i in range(len(test_lines)):
    article, _, aid, sid = test_lines[i]
    
    if int(aid) != curr_aid:
        combine_test_lines.append(curr_line)
        combine_test_pred.append(curr_pred)
        curr_line = []
        curr_pred = []
        curr_aid += 1
       
    curr_line.extend(article.split())
    curr_pred.extend(all_pred[i][1:-1][:len(article.split())])    
    assert len(curr_line) == len(curr_pred)
    
if len(curr_line) and len(curr_pred):
    combine_test_lines.append(curr_line)
    combine_test_pred.append(curr_pred)

In [None]:
upload_tsv = [['article_id', 'start_position', 'end_position', 'entity_text', 'entity_type']]

for aid in range(len(combine_test_lines)):
    curr_line = combine_test_lines[aid]
    curr_pred = combine_test_pred[aid]

    entity_idxs = []
    entity_text = []
    entity_type = None

    for i in range(len(curr_line)):
        curr_text = curr_line[i]
        curr_pred_token = curr_pred[i]
        curr_prefix, *curr_type = all_labels[curr_pred_token].split('-')

        ### curr state ###
        if curr_prefix == 'B':
            entity_idxs.append(i)
            entity_text.append(curr_text)
            entity_type = curr_type[0]

        elif curr_prefix == 'I':
            entity_idxs.append(i)
            entity_text.append(curr_text)

        ### next state ###
        if i == len(curr_line)-1:
            if len(entity_idxs) and len(entity_text) and entity_type is not None:
                upload_tsv.append([str(aid), str(entity_idxs[0]), str(entity_idxs[-1]+1), ''.join(entity_text), entity_type])
        else:
            next_pred_token = curr_pred[i+1]
            next_prefix, *next_type = all_labels[next_pred_token].split('-')

            if next_prefix in {'O', 'X', 'B'} or (next_prefix == 'I' and next_type[0] != entity_type):
                ### update ###
                if len(entity_idxs) and len(entity_text) and entity_type is not None:
                    upload_tsv.append([str(aid), str(entity_idxs[0]), str(entity_idxs[-1]+1), ''.join(entity_text), entity_type])

                ### reset ###
                if next_prefix in {'O', 'X', 'B'}:
                    entity_idxs = []
                    entity_text = []
                    entity_type = None
                else:
                    entity_idxs = [i]
                    entity_text = [curr_text]
                    entity_type = next_type[0]

In [None]:
#with open('./bert_pytorch_rebuild_result/baseline_3.tsv', 'w', encoding='utf8') as f:
#    f.write('\n'.join(['\t'.join(l) for l in upload_tsv]))

#with open('./bert_pytorch_rebuild_result/baseline_9_e40.tsv', 'w', encoding='utf8') as f:
#    f.write('\n'.join(['\t'.join(l) for l in upload_tsv]))

with open('./bert_pytorch_rebuild_result/baseline_10_e30.tsv', 'w', encoding='utf8') as f:
    f.write('\n'.join(['\t'.join(l) for l in upload_tsv]))

### recorded

1. baseline_0 = bert-base-chinese / tf / max_len 512 / batch 4 / lr 3e-5 / 100 epoch
2. baseline_1 = bert-base-chinese / tf / max_len 128 / batch 16 / lr 2e-5 / 100 epoch
3. baseline_2 = bert-base-chinese / huggingface / max_len 512 / batch 4 / lr 2e-5 / 40 epoch / lr 1e-5 / 20 epoch /
4. baseline_3 = bert-base-chinese / huggingface / max_len 512 / batch 4 / lr 2e-5 / 40 epoch / lr 1e-5 / 20 epoch / + chinese   number convert



10. baseline_9_e20 = distilbert-base-multilingual-case / max_len 512 / batch 4 / lr 5e-6 / 20 epoch
10. baseline_9_e40 = distilbert-base-multilingual-case / max_len 512 / batch 4 / lr 2e-6 / 20 ~ 40 epoch


11. baseline_10_e20 = electra-base / max_len 512 / batch 8 / lr 1e-4 / 20 epoch
11. baseline_10_e30 = electra-base / max_len 512 / batch 8 / lr 1e-4 / 30 epoch