In [329]:
import os
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as npa
from tqdm import tqdm
import json
import jieba
import pickle as pkl
from typing import Any

import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F
from TorchCRF import CRF

from transformers import BertTokenizerFast

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [3]:
with open('renmindata.pkl', 'rb') as f:
    data = pkl.load(f)

In [4]:
def tokenize(text):
    from transformers import BertTokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    text = "这是一个测试文本。" if text is None else text
    tokens = tokenizer.tokenize(text)

    return tokens

In [5]:
from datasets import load_dataset

ds = load_dataset("qgyd2021/chinese_ner_sft", "Bank", trust_remote_code=True)

In [6]:
text_sample = ds['train']['text'][0]
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
text_sample = text_sample[:30]

In [73]:
class LSTM_CRF(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, tag_size) -> None:
        super().__init__()
        self.tag_size = tag_size
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size //2, num_layers=num_layers, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_size, tag_size)
        self.trans = nn.Parameter(torch.rand((tag_size, tag_size)))
        
    def _get_lstm(self, x):
        x = self.embed(x)
        lstm_out, _ = self.lstm(x)
        return lstm_out
    
    def _forward_alg(self, feats):
        init_alpha = torch.full((1, self.tag_size), -1000)
        init_alpha[0][0] = 0.0
        forward_var = init_alpha
        batch_size = feats.shape[0]
        for word_idx in range(feats.shape[1]):
            feat = feats[:, word_idx, :]
            scores = (
                forward_var.unsqueeze(1).expand(batch_size, self.tag_size, self.tag_size) +
                self.trans.unsqueeze(0).expand(batch_size, self.tag_size, self.tag_size) + 
                feat.unsqueeze(2).expand(batch_size, self.tag_size, self.tag_size)
            )
            forward_var = torch.logsumexp(scores, dim=2)
        
        return forward_var
    
    def _score_sentence(self, feats, tags):
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([0], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score = score + self.trans[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        return score
    
    def _viterbi_decode(self, feats):
        backpointers = []
        init_vvars = torch.full((1, self.tag_size), -1000.0)
        init_vvars[0][0] = 0.0
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []
            viterbivars_t = []
            for next_tag in range(self.tag_size):
                next_tag_var = forward_var + self.trans[next_tag]
                best_tag_id = torch.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)
        
        best_tag_id = torch.argmax(forward_var)
        path_score = forward_var[0][best_tag_id]
        
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        best_path.pop()
        best_path.reverse()
        return path_score, best_path
    
    def neg_log_likelihood(self, x, tags):
        feats = self._get_lstm(x)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score
    
    def forward(self, x):
        feats = self._get_lstm(x)
        score, tag_seq = self._viterbi_decode(feats)
        return score, tag_seq


In [9]:
ds['train']['text'][0][:30]

'交行14年用过，半年准备提额，却直接被降到1Ｋ，半年期间只T'

In [10]:
ds['train']['entities'][0]

{'start_idx': [0, 12, 19, 42, 54, 58, 64, 70],
 'end_idx': [2, 14, 21, 44, 56, 60, 66, 71],
 'entity_text': ['交行', '提额', '降到', '消费', '增加', '提额', '分期', '降'],
 'entity_label': ['BANK',
  'COMMENTS_N',
  'COMMENTS_ADJ',
  'COMMENTS_N',
  'COMMENTS_N',
  'COMMENTS_N',
  'PRODUCT',
  'COMMENTS_ADJ'],
 'entity_names': [['银行', '银行名称'],
  ['金融名词'],
  ['形容词'],
  ['金融名词'],
  ['金融名词'],
  ['金融名词'],
  ['产品', '产品名称', '金融名词', '金融产品', '银行产品'],
  ['形容词']]}

In [91]:
def _mapping_idx(text, entity):
    tokens = tokenizer.tokenize(text)
    entity_idx = 0
    label_list = ['O'] * len(tokens)
    map_index = tokenizer(text, return_offsets_mapping=True)['offset_mapping']
    map_index.pop(0)
    map_index.pop(-1)
    
    for token_idx, token_int in enumerate(map_index):
        char_start, char_end = token_int
        if entity_idx >= len(entity['start_idx']):
            print(token_idx)
            break
        try:
            entity_start, entity_end = entity['start_idx'][entity_idx], entity['end_idx'][entity_idx]
            label = entity['entity_label'][entity_idx]
        except:
            # print(tokens, entity['entity_label'])
            print(entity_idx)
            print(token_idx, token_int)
            print(entity['start_idx'])
        if char_start == entity_start:
            label_list[token_idx] = f'B-{label}'
        elif char_start > entity_start and char_end <= entity_end:
            label_list[token_idx] = f'I-{label}'
        
        if char_end >= entity_end:
            entity_idx += 1
    return tokens, label_list

In [92]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
text_sample = ds['train']['text'][1]
entity = ds['train']['entities'][1]
_mapping_idx(text_sample, entity)

(['单',
  '标',
  '我',
  '有',
  '了',
  '，',
  '最',
  '近',
  'visa',
  '双',
  '标',
  '返',
  '现',
  '活',
  '动',
  '好'],
 ['B-PRODUCT',
  'I-PRODUCT',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B-PRODUCT',
  'B-PRODUCT',
  'I-PRODUCT',
  'B-COMMENTS_N',
  'I-COMMENTS_N',
  'I-COMMENTS_N',
  'I-COMMENTS_N',
  'B-COMMENTS_ADJ'])

In [82]:
entity

{'start_idx': [0, 8, 12, 14, 18],
 'end_idx': [2, 12, 14, 18, 19],
 'entity_text': ['单标', 'visa', '双标', '返现活动', '好'],
 'entity_label': ['PRODUCT',
  'PRODUCT',
  'PRODUCT',
  'COMMENTS_N',
  'COMMENTS_ADJ'],
 'entity_names': [['产品', '产品名称', '金融名词', '金融产品', '银行产品'],
  ['产品', '产品名称', '金融名词', '金融产品', '银行产品'],
  ['产品', '产品名称', '金融名词', '金融产品', '银行产品'],
  ['金融名词'],
  ['形容词']]}

In [358]:
class EntityData(Dataset):
    def __init__(self, data, tokenizer) -> None:
        super().__init__()
        self.labels = None
        self.token_list = None
        
        self.tokenizer = tokenizer
        self.tag_size = None
        self.PAD_TAG = 'PAD'
        
        raw_text = data['text']
        entities = data['entities']
        self._get_CRF_labels(entities)
        self._processing_data(raw_text, entities)
        self.tag_size = len(self.labels2id)
    
    def __len__(self):
        return len(self.token_list)
    
    def _get_CRF_labels(self, entities):
        raw_labels = {label for entity in entities for label in entity['entity_label']}
        labels_cate = {f'B-{label}' for label in raw_labels} | {f'I-{label}' for label in raw_labels} | {'O'}
        
        self.labels2id = {label: idx+1 for idx, label in enumerate(labels_cate)}
        self.labels2id['[PAD]'] = 0
        self.id2labels = {idx: label for label, idx in self.labels2id.items()}
    
    def _mapping_idx(self, text, entity):
        tokens = self.tokenizer.tokenize(text)
        entity_idx = 0
        label_list = ['O'] * len(tokens)
        map_index = self.tokenizer(text, return_offsets_mapping=True)['offset_mapping']
        map_index.pop(0)
        map_index.pop(-1)
        
        for token_idx, token_int in enumerate(map_index):
            char_start, char_end = token_int
            if entity_idx >= len(entity['start_idx']):
                # print(token_idx)
                break
            # try:
            entity_start, entity_end = entity['start_idx'][entity_idx], entity['end_idx'][entity_idx]
            # except:
            #     print(entity_idx)
            #     print(token_idx, token_int)
            #     print(entity['start_idx'])
            label = entity['entity_label'][entity_idx]
            if char_start == entity_start:
                label_list[token_idx] = f'B-{label}'
            elif char_start > entity_start and char_end <= entity_end:
                label_list[token_idx] = f'I-{label}'
            if char_end >= entity_end:
                entity_idx += 1
        return tokens, label_list
    
    def _processing_data(self, raw_text, entities):
        self.token_list = []
        self.labels = []
        for text, entity in zip(raw_text, entities):
            tokens, label_list = self._mapping_idx(text, entity)
            self.token_list.append(tokens)
            self.labels.append(label_list)
            
    def decode_label(self, labels):
        return [self.id2labels(label) for label in labels]
    
    def decode_text(self, token_ids):
        return self.tokenizer.convert_ids_to_tokens(token_ids)
    
    def __getitem__(self, index) -> Any:
        inputs_id = self.tokenizer.convert_tokens_to_ids(self.token_list[index])
        inputs_labels = [self.labels2id[label] for label in self.labels[index]]
        return torch.tensor(inputs_id, dtype=torch.long), torch.tensor(inputs_labels, dtype=torch.long)
        # return {
        #     "input_id": torch.tensor(inputs_id, dtype=torch.long), 
        #     "label": torch.tensor(inputs_labels, dtype=torch.long)
        #     }
        # return torch.tensor(inputs_id, dtype=torch.long), torch.tensor(label, dtype=torch.long)

In [359]:
dataset = EntityData(ds['train'], tokenizer=tokenizer)

In [356]:
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])

