In [22]:
import torch.nn as nn
from transformers import BertModel

# Bert

class BertClassifier(nn.Module):
    def __init__(self, bert_config, num_labels):
        super().__init__()
        # 定义BERT模型
        self.bert = BertModel(config=bert_config)
        # 定义Dropout层
        self.dropout = nn.Dropout(p=0.2)  # Dropout概率为0.2
        # 定义分类器
        self.classifier = nn.Linear(bert_config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # BERT的输出
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # 取[CLS]位置的pooled output
        pooled = bert_output[1]
        # 在pooled output上应用Dropout
        pooled = self.dropout(pooled)
        # 分类
        logits = self.classifier(pooled)
        # 返回结果
        return logits


# Bert+BiLSTM，用法与BertClassifier一样，可直接在train里面调用
class BertLstmClassifier(nn.Module):
    def __init__(self, bert_config, num_labels):
        super().__init__()
        self.bert = BertModel(config=bert_config)
        self.lstm = nn.LSTM(input_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size, num_layers=2, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(bert_config.hidden_size*2, num_labels)  # 双向LSTM 需要乘以2

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        last_hidden_state = outputs.last_hidden_state
        out, _ = self.lstm(last_hidden_state)
        logits = self.classifier(out[:, -1, :]) # 取最后时刻的输出
        return logits


In [23]:
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import BertTokenizer
import csv
class SinaNewsDataset(Dataset):
    def __init__(self, filename, tokenizer):
        # 数据集初始化
        self.labels = ['0','1','2','3']
        self.labels_id = list(range(4))
        self.tokenizer = tokenizer
        self.input_ids = []
        self.token_type_ids = []
        self.attention_mask = []
        self.label_id = []
        self.load_data(filename)
    
    def load_data(self, filename):
        # 加载数据
        print('Loading data from:', filename)
        with open(filename, 'r', encoding='utf-8') as rf:
            reader = csv.reader(rf)
            next(reader)  # 跳过表头，如果没有表头可以注释掉
            for line in tqdm(reader, ncols=100):
                label_id, text = int(line[0]), line[1]
                token = self.tokenizer(
                    text,
                    add_special_tokens=True,
                    padding='max_length',
                    truncation=True,
                    max_length=512
                )
                self.input_ids.append(np.array(token['input_ids']))
                self.token_type_ids.append(np.array(token.get('token_type_ids', [])))
                self.attention_mask.append(np.array(token['attention_mask']))
                self.label_id.append(label_id)

        print(f'Data loaded successfully: {len(self.input_ids)} samples.')

    def __getitem__(self, index):
        return self.input_ids[index], self.token_type_ids[index], self.attention_mask[index], self.label_id[index]

    def __len__(self):
        return len(self.input_ids)
    
    
## 取消注释测试
# tokenizer = BertTokenizer.from_pretrained('model/bert-base-chinese')
# data_loader = SinaNewsDataset("data/weibo_4/simplifyweibo_4_moods.csv",tokenizer)
# print(data_loader[0])

In [24]:
import csv
import random

def split_csv(file_path, train_path, val_path):
    """
    先随机取出数据集的1/5，再从中按8:2划分为训练集和验证集。

    Args:
        file_path (str): 原始 CSV 文件路径。
        train_path (str): 训练集文件路径。
        val_path (str): 验证集文件路径。
    """
    # 读取数据
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = list(csv.reader(f))
        header = reader[0]  # 提取表头
        data = reader[1:]   # 提取数据部分

    # 随机取出1/5的数据
    total_count = len(data)
    subset_count = total_count // 5
    subset_indices = random.sample(range(total_count), subset_count)
    subset_data = [data[i] for i in subset_indices]

    # 对这1/5按8:2划分为训练集和验证集
    split_index = int(len(subset_data) * 0.8)
    random.shuffle(subset_data)  # 再次打乱
    train_data = subset_data[:split_index]
    val_data = subset_data[split_index:]

    # 保存训练集
    with open(train_path, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(header)  # 写入表头
        writer.writerows(train_data)  # 写入数据

    # 保存验证集
    with open(val_path, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(header)  # 写入表头
        writer.writerows(val_data)  # 写入数据

    print(f"数据集拆分完成：")
    print(f"从原始数据集中随机选取 1/5 的数据。")
    print(f"训练集保存至: {train_path}, 样本数: {len(train_data)}")
    print(f"验证集保存至: {val_path}, 样本数: {len(val_data)}")

# 示例用法
input_file = 'data/weibo_4/simplifyweibo_4_moods.csv'  # 原始 CSV 文件路径
train_file = 'data/weibo_4/train.csv'  # 保存训练集的路径
val_file = 'data/weibo_4/valid.csv'  # 保存验证集的路径

# 创建随机种子保证复现性（可选）
random.seed(42)

split_csv(input_file, train_file, val_file)


数据集拆分完成：
从原始数据集中随机选取 1/5 的数据。
训练集保存至: data/weibo_4/train.csv, 样本数: 57878
验证集保存至: data/weibo_4/valid.csv, 样本数: 14470


In [25]:


import os
import torch
import torch.nn as nn
from transformers import BertTokenizer, AdamW, BertConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn import metrics


def main():

    # 参数设置
    model_path = r'model/bert-base-chinese/'
    data_path = r'data/weibo_4/'
    batch_size = 32
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs = 10
    learning_rate = 5e-6    #Learning Rate不宜太大
    tokenizer = BertTokenizer.from_pretrained(model_path)

    # 获取到dataset
    train_dataset = SinaNewsDataset(data_path + 'train.csv', tokenizer)
    valid_dataset = SinaNewsDataset(data_path + 'valid.csv', tokenizer)


    # 生成Batch
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # 读取BERT的配置文件
    bert_config = BertConfig.from_pretrained(model_path)
    num_labels = len(train_dataset.labels)

    # 初始化模型
    model = BertClassifier(bert_config, num_labels).to(device)

    # 优化器
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    # 损失函数
    criterion = nn.CrossEntropyLoss()

    best_f1 = 0

    for epoch in range(1, epochs+1):
        losses = 0      # 损失
        accuracy = 0    # 准确率

        model.train()
        
        train_bar = tqdm(train_dataloader, ncols=100)
        for input_ids, token_type_ids, attention_mask, label_id in train_bar:
            # 梯度清零
            model.zero_grad()
            train_bar.set_description('Epoch %i train' % epoch)

            # 传入数据，调用model.forward()
            output = model(
                input_ids=input_ids.to(device), 
                attention_mask=attention_mask.to(device), 
                token_type_ids=token_type_ids.to(device), 
            )

            # 计算loss
            loss = criterion(output, label_id.to(device))
            losses += loss.item()

            pred_labels = torch.argmax(output, dim=1)   # 预测出的label
            acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) #acc
            accuracy += acc

            loss.backward()
            optimizer.step()
            train_bar.set_postfix(loss=loss.item(), acc=acc)

        average_loss = losses / len(train_dataloader)
        average_acc = accuracy / len(train_dataloader)

        print('\tTrain ACC:', average_acc, '\tLoss:', average_loss)

        # 验证
        model.eval()
        losses = 0      # 损失
        pred_labels = []
        true_labels = []
        valid_bar = tqdm(valid_dataloader, ncols=100)
        for input_ids, token_type_ids, attention_mask, label_id  in valid_bar:
            valid_bar.set_description('Epoch %i valid' % epoch)

            output = model(
                input_ids=input_ids.to(device), 
                attention_mask=attention_mask.to(device), 
                token_type_ids=token_type_ids.to(device), 
            )
            
            loss = criterion(output, label_id.to(device))
            losses += loss.item()

            pred_label = torch.argmax(output, dim=1)   # 预测出的label
            acc = torch.sum(pred_label == label_id.to(device)).item() / len(pred_label) #acc
            valid_bar.set_postfix(loss=loss.item(), acc=acc)

            pred_labels.extend(pred_label.cpu().numpy().tolist())
            true_labels.extend(label_id.numpy().tolist())

        average_loss = losses / len(valid_dataloader)
        print('\tLoss:', average_loss)
        
        # 分类报告
        report = metrics.classification_report(true_labels, pred_labels, labels=valid_dataset.labels_id, target_names=valid_dataset.labels)
        print('* Classification Report:')
        print(report)

        # f1 用来判断最优模型
        f1 = metrics.f1_score(true_labels, pred_labels, labels=valid_dataset.labels_id, average='micro')
        
        if not os.path.exists('models'):
            os.makedirs('models')
        
        # 判断并保存验证集上表现最好的模型
        if f1 > best_f1:
            best_f1 = f1
            torch.save(model.state_dict(), 'models/best_model.pkl')
        
if __name__ == '__main__':
    main()

Loading data from: data/weibo_4/train.csv


57878it [00:15, 3856.22it/s]


Data loaded successfully: 57878 samples.
Loading data from: data/weibo_4/valid.csv


14470it [00:03, 3743.90it/s]


Data loaded successfully: 14470 samples.


Epoch 1 train: 100%|██████████████████████| 1809/1809 [20:13<00:00,  1.49it/s, acc=0.636, loss=1.02]


	Train ACC: 0.5491701844313784 	Loss: 1.1699675033934673


Epoch 1 valid: 100%|████████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.667, loss=1.17]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


	Loss: 1.101266787541623
* Classification Report:
              precision    recall  f1-score   support

           0       0.57      0.98      0.73      8003
           1       0.42      0.03      0.06      2113
           2       0.35      0.10      0.15      2186
           3       0.00      0.00      0.00      2168

    accuracy                           0.56     14470
   macro avg       0.34      0.28      0.23     14470
weighted avg       0.43      0.56      0.43     14470



Epoch 2 train: 100%|██████████████████████| 1809/1809 [20:15<00:00,  1.49it/s, acc=0.409, loss=1.29]


	Train ACC: 0.5597894366551083 	Loss: 1.1023511673087505


Epoch 2 valid: 100%|████████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.667, loss=0.95]


	Loss: 1.0819044146053576
* Classification Report:
              precision    recall  f1-score   support

           0       0.60      0.97      0.74      8003
           1       0.32      0.13      0.19      2113
           2       0.33      0.02      0.03      2186
           3       0.36      0.08      0.13      2168

    accuracy                           0.57     14470
   macro avg       0.40      0.30      0.27     14470
weighted avg       0.48      0.57      0.46     14470



Epoch 3 train: 100%|█████████████████████| 1809/1809 [20:16<00:00,  1.49it/s, acc=0.636, loss=0.952]


	Train ACC: 0.5719037388813509 	Loss: 1.0725173981496812


Epoch 3 valid: 100%|██████████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.5, loss=1.02]


	Loss: 1.1050543733779958
