In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
# 1. 读取数据
data = pd.read_csv('./movie-review-sentiment-analysis-kernels-only/train.tsv', delimiter='\t')
data = data[["Phrase", "Sentiment"]]

# 2. 数据切分
train_texts, test_texts, train_labels, test_labels = train_test_split(
    data["Phrase"].tolist(), 
    data["Sentiment"].tolist(), 
    test_size=0.2, 
    random_state=42
)

In [3]:
# 3. 定义自定义数据集
class PhraseDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [5]:
# 4. 加载分词器和数据集
tokenizer = BertTokenizer.from_pretrained("./models/bert-base-uncased")
train_dataset = PhraseDataset(train_texts, train_labels, tokenizer)
test_dataset = PhraseDataset(test_texts, test_labels, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 5. 初始化模型
model = BertForSequenceClassification.from_pretrained("./models/bert-base-uncased", num_labels=5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [6]:
# 6. 定义优化器和损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = torch.nn.CrossEntropyLoss()

In [7]:
# 7. 训练和评估函数
def train_one_epoch(model, data_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(data_loader, desc="Training")
    
    for batch in progress_bar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        accuracy = correct / total

        progress_bar.set_postfix({"Loss": running_loss / total, "Accuracy": accuracy})

def evaluate(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc="Evaluating")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            running_loss += loss.item()
            preds = torch.argmax(F.softmax(logits, dim=1), dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            accuracy = correct / total

            progress_bar.set_postfix({"Loss": running_loss / total, "Accuracy": accuracy})

    return running_loss / total, accuracy

In [10]:
# 8. 训练模型
epochs = 10
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_accuracy = evaluate(model, test_loader, criterion, device)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

Epoch 1/10


Training: 100%|██████████| 7803/7803 [07:15<00:00, 17.93it/s, Loss=0.0316, Accuracy=0.79] 
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.46it/s, Loss=0.0128, Accuracy=0.687]


Validation Loss: 0.0128, Validation Accuracy: 0.6866
Epoch 2/10


Training: 100%|██████████| 7803/7803 [07:14<00:00, 17.94it/s, Loss=0.0268, Accuracy=0.821]
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.15it/s, Loss=0.0141, Accuracy=0.676]


Validation Loss: 0.0141, Validation Accuracy: 0.6759
Epoch 3/10


Training: 100%|██████████| 7803/7803 [07:13<00:00, 17.99it/s, Loss=0.0222, Accuracy=0.853]
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.49it/s, Loss=0.0157, Accuracy=0.674]


Validation Loss: 0.0157, Validation Accuracy: 0.6740
Epoch 4/10


Training: 100%|██████████| 7803/7803 [07:16<00:00, 17.89it/s, Loss=0.0182, Accuracy=0.88] 
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.04it/s, Loss=0.0178, Accuracy=0.653]


Validation Loss: 0.0178, Validation Accuracy: 0.6534
Epoch 5/10


Training: 100%|██████████| 7803/7803 [07:15<00:00, 17.93it/s, Loss=0.0153, Accuracy=0.901]
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.38it/s, Loss=0.0193, Accuracy=0.661]


Validation Loss: 0.0193, Validation Accuracy: 0.6608
Epoch 6/10


Training: 100%|██████████| 7803/7803 [07:16<00:00, 17.87it/s, Loss=0.0129, Accuracy=0.918]
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.11it/s, Loss=0.0212, Accuracy=0.657]


Validation Loss: 0.0212, Validation Accuracy: 0.6566
Epoch 7/10


Training: 100%|██████████| 7803/7803 [07:17<00:00, 17.82it/s, Loss=0.0111, Accuracy=0.931] 
Evaluating: 100%|██████████| 488/488 [00:26<00:00, 18.63it/s, Loss=0.0219, Accuracy=0.654]


Validation Loss: 0.0219, Validation Accuracy: 0.6541
Epoch 8/10


Training: 100%|██████████| 7803/7803 [07:17<00:00, 17.85it/s, Loss=0.00962, Accuracy=0.941]
Evaluating: 100%|██████████| 488/488 [00:26<00:00, 18.77it/s, Loss=0.0239, Accuracy=0.662]


Validation Loss: 0.0239, Validation Accuracy: 0.6622
Epoch 9/10


Training: 100%|██████████| 7803/7803 [07:14<00:00, 17.94it/s, Loss=0.00846, Accuracy=0.948]
Evaluating: 100%|██████████| 488/488 [00:24<00:00, 19.54it/s, Loss=0.026, Accuracy=0.659] 


Validation Loss: 0.0260, Validation Accuracy: 0.6594
Epoch 10/10


Training: 100%|██████████| 7803/7803 [07:15<00:00, 17.90it/s, Loss=0.00751, Accuracy=0.955]
Evaluating: 100%|██████████| 488/488 [00:25<00:00, 19.29it/s, Loss=0.0265, Accuracy=0.658]

Validation Loss: 0.0265, Validation Accuracy: 0.6576





In [9]:
# 9. 保存模型
torch.save(model.state_dict(), "bert_sentiment_model.pth")