In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 定义数据集
class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        text = str(self.texts[index])
        label = self.labels[index]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }
    
# 定义Self-Attention层
class SelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super(SelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.projection = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(True),
            nn.Linear(64, 1)
        )

    def forward(self, encoder_outputs):
        energy = self.projection(encoder_outputs)
        weights = torch.softmax(energy.squeeze(-1), dim=1)
        outputs = (encoder_outputs * weights.unsqueeze(-1)).sum(dim=1)
        return outputs

# 定义模型
class NewsClassifier(nn.Module):
    # hidden_size = 128
    def __init__(self, num_classes, hidden_size=768, num_layers=2, bidirectional=True):
        super(NewsClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('../../bert-base-multilingual-cased')
        self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, 
                            bidirectional=bidirectional, batch_first=True)
        
        self.attention = SelfAttention(hidden_size * (2 if bidirectional else 1))
        self.fc = nn.Linear(hidden_size * (2 if bidirectional else 1), num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        lstm_outputs, _ = self.lstm(last_hidden_state)
        attention_outputs = self.attention(lstm_outputs)
        logits = self.fc(attention_outputs)
        return logits

# 定义训练函数
def train_model(model, dataloader, optimizer, scheduler, device, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch + 1}', leave=False)

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(logits, labels)
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        avg_loss = total_loss / len(dataloader)
        # progress_bar.set_postfix({'Loss': avg_loss:.4f})
        # progress_bar.set_postfix({'Loss': avg_loss:.4f})
        progress_bar.set_postfix({'Loss': "{:.4f}".format(avg_loss) })
        


    return avg_loss

# 定义评估函数
def evaluate_model(model, dataloader, device):
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            batch_predictions = torch.argmax(logits, dim=1)

            predictions.extend(batch_predictions.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    return classification_report(true_labels, predictions, digits=4)



In [3]:

# 加载数据
file_path = '../../datasets_FIX2/FIX2_deduplicated_mangoNews_Nums3000p_CategoryMerge_new_undersampled_Example.csv'
# file_path = '../datasets_FIX2/FIX2_deduplicated_mangoNews_Nums3000p_CategoryMerge_new_undersampled.csv'

data = pd.read_csv(file_path,low_memory=False,lineterminator="\n")

texts = data['body'].tolist()
labels = data['category1'].tolist()

# 对标签进行编码
unique_labels = list(set(labels))
label_to_id = {label: i for i, label in enumerate(unique_labels)}
labels = [label_to_id[label] for label in labels]
num_classes = len(unique_labels)

# 加载BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('../../bert-base-multilingual-cased')



In [4]:
# 设置超参数
max_length = 256
batch_size = 16
epochs = 2
learning_rate = 2e-5
hidden_size = 768
num_layers = 2
bidirectional = True

# 使用KFold进行交叉验证
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_models = []
best_reports = []

for fold, (train_index, val_index) in enumerate(kfold.split(texts, labels)):
    print(f'Fold {fold + 1}')
    print('-' * 30)

    train_texts, val_texts = [texts[i] for i in train_index], [texts[i] for i in val_index]
    train_labels, val_labels = [labels[i] for i in train_index], [labels[i] for i in val_index]

    train_dataset = NewsDataset(train_texts, train_labels, tokenizer, max_length)
    val_dataset = NewsDataset(val_texts, val_labels, tokenizer, max_length)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = NewsClassifier(num_classes, hidden_size, num_layers, bidirectional)
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    best_val_loss = float('inf')
    best_model = None

    for epoch in range(epochs):
        train_loss = train_model(model, train_dataloader, optimizer, scheduler, device, epoch)
        val_report = evaluate_model(model, val_dataloader, device)

        val_loss = 1 - float(val_report.split('\n')[-3].split()[-2])  # 提取验证集损失

        print(f'Epoch {epoch + 1}/{epochs}')
        print(f'Train Loss: {train_loss:.4f}')
        print('Validation Report:')
        print(val_report)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            # best_model = model.state_dict()
            # best_report = val_report
            torch.save(model.state_dict(), f'best_MultiBert_BiLSTM_SelfAttention_model_fold_{fold + 1}.pth')


    # print(f'Fold {fold + 1} Best Validation Results:')
    # print(best_report)
    # print()

# 在每个fold结束后,评估最佳模型在验证集上的性能
    # best_model = LSTMClassifier(bert_model, hidden_size, num_classes)
    best_model = NewsClassifier(num_classes, hidden_size, num_layers, bidirectional)

    best_model.load_state_dict(torch.load(f'best_MultiBert_BiLSTM_SelfAttention_model_fold_{fold + 1}.pth'))
    best_model.to(device)
    val_report = evaluate_model(best_model, val_dataloader, device)
    # all_reports.append(val_report)

    print(f'Fold {fold + 1} Best Validation Report:')
    print(val_report)
    print()

Fold 1
------------------------------


                                                                     

Epoch 1/2
Train Loss: 1.9951
Validation Report:
              precision    recall  f1-score   support

           0     0.5926    0.5517    0.5714        29
           1     0.8400    0.8400    0.8400        25
           2     0.6667    0.9630    0.7879        27
           3     0.5238    0.6471    0.5789        17
           4     0.9167    0.9565    0.9362        23
           5     1.0000    0.9048    0.9500        21
           6     0.7143    0.6250    0.6667        16
           7     0.7391    0.7083    0.7234        24
           8     1.0000    0.4444    0.6154        18

    accuracy                         0.7500       200
   macro avg     0.7770    0.7379    0.7411       200
weighted avg     0.7717    0.7500    0.7464       200



                                                                     

Epoch 2/2
Train Loss: 1.4363
Validation Report:
              precision    recall  f1-score   support

           0     0.7037    0.6552    0.6786        29
           1     0.8519    0.9200    0.8846        25
           2     0.7419    0.8519    0.7931        27
           3     0.6316    0.7059    0.6667        17
           4     0.9167    0.9565    0.9362        23
           5     0.9091    0.9524    0.9302        21
           6     0.6000    0.7500    0.6667        16
           7     0.8947    0.7083    0.7907        24
           8     1.0000    0.6111    0.7586        18

    accuracy                         0.7950       200
   macro avg     0.8055    0.7901    0.7895       200
weighted avg     0.8086    0.7950    0.7945       200

Fold 1 Best Validation Report:
              precision    recall  f1-score   support

           0     0.7037    0.6552    0.6786        29
           1     0.8519    0.9200    0.8846        25
           2     0.7419    0.8519    0.7931        27

                                                                     

Epoch 1/2
Train Loss: 1.9848
Validation Report:
              precision    recall  f1-score   support

           0     0.7143    0.5000    0.5882        20
           1     0.6667    0.7826    0.7200        23
           2     0.5526    0.9545    0.7000        22
           3     0.6923    0.6429    0.6667        28
           4     0.9615    1.0000    0.9804        25
           5     0.9444    0.7727    0.8500        22
           6     0.9286    0.5652    0.7027        23
           7     0.6875    0.7857    0.7333        14
           8     0.7619    0.6957    0.7273        23

    accuracy                         0.7450       200
   macro avg     0.7678    0.7444    0.7410       200
weighted avg     0.7724    0.7450    0.7438       200



                                                                     

Epoch 2/2
Train Loss: 1.4170
Validation Report:
              precision    recall  f1-score   support

           0     0.7857    0.5500    0.6471        20
           1     0.6538    0.7391    0.6939        23
           2     0.6774    0.9545    0.7925        22
           3     0.6897    0.7143    0.7018        28
           4     1.0000    0.9600    0.9796        25
           5     0.9524    0.9091    0.9302        22
           6     0.9412    0.6957    0.8000        23
           7     0.7500    0.8571    0.8000        14
           8     0.7273    0.6957    0.7111        23

    accuracy                         0.7850       200
   macro avg     0.7975    0.7862    0.7840       200
weighted avg     0.7990    0.7850    0.7845       200

Fold 2 Best Validation Report:
              precision    recall  f1-score   support

           0     0.7857    0.5500    0.6471        20
           1     0.6538    0.7391    0.6939        23
           2     0.6774    0.9545    0.7925        22

                                                                     

Epoch 1/2
Train Loss: 1.9999
Validation Report:
              precision    recall  f1-score   support

           0     0.6667    0.4000    0.5000        20
           1     0.8333    0.9091    0.8696        22
           2     0.9048    0.8636    0.8837        22
           3     0.8095    0.8095    0.8095        21
           4     0.8636    1.0000    0.9268        19
           5     0.8947    0.8500    0.8718        20
           6     0.8966    0.8125    0.8525        32
           7     0.4706    0.8000    0.5926        20
           8     0.9444    0.7083    0.8095        24

    accuracy                         0.7950       200
   macro avg     0.8094    0.7948    0.7907       200
weighted avg     0.8182    0.7950    0.7959       200



                                                                     

Epoch 2/2
Train Loss: 1.4729
Validation Report:
              precision    recall  f1-score   support

           0     0.6500    0.6500    0.6500        20
           1     0.8077    0.9545    0.8750        22
           2     0.7097    1.0000    0.8302        22
           3     0.8571    0.8571    0.8571        21
           4     0.9500    1.0000    0.9744        19
           5     0.9474    0.9000    0.9231        20
           6     0.8889    0.7500    0.8136        32
           7     0.8750    0.7000    0.7778        20
           8     0.9000    0.7500    0.8182        24

    accuracy                         0.8350       200
   macro avg     0.8429    0.8402    0.8355       200
weighted avg     0.8446    0.8350    0.8336       200

Fold 3 Best Validation Report:
              precision    recall  f1-score   support

           0     0.6500    0.6500    0.6500        20
           1     0.8077    0.9545    0.8750        22
           2     0.7097    1.0000    0.8302        22

                                                                     

Epoch 1/2
Train Loss: 1.9442
Validation Report:
              precision    recall  f1-score   support

           0     0.7391    0.6800    0.7083        25
           1     0.7647    0.9286    0.8387        28
           2     0.7353    0.9259    0.8197        27
           3     0.5946    1.0000    0.7458        22
           4     1.0000    0.9062    0.9508        32
           5     1.0000    1.0000    1.0000         7
           6     0.8000    0.8000    0.8000        15
           7     1.0000    0.4762    0.6452        21
           8     0.8182    0.3913    0.5294        23

    accuracy                         0.7850       200
   macro avg     0.8280    0.7898    0.7820       200
weighted avg     0.8182    0.7850    0.7744       200



                                                                     

Epoch 2/2
Train Loss: 1.3900
Validation Report:
              precision    recall  f1-score   support

           0     0.7917    0.7600    0.7755        25
           1     0.8966    0.9286    0.9123        28
           2     0.8125    0.9630    0.8814        27
           3     0.6667    1.0000    0.8000        22
           4     0.9688    0.9688    0.9688        32
           5     0.8750    1.0000    0.9333         7
           6     0.7647    0.8667    0.8125        15
           7     0.9231    0.5714    0.7059        21
           8     0.8333    0.4348    0.5714        23

    accuracy                         0.8300       200
   macro avg     0.8369    0.8326    0.8179       200
weighted avg     0.8432    0.8300    0.8201       200

Fold 4 Best Validation Report:
              precision    recall  f1-score   support

           0     0.7917    0.7600    0.7755        25
           1     0.8966    0.9286    0.9123        28
           2     0.8125    0.9630    0.8814        27

                                                                     

Epoch 1/2
Train Loss: 1.9606
Validation Report:
              precision    recall  f1-score   support

           0     0.6111    0.4583    0.5238        24
           1     0.6429    0.9000    0.7500        20
           2     0.7391    0.8095    0.7727        21
           3     0.6970    0.8846    0.7797        26
           4     0.8750    0.9545    0.9130        22
           5     0.9167    0.8800    0.8980        25
           6     0.8571    0.8182    0.8372        22
           7     1.0000    0.6957    0.8205        23
           8     0.7692    0.5882    0.6667        17

    accuracy                         0.7800       200
   macro avg     0.7898    0.7766    0.7735       200
weighted avg     0.7913    0.7800    0.7761       200



                                                                     

Epoch 2/2
Train Loss: 1.4003
Validation Report:
              precision    recall  f1-score   support

           0     0.6364    0.5833    0.6087        24
           1     0.7826    0.9000    0.8372        20
           2     0.7500    0.8571    0.8000        21
           3     0.8400    0.8077    0.8235        26
           4     1.0000    0.9545    0.9767        22
           5     0.8800    0.8800    0.8800        25
           6     0.8636    0.8636    0.8636        22
           7     0.7826    0.7826    0.7826        23
           8     0.8000    0.7059    0.7500        17

    accuracy                         0.8150       200
   macro avg     0.8150    0.8150    0.8136       200
weighted avg     0.8156    0.8150    0.8140       200

Fold 5 Best Validation Report:
              precision    recall  f1-score   support

           0     0.6364    0.5833    0.6087        24
           1     0.7826    0.9000    0.8372        20
           2     0.7500    0.8571    0.8000        21