* Classification Report:
              precision    recall  f1-score   support

           0       0.59      0.99      0.74      8003
           1       0.45      0.05      0.09      2113
           2       0.39      0.01      0.01      2186
           3       0.37      0.13      0.19      2168

    accuracy                           0.57     14470
   macro avg       0.45      0.29      0.26     14470
weighted avg       0.51      0.57      0.45     14470



Epoch 4 train: 100%|██████████████████████| 1809/1809 [20:14<00:00,  1.49it/s, acc=0.545, loss=1.17]


	Train ACC: 0.579247072717222 	Loss: 1.053979066836366


Epoch 4 valid: 100%|███████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.667, loss=0.956]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


	Loss: 1.0549312774708728
* Classification Report:
              precision    recall  f1-score   support

           0       0.60      0.97      0.75      8003
           1       0.42      0.13      0.19      2113
           2       0.39      0.16      0.23      2186
           3       0.00      0.00      0.00      2168

    accuracy                           0.58     14470
   macro avg       0.35      0.32      0.29     14470
weighted avg       0.45      0.58      0.48     14470



Epoch 5 train: 100%|█████████████████████| 1809/1809 [20:14<00:00,  1.49it/s, acc=0.727, loss=0.721]


	Train ACC: 0.5862056510377406 	Loss: 1.035863955980378


