## Dataset define

In [None]:
from datasets import load_dataset

dataset = load_dataset(
    "csv",
    data_files={
        "train": "imdb_train.csv",
        "validation": "imdb_validation.csv",
        "test": "imdb_test.csv"
    }
)
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 35000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 5000
    })
})


In [38]:
import pandas as pd

# 转成 DataFrame 显示
df = pd.DataFrame(dataset["train"][:10])
print(df)

                                                text  label
0  It's a Time Machine all right. It runs in "rea...      0
1  Recognizing the picture of the diner on the co...      0
2  Hmm, Hip Hop music to a period western. Modern...      0
3  This is a formula B science fiction movie, and...      0
4  Strange yet emotionally disturbing chiller abo...      1
5  As a fan of Eric Rohmer's studies of the conte...      0
6  I am one of Jehovah's Witnesses and I also wor...      1
7  Watching Josh Kornbluth 'act' in this movie re...      0
8  I had fun watching this movie, mainly due to S...      1
9  The penultimate episode of Star Trek's third s...      1


## Model define

In [39]:
from transformers import AutoTokenizer

def tokenize_function(example):
    return tokenizer(
        example["text"],
        padding="max_length",        # 统一长度
        truncation=True,             # 截断长句
        max_length=256               # 设置最大长度（可以调整为128/256/512）
    )

# 加载 BERT 的 tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 对整个数据集进行编码（map 到每个子集）
tokenized_datasets = {
    split: dataset.map(tokenize_function, batched=True)
    for split, dataset in dataset.items()
}

Map:   0%|          | 0/35000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [40]:
print(tokenized_datasets['train'][0])

{'text': 'It\'s a Time Machine all right. It runs in "real time" for 96 minutes but it felt like 96 years. The first 20 minutes were utterly superfluous. Massive amounts of "dead" time throughout. What happened? When will something happen? Who cares? Apparently the film was made on a tight budget, I note for your edification the following: The Morlochs: nothing like saving a little money by reusing the sets and costumes from Lord of the Rings part I, hey? The "scary dude" in charge of controlling the Morlochs... The scariest thing these guys could think of was somebody wearing one of Gene Simmons: (of the band Kiss) old costumes??? Little-known fact: freaks of the future have perfectly manicured nails.<br /><br />Save your money, save your time. Pass on this one.', 'label': 0, 'input_ids': [101, 2009, 1005, 1055, 1037, 2051, 3698, 2035, 2157, 1012, 2009, 3216, 1999, 1000, 2613, 2051, 1000, 2005, 5986, 2781, 2021, 2009, 2371, 2066, 5986, 2086, 1012, 1996, 2034, 2322, 2781, 2020, 12580, 

In [47]:
from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [48]:
# 把 "label" 字段重命名为 "labels"
for split in tokenized_datasets:
    tokenized_datasets[split] = tokenized_datasets[split].rename_column("label", "labels")

In [51]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

# 设置 batch size
batch_size = 16

train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=default_data_collator)
val_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=batch_size, collate_fn=default_data_collator)
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size, collate_fn=default_data_collator)

In [43]:
from transformers import get_scheduler
from torch.optim import AdamW
import torch

optimizer = AdamW(model.parameters(), lr=2e-5)

# 使用 GPU 如果可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 训练步数
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

# 学习率调度器（线性 warmup）
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [44]:
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

loss_fn = CrossEntropyLoss()

In [45]:
import torch
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score


def train(model, train_dataloader, val_dataloader, optimizer, lr_scheduler, device, num_epochs=3):
    model.train()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        total_loss = 0
        all_preds = []
        all_labels = []

        # Training Loop
        model.train()
        loop = tqdm(train_dataloader, desc="Training")

        for batch in loop:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            logits = outputs.logits

            # Backward
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Metrics
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.detach().cpu().numpy())
            all_labels.extend(batch['labels'].detach().cpu().numpy())

        train_acc = accuracy_score(all_labels, all_preds)
        train_f1 = f1_score(all_labels, all_preds)

        print(f"Train Loss: {total_loss:.4f} | Accuracy: {train_acc:.4f} | F1: {train_f1:.4f}")

        # Validation Loop
        evaluate(model, val_dataloader, device)


@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(batch['labels'].detach().cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    print(f"Validation Accuracy: {acc:.4f} | F1 Score: {f1:.4f}")

In [52]:
train(model, train_dataloader, val_dataloader, optimizer, lr_scheduler, device, num_epochs=3)


Epoch 1/3


Training:  76%|███████▌  | 1659/2188 [1:38:24<31:22,  3.56s/it]


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

@torch.no_grad()
def final_evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch["labels"].cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    print(f"✅ Test Accuracy: {acc:.4f} | F1 Score: {f1:.4f}")

    # 混淆矩阵可视化
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["neg", "pos"], yticklabels=["neg", "pos"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
evaluate(model, test_dataloader, device)

In [None]:
torch.save(model.state_dict(), "bert_base.pt")

# from transformers import BertForSequenceClassification

# model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
# model.load_state_dict(torch.load("bert_sentiment.pt", map_location=device))
# model.to(device)
# model.eval()


tokenizer.save_pretrained("bert_base_tokenizer/")

# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("bert_tokenizer/")