In [407]:
def collate_fn(batch):
    max_length = 0
    # print('in collate_fn', max_length)
    # print(batch)
    for sample in batch:
        tokens, label = sample
        # tokens = sample['input_id']
        # label = sample['label']
        max_length = max(len(tokens), max_length)
        # print(batch)
        assert len(tokens) == len(label), f'the length of tokens {len(tokens)} is not equal to the length of labels {len(label)}'
    
    padded_tokens = []
    padded_labels = []
    
    for tokens, labels in batch:
        padded_tokens.append(torch.cat([tokens, torch.zeros(max_length - len(tokens), dtype=torch.long)]))
        padded_labels.append(torch.cat([labels, torch.zeros(max_length - len(labels), dtype=torch.long)]))
    
    return torch.stack(padded_tokens), torch.stack(padded_labels)

In [345]:
class LSTM_CRF(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, tag_size) -> None:
        super().__init__()
        self.tag_size = tag_size
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size // 2, num_layers=num_layers, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_size, tag_size)
        self.crf = CRF(self.tag_size)

    def _get_lstm_features(self, x):
        x = self.embed(x)
        lstm_out, _ = self.lstm(x)
        # print(f"after lstm: {lstm_out.shape}")
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats
    
    def forward(self, x, tags=None):
        mask = (x != 0).type(torch.bool)
        lstm_feats = self._get_lstm_features(x)
        # print("LSTM feats shape: ", lstm_feats.shape)

        if tags is not None:
            # print("Input shape: ", x.shape)
            # print("Mask shape: ", mask.shape)
            # print("tags shape: ", tags.shape)
            assert mask.shape == tags.shape
            loss = -self.crf.forward(lstm_feats, tags, mask=mask)
            return loss
        else:
            # print("LSTM feats shape: ", lstm_feats.shape)
            if len(lstm_feats.shape) == 2:
                lstm_feats.unsqueeze(0)
            pred_tags = self.crf.viterbi_decode(lstm_feats, mask=mask)
            return pred_tags

