In [None]:
import random
import numpy as np
import pandas as pd
from ast import literal_eval as load
from datasets import Dataset
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers.models.bert.modeling_bert import BertForTokenClassification
from transformers import BertTokenizer, BertConfig, AdamW, get_cosine_schedule_with_warmup
from ast import literal_eval as load


In [None]:
df = pd.read_csv('data.csv', index_col=False, converters={
                 'words': load, 'sent_tag': load}).rename(columns={'words': 'text', 'sent_tag': 'labels'})


In [None]:
model_path = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
tokenizer = BertTokenizer.from_pretrained(model_path)
device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")


In [None]:
maxlen = 512
pad_token_id = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
ignore_label_id = CrossEntropyLoss().ignore_index


def convert(df):
    input_ids_ls = []
    attention_mask_ls = []
    labels_ls = []
    for i in range(len(df)):
        sents = df.loc[i, 'text']
        labels = df.loc[i, 'labels']
        tokens = []
        label_ids = []
        for j in range(len(sents)):  # loop over each sentence
            token_tmp = []
            for word in sents[j]:
                word_tokens = tokenizer.tokenize(word)
                token_tmp.extend(word_tokens)
            token_tmp.extend([tokenizer.sep_token])
            label_ids.extend([ignore_label_id] *
                                (len(token_tmp)-1)+[labels[j]])
            tokens.extend(token_tmp)

        if len(tokens) > maxlen:
            tokens = tokens[:maxlen]
            label_ids = label_ids[:maxlen]
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        padding_length = maxlen - len(input_ids)
        attention_mask = [1]*len(input_ids) + [0]*padding_length
        input_ids.extend([pad_token_id] * padding_length)
        label_ids.extend([ignore_label_id] * padding_length)
        assert len(input_ids) == maxlen
        assert len(attention_mask) == maxlen
        assert len(label_ids) == maxlen
        input_ids_ls.append(input_ids)
        attention_mask_ls.append(attention_mask)
        labels_ls.append(label_ids)
    tokenized_df = pd.DataFrame(
        [input_ids_ls, attention_mask_ls, labels_ls]).transpose()
    tokenized_df.columns = ['input_ids',
                            'attention_mask', 'labels']
    return tokenized_df


In [None]:
ids = df['paper'].unique()
bound = int(0.9*len(ids))
train_ids, eval_ids = ids[30:], ids[:30]
train_df = df.set_index("paper").loc[train_ids].reset_index()
eval_df = df.set_index("paper").loc[eval_ids].reset_index()
train_df = train_df.sample(frac=1, random_state=1)


In [None]:
tokenized_train = Dataset.from_pandas(convert(train_df))
tokenized_eval = Dataset.from_pandas(convert(eval_df))
tokenized_train.set_format("torch")
tokenized_eval.set_format("torch")
train_loader = DataLoader(tokenized_train, shuffle=True, batch_size=4)
eval_loader = DataLoader(tokenized_eval, batch_size=4)


In [None]:
def F1(ref, pred):
    tp = fn = fp = 0
    for i in range(len(pred)):
        for j in range(len(pred[i])):
            if pred[i][j] == 1 and ref[i][j] == 1:
                tp += 1
            elif pred[i][j] == 1 and ref[i][j] == 0:
                fp += 1
            elif pred[i][j] == 0 and ref[i][j] == 1:
                fn += 1
    pc = rc = f1 = 0
    if tp != 0:
        pc = tp/(tp+fp)
        rc = tp/(tp+fn)
        f1 = 2*pc*rc/(pc+rc)
    print(f'precision: {pc}, recall: {rc}, F1: {f1}')
    return f1


In [None]:
import torch
from torch import nn
import torch.nn.functional as F

