In [3]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from BertBiLSTMCRF import BertBiLSTMCRF
import datasets
from dataloader import NERDataset

In [4]:
# load data
data = datasets.load_dataset("json", data_files='dataset.json', split='train')
train_dataset, val_dataset = data.train_test_split(test_size=0.2).values()

In [5]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [6]:
# process the EOS tag
def align_labels(text, tags):
    tokenized_inputs = tokenizer(text, truncation=True, padding=True, max_length=512, return_offsets_mapping=True,
                                is_split_into_words=False)
    offsets = tokenized_inputs['offset_mapping']
    new_labels = []
    tag_index = 0

    for offset in offsets:
        if offset == (0, 0):
            new_labels.append(0)
        else:
            if tag_index < len(tags):
                new_labels.append(tags[tag_index])
                tag_index += 1
            else:
                new_labels.append(0)

    while len(new_labels) < len(tokenized_inputs['input_ids']):
        new_labels.append(0)
    return {
        "input_ids": torch.tensor(tokenized_inputs['input_ids']),
        "attention_mask": torch.tensor(tokenized_inputs['attention_mask']),
        "labels": torch.tensor(new_labels)
    }

In [7]:
# padding tag and mask to remain the same dim of inputs_ids
def token_func(batch):

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    for item in batch:

        tokenized_and_aligned = align_labels(text=item['text'], tags=item['tags'])
        batch_input_ids.append(tokenized_and_aligned['input_ids'])
        batch_attention_mask.append(tokenized_and_aligned['attention_mask'])
        batch_labels.append(tokenized_and_aligned['labels'])

    input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)
    labels = pad_sequence(batch_labels, batch_first=True, padding_value=0)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [8]:
batch_size = 16
train_dataset = NERDataset(train_dataset)
val_dataset = NERDataset(val_dataset)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=token_func)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=token_func)

model = BertBiLSTMCRF('bert-base-uncased', num_tags=9, lstm_hidden_size=768).to("cuda:0")

In [None]:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

epoch = 20
optimizer = AdamW(model.parameters(), lr=5e-5)
best_val_loss = float('inf')
save_path = 'best_model.pth'

train_losses = []
val_losses = []
accuracies = []
precisions = []
recalls = []
f1_scores = []

for epoch in range(epoch):
    train_loss = 0
    model.train()
    for batch in train_dataloader:
        

        input_ids = batch['input_ids'].to("cuda:0")
        labels = batch['labels'].to("cuda:0")
        attention_masks = batch['attention_mask'].to('cuda:0')

        loss = model(input_ids, attention_masks, labels=labels)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        torch.cuda.empty_cache()
        
    avg_epoch_loss = train_loss / len(train_dataloader)

    print("Epoch: {} Average loss: {:.6f}".format(epoch + 1, avg_epoch_loss))

    model.eval()
    val_loss = 0
    predictions = []
    true_labels_trimmed = []

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to("cuda:0")
            labels = batch['labels'].to("cuda:0")
            attention_masks = batch['attention_mask'].to('cuda:0')
        
            loss = model(input_ids, attention_mask=attention_masks, labels=labels)
            val_loss += loss.item()

            predictions_batch = model(input_ids, attention_mask=attention_masks)


            for num, pred in enumerate(predictions_batch):
                
                actual_length = len(pred)
                trimmed_labels = labels[num][:actual_length]
                true_labels_trimmed += trimmed_labels.tolist()
                predictions += pred 
            
    precision, recall, f1, _ = precision_recall_fscore_support(y_true=true_labels_trimmed, 
                                                               y_pred=predictions, 
                                                               average='macro', 
                                                               zero_division=0
                                                               )
    ac_score = accuracy_score(y_true=true_labels_trimmed, y_pred=predictions)
    avg_val_loss = val_loss / len(val_dataloader)
    
    print(f"Validation loss: {avg_val_loss:.6f}")
    print(f'Accuracy on validation set: {ac_score:.6f}')
    print(f"Precision: {precision:.6f}, Recall: {recall:.6f}, F1: {f1:.6f}")
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss  
        torch.save(model.state_dict(), save_path)  
        print(f"Epoch {epoch + 1}: Validation loss improved, model saved to {save_path}")
    else:
        print(f"Epoch {epoch + 1}: Validation loss did not improve from {best_val_loss:.6f}")

    train_losses.append(avg_epoch_loss)
    val_losses.append(avg_val_loss)
    accuracies.append(ac_score)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

In [None]:
import matplotlib.pyplot as plt

epochs = range(1, 20)

# 绘制训练和验证损失图
plt.figure(figsize=(10, 7))
plt.plot(epochs, train_losses[1:], 'bo-', label='Training Loss')
plt.plot(epochs, val_losses[1:], 'ro-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制准确率、精确度、召回率和F1得分图
plt.figure(figsize=(10, 7))
plt.plot(epochs, accuracies[1:], 'go-', label='Accuracy')
plt.plot(epochs, precisions[1:], 'mo-', label='Precision')
plt.plot(epochs, recalls[1:], 'co-', label='Recall')
plt.plot(epochs, f1_scores[1:], 'yo-', label='F1 Score')
plt.title('Accuracy, Precision, Recall, and F1 Score')
plt.xlabel('Epochs')
plt.ylabel('Score')
plt.legend()
plt.show()