In [1]:
import random
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn as nn

In [2]:
from config import parse_args

args = parse_args()
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
setup_seed(args.seed)

In [3]:
data_path = '../dataset/train.csv'
df = pd.read_csv(data_path, delimiter="\t")
df['tag'] = df['tag'].apply(lambda x: eval(x))
df.info()

df.head(5)


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6000 entries, 0 to 5999
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    6000 non-null   object
 1   tag     6000 non-null   object
dtypes: object(2)
memory usage: 93.9+ KB


Unnamed: 0,text,tag
0,会安博物馆等，漫步会安古镇各精致的工艺品店、品尝路边的小吃摊，体验当地的风土民情。,[会安古镇]
1,贝蒂斯vs西班牙人,"[贝蒂斯, 西班牙人]"
2,最终橘子熊在特种部队项目以7：2，跑跑卡丁车项目以7：1痛击曜越太阳神，,[橘子熊]
3,2008年11月22日，北京的气温陡降到零下4度，但雍和宫星光现场里“beijing,[北京]
4,光谱代理《大战略PERFECT3》繁体版,[光谱]


In [4]:
bio_list = []
for i in tqdm(range(len(df))):
    text = df['text'][i]
    tags = df['tag'][i]
    bios = ['O']*len(text)
    for t in tags:
        idx = text.find(t)
        bios[idx] = 'B-0'
        for j in range(idx+1, idx+len(t)):
            bios[j] = 'I-0'
    bio_list.append(bios)

100%|██████████| 6000/6000 [00:00<00:00, 74072.68it/s]


In [5]:
df['bio'] = bio_list


In [6]:
from sklearn.model_selection import train_test_split
train_data, valid_data = train_test_split(df, test_size = 0.2, random_state=args.seed)
train_data.index = list(range(len(train_data)))
valid_data.index = list(range(len(valid_data)))
# print(len(train_data), len(valid_data))

In [7]:
# 将text和标注组合存进元组
train_data['training_data'] = train_data.apply(lambda row: [list(row['text']), row['bio']], axis=1)
valid_data['validating_data'] = valid_data.apply(lambda row: [list(row['text']), row['bio']], axis=1)

# test_data['testing_data'] = test_data.apply(lambda row: list(row['text']), axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [8]:
training_data_txt = train_data['training_data'].to_list()
validating_data_txt = valid_data['validating_data'].to_list()
# testing_data_txt = test_data['testing_data'].to_list()
print('训练集大小：',len(training_data_txt))
print('验证集大小：',len(validating_data_txt))


训练集大小： 4800
验证集大小： 1200


In [9]:
# --------------------------建立字典，字: idx-------------------------------------
word2idx = {}
# 训练集的
for sentence, tags in training_data_txt:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = len(word2idx)

# 验证集的
for sentence, tags in validating_data_txt:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = len(word2idx)

# 测试集的
# testing_data = testing_data_txt
# for sentence in testing_data:
#     for word in sentence:
#         if word not in word2idx:
#             word2idx[word] = len(word2idx)

# 加2个特殊字符
word2idx['<UNK>'] = len(word2idx)
word2idx['<PAD>'] = len(word2idx)

args.word2idx = word2idx
import pickle
with open('./word2idx.pkl', 'wb') as f:
    pickle.dump(args.word2idx, f)


args.vocab_len = len(word2idx)

print('vocab_len: ', args.vocab_len)

vocab_len:  3040


In [10]:

args.tag2idx = {'O':0, 'B-0':1, 'I-0':2}
args.idx2tag = {0: 'O', 1: 'B-0', 2:'I-0'}

In [11]:
# training_data_txt

In [12]:
from data_helper import create_data_loader
train_data_loader = create_data_loader(training_data_txt, args)
valid_data_loader = create_data_loader(validating_data_txt, args)
# test_data_loader = create_data_loader(testing_data_txt, configs) # 没有标签的测试集就不这样构建，因为没有label

In [13]:
len(train_data_loader),len(valid_data_loader)

(300, 75)

In [14]:
def jaccard_score(pred, label):
    return len(set(pred) & set(label)) / len(set(pred) | set(label))

In [15]:
def train_epoch(model, data_loader, optimizer, args):
    # 训练模式
    model = model.train()
    train_loss = 0
    for sample in tqdm(data_loader):
        sentence_tensor = sample['sentence_tensor'].to(args.device)
        mask_tensor = sample['mask_tensor'].to(args.device)
        label_tensor = sample['label_tensor'].to(args.device)
        # print(sentence_tensor)
        # print(mask_tensor)
        # print(label_tensor)
        out, loss = model(sentence_tensor=sentence_tensor,
                        label_tensor=label_tensor,
                        mask_tensor=mask_tensor)
        # print(out)

        train_loss += loss.item()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        # scheduler.step()
        optimizer.zero_grad()

    return train_loss/len(data_loader)

from ark_nlp.factory.utils.conlleval import get_entity_bio
def return_entity(label):
    entity_labels = []
    for _type, _start_idx, _end_idx in get_entity_bio(label, id2label=None):
            entity_labels.append({
                'start_idx': _start_idx,
                'end_idx': _end_idx,
                'type': _type
            })
    entity_labels = [str(dic['start_idx'])+'-'+str(dic['end_idx']) for dic in entity_labels]
    return entity_labels


