In [1]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
import torch.nn as nn
import numpy as np
from tqdm import tqdm

# 加载数据集
my_dataset = load_dataset("sst2")

# 加载BERT模型和分词器
model_path = 'model'  # 使用预训练的BERT模型
tokenizer = BertTokenizer.from_pretrained(model_path)
bert_model = BertModel.from_pretrained(model_path)

# 将模型和数据移动到GPU上
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model.to(device)

# 定义情感分类模型
class SentimentClassification(nn.Module):
    def __init__(self, bert_model, hidden_dim, output_dim):
        super().__init__()
        self.bert = bert_model
        self.fc1 = nn.Linear(bert_model.config.hidden_size, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        # 获取BERT的输出
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # 获取最后一个隐藏层的表示
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # 使用全连接层进行分类
        x = self.fc1(pooled_output)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Tokenize examples
tokenized_dataset = my_dataset.map(lambda example: tokenizer(example["sentence"], padding="max_length", truncation=True), batched=True)

# 切分数据集
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]
small_train_dataset = tokenized_dataset["train"].shuffle(seed=42).select(range(80))
small_eval_dataset = tokenized_dataset["validation"].shuffle(seed=42).select(range(32))

# 定义数据加载器
batch_size = 2

def collate_fn(batch):
    # 提取输入列表和标签列表
    input_ids_list = [example['input_ids'] for example in batch]
    attention_mask_list = [example['attention_mask'] for example in batch]
    labels_list = [example['label'] for example in batch]
    
    # 转换为张量并删除大小为1的维度
    input_ids = torch.tensor(input_ids_list).squeeze().to(device)
    attention_mask = torch.tensor(attention_mask_list).squeeze().to(device)
    labels = torch.tensor(labels_list).to(device)
    
    return input_ids, attention_mask, labels

