In [1]:
from transformers import BertForMaskedLM, BertTokenizerFast
import torch
from transformers import AdamW, get_scheduler
from torch.utils.data import DataLoader
from src.dataset import ContrastiveLearningDataset
from src.utils import plot_training_validation_loss
from src.model import ContrastiveTrainer
from run import train_contrastive_model
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [3]:
model = BertForMaskedLM.from_pretrained("bert-base-chinese").to(device)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [4]:
train_data_path = "./data/TurtleBench-extended-zh/train_8k.json"
test_data_path = "./data/TurtleBench-extended-zh/test_1.5k.json"
prompt_path = "./prompts/prompt_zh.json"

In [5]:
batch_size = 4
epochs = 10
learning_rate = 1e-5

In [6]:
contrastive_dataset = ContrastiveLearningDataset(
    data_path=train_data_path,
    tokenizer=tokenizer,
    max_length=256
)

train_indices, val_indices = train_test_split(
    list(range(len(contrastive_dataset))), test_size=0.15
)
train_dataset = torch.utils.data.Subset(contrastive_dataset, train_indices)
val_dataset = torch.utils.data.Subset(contrastive_dataset, val_indices)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
num_training_steps = len(train_dataloader) * 10
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)



In [None]:
contrastive_trainer = ContrastiveTrainer(model, tokenizer, device)
train_losses, val_losses = train_contrastive_model(
    trainer=contrastive_trainer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    epochs=5,
    margin=1.0
)

In [None]:
plot_training_validation_loss(train_losses, val_losses)