In [2]:
import numpy as np
import os
import torch
import json
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from transformers import BertForMaskedLM, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def preprocess_data(data):
    input_texts = []
    labels = []

    for example in data:
        example = json.loads(example)

        input_text = example['content']
        ground_truth = example['groundTruth']
        candidates = example['candidates']
        
        for i, idiom in enumerate(ground_truth):
            candidates_str = '，'.join([c for c in candidates[i]])
            input_text = input_text.replace('#idiom#', f"#[MASK][MASK][MASK][MASK]#({candidates})", 1)
            # input_text = input_text.replace('#idiom#', f"#[MASK][MASK][MASK][MASK]#({candidates_str})", 1)
        if len(input_text) > 500:
            continue  
        input_texts.append(input_text)
        labels.append(''.join(ground_truth))


    concat_inputs = tokenizer(input_texts,return_token_type_ids=False)
    labels = tokenizer(labels, return_token_type_ids=False, return_attention_mask=False, add_special_tokens=False)
    return concat_inputs, labels


In [5]:
class IdiomDataset(Dataset):

    def __init__(self, inputs, labels) -> None:
        super().__init__()
        self.inputs = inputs
        self.labels = self._get_label(inputs["input_ids"], labels["input_ids"])

    
    def __len__(self):
        return len(self.inputs["input_ids"])
    
    def _get_label(self, inputs, labels):
        results = []
        for inp, label in zip(inputs, labels):
            inp = np.array(inp)
            l = np.full_like(inp, fill_value=-100)
            l[inp == tokenizer.mask_token_id] = label
            results.append(l)
        return results


    def __getitem__(self, index):
        return {"input_ids": self.inputs["input_ids"][index], "attention_mask": self.inputs["attention_mask"][index], "labels": self.labels[index]}


def collate_fn(batch):
    batch_input_ids = [torch.LongTensor(each["input_ids"]) for each in batch]
    batch_att_mask = [torch.LongTensor(each["attention_mask"]) for each in batch]
    batch_label = [torch.LongTensor(each["labels"]) for each in batch]
    padded_batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_att_mask = pad_sequence(batch_att_mask, batch_first=True, padding_value=0)
    padded_batch_label = pad_sequence(batch_label, batch_first=True, padding_value=-100)
    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}
    
def to_device(data, device):
    new_data = {}
    for k in data:
        new_data[k] = data[k].to(device)
    return new_data


In [6]:
def train(model:nn.Module, train_loader:DataLoader, optimizer:optim.Optimizer, log_step=100):
    model.train()
    epoch_loss = 0.0
    log_loss = 0.0
    for idx, batch in enumerate(train_loader, 1):
        optimizer.zero_grad()
        batch = to_device(batch, device)
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        log_loss += loss.item()
        if idx % log_step == 0:
            print(f"Train Step: {idx} Loss: {log_loss / log_step}")
            log_loss = 0.0
    return epoch_loss / len(train_loader)
        

@torch.no_grad()
def evaluate(model:nn.Module, eval_loader:DataLoader):
    eval_loss = 0.0
    correct = 0
    total = 0
    model.eval()
    for batch in eval_loader:
        batch = to_device(batch, device)
        output = model(**batch)
        loss = output.loss
        eval_loss += loss.item()
        pred = output.logits.argmax(-1)
        label = batch["labels"]
        correct += torch.where(label!=-100, pred==label, 0).sum().item()
        total += torch.sum(label != -100).item()

    eval_acc = correct / total
    eval_loss = eval_loss / len(eval_loader) 
    print(total, correct)
    return eval_acc, eval_loss

In [7]:
# Load the Chinese Idioms dataset
train_data_file = './data/train_15000.txt'
val_data_file = './data/dev_2000.txt'


with open(train_data_file) as f:
    train_data = f.readlines()

with open(val_data_file) as f:
    val_data = f.readlines()

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", cache_dir="./models")

train_inputs, train_labels = preprocess_data(train_data)
val_inputs, val_labels = preprocess_data(val_data)

