In [1]:
import torch
import pandas as pd
import numpy as np

# 数据预处理

## 数据读取

In [2]:
from tqdm.auto import tqdm
def txt2list(path):
    Text = []
    with open(path, "r", encoding="utf-8") as F:
        for line in tqdm(F):
            Text.append(line.strip().split("_!_"))
    return Text

In [3]:
train_path = r"../input/textclassification/classification/train.txt"
test_path = r"../input/textclassification/classification/test.txt"
valid_path = r"../input/textclassification/classification/valid.txt"

train_Text = txt2list(train_path)
test_Text = txt2list(test_path)
valid_Text = txt2list(valid_path)

train_Text[0], test_Text[0], valid_Text[0]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

(['6552368441838272771', '101', 'news_culture', '发酵床的垫料种类有哪些？哪种更好？', ''],
 ['6551700932705387022',
  '101',
  'news_culture',
  '京城最值得你来场文化之旅的博物馆',
  '保利集团,马未都,中国科学技术馆,博物馆,新中国'],
 ['6552452982015787268', '101', 'news_culture', '松涛听雨莺婉转，下联？', ''])

## 标记数据

In [4]:
train_arr = np.array(train_Text)
valid_arr = np.array(valid_Text)
test_arr = np.array(test_Text)

# 重新制作标签为 0 -> 16

X_train = train_arr[:, 3]
y_train = np.apply_along_axis(lambda x: x - 100, 0, train_arr[:, 1].astype(np.int32))

X_valid = valid_arr[:, 3]
y_valid = np.apply_along_axis(lambda x: x - 100, 0, valid_arr[:, 1].astype(np.int32))

X_test = test_arr[:, 3]
y_test = np.apply_along_axis(lambda x: x - 100, 0, test_arr[:, 1].astype(np.int32))

X_train[0], y_train[0]

('发酵床的垫料种类有哪些？哪种更好？', 1)

In [5]:
## 使用autotokenizer对文本进行标记
from transformers import AutoTokenizer

# 使用bert-base-chinese模型

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

train_encoding = tokenizer(list(X_train), truncation=True, padding=True, max_length=64)
valid_encoding = tokenizer(list(X_valid), truncation=True, padding=True, max_length=64)
test_encoding = tokenizer(list(X_test), truncation=True, padding=True, max_length=64)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/624 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/107k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/263k [00:00<?, ?B/s]

## 创建数据集

In [6]:
from torch.utils.data import Dataset,DataLoader,TensorDataset

class tokenizedDataset(Dataset):
    def __init__(self, encoding, labels):
        super(tokenizedDataset, self).__init__()
        self.encoding = encoding
        self.labels = labels

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encoding.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        return item
    
train_dataset = tokenizedDataset(train_encoding, y_train)
valid_dataset = tokenizedDataset(valid_encoding, y_valid)
test_dataset = tokenizedDataset(test_encoding, y_test)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

# 模型

In [8]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=17)
# 虽然是15分类的数据，但是ind是17分类的ind，因此假定是有两类无数据的17分类

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

Downloading:   0%|          | 0.00/393M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## 训练函数

In [9]:
from tqdm.auto import tqdm

def train(model, train_loader):
    model.train()
    train_acc = 0.
    train_loss = 0.
    c = 0
    
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        logits = outputs.logits
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        predictions = torch.argmax(logits, dim=-1)
        num_correct = torch.sum(predictions == labels)

        train_acc += num_correct.cpu()
        train_loss += loss.item()
                
        c += 1
        if c % 500 == 0:
            print("epoth: %d, iter_num: %d, loss: %.4f" % (epoch, c, loss.item()))
    
    print('train Loss: {:.6f}, Acc: {:.6f}'.format(loss, train_acc))
    print("-------------------------------")
    return train_loss, train_acc

In [10]:
for batch in train_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    print(outputs[0])
    break

tensor(2.9775, device='cuda:0', grad_fn=<NllLossBackward0>)


## 评估函数

In [11]:
from sklearn.metrics import classification_report

