In [16]:
import os
import numpy as np
import torch
from torch import cuda, nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import GPT2LMHeadModel, BertTokenizer
import json

In [3]:

tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall", cache_dir="./models")
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall", cache_dir="./models")
device = 'cuda' if cuda.is_available() else 'cpu'
model = model.to(device)

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/217 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/421M [00:00<?, ?B/s]

In [5]:
# Preprocess the inputs and outputs
def preprocess_data(data):
    input_texts = []
    labels = []

    for example in data:
        example = json.loads(example)

        input_text = example['content']
        ground_truth = example['groundTruth']
        candidates = example['candidates']
        
        for i, idiom in enumerate(ground_truth):
            input_text = "请从（）里选择出最合适的成语: " + input_text
            candidates_str = '|'.join([c for c in candidates[i]])
            input_text = input_text.replace('#idiom#', "（" + candidates_str + "）", 1)
        if len(input_text) > 500:
            continue  
        input_texts.append(input_text)
        labels.append('、'.join(ground_truth))


    concat_inputs = tokenizer(input_texts, labels, return_token_type_ids=False)
    return concat_inputs



In [102]:
class IdiomDataset(Dataset):

    def __init__(self, data) -> None:
        super().__init__()
        self.data = data
        self.label = self._get_label(data["input_ids"])
        self.is_inference = False
    
    def __len__(self):
        return len(self.data["input_ids"])
    
    def _get_label(self, inputs):
        labels = []
        for inp in inputs:
            sep_idx = inp.index(102)
            label = [-100] * len(inp)
            label[sep_idx + 1:] = inp[sep_idx + 1:]
            labels.append(label)
        return labels

    def inference(self):
        self.is_inference = True

    def train(self):
        self.is_inference = False
        
    def is_inference(self):
        return self.is_inference

    def __getitem__(self, index):
        if not self.is_inference:
            return {"input_ids": self.data["input_ids"][index], "attention_mask": self.data["attention_mask"][index], "labels": self.label[index]}
        else:
            sep_idx = self.data["input_ids"][index].index(102)
            input_ids = self.data["input_ids"][index][:sep_idx+1]
            att_mask = self.data["attention_mask"][index][:sep_idx+1]
            label = self.label[index][sep_idx+1:]
            return {"input_ids": input_ids, "attention_mask":att_mask, "labels":label}


def collate_fn(batch):
    batch_input_ids = [torch.LongTensor(each["input_ids"]) for each in batch]
    batch_att_mask = [torch.LongTensor(each["attention_mask"]) for each in batch]
    batch_label = [torch.LongTensor(each["labels"]) for each in batch]
    padded_batch_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_att_mask = pad_sequence(batch_att_mask, batch_first=True, padding_value=0)
    padded_batch_label = pad_sequence(batch_label, batch_first=True, padding_value=-100)
    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}

def to_device(data, device):
    new_data = {}
    for k in data:
        new_data[k] = data[k].to(device)
    return new_data



In [7]:
def train(model:nn.Module, train_loader:DataLoader, optimizer:optim.Optimizer, log_step=100):
    model.train()
    epoch_loss = 0.0
    log_loss = 0.0
    for idx, batch in enumerate(train_loader, 1):
        optimizer.zero_grad()
        batch = to_device(batch, device)
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        log_loss += loss.item()
        if idx % log_step == 0:
            print(f"Train Step: {idx} Loss: {log_loss / log_step}")
            log_loss = 0.0
    return epoch_loss / len(train_loader)
        

@torch.no_grad()
def evaluate(model:nn.Module, eval_loader:DataLoader):
    eval_loss = 0.0
    correct = 0
    total = 0
    model.eval()
    for batch in eval_loader:
        batch = to_device(batch, device)
        output = model(**batch)
        loss = output.loss
        eval_loss += loss.item()
        pred = output.logits.argmax(-1)[..., :-1]
        label = batch["labels"][..., 1:]
        correct += torch.where(label!=-100, pred==label, 0).sum().item()
        total += torch.sum(label != -100).item()

    eval_acc = correct / total
    eval_loss = eval_loss / len(eval_loader) 
    print(total, correct)
    return eval_acc, eval_loss


In [23]:
# Load the Chinese Idioms dataset
train_data_file = './data/train_15000.txt'
val_data_file = './data/dev_2000.txt'


with open(train_data_file) as f:
    train_data = f.readlines()

with open(val_data_file) as f:
    val_data = f.readlines()

train_inputs = preprocess_data(train_data)
val_inputs = preprocess_data(val_data)

In [9]:
train_dataset = IdiomDataset(train_inputs)
val_dataset = IdiomDataset(val_inputs)

train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False)

In [None]:
epoches = 5
optimizer = optim.Adam(model.parameters(), lr=2e-5)
model.train()

for epoch in range(1, epoches+1):
    print(f"Training Epoch {epoch}")
    train_loss = train(model, train_loader, optimizer)
    print(f"Epoch {epoch} Training Loss: {train_loss}")
    eval_acc, eval_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch} Eval Acc: {eval_acc}; Eval Loss: {eval_loss}")
torch.save(model.state_dict(), "gpt2_ckpt.pt")

