In [1]:
from transformers import BertForMaskedLM, BertTokenizerFast, BertTokenizer
import torch
from transformers import AdamW, get_scheduler
from torch.utils.data import DataLoader, random_split
from src.dataset import TurtleSoupDataset
from src.utils import plot_training_validation_loss, plot_training_validation_acc
from src.model import PET, DiffPET
from run import train_pet_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [4]:
model = BertForMaskedLM.from_pretrained("bert-large-uncased").to(device)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archit

In [5]:
batch_size = 8
epochs = 20
learning_rate = 1e-5

template = "Based on the judgment rule, this player's guess is [MASK]."
label_map = {
    "Correct": "correct",
    "Incorrect": "incorrect",
    "Unknown": "unknown"
}

In [3]:
train_data_path = "./data/en_train_8k.json"
test_data_path = "./data/en_test_1.5k.json"
prompt_path = "./prompts/prompt_en.json"

In [6]:
train_dataset = TurtleSoupDataset(train_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)
val_dataset = TurtleSoupDataset(test_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)

# 創建 DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)



In [None]:
pet_model = PET(model, tokenizer, device)

train_losses, train_accuracies, val_losses, val_accuracies = train_pet_model(pet_model, train_dataloader, val_dataloader, optimizer, lr_scheduler, epochs=epochs)

In [None]:
plot_training_validation_loss(train_losses, val_losses)
plot_training_validation_acc(train_accuracies, val_accuracies)

In [None]:
# pet_model.save_model('./params/bert-turtle-soup-pet-en')

## DiffPET

In [None]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

In [None]:
torch.cuda.empty_cache()

In [None]:
labels = ['correct', 'incorrect', 'unknown']

In [None]:
diff_pet_model = DiffPET(model, tokenizer, template, labels, device)

train_losses, train_accuracies, val_losses, val_accuracies = train_pet_model(diff_pet_model, train_dataloader, val_dataloader, optimizer, lr_scheduler, epochs=epochs)

In [None]:
plot_training_validation_loss(train_losses, val_losses)
plot_training_validation_acc(train_accuracies, val_accuracies)

In [None]:
# diff_pet_model.save_model('./params/bert-turtle-soup-diffpet-en')