# 在数据加载器中使用自定义的collate_fn
train_dataloader = DataLoader(
    small_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(
    small_eval_dataset, batch_size=batch_size, collate_fn=collate_fn)

# 初始化情感分类模型
output_dim = 2  # 二分类任务
hidden_dim = 128  # 隐藏层维度
model = SentimentClassification(bert_model, hidden_dim, output_dim).to(device)

# 定义优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 训练参数设置
num_epochs = 50
eval_steps = len(train_dataloader) // 10

# 模型训练
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# 设置训练参数
num_training_steps = len(train_dataloader) * num_epochs
progress_bar = tqdm(range(num_training_steps))
def compute_metrics(model, eval_dataloader):
    model.eval()
    eval_loss = 0
    eval_accuracy = 0
    total_eval_samples = 0
    
    for batch in eval_dataloader:
        with torch.no_grad():
            inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
            outputs = model(**inputs)
            logits = outputs
            loss = criterion(logits, batch[2])
        
        eval_loss += loss.item()
        logits_np = logits.detach().cpu().numpy()
        predictions = np.argmax(logits_np, axis=-1)
        labels_np = batch[2].cpu().numpy()
        eval_accuracy += np.sum(predictions == labels_np)
        total_eval_samples += len(labels_np)
    
    eval_loss /= len(eval_dataloader)
    eval_accuracy /= total_eval_samples
    return eval_loss, eval_accuracy

# 模型训练
model.train()
for epoch in range(num_epochs):
    print('epoch:', epoch)
    for batch in train_dataloader:
        inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
        outputs = model(**inputs)
        loss = criterion(outputs, batch[2])
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
    
    # 在评估数据集上计算指标
    eval_loss, eval_accuracy = compute_metrics(model, eval_dataloader)
    print("Eval Loss:", eval_loss)
    print("Eval Accuracy:", eval_accuracy)
    
    # 保存当前模型
    torch.save(model.state_dict(), './checkpoint.pt')

# 加载最佳模型
best_model = SentimentClassification(bert_model, hidden_dim, output_dim).to(device)
best_model.load_state_dict(torch.load('./checkpoint.pt'))


  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/2000 [00:00<?, ?it/s]

epoch: 0


  2%|▏         | 40/2000 [00:02<01:48, 18.00it/s]

Eval Loss: 0.6853411234915257
Eval Accuracy: 0.53125


  2%|▏         | 42/2000 [00:04<07:48,  4.17it/s]

epoch: 1


  4%|▍         | 80/2000 [00:08<03:02, 10.54it/s]

Eval Loss: 0.6623741574585438
Eval Accuracy: 0.5625


  4%|▍         | 82/2000 [00:09<10:47,  2.96it/s]

epoch: 2


  6%|▌         | 120/2000 [00:13<02:53, 10.84it/s]

Eval Loss: 1.0610534744337201
Eval Accuracy: 0.53125


  6%|▌         | 122/2000 [00:15<09:55,  3.15it/s]

epoch: 3


  8%|▊         | 160/2000 [00:19<03:25,  8.97it/s]

Eval Loss: 0.5703373402357101
Eval Accuracy: 0.75


  8%|▊         | 162/2000 [00:21<10:29,  2.92it/s]

epoch: 4


 10%|█         | 200/2000 [00:25<03:20,  9.00it/s]

Eval Loss: 0.7488344277953729
Eval Accuracy: 0.8125


 10%|█         | 201/2000 [00:26<12:48,  2.34it/s]

epoch: 5


 12%|█▏        | 239/2000 [00:29<01:34, 18.56it/s]

Eval Loss: 0.8141119051724672
Eval Accuracy: 0.8125


 12%|█▏        | 243/2000 [00:30<04:37,  6.32it/s]

epoch: 6


 14%|█▍        | 279/2000 [00:32<01:31, 18.85it/s]

Eval Loss: 0.8288086804095656
Eval Accuracy: 0.8125


 14%|█▍        | 283/2000 [00:33<04:28,  6.40it/s]

epoch: 7


 16%|█▌        | 319/2000 [00:35<02:42, 10.35it/s]

Eval Loss: 0.8773941621766426
Eval Accuracy: 0.8125


 16%|█▌        | 322/2000 [00:37<07:53,  3.55it/s]

epoch: 8


 18%|█▊        | 360/2000 [00:41<02:30, 10.87it/s]

Eval Loss: 0.8891776298405603
Eval Accuracy: 0.8125


 18%|█▊        | 362/2000 [00:42<08:09,  3.35it/s]

epoch: 9


 20%|██        | 400/2000 [00:46<02:52,  9.28it/s]

Eval Loss: 0.917597218620358
Eval Accuracy: 0.78125


 20%|██        | 401/2000 [00:48<09:52,  2.70it/s]

epoch: 10


 22%|██▏       | 439/2000 [00:51<02:11, 11.91it/s]

Eval Loss: 0.9372809387277812
Eval Accuracy: 0.78125


 22%|██▏       | 443/2000 [00:53<06:17,  4.12it/s]

epoch: 11


 24%|██▍       | 480/2000 [00:57<02:28, 10.21it/s]

Eval Loss: 0.9444773899449501
Eval Accuracy: 0.78125


 24%|██▍       | 482/2000 [00:58<07:22,  3.43it/s]

epoch: 12


 26%|██▌       | 520/2000 [01:02<02:33,  9.65it/s]

Eval Loss: 0.9731908648100216
Eval Accuracy: 0.78125


 26%|██▌       | 522/2000 [01:04<07:21,  3.35it/s]

epoch: 13


 28%|██▊       | 560/2000 [01:08<02:41,  8.93it/s]

Eval Loss: 0.9782810426549986
Eval Accuracy: 0.78125


 28%|██▊       | 562/2000 [01:09<08:15,  2.90it/s]

epoch: 14


 30%|███       | 600/2000 [01:13<02:23,  9.74it/s]

Eval Loss: 0.9954189840937033
Eval Accuracy: 0.78125


 30%|███       | 602/2000 [01:15<07:33,  3.08it/s]

epoch: 15


 32%|███▏      | 640/2000 [01:18<02:06, 10.79it/s]

Eval Loss: 1.015680251584854
Eval Accuracy: 0.78125


 32%|███▏      | 642/2000 [01:20<07:08,  3.17it/s]

epoch: 16


 34%|███▍      | 680/2000 [01:24<02:32,  8.65it/s]

Eval Loss: 1.0209865599754266
Eval Accuracy: 0.78125


 34%|███▍      | 681/2000 [01:25<09:26,  2.33it/s]

epoch: 17


 36%|███▌      | 720/2000 [01:30<02:16,  9.39it/s]

Eval Loss: 1.0426147987891454
Eval Accuracy: 0.78125


 36%|███▌      | 722/2000 [01:31<07:26,  2.86it/s]

epoch: 18


 38%|███▊      | 760/2000 [01:35<02:15,  9.17it/s]

Eval Loss: 1.046975118937553
Eval Accuracy: 0.78125


 38%|███▊      | 763/2000 [01:37<05:39,  3.64it/s]

epoch: 19


 40%|████      | 800/2000 [01:39<01:04, 18.52it/s]

Eval Loss: 1.0598961972282268
Eval Accuracy: 0.78125


 40%|████      | 802/2000 [01:40<04:10,  4.78it/s]

epoch: 20


 42%|████▏     | 840/2000 [01:42<01:02, 18.53it/s]

Eval Loss: 1.0733922637009528
Eval Accuracy: 0.78125


 42%|████▏     | 842/2000 [01:43<04:25,  4.36it/s]

epoch: 21


 44%|████▍     | 879/2000 [01:46<02:04,  9.04it/s]

Eval Loss: 1.0761425620439695
Eval Accuracy: 0.78125


 44%|████▍     | 883/2000 [01:48<04:52,  3.81it/s]

epoch: 22


 46%|████▌     | 920/2000 [01:51<01:09, 15.44it/s]

Eval Loss: 1.0923535237452597
Eval Accuracy: 0.75


 46%|████▌     | 922/2000 [01:52<04:28,  4.01it/s]

epoch: 23


 48%|████▊     | 960/2000 [01:56<01:44,  9.94it/s]

Eval Loss: 1.0946504254316096
Eval Accuracy: 0.75


 48%|████▊     | 962/2000 [01:58<05:42,  3.03it/s]

epoch: 24


 50%|████▉     | 999/2000 [02:02<01:27, 11.39it/s]

Eval Loss: 1.1048736733428086
Eval Accuracy: 0.75


 50%|█████     | 1002/2000 [02:03<04:44,  3.51it/s]

epoch: 25


 52%|█████▏    | 1040/2000 [02:07<01:51,  8.62it/s]

Eval Loss: 1.1153699382775812
Eval Accuracy: 0.75


 52%|█████▏    | 1042/2000 [02:09<05:39,  2.82it/s]

epoch: 26


 54%|█████▍    | 1079/2000 [02:13<01:41,  9.10it/s]

Eval Loss: 1.1185396870205295
Eval Accuracy: 0.75


 54%|█████▍    | 1082/2000 [02:14<05:03,  3.03it/s]

epoch: 27


 56%|█████▌    | 1120/2000 [02:18<01:38,  8.93it/s]

Eval Loss: 1.1321905516379047
Eval Accuracy: 0.75


 56%|█████▌    | 1121/2000 [02:20<06:07,  2.39it/s]

epoch: 28


 58%|█████▊    | 1159/2000 [02:24<01:38,  8.50it/s]

Eval Loss: 1.1360621668427484
Eval Accuracy: 0.75


 58%|█████▊    | 1162/2000 [02:26<04:54,  2.85it/s]

epoch: 29


 60%|██████    | 1200/2000 [02:29<01:15, 10.58it/s]

Eval Loss: 1.144849481563142
Eval Accuracy: 0.75


 60%|██████    | 1202/2000 [02:31<04:32,  2.93it/s]

epoch: 30


 62%|██████▏   | 1240/2000 [02:35<01:03, 12.03it/s]

Eval Loss: 1.152430091293354
Eval Accuracy: 0.75


 62%|██████▏   | 1242/2000 [02:36<04:05,  3.09it/s]

epoch: 31


 64%|██████▍   | 1279/2000 [02:40<01:25,  8.39it/s]

Eval Loss: 1.155596310592955
Eval Accuracy: 0.75


 64%|██████▍   | 1282/2000 [02:42<03:44,  3.20it/s]

epoch: 32


 66%|██████▌   | 1320/2000 [02:46<01:03, 10.65it/s]

Eval Loss: 1.169466278362961
Eval Accuracy: 0.75


 66%|██████▌   | 1322/2000 [02:48<03:48,  2.96it/s]

epoch: 33


 68%|██████▊   | 1360/2000 [02:50<00:34, 18.50it/s]

Eval Loss: 1.1727727545548987
Eval Accuracy: 0.75


 68%|██████▊   | 1362/2000 [02:51<02:27,  4.31it/s]

epoch: 34


 70%|███████   | 1400/2000 [02:53<00:32, 18.51it/s]

Eval Loss: 1.1808628744074667
Eval Accuracy: 0.75


 70%|███████   | 1402/2000 [02:54<02:05,  4.77it/s]

epoch: 35


 72%|███████▏  | 1439/2000 [02:58<00:54, 10.27it/s]

Eval Loss: 1.1916990761710622
Eval Accuracy: 0.75


 72%|███████▏  | 1441/2000 [02:59<03:08,  2.97it/s]

epoch: 36


 74%|███████▍  | 1479/2000 [03:03<00:53,  9.70it/s]

Eval Loss: 1.1944852510860073
Eval Accuracy: 0.75


 74%|███████▍  | 1483/2000 [03:05<01:59,  4.33it/s]

epoch: 37


 76%|███████▌  | 1520/2000 [03:09<00:54,  8.74it/s]

Eval Loss: 1.2069988156399631
Eval Accuracy: 0.75


 76%|███████▌  | 1521/2000 [03:10<03:45,  2.12it/s]

epoch: 38


 78%|███████▊  | 1559/2000 [03:14<00:42, 10.41it/s]

Eval Loss: 1.2096653893750045
Eval Accuracy: 0.75


 78%|███████▊  | 1562/2000 [03:15<01:56,  3.76it/s]

epoch: 39


 80%|███████▉  | 1599/2000 [03:19<00:38, 10.32it/s]

Eval Loss: 1.2175417158941855
Eval Accuracy: 0.75


 80%|████████  | 1602/2000 [03:21<01:55,  3.45it/s]

epoch: 40


 82%|████████▏ | 1640/2000 [03:25<00:42,  8.48it/s]

Eval Loss: 1.2262658015570196
Eval Accuracy: 0.75


 82%|████████▏ | 1643/2000 [03:26<01:36,  3.71it/s]

epoch: 41


 84%|████████▍ | 1680/2000 [03:30<00:34,  9.40it/s]

Eval Loss: 1.2284113312271074
Eval Accuracy: 0.75


 84%|████████▍ | 1681/2000 [03:31<01:54,  2.78it/s]

epoch: 42


 86%|████████▌ | 1720/2000 [03:36<00:32,  8.73it/s]

Eval Loss: 1.239759272430092
Eval Accuracy: 0.75


 86%|████████▌ | 1722/2000 [03:37<01:45,  2.63it/s]

epoch: 43


 88%|████████▊ | 1760/2000 [03:41<00:18, 13.15it/s]

Eval Loss: 1.2417562702794385
Eval Accuracy: 0.75


 88%|████████▊ | 1762/2000 [03:42<01:12,  3.30it/s]

epoch: 44


 90%|█████████ | 1800/2000 [03:46<00:18, 10.91it/s]

Eval Loss: 1.2509224095047102
Eval Accuracy: 0.75


 90%|█████████ | 1802/2000 [03:48<01:08,  2.87it/s]

epoch: 45


 92%|█████████▏| 1840/2000 [03:52<00:18,  8.78it/s]

Eval Loss: 1.259675869703642
Eval Accuracy: 0.75


 92%|█████████▏| 1842/2000 [03:53<00:59,  2.64it/s]

epoch: 46


 94%|█████████▍| 1879/2000 [03:57<00:11, 10.82it/s]

Eval Loss: 1.2616180361583247
Eval Accuracy: 0.75


 94%|█████████▍| 1883/2000 [03:59<00:26,  4.39it/s]

epoch: 47


 96%|█████████▌| 1919/2000 [04:01<00:04, 18.55it/s]

Eval Loss: 1.2714806131043588
Eval Accuracy: 0.75


 96%|█████████▌| 1923/2000 [04:02<00:12,  6.05it/s]

epoch: 48


 98%|█████████▊| 1959/2000 [04:04<00:02, 18.54it/s]

Eval Loss: 1.2746129347106034
Eval Accuracy: 0.75


 98%|█████████▊| 1963/2000 [04:05<00:06,  5.77it/s]

epoch: 49


100%|█████████▉| 1999/2000 [04:08<00:00,  9.24it/s]

Eval Loss: 1.281988362876291
Eval Accuracy: 0.75


<All keys matched successfully>