def eval_epoch(model, data_loader, args):
    # 验证模式
    model = model.eval()
    val_loss = 0
    jc_score_list = []
    # 关闭自动求导，省内存加速，因为是不是训练模式了，没必要求导
    with torch.no_grad():
        for sample in tqdm(data_loader):
            sentence_tensor = sample['sentence_tensor'].to(args.device)
            mask_tensor = sample['mask_tensor'].to(args.device)
            label_tensor = sample['label_tensor'].to(args.device)
            out, loss = model(sentence_tensor=sentence_tensor,
                        label_tensor=label_tensor,
                        mask_tensor=mask_tensor)

            val_loss += loss.item()


            predict_ids = out
            # predict_ids
            #%%
            label_ids = sample['label_tensor'].numpy().tolist()

            entity_all_label_ids = []
            entity_all_predict_ids = []
            for i in range(len(label_ids)):
                tmp_label, tmp_predict = [], []
                # 因为我crf有做mask所以这里的len(len(predict_tag[i]))是不带有pad的长度
                for j in range(0, len(predict_ids[i])):
                    tmp_label.append(args.idx2tag[label_ids[i][j]])
                    tmp_predict.append(args.idx2tag[predict_ids[i][j]])
                entity_all_label_ids.append(tmp_label)
                entity_all_predict_ids.append(tmp_predict)


            for label, pred in zip(entity_all_label_ids, entity_all_predict_ids):
                label_entity = return_entity(label)
                pred_entity = return_entity(pred)
                jc_score_list.append(jaccard_score(pred=pred_entity, label=label_entity))

    return val_loss/len(data_loader), np.mean(jc_score_list)


In [16]:
from model import BiLSTM_CRF
import torch.optim as optim
if torch.cuda.is_available():
    args.device = 'cuda:0'
    print('使用：', args.device,' ing........')
model = BiLSTM_CRF(args).to(args.device)
# 优化器
optimizer = optim.Adam(model.parameters(), lr=1e-3)


使用： cuda:0  ing........


In [17]:
best_jc_score = 0
for epoch in range(args.max_epochs):
    print('——'*10, f'Epoch {epoch + 1}/{args.max_epochs}', '——'*10)
    train_loss = train_epoch(model, train_data_loader, optimizer, args)
    # #scheduler.step()
    # print('-'*20)
    print(f'Train loss : {round(train_loss, 2)}\n')
    val_loss, jc_score = eval_epoch(model, valid_data_loader, args)


    if jc_score>best_jc_score:
        best_jc_score = jc_score
        print(f'val loss : {round(val_loss, 3)}')
        print(f"jc_score: {round(jc_score, 3)}")
        print('-'*20)
        torch.save(model.state_dict(), './save_model/best_model.pth')
        print('+'*6,'best save_model saved','+'*6)


———————————————————— Epoch 1/16 ————————————————————
Train loss : 175.03

val loss : 117.332
jc_score: 0.434
--------------------
++++++ best save_model saved ++++++
———————————————————— Epoch 2/16 ————————————————————
Train loss : 85.81

val loss : 89.803
jc_score: 0.526
--------------------
++++++ best save_model saved ++++++
———————————————————— Epoch 3/16 ————————————————————
Train loss : 52.94

val loss : 83.354
jc_score: 0.594
--------------------
++++++ best save_model saved ++++++
———————————————————— Epoch 4/16 ————————————————————
Train loss : 33.87

———————————————————— Epoch 5/16 ————————————————————
Train loss : 23.72

val loss : 86.901
jc_score: 0.635
--------------------
++++++ best save_model saved ++++++
———————————————————— Epoch 6/16 ————————————————————
Train loss : 17.47

———————————————————— Epoch 7/16 ————————————————————
Train loss : 14.19

———————————————————— Epoch 8/16 ————————————————————
Train loss : 12.3

———————————————————— Epoch 9/16 ———————————————————

100%|██████████| 300/300 [00:23<00:00, 12.60it/s]
100%|██████████| 75/75 [00:02<00:00, 25.94it/s]
100%|██████████| 300/300 [00:21<00:00, 13.71it/s]
100%|██████████| 75/75 [00:02<00:00, 25.63it/s]
100%|██████████| 300/300 [00:21<00:00, 14.09it/s]
100%|██████████| 75/75 [00:02<00:00, 26.16it/s]
100%|██████████| 300/300 [00:21<00:00, 14.27it/s]
100%|██████████| 75/75 [00:02<00:00, 25.82it/s]
100%|██████████| 300/300 [00:21<00:00, 13.68it/s]
100%|██████████| 75/75 [00:02<00:00, 25.86it/s]
100%|██████████| 300/300 [00:23<00:00, 12.76it/s]
100%|██████████| 75/75 [00:03<00:00, 24.64it/s]
100%|██████████| 300/300 [00:22<00:00, 13.09it/s]
100%|██████████| 75/75 [00:02<00:00, 25.08it/s]
100%|██████████| 300/300 [00:22<00:00, 13.31it/s]
100%|██████████| 75/75 [00:02<00:00, 25.42it/s]
100%|██████████| 300/300 [00:25<00:00, 11.70it/s]
100%|██████████| 75/75 [00:03<00:00, 20.97it/s]
100%|██████████| 300/300 [00:24<00:00, 12.33it/s]
100%|██████████| 75/75 [00:03<00:00, 24.83it/s]
100%|██████████| 300