Training Epoch 1
Train Step: 100 Loss: 0.9795388734340668
Train Step: 200 Loss: 0.7113763672113419
Train Step: 300 Loss: 0.6371003955602645
Train Step: 400 Loss: 0.5993250006437302
Train Step: 500 Loss: 0.5624085527658462
Train Step: 600 Loss: 0.5670115682482719
Train Step: 700 Loss: 0.5523709109425545
Train Step: 800 Loss: 0.5443778941035271
Train Step: 900 Loss: 0.5170685809850692
Epoch 1 Training Loss: 0.6248667251898535
11950 10259
Epoch 1 Eval Acc: 0.8584937238493724; Eval Loss: 0.45387884497642517
Training Epoch 2
Train Step: 100 Loss: 0.420941047668457
Train Step: 200 Loss: 0.409544515311718
Train Step: 300 Loss: 0.4202533406019211
Train Step: 400 Loss: 0.4160501465201378
Train Step: 500 Loss: 0.4195918321609497
Train Step: 600 Loss: 0.40703207135200503
Train Step: 700 Loss: 0.4191902096569538
Train Step: 800 Loss: 0.4008663776516914
Train Step: 900 Loss: 0.3856555975973606
Epoch 2 Training Loss: 0.41043707115182493
11950 10432
Epoch 2 Eval Acc: 0.8729707112970712; Eval Loss: 0.

In [239]:
from transformers import StoppingCriteria, StoppingCriteriaList
class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if all((input_ids==self.keywords[0]).sum(dim=-1) >= 2):
            return True
        return False

stop_words = ["[SEP]"]
stop_ids = [tokenizer.convert_tokens_to_ids(w) for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)
stop_criteria_list = StoppingCriteriaList([stop_criteria])

@torch.no_grad()
def fill_idiom(model, loader):

    all_preds = []
    all_labels = []
    model.eval()
    for batch in loader:   
        batch = to_device(batch, device)
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        # pos_ids = batch["position_ids"]
        labels = batch["labels"]
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True, pad_token_id=50256, max_length=512, top_k=10, stopping_criteria=stop_criteria_list)
        pred_start = torch.nonzero(input_ids==tokenizer.sep_token_id, as_tuple=True)[1][0] + 1
        truncated_outputs = []
        for out in outputs["sequences"]:
            sep_idxs = torch.nonzero(out==tokenizer.sep_token_id, as_tuple=True)[0]
            if len(sep_idxs) == 1:
                end_idx = -1
            else:
                end_idx = sep_idxs[1]
            truncated_outputs.append(out[pred_start:end_idx])
        decode_texts = tokenizer.batch_decode(truncated_outputs)
        gold_texts = tokenizer.batch_decode([l[l != -100][:-1] for l in labels])

        for gold, decode in zip(gold_texts, decode_texts):
            l = set(gold.replace(" ", "").split("、"))
            p = set(decode.replace(" ", "").split("、"))
            all_labels.append(l)
            all_preds.append(p)
    
    return all_preds, all_labels
        

def f1_score(sys, gold):
    tp = 0
    t = 0
    p = 0
    for s, g in zip(sys, gold):
        t += len(g)
        p += len(s)
        tp += len(g & s)
    precision = tp / p if p != 0 else 0
    recall = tp / t if t != 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
    return precision, recall, f1

def left_pad_sequence(sequence, batch_first, padding_value=0):
    padded = []
    max_len = max(len(each) for each in sequence)
    for each in sequence:
        if not isinstance(each, torch.LongTensor):
            each = torch.LongTensor(each)
        pad = torch.full((max_len-len(each),), fill_value=padding_value,dtype=each.dtype)
        padded.append(torch.cat([pad, each]))
    padded = torch.vstack(padded)
    if not batch_first:
        padded = padded.permute(1, 0, 2)
    return padded
        
def inference_colate_fn(batch):
    batch_input_ids = [torch.LongTensor(each["input_ids"]) for each in batch]
    batch_att_mask = [torch.LongTensor(each["attention_mask"]) for each in batch]
    batch_label = [torch.LongTensor(each["labels"]) for each in batch]
    batch_position_ids = [torch.arange(len(each["input_ids"]), dtype=torch.long) for each in batch]
    
    padded_batch_input_ids = left_pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    padded_batch_att_mask = left_pad_sequence(batch_att_mask, batch_first=True, padding_value=0)
    padded_batch_label = pad_sequence(batch_label, batch_first=True, padding_value=-100)
    # padded_batch_position_ids = left_pad_sequence(batch_position_ids, batch_first=True, padding_value=0)
    # return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "position_ids":padded_batch_position_ids, "labels": padded_batch_label}   
    return {"input_ids": padded_batch_input_ids, "attention_mask": padded_batch_att_mask, "labels": padded_batch_label}    

val_dataset.inference()
inf_loader = DataLoader(val_dataset, batch_size=64, collate_fn=inference_colate_fn)

In [240]:
model.load_state_dict(torch.load("gpt2_ckpt.pt", map_location=device))
sys, gold = fill_idiom(model, loader=inf_loader)
p, r, f1 = f1_score(sys, gold)
print(p, r, f1)

0.46860514117151286 0.4672268907563025 0.46791500105196715


In [228]:
# for batch in inf_loader:   
#     batch = to_device(batch, device)
#     input_ids = batch["input_ids"]
#     attention_mask = batch["attention_mask"]
#     pos_ids = batch["position_ids"]
#     labels = batch["labels"]
#     outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, position_ids=pos_ids, return_dict_in_generate=True, pad_token_id=50256, max_length=512, top_k=10, stopping_criteria=stop_criteria_list)
#     res = tokenizer.batch_decode(outputs["sequences"])
#     break

In [116]:
# for batch in inf_loader:   
#     batch = to_device(batch, device)
#     input_ids = batch["input_ids"]
#     attention_mask = batch["attention_mask"]
#     labels = batch["labels"]
#     outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, return_dict_in_generate=True, pad_token_id=50256, max_length=512, top_k=10, stopping_criteria=stop_criteria_list)
#     break