In [62]:
from datasets import load_dataset, load_from_disk
from transformers import BertForMaskedLM, BertTokenizerFast, BertForTokenClassification
import torch
from tqdm import tqdm
import numpy as np

In [81]:
dataset = load_dataset('R5dwMg/foodiereview_yue', split='train')
# dataset = load_from_disk('data/nlptea_dataset')['train']

In [82]:
tokenizer = BertTokenizerFast.from_pretrained("hon9kon9ize/bert-base-cantonese")
# model = BertForMaskedLM.from_pretrained("models/canto-pretrain/checkpoint-27000")
# model = BertForMaskedLM.from_pretrained("bert-base-chinese")
# model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=2) 
# model = BertForTokenClassification.from_pretrained("models/canto-pretrain/checkpoint-27000", num_labels=2)
model = BertForMaskedLM.from_pretrained("hon9kon9ize/bert-base-cantonese")

In [83]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

def collate_fn(batch, tokenizer):
    texts = [item['text'] for item in batch]
    return tokenizer(texts, return_tensors='pt', padding=True, truncation=True)

def evaluate_model_batched(model, tokenizer, dataset, batch_size=8, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.to(device)
    model.eval()

    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda b: collate_fn(b, tokenizer))

    total_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = input_ids.clone()

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            total_loss += loss.item() * input_ids.size(0)  # Multiply by batch size
            total_samples += input_ids.size(0)

    average_loss = total_loss / total_samples
    return average_loss


In [84]:
average_loss = evaluate_model_batched(model, tokenizer, dataset, batch_size=8, device='mps')
print(f"Average Loss for BERT Model: {average_loss}")

 57%|█████▋    | 407/720 [05:54<04:32,  1.15it/s]


KeyboardInterrupt: 

In [77]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"cantonese_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens and non-target labels to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx] if label[word_idx] == 1 else -100)
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

def evaluate_token_class():
    total_loss = 0.0
    total_tokens = 0

    tokenized_dataset = dataset.map(
        tokenize_and_align_labels,
        batched=True,
    )
    for example in tqdm(tokenized_dataset):
        inputs = tokenizer(example["tokens"], truncation=True, is_split_into_words=True, return_tensors="pt").to('mps')
        labels = torch.tensor(example["labels"]).unsqueeze(0).to('mps')

        with torch.no_grad():
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            # check if loss.item() is a NaN
            if loss.item() is not None and not np.isnan(loss.item()):
                total_loss += loss.item()
                total_tokens += inputs.input_ids.size(1)
        

    average_loss = total_loss / total_tokens
    return average_loss

In [78]:
average_loss = evaluate_token_class()
print(f"Average Loss for Token Classification: {average_loss}")

100%|██████████| 2062/2062 [01:12<00:00, 28.40it/s]

Average Loss for Token Classification: 0.8704476618805232



