In [None]:
from transformers import BertForMaskedLM, BertTokenizer
import torch
from transformers import AdamW, get_scheduler
from torch.utils.data import DataLoader
from src.dataset import TurtleSoupDataset
from src.utils import plot_training_validation_loss, plot_training_validation_acc, save_training_results
from src.model import DiffPET
from run import train_pet_model

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
model = BertForMaskedLM.from_pretrained("bert-large-uncased").to(device)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

In [None]:
batch_size = 4
epochs = 10
learning_rate = 1e-5

template = "Based on the judgment rule, this player's guess is [MASK]"
label_map = {
    "T": "correct",
    "F": "incorrect",
    "N": "unknown"
}

In [None]:
train_data_path = "./data/TurtleBench-extended-en/train_8k.json"
test_data_path = "./data/TurtleBench-extended-en/test_1.5k.json"
prompt_path = "./prompts/prompt_en.json"

In [None]:
train_dataset = TurtleSoupDataset(train_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)
val_dataset = TurtleSoupDataset(test_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)

# 創建 DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [None]:
labels = ['correct', 'incorrect', 'unknown']

In [None]:
diff_pet_model = DiffPET(model, tokenizer, template, labels, device)

train_losses, train_accuracies, val_losses, val_accuracies = train_pet_model(diff_pet_model, train_dataloader, val_dataloader, optimizer, lr_scheduler, epochs=epochs)

In [None]:
plot_training_validation_loss(train_losses, val_losses)
plot_training_validation_acc(train_accuracies, val_accuracies)