In [75]:
# 训练数据预处理
import numpy as np
from transformers import BertTokenizer, AdamW, BertModel, BertPreTrainedModel, BertConfig, AutoTokenizer, AutoModel
import pandas as pd
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
import tqdm

# load pre-trained model

In [65]:
# tokenizer  = AutoTokenizer.from_pretrained('/home/zhoujx/Pretrained_models/chinese_xlnet_base_pytorch')
bert_config = BertConfig.from_pretrained(r'/home/zhoujx/Pretrained_models/chinese_roberta_wwm_large_ext_pytorch/bert_config.json', output_hidden_states=True)
tokenizer  = BertTokenizer.from_pretrained(r'/home/zhoujx/Pretrained_models/chinese_roberta_wwm_large_ext_pytorch/vocab.txt', config=bert_config)

I0415 10:30:03.341242 140138955777792 configuration_utils.py:281] loading configuration file /home/zhoujx/Pretrained_models/chinese_roberta_wwm_large_ext_pytorch/bert_config.json
I0415 10:30:03.344393 140138955777792 configuration_utils.py:319] Model config BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "m

# load data

In [78]:
df_train = pd.read_csv(r"./nCoV_100k_train.labled.csv")
df_test = pd.read_csv(r"./nCov_10k_test.csv")

In [None]:
pd.set_option('')

In [79]:
df_train.head()

Unnamed: 0,微博id,微博发布时间,发布人账号,微博中文内容,微博图片,微博视频,情感倾向
0,4456072029125500,01月01日 23:50,存曦1988,写在年末冬初孩子流感的第五天，我们仍然没有忘记热情拥抱这2020年的第一天。带着一丝迷信，早...,['https://ww2.sinaimg.cn/orj360/005VnA1zly1gah...,[],0
1,4456074167480980,01月01日 23:58,LunaKrys,开年大模型…累到以为自己发烧了腰疼膝盖疼腿疼胳膊疼脖子疼#Luna的Krystallife#?,[],[],-1
2,4456054253264520,01月01日 22:39,小王爷学辩论o_O,邱晨这就是我爹，爹，发烧快好，毕竟美好的假期拿来养病不太好，假期还是要好好享受快乐，爹，新...,['https://ww2.sinaimg.cn/thumb150/006ymYXKgy1g...,[],1
3,4456061509126470,01月01日 23:08,芩鎟,新年的第一天感冒又发烧的也太衰了但是我要想着明天一定会好的?,['https://ww2.sinaimg.cn/orj360/005FL9LZgy1gah...,[],1
4,4455979322528190,01月01日 17:42,changlwj,问：我们意念里有坏的想法了，天神就会给记下来，那如果有好的想法也会被记下来吗？答：那当然了。...,[],[],1


## train_data & dev_data

In [76]:
df_train['微博中文内容'] = df_train.微博中文内容.fillna('内容缺失')

In [77]:
df_train.columns

Index(['微博id', '微博发布时间', '发布人账号', '微博中文内容', '微博图片', '微博视频', '情感倾向'], dtype='object')

In [60]:
def get_data(df):
    with tqdm.tqdm(range(df.shape[0])) as qbar:
        input_ids_list = []
        token_type_ids_list = []
        attention_mask_list = []
        for idx in qbar:
            tokenize_out = tokenizer.encode_plus(df.loc[idx, '微博中文内容'], pad_to_max_length=True, max_length=150)
            input_ids = tokenize_out['input_ids']
            token_type_ids = tokenize_out['token_type_ids']
            attention_mask = tokenize_out['attention_mask']

            input_ids_list.append(input_ids)
            token_type_ids_list.append(token_type_ids)
            attention_mask_list.append(attention_mask)
    
    df['input_ids'] = input_ids_list
    df['token_type_ids'] = token_type_ids_list
    df['attention_mask'] = attention_mask_list
    df.drop(['微博发布时间','发布人账号','微博中文内容','微博图片','微博视频'], axis=1, inplace=True)
    return df
    

In [61]:
df_train = get_data(df_train)

100%|██████████| 100000/100000 [04:07<00:00, 307.27it/s]


In [63]:
df_train = df_train[df_train.情感倾向.isin(['-1','0','1',-1,0,1])]
# df_train['情感倾向'] = df_train.情感倾向.astype(str)
df_train['情感倾向'] = df_train.情感倾向.map({'-1':0, '0':1, '1':2})
df_train = df_train.reset_index()
df_train.to_csv('./df_train.csv', index=False)

## test_data

In [67]:
df_test['微博中文内容'] = df_test.微博中文内容.fillna('内容缺失')

In [69]:
df_test.columns

Index(['微博id', '微博发布时间', '发布人账号', '微博中文内容', '微博图片', '微博视频'], dtype='object')

In [70]:
def get_test_data(df):
    with tqdm.tqdm(range(df.shape[0])) as qbar:
        input_ids_list = []
        token_type_ids_list = []
        attention_mask_list = []
        for idx in qbar:
            tokenize_out = tokenizer.encode_plus(df.loc[idx, '微博中文内容'], pad_to_max_length=True, max_length=150)
            input_ids = tokenize_out['input_ids']
            token_type_ids = tokenize_out['token_type_ids']
            attention_mask = tokenize_out['attention_mask']

            input_ids_list.append(input_ids)
            token_type_ids_list.append(token_type_ids)
            attention_mask_list.append(attention_mask)
    
    df['input_ids'] = input_ids_list
    df['token_type_ids'] = token_type_ids_list
    df['attention_mask'] = attention_mask_list
    df.drop(['微博发布时间','发布人账号','微博中文内容','微博图片','微博视频'], axis=1, inplace=True)
    return df

In [71]:
df_test = get_test_data(df_test)

100%|██████████| 10000/10000 [00:24<00:00, 416.17it/s]


In [72]:
df_test = df_test.reset_index()
df_test.to_csv('./df_test.csv', index=False)

In [74]:
df_test.head()

Unnamed: 0,index,微博id,input_ids,token_type_ids,attention_mask
0,0,4456068992182160,"[101, 108, 872, 1962, 8439, 108, 3173, 2399, 5...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,1,4456424178427250,"[101, 1920, 2140, 1348, 2697, 1088, 7965, 1853...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,2,4456797466940200,"[101, 6820, 6206, 1343, 6783, 697, 1921, 3890,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,3,4456791021108920,"[101, 2769, 1922, 7410, 749, 1166, 782, 2582, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,4,4457086404997440,"[101, 3362, 4197, 3221, 6206, 4567, 671, 1767,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
