In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# === 3. 导入模块 ===
import os
import pandas as pd
from sklearn.model_selection import train_test_split
# os.environ['WANDB_API_KEY'] = '255fab36462f5587d825c69b9d5b53a852a2c4d3'  # 替换为你的实际 API 密钥
os.environ['WANDB_MODE'] = 'disabled'

# === 4. 数据加载与预处理 ===
def load_data(file_path):
    df = pd.read_csv(file_path)
    texts = df['text'].tolist()
    labels = df['Y'].tolist()
    return texts, labels

# 路径配置（请确保文件已上传到Google Drive）
train_file = "/content/drive/MyDrive/Colab Notebooks/sarcasm/sarcasm_train.csv"
test_file = "/content/drive/MyDrive/Colab Notebooks/sarcasm/sarcasm_test.csv"

# 加载数据
train_texts, train_labels = load_data(train_file)
test_texts, test_labels = load_data(test_file)

# 分割训练集/验证集
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_texts, train_labels, test_size=0.2, random_state=42
)

print(f"conjunto de entrenamiento: {len(train_texts)} | conjunto de validación: {len(val_texts)} | conjunto de pruebas: {len(test_texts)}")


conjunto de entrenamiento: 16026 | conjunto de validación: 4007 | conjunto de pruebas: 8586


In [5]:
import os
import re
import numpy as np

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import torch

# === 5. BERT专用预处理 ===
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess(text):
    text = text.lower()
    text = re.sub(r'http\S+', '', text)          # 移除URL
    text = re.sub(r'\b(not|no|never)\b\s*', r'\1_', text)  # 处理否定词
    return text.strip()

# 清洗数据
cleaned_train = [preprocess(text) for text in train_texts]
cleaned_val = [preprocess(text) for text in val_texts]
cleaned_test = [preprocess(text) for text in test_texts]

# 转换为Dataset格式
class SarcasmDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels):
        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=128,  # 讽刺文本通常较短
            return_tensors='pt'
        )
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = SarcasmDataset(cleaned_train, train_labels)
val_dataset = SarcasmDataset(cleaned_val, val_labels)
test_dataset = SarcasmDataset(cleaned_test, test_labels)

# === 6. 模型配置与训练 ===
model_save_path = "/content/drive/MyDrive/Colab Notebooks/model/sarcasm"
os.makedirs(model_save_path, exist_ok=True)

# 加载预训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2,
    attention_probs_dropout_prob=0.1  # 增加Dropout防止过拟合
).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# 训练参数
training_args = TrainingArguments(
    output_dir=model_save_path,
    evaluation_strategy="epoch",
    learning_rate=5e-5,          # 更高的学习率
    per_device_train_batch_size=8,
    per_device_eval_batch_size=64,
    num_train_epochs=5,          # 增加训练轮次
    weight_decay=0.01,
    save_total_limit=2,
    load_best_model_at_end=True, # 根据验证集加载最优模型
    metric_for_best_model="f1",
    logging_dir='./logs',
    logging_steps=50,
    save_strategy="epoch"
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    accuracy = accuracy_score(labels, preds)
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}


# 添加早停机制
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# 执行训练
trainer.train()

# 保存最佳模型
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)

# === 7. 模型评估 ===
predictions = trainer.predict(test_dataset)
bert_preds = np.argmax(predictions.predictions, axis=1)

# 计算指标
accuracy = accuracy_score(test_labels, bert_preds)
precision = precision_score(test_labels, bert_preds)
recall = recall_score(test_labels, bert_preds)
f1 = f1_score(test_labels, bert_preds)

print(f"\nBERT Results for Sarcasm Detection:")
print(f"Accuracy: {accuracy:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1-Score: {f1:.4f}")



Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.2919,0.286147,0.895433,0.920091,0.852459,0.884985
2,0.204,0.354835,0.907162,0.894956,0.9101,0.902465
3,0.0575,0.519682,0.913651,0.922362,0.892121,0.906989
4,0.023,0.573443,0.918642,0.906072,0.923321,0.914615
5,0.0001,0.644421,0.918393,0.91863,0.907456,0.913009



BERT Results for Sarcasm Detection:
Accuracy: 0.9115 | Precision: 0.9031 | Recall: 0.9115 | F1-Score: 0.9073


In [7]:
# === 8. 错误案例分析 ===
errors = []
for i in range(len(test_labels)):
    if bert_preds[i] != test_labels[i]:
        errors.append({
            'Text': test_texts[i],
            'Cleaned Text': cleaned_test[i],
            'True': test_labels[i],
            'Predicted': bert_preds[i],
            'Logits': predictions.predictions[i].tolist()
        })