Epoch 5 valid: 100%|███████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.833, loss=0.882]


	Loss: 1.0578647350633381
* Classification Report:
              precision    recall  f1-score   support

           0       0.63      0.94      0.76      8003
           1       0.36      0.25      0.30      2113
           2       0.36      0.14      0.21      2186
           3       0.36      0.05      0.08      2168

    accuracy                           0.58     14470
   macro avg       0.43      0.35      0.34     14470
weighted avg       0.51      0.58      0.51     14470



Epoch 6 train: 100%|█████████████████████| 1809/1809 [20:15<00:00,  1.49it/s, acc=0.636, loss=0.987]


	Train ACC: 0.5905950047741093 	Loss: 1.0250297227638607


Epoch 6 valid: 100%|█████████████████████████| 453/453 [01:41<00:00,  4.46it/s, acc=0.5, loss=0.949]


	Loss: 1.055707980596993
* Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.96      0.75      8003
           1       0.38      0.21      0.27      2113
           2       0.39      0.16      0.23      2186
           3       0.40      0.02      0.03      2168

    accuracy                           0.59     14470
   macro avg       0.45      0.34      0.32     14470
weighted avg       0.52      0.59      0.49     14470



Epoch 7 train: 100%|█████████████████████| 1809/1809 [20:15<00:00,  1.49it/s, acc=0.545, loss=0.917]


	Train ACC: 0.5959690185436454 	Loss: 1.0115163120003021


