In [1]:
import torch
from transformers import AutoTokenizer, FNetForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 加载 IMDb 数据集
dataset = load_dataset("imdb")

# 划分数据集
train_dataset = dataset["train"]
validation_dataset = dataset["test"]

Downloading readme: 100%|██████████| 7.81k/7.81k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 21.0M/21.0M [00:22<00:00, 938kB/s]
Downloading data: 100%|██████████| 20.5M/20.5M [01:00<00:00, 340kB/s]
Downloading data: 100%|██████████| 42.0M/42.0M [00:50<00:00, 840kB/s]
Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 214782.79 examples/s]
Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 285670.38 examples/s]
Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 278952.27 examples/s]


In [3]:
# 初始化模型和分词器
tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
model = FNetForSequenceClassification.from_pretrained("google/fnet-base", num_labels=2)

# 定义微调参数
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
num_epochs = 3
batch_size = 32

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Some weights of FNetForSequenceClassification were not initialized from the model checkpoint at google/fnet-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# 微调模型
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in train_loader:
        inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
        labels = torch.tensor(batch["label"])
        
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

  labels = torch.tensor(batch["label"])


KeyboardInterrupt: 

In [None]:
# 评估模型（可选）
# validation_loader = DataLoader(validation_dataset, batch_size=batch_size)
# evaluation_code_here()

# 推断（可选）
# inference_code_here()