In [None]:
from pathlib import Path
import pandas as pd
from classification_fine_tuning import (
    download_and_unzip_spam_data,
    create_balanced_dataset,
    random_split,
    SpamDataset
)
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

# 下载并解压Spam数据集
download_and_unzip_spam_data(url=url, zip_path=zip_path, extracted_path=extracted_path, data_file_path=data_file_path)
# 读取数据
df = pd.read_csv(filepath_or_buffer=data_file_path, sep="\t", header=None, names=["Label", "Text"])
# 创建平衡数据集
balanced_df = create_balanced_dataset(df=df)
balanced_df['Label'] = balanced_df['Label'].map({"ham": 0, "spam": 1})
# 随机划分训练集、验证集和测试集
train_df, validation_df, test_df = random_split(df=balanced_df, train_frac=0.7, validation_frac=0.1)
# 保存数据集到csv文件
train_df.to_csv(path_or_buf="classification_lora_train.csv", index=None)
validation_df.to_csv(path_or_buf="classification_lora_val.csv", index=None)
test_df.to_csv(path_or_buf="classification_lora_test.csv", index=None)
# 初始化分词器
tokenizer = tiktoken.get_encoding("gpt2")
# 创建训练、验证、测试SpamDataset（包含自动填充长度）
train_dataset = SpamDataset(csv_file="classification_lora_train.csv", max_length=None, tokenizer=tokenizer)
val_dataset = SpamDataset(csv_file="classification_lora_val.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)
test_dataset = SpamDataset(csv_file="classification_lora_test.csv", max_length=train_dataset.max_length, tokenizer=tokenizer)
# 创建PyTorch DataLoader
num_workers = 0
batch_size = 8
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False)
# 检查批次维度、数据集中的总批次
print("Train loader:")
for input_batch, target_batch in train_loader:
    pass
print("Input batch dimensions:", input_batch.shape)
print("Label batch dimensions:", target_batch.shape)

print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches")

2026-02-05 14:36:31.883937: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-05 14:36:31.972859: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-05 14:36:34.071161: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


File already exists and is up-to-date: gpt2/124M/checkpoint
File already exists and is up-to-date: gpt2/124M/encoder.json
File already exists and is up-to-date: gpt2/124M/hparams.json
File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/124M/model.ckpt.index
File already exists and is up-to-date: gpt2/124M/model.ckpt.meta


KeyboardInterrupt: 