Epoch 7 valid: 100%|███████████████████████| 453/453 [01:41<00:00,  4.46it/s, acc=0.667, loss=0.742]


	Loss: 1.0792742163666562
* Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.85      0.75      8003
           1       0.30      0.41      0.35      2113
           2       0.34      0.12      0.18      2186
           3       0.31      0.11      0.16      2168

    accuracy                           0.57     14470
   macro avg       0.41      0.37      0.36     14470
weighted avg       0.52      0.57      0.52     14470



Epoch 8 train: 100%|█████████████████████| 1809/1809 [20:14<00:00,  1.49it/s, acc=0.636, loss=0.969]


	Train ACC: 0.6014435398763758 	Loss: 0.995746215233953


Epoch 8 valid: 100%|█████████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.5, loss=0.945]


	Loss: 1.0613147172969961
* Classification Report:
              precision    recall  f1-score   support

           0       0.67      0.87      0.76      8003
           1       0.34      0.33      0.34      2113
           2       0.35      0.19      0.24      2186
           3       0.27      0.12      0.16      2168

    accuracy                           0.57     14470
   macro avg       0.41      0.37      0.37     14470
weighted avg       0.51      0.57      0.53     14470



Epoch 9 train: 100%|██████████████████████| 1809/1809 [20:14<00:00,  1.49it/s, acc=0.545, loss=1.02]


	Train ACC: 0.606472058897432 	Loss: 0.984602918968865


Epoch 9 valid: 100%|████████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.667, loss=1.27]


	Loss: 1.0746362504053852
* Classification Report:
              precision    recall  f1-score   support

           0       0.64      0.94      0.76      8003
           1       0.49      0.10      0.17      2113
           2       0.34      0.34      0.34      2186
           3       0.00      0.00      0.00      2168

    accuracy                           0.59     14470
   macro avg       0.37      0.34      0.32     14470
weighted avg       0.47      0.59      0.49     14470



Epoch 10 train: 100%|█████████████████████| 1809/1809 [20:15<00:00,  1.49it/s, acc=0.455, loss=1.18]


	Train ACC: 0.6109132368460727 	Loss: 0.9744605652992075


Epoch 10 valid: 100%|███████████████████████| 453/453 [01:41<00:00,  4.47it/s, acc=0.667, loss=1.07]


	Loss: 1.0726215982805551
* Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.85      0.75      8003
           1       0.38      0.24      0.30      2113
           2       0.31      0.42      0.35      2186
           3       0.21      0.01      0.01      2168

    accuracy                           0.57     14470
   macro avg       0.39      0.38      0.35     14470
weighted avg       0.51      0.57      0.52     14470