In [8]:
val_dataset = IdiomDataset(val_inputs, val_labels)
print(tokenizer.decode(val_dataset[0]["labels"]))
print(val_dataset[0]["labels"])
print(tokenizer.decode(val_dataset[0]["input_ids"]))
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False)
for each in val_loader:
    print(each["input_ids"])
    print(each.keys())
    break

[UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] 深 恶 痛 绝 [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK] [UNK]
[-100 -100 -100 -100 -100 -100 -

In [9]:
train_dataset = IdiomDataset(train_inputs, train_labels)
val_dataset = IdiomDataset(val_inputs, val_labels)

train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False)

In [10]:
model = BertForMaskedLM.from_pretrained("bert-base-chinese", cache_dir="./models").to(device)

epoches = 5
optimizer = optim.Adam(model.parameters(), lr=2e-5)

model.train()

for epoch in range(1, epoches+1):
    print(f"Training Epoch {epoch}")
    
    train_loss = train(model, train_loader, optimizer)
    print(f"Epoch {epoch} Training Loss: {train_loss}")
    
    eval_acc, eval_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")
torch.save(model.state_dict(), "bert_ckpt_new.pt")

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training Epoch 1
Train Step: 100 Loss: 2.9327831733226777
Train Step: 200 Loss: 2.3315607142448425
Train Step: 300 Loss: 2.2123726773262025
Train Step: 400 Loss: 2.1301386761665344
Train Step: 500 Loss: 2.0706395006179807
Train Step: 600 Loss: 1.9072985714673996
Train Step: 700 Loss: 1.8154919040203095
Train Step: 800 Loss: 1.7903302580118179
Epoch 1 Training Loss: 2.1094888179923124
8452 5045
Epoch 1 Eval Acc: 0.59690014197823; Eval Loss: 1.383521120337879
Training Epoch 2
Train Step: 100 Loss: 1.1090427428483962
Train Step: 200 Loss: 1.132837101817131
Train Step: 300 Loss: 1.0888219580054284
Train Step: 400 Loss: 1.1478842985630036
Train Step: 500 Loss: 1.1079305881261825
Train Step: 600 Loss: 1.1348290991783143
Train Step: 700 Loss: 1.1474296295642852
Train Step: 800 Loss: 1.1295168882608413
Epoch 2 Training Loss: 1.1243057316494633
8452 5528
Epoch 2 Eval Acc: 0.6540463795551349; Eval Loss: 1.2507230892401784
Training Epoch 3
Train Step: 100 Loss: 0.5392642591893673
Train Step: 200 

In [14]:
@torch.no_grad()
def fill_idiom(model, loader):
    all_preds = []
    all_labels = []
    model.eval()
    for batch in loader:   
        batch = to_device(batch, device)
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        output = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = output.logits.argmax(-1)
        for pred, label in zip(preds, labels):
            idiom_idx = label != -100
            l = label[idiom_idx].split(4)
            p = pred[idiom_idx].split(4)
            gold_idiom =  set("".join(tokenizer.convert_ids_to_tokens(i)) for i in l)
            pred_idiom = set("".join(tokenizer.convert_ids_to_tokens(i)) for i in p)
            all_labels.append(set(gold_idiom))
            all_preds.append(set(pred_idiom))
    
    return all_preds, all_labels
        

def f1_score(sys, gold):
    tp = 0
    t = 0
    p = 0
    for s, g in zip(sys, gold):
        t += len(g)
        p += len(s)
        tp += len(g & s)
    precision = tp / p if p != 0 else 0
    recall = tp / t if t != 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    return precision, recall, f1
    

In [15]:
model = BertForMaskedLM.from_pretrained("bert-base-chinese", cache_dir="./models").to(device)
model.load_state_dict(torch.load("bert_ckpt_new.pt", map_location=device))

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [16]:
sys, gold = fill_idiom(model, val_loader)
p, r, f1 = f1_score(sys, gold)
print(p, r, f1)

0.5450664136622391 0.5448079658605974 0.5449371591178563
