In [1]:
from transformers import BertForMaskedLM, BertTokenizer
import torch
from transformers import AdamW, get_scheduler
from torch.utils.data import DataLoader
from src.dataset import TurtleSoupDataset
from src.utils import plot_training_validation_loss, plot_training_validation_acc, save_training_results
from src.model import PET, DiffPET
from run import train_pet_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
model = BertForMaskedLM.from_pretrained("bert-large-uncased").to(device)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archit

In [4]:
batch_size = 4
epochs = 10
learning_rate = 1e-5

template = "Based on the judgment rule, this player's guess is [MASK]"
label_map = {
    "T": "correct",
    "F": "incorrect",
    "N": "unknown"
}

In [5]:
train_data_path = "./data/TurtleBench-extended-en/train_8k.json"
test_data_path = "./data/TurtleBench-extended-en/test_1.5k.json"
prompt_path = "./prompts/prompt_en.json"

In [6]:
train_dataset = TurtleSoupDataset(train_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)
val_dataset = TurtleSoupDataset(test_data_path, prompt_path, tokenizer, max_length=512, template=template, label_map=label_map)

# 創建 DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## PET

In [12]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)



In [13]:
pet_model = PET(model, tokenizer, device)

train_losses, train_accuracies, val_losses, val_accuracies = train_pet_model(pet_model, train_dataloader, val_dataloader, optimizer, lr_scheduler, epochs=epochs)

In [14]:
plot_training_validation_loss(train_losses, val_losses)
plot_training_validation_acc(train_accuracies, val_accuracies)

In [15]:
pet_model.save_model('./params/bert-turtle-soup-pet-en')
save_training_results("pet_en", train_losses, train_accuracies, val_losses, val_accuracies)

## DiffPET

In [7]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)



In [8]:
torch.cuda.empty_cache()

In [9]:
labels = ['correct', 'incorrect', 'unknown']

In [19]:
diff_pet_model = DiffPET(model, tokenizer, template, labels, device)

train_losses, train_accuracies, val_losses, val_accuracies = train_pet_model(diff_pet_model, train_dataloader, val_dataloader, optimizer, lr_scheduler, epochs=epochs)

Epoch 1/10:   0%|          | 0/2039 [00:00<?, ?it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 1/2039 [00:00<27:27,  1.24it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 2/2039 [00:01<32:51,  1.03it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 3/2039 [00:02<24:52,  1.36it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 4/2039 [00:03<25:18,  1.34it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 5/2039 [00:04<27:19,  1.24it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 6/2039 [00:04<28:29,  1.19it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 7/2039 [00:05<24:12,  1.40it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 8/2039 [00:06<26:48,  1.26it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 9/2039 [00:07<28:10,  1.20it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   0%|          | 10/2039 [00:07<24:12,  1.40it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 11/2039 [00:08<26:25,  1.28it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 12/2039 [00:09<27:56,  1.21it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 13/2039 [00:10<24:04,  1.40it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 14/2039 [00:10<26:13,  1.29it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 15/2039 [00:11<27:34,  1.22it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 16/2039 [00:12<28:30,  1.18it/s]

template_positions:  torch.Size([4, 512])
template_ids: torch.Size([12])
input_ids[template_positions]: torch.Size([48])


Epoch 1/10:   1%|          | 16/2039 [00:13<28:46,  1.17it/s]

In [None]:
plot_training_validation_loss(train_losses, val_losses)
plot_training_validation_acc(train_accuracies, val_accuracies)

In [None]:
# diff_pet_model.save_model('./params/bert-turtle-soup-diffpet-en')
# save_training_results("diffpet_en", train_losses, train_accuracies, val_losses, val_accuracies)