# adapted from https://github.com/DHPO/GHM_Loss.pytorch
class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha, ignore_label_id=-100):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None
        self._ignore_label_id = ignore_label_id

    def _g2bin(self, g):
        return torch.floor(g * (self._bins - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def forward(self, x, target):
        valid_label = (target != self._ignore_label_id)
        x = x[valid_label]
        target = target[valid_label]

        g = torch.abs(self._custom_loss_grad(x, target)).detach()
        bin_idx = self._g2bin(g)
        bin_count = torch.zeros((self._bins), device=x.device)
        for i in range(self._bins):
            bin_count[i] = (bin_idx == i).sum().item()
        N = x.size(0)  # N = (x.size(0) * x.size(1))
        if self._last_bin_count is None:
            self._last_bin_count = bin_count
        else:
            bin_count = self._alpha * self._last_bin_count + \
                (1 - self._alpha) * bin_count
            self._last_bin_count = bin_count
        nonempty_bins = (bin_count > 0).sum().item()
        gd = bin_count * nonempty_bins
        gd = torch.clamp(gd, min=1)  # min=0.0001
        beta = N / gd
        return self._custom_loss(x, target, beta[bin_idx])


class GHMC_Loss(GHM_Loss):
    def __init__(self, bins, alpha):
        super(GHMC_Loss, self).__init__(bins, alpha)

    def _custom_loss(self, x, target, weight):
        criterion = CrossEntropyLoss(reduction='none')
        loss = criterion(x, target)
        loss = torch.mean(loss * weight)
        return loss

    def _custom_loss_grad(self, x, target):
        return torch.gather(F.softmax(x, dim=-1).detach(), 1, target.view(-1, 1)) - 1.0


In [None]:
def evaluate(model, data_loader, criterion, device=device):
    model.eval()
    val_true, val_pred = [], []
    i = 0
    eval_loss = 0
    for batch in data_loader:
        i += 1
        mlen = batch['attention_mask'].sum(1).max().item()
        batch = {k: v[:, :mlen].to(device) if v.dim()==2 else v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            eval_loss += criterion(logits.view(-1, 2),
                                   batch['labels'].view(-1)).item()
        y_pred = torch.argmax(logits, dim=-1).detach().cpu().numpy().tolist()
        y_true = batch['labels'].detach().cpu().numpy().tolist()  # .squeeze()
        real_len = torch.sum(batch['attention_mask'],
                             1).detach().cpu().numpy().tolist()
        for j in range(len(real_len)):
            pred_tmp = []
            true_tmp = []
            for k in range(real_len[j]):  # range(len(y_true[j])):
                if y_true[j][k] != ignore_label_id:
                    pred_tmp.append(y_pred[j][k])
                    true_tmp.append(y_true[j][k])
            val_true.append(true_tmp)
            val_pred.append(pred_tmp)
    eval_loss = eval_loss / (i+1)
    f1 = F1(val_true, val_pred)

    return f1, eval_loss


In [None]:
def train_and_eval(model, train_loader, valid_loader, criterion,
                   optimizer, scheduler, num_epochs, device=device):
    best_acc = 0.0
    # patience = 0
    progress_bar = tqdm(range(num_training_steps))
    for epoch in range(num_epochs):
        count = 0
        model.train()
        print("***** Running training epoch {} *****".format(epoch+1))
        running_loss = 0
        for batch in train_loader:
            mlen = batch['attention_mask'].sum(1).max().item()
            batch = {k: v[:, :mlen].to(device) if v.dim()==2 else v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            loss = criterion(logits.view(-1, 2), batch['labels'].view(-1))
            loss /= grad_acu_steps
            running_loss += loss.item()
            loss.backward()
            count += 1
            if count % grad_acu_steps == 0:
                if count % (grad_acu_steps*20) == 0:
                    running_loss /= 20
                    print(f'running_loss:{running_loss}')
                    running_loss = 0
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
        model.eval()
        acc, eval_loss = evaluate(model, valid_loader, criterion)
        print(f'evaluation loss: {eval_loss}')
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_model.pth")
        print("current acc is {:.4f}, best acc is {:.4f}".format(acc, best_acc))


In [None]:
grad_acu_steps = 4 
num_epochs = 8
num_training_steps = num_epochs*len(train_loader)//grad_acu_steps

config = BertConfig.from_pretrained(
    model_path, max_position_embeddings=maxlen, num_labels=2)  # 18
model = BertForTokenClassification.from_pretrained(
    model_path, config=config, ignore_mismatched_sizes=True).to(device)
#optimizer = AdamW([{'params': model.bert.parameters()},
                   #{'params': model.classifier.parameters(), 'lr': 5e-5}], lr=1e-5)

no_decay = ['bias', 'LayerNorm']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.bert.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.bert.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
{'params': [p for n, p in model.classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, 'lr': 1e-4},
{'params': [p for n, p in model.classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': 1e-4},
]

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=200,
                                            num_training_steps=num_training_steps, num_cycles=0.5)
criterion = GHMC_Loss(bins=10, alpha=0.9)

train_and_eval(model, train_loader, eval_loader,
               criterion, optimizer, scheduler, num_epochs)