target_names = ['news_story',
'news_culture',
'news_entertainment',
'news_sports',
'news_finance',
'idx5-无数据',
'news_house',
'news_car',
'news_edu',
'news_tech',
'news_military',
'idx11-无数据',
'news_travel',
'news_world',
'stock',
'news_agriculture',
'news_game']

def evaluation(model, valid_loader):
    model.eval()
    eval_acc = 0.
    eval_loss = 0.
    pred_list = []
    labels_list = []
    for batch in tqdm(valid_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        num_correct = torch.sum(predictions == labels)
        
        eval_acc += num_correct.cpu()
        eval_loss += loss.item()
        
        pred_list.append(predictions.cpu())
        labels_list.append(labels.cpu())


    labels = torch.cat(labels_list).view(-1)
    pred = torch.cat(pred_list).view(-1)
    eval_loss /= len(labels)
    eval_acc /= len(labels)
    
    print('valid Loss: {:.6f}, Acc: {:.6f}'.format(loss, eval_acc))
    print("-------------------------------")
    print(classification_report(labels, pred, labels=range(17), target_names=target_names))
    return eval_loss, eval_acc

# 训练

In [12]:
from torch.optim import AdamW
from transformers import get_scheduler

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 1

In [13]:
print("training on ", device)

for epoch in range(num_epochs):
    print('epoch {}'.format(epoch + 1))
    train_loss, train_acc = train(model, train_loader)
    valid_loss, valid_acc = evaluation(model, valid_loader)

training on  cuda
epoch 1


  0%|          | 0/19135 [00:00<?, ?it/s]

epoth: 0, iter_num: 500, loss: 0.2711
epoth: 0, iter_num: 1000, loss: 1.0612
epoth: 0, iter_num: 1500, loss: 0.6627
epoth: 0, iter_num: 2000, loss: 0.7588
epoth: 0, iter_num: 2500, loss: 0.9288
epoth: 0, iter_num: 3000, loss: 0.4514
epoth: 0, iter_num: 3500, loss: 0.8017
epoth: 0, iter_num: 4000, loss: 0.3315
epoth: 0, iter_num: 4500, loss: 0.4068
epoth: 0, iter_num: 5000, loss: 1.1941
epoth: 0, iter_num: 5500, loss: 0.4379
epoth: 0, iter_num: 6000, loss: 0.2473
epoth: 0, iter_num: 6500, loss: 0.3364
epoth: 0, iter_num: 7000, loss: 0.3692
epoth: 0, iter_num: 7500, loss: 0.3929
epoth: 0, iter_num: 8000, loss: 0.2160
epoth: 0, iter_num: 8500, loss: 0.1653
epoth: 0, iter_num: 9000, loss: 0.6662
epoth: 0, iter_num: 9500, loss: 0.3494
epoth: 0, iter_num: 10000, loss: 0.4817
epoth: 0, iter_num: 10500, loss: 0.4790
epoth: 0, iter_num: 11000, loss: 0.7480
epoth: 0, iter_num: 11500, loss: 0.8667
epoth: 0, iter_num: 12000, loss: 0.7908
epoth: 0, iter_num: 12500, loss: 0.3355
epoth: 0, iter_num: 

  0%|          | 0/2392 [00:00<?, ?it/s]

valid Loss: 1.294381, Acc: 0.866693
-------------------------------
                    precision    recall  f1-score   support

        news_story       0.74      0.85      0.79       624
      news_culture       0.82      0.89      0.85      2878
news_entertainment       0.91      0.89      0.90      3924
       news_sports       0.96      0.91      0.94      3729
      news_finance       0.79      0.77      0.78      2640
          idx5-无数据       0.00      0.00      0.00         0
        news_house       0.88      0.90      0.89      1820
          news_car       0.92      0.91      0.91      3581
          news_edu       0.91      0.86      0.88      2674
         news_tech       0.85      0.86      0.85      4154
     news_military       0.87      0.85      0.86      2494
         idx11-无数据       0.00      0.00      0.00         0
       news_travel       0.81      0.84      0.82      2181
        news_world       0.84      0.78      0.81      2618
             stock       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# 评估测试集表现

In [14]:
len(test_dataset), len(valid_dataset)

(38273, 38265)

In [15]:
evaluation(model, train_loader)

  0%|          | 0/19135 [00:00<?, ?it/s]

valid Loss: 0.022841, Acc: 0.887565
-------------------------------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                    precision    recall  f1-score   support

        news_story       0.77      0.87      0.82      4991
      news_culture       0.84      0.91      0.88     22377
news_entertainment       0.92      0.90      0.91     31464
       news_sports       0.97      0.93      0.95     30015
      news_finance       0.83      0.80      0.81     21657
          idx5-无数据       0.00      0.00      0.00         0
        news_house       0.89      0.93      0.91     14134
          news_car       0.93      0.92      0.93     28660
          news_edu       0.93      0.88      0.90     21655
         news_tech       0.87      0.89      0.88     33360
     news_military       0.88      0.87      0.88     20003
         idx11-无数据       0.00      0.00      0.00         0
       news_travel       0.86      0.85      0.85     17200
        news_world       0.87      0.80      0.83     21554
             stock       0.00      0.00      0.00       268
  news_agriculture       0.84      0.91

(0.02356302402865506, tensor(0.8876))

In [16]:
evaluation(model, valid_loader)

  0%|          | 0/2392 [00:00<?, ?it/s]

valid Loss: 0.198649, Acc: 0.866693
-------------------------------
                    precision    recall  f1-score   support

        news_story       0.74      0.85      0.79       624
      news_culture       0.82      0.89      0.85      2878
news_entertainment       0.91      0.89      0.90      3924
       news_sports       0.96      0.91      0.94      3729
      news_finance       0.79      0.77      0.78      2640
          idx5-无数据       0.00      0.00      0.00         0
        news_house       0.88      0.90      0.89      1820
          news_car       0.92      0.91      0.91      3581
          news_edu       0.91      0.86      0.88      2674
         news_tech       0.85      0.86      0.85      4154
     news_military       0.87      0.85      0.86      2494
         idx11-无数据       0.00      0.00      0.00         0
       news_travel       0.81      0.84      0.82      2181
        news_world       0.84      0.78      0.81      2618
             stock       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


(0.028208660772981776, tensor(0.8667))

In [17]:
evaluation(model, test_loader)

  0%|          | 0/2393 [00:00<?, ?it/s]

valid Loss: 0.346272, Acc: 0.866668
-------------------------------
                    precision    recall  f1-score   support

        news_story       0.75      0.84      0.79       658
      news_culture       0.82      0.88      0.85      2776
news_entertainment       0.90      0.88      0.89      4008
       news_sports       0.96      0.92      0.94      3824
      news_finance       0.80      0.76      0.78      2788
          idx5-无数据       0.00      0.00      0.00         0
        news_house       0.86      0.91      0.89      1718
          news_car       0.92      0.91      0.91      3544
          news_edu       0.91      0.85      0.88      2729
         news_tech       0.84      0.88      0.86      4029
     news_military       0.86      0.86      0.86      2487
         idx11-无数据       0.00      0.00      0.00         0
       news_travel       0.82      0.83      0.82      2041
        news_world       0.84      0.78      0.81      2737
             stock       0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


(0.028421858979592356, tensor(0.8667))

In [18]:
# classes
# 0 民生 故事 news_story
# 1 文化 文化 news_culture
# 2 娱乐 娱乐 news_entertainment
# 3 体育 体育 news_sports
# 4 财经 财经 news_finance
# 6 房产 房产 news_house
# 7 汽车 汽车 news_car
# 8 教育 教育 news_edu 
# 9 科技 科技 news_tech
# 10 军事 军事 news_military
# 12 旅游 旅游 news_travel
# 13 国际 国际 news_world
# 14 证券 股票 stock
# 15 农业 三农 news_agriculture
# 16 电竞 游戏 news_game