print(f"\nNúmero total de errores: {len(errors)} (tasa de error: {len(errors)/len(test_labels):.2%})")

print("Ejemplos de errores críticos:")
for idx, err in enumerate(errors[:5], 1):
    print(f"\nCaso {idx}:")
    print(f"Etiqueta real: {err['True']}, Predicción: {err['Predicted']}")
    print(f"Texto original: {err['Text'][:200]!r}")
    print(f"Texto procesado: {err['Cleaned Text'][:200]!r}")
    print(f"Logits: {err['Logits']}")


Número total de errores: 760 (tasa de error: 8.85%)
Ejemplos de errores críticos:

Caso 1:
Etiqueta real: 1, Predicción: 0
Texto original: 'raytheon ceo sends obama another article about mounting unrest in libya'
Texto procesado: 'raytheon ceo sends obama another article about mounting unrest in libya'
Logits: [4.342543601989746, -3.9952735900878906]

Caso 2:
Etiqueta real: 0, Predicción: 1
Texto original: 'paula abdul back at it'
Texto procesado: 'paula abdul back at it'
Logits: [-4.56586217880249, 4.266960144042969]

Caso 3:
Etiqueta real: 0, Predicción: 1
Texto original: 'fidelity matches some ira contributions'
Texto procesado: 'fidelity matches some ira contributions'
Logits: [-4.588874340057373, 4.459471702575684]

Caso 4:
Etiqueta real: 1, Predicción: 0
Texto original: 'bus transporting carnival cruise passengers crashes into sewage treatment plant'
Texto procesado: 'bus transporting carnival cruise passengers crashes into sewage treatment plant'
Logits: [4.005356311798096, -3.

In [8]:
errors = []
corrects = []  # 新增正确案例列表

for i in range(len(test_labels)):
    if bert_preds[i] != test_labels[i]:
        errors.append({
            'Text': test_texts[i],
            'Cleaned Text': cleaned_test[i],
            'True': test_labels[i],
            'Predicted': bert_preds[i],
            'Logits': predictions.predictions[i].tolist()
        })
    else:
        corrects.append({  # 新增正确案例收集
            'Text': test_texts[i],
            'Cleaned Text': cleaned_test[i],
            'True': test_labels[i],
            'Predicted': bert_preds[i],
            'Logits': predictions.predictions[i].tolist()
        })

# 打印错误统计
print(f"\nNúmero total de errores: {len(errors)} (tasa de error: {len(errors)/len(test_labels):.2%})")
print(f"Número total de aciertos: {len(corrects)} (tasa de acierto: {len(corrects)/len(test_labels):.2%})")  # 新增正确率

# 打印错误示例
print("\nEjemplos de errores críticos:")
for idx, err in enumerate(errors[:5], 1):
    print(f"\nCaso {idx}:")
    print(f"Etiqueta real: {err['True']}, Predicción: {err['Predicted']}")
    print(f"Texto original: {err['Text'][:200]!r}")
    print(f"Texto procesado: {err['Cleaned Text'][:200]!r}")
    print(f"Logits: {err['Logits']}")

# 新增正确示例打印
print("\nEjemplos de predicciones correctas:")
for idx, cor in enumerate(corrects[:5], 1):
    print(f"\nCaso correcto {idx}:")
    print(f"Etiqueta real: {cor['True']}, Predicción: {cor['Predicted']}")
    print(f"Texto original: {cor['Text'][:200]!r}")
    print(f"Texto procesado: {cor['Cleaned Text'][:200]!r}")
    print(f"Logits: {cor['Logits']}")


Número total de errores: 760 (tasa de error: 8.85%)
Número total de aciertos: 7826 (tasa de acierto: 91.15%)

Ejemplos de errores críticos:

Caso 1:
Etiqueta real: 1, Predicción: 0
Texto original: 'raytheon ceo sends obama another article about mounting unrest in libya'
Texto procesado: 'raytheon ceo sends obama another article about mounting unrest in libya'
Logits: [4.342543601989746, -3.9952735900878906]

Caso 2:
Etiqueta real: 0, Predicción: 1
Texto original: 'paula abdul back at it'
Texto procesado: 'paula abdul back at it'
Logits: [-4.56586217880249, 4.266960144042969]

Caso 3:
Etiqueta real: 0, Predicción: 1
Texto original: 'fidelity matches some ira contributions'
Texto procesado: 'fidelity matches some ira contributions'
Logits: [-4.588874340057373, 4.459471702575684]

Caso 4:
Etiqueta real: 1, Predicción: 0
Texto original: 'bus transporting carnival cruise passengers crashes into sewage treatment plant'
Texto procesado: 'bus transporting carnival cruise passengers crashes in