In [346]:
import torch
from TorchCRF import CRF
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 2
sequence_size = 3
num_labels = 5
mask = torch.BoolTensor([[1, 1, 1], [1, 1, 0]]).to(device) # (batch_size. sequence_size)
labels = torch.LongTensor([[0, 2, 3], [1, 4, 1]]).to(device)  # (batch_size, sequence_size)
hidden = torch.randn((batch_size, sequence_size, num_labels), requires_grad=True).to(device)
crf = CRF(num_labels)

In [316]:
crf.forward(hidden, labels, mask)

tensor([-5.9362, -2.4696], grad_fn=<SubBackward0>)

In [408]:
vocab_size = len(tokenizer.vocab)
embedding_dim = 128
hidden_dim = 256
tag_size = dataset.tag_size

model = LSTM_CRF(vocab_size=vocab_size, hidden_size=hidden_dim, num_layers=2, tag_size=tag_size)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

loss_list = []
for epoch in range(10):
    model.train()
    epoch_loss = 0.
    for batch in tqdm(train_dataloader):
        # print(batch)
        # input_ids = batch['input_ids']
        # labels = batch['labels']
        input_ids, labels = batch
        model.zero_grad()
        
        loss = model(input_ids, labels)
        loss = loss.mean()
        
        epoch_loss += loss
        
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")
    loss_list.append(epoch_loss)

torch.save(model.state_dict(), 'task4_.pth')

100%|██████████| 16/16 [00:28<00:00,  1.76s/it]


Epoch 0 loss: 16.95537757873535


100%|██████████| 16/16 [00:26<00:00,  1.65s/it]


Epoch 1 loss: 14.407644271850586


100%|██████████| 16/16 [00:22<00:00,  1.43s/it]


Epoch 2 loss: 12.117145538330078


100%|██████████| 16/16 [00:28<00:00,  1.81s/it]


Epoch 3 loss: 9.773664474487305


100%|██████████| 16/16 [00:27<00:00,  1.72s/it]


Epoch 4 loss: 6.583451747894287


100%|██████████| 16/16 [00:26<00:00,  1.63s/it]


Epoch 5 loss: 5.128957271575928


100%|██████████| 16/16 [00:26<00:00,  1.65s/it]


Epoch 6 loss: 3.8844590187072754


100%|██████████| 16/16 [00:28<00:00,  1.79s/it]


Epoch 7 loss: 3.4639732837677


100%|██████████| 16/16 [00:28<00:00,  1.77s/it]


Epoch 8 loss: 2.5577492713928223


100%|██████████| 16/16 [00:28<00:00,  1.76s/it]

Epoch 9 loss: 2.573702096939087





In [410]:
def check_for_pth_files(directory):
    # 列出目录中的所有文件和子目录
    for root, dirs, files in os.walk(directory):
        # 检查每个文件是否以.pth结尾
        for file in files:
            if file.endswith('.pth'):
                return True

def calculate_prediction(true_label, pred_label):
    pad_size = true_label.shape[1]
    pred = []
    for tl, pl in zip(true_label, pred_label):
        seq_len = len(pl)
        pl = torch.tensor(pl, dtype=torch.long)
        # print(tl[:seq_len] == pl)
        pred.append((tl[:seq_len] == pl).float().mean().item()) 
    return pred

In [411]:
if check_for_pth_files('./'):
    model = LSTM_CRF(vocab_size=vocab_size, hidden_size=hidden_dim, num_layers=2, tag_size=tag_size)
    model.load_state_dict(torch.load('task4.pth'))

prediction = []
model.eval()
with torch.no_grad():
    for test_batch in test_dataloader:
        test_input_ids, test_label = test_batch
        pred_label = model(test_input_ids)
        # prediction = model.decode(test_feats, test_mask)
        pred = calculate_prediction(test_label, pred_label)
        prediction += pred

In [412]:
torch.mean(torch.tensor(prediction))

tensor(0.9871)