In [1]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from BertCRF import BertCRF
import datasets
from dataloader import NERDataset


# load data
data = datasets.load_dataset('../../data/ner1', split='train')

# tokenize
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

def align_labels(text, tags):
    tokenized_inputs = tokenizer(text, truncation=True, padding=True, max_length=512, return_offsets_mapping=True,
                                is_split_into_words=False)
    offsets = tokenized_inputs['offset_mapping']
    new_labels = []
    new_mask = []
    tag_index = 0

    for offset in offsets[0]:
        if offset == (0, 0):
            new_labels.append(0)
            new_mask.append(0)
        else:

            if tag_index < len(tags):
                new_labels.append(tags[tag_index])
                new_mask.append(tokenized_inputs['attention_mask'][tag_index])
                tag_index += 1
            else:
                new_labels.append(0)
                new_mask.append(0)

    while len(new_labels) < len(tokenized_inputs['input_ids']):
        new_labels.append(0)
        new_mask.append(0)
    return {
        "input_ids": torch.tensor(tokenized_inputs['input_ids']),
        "attention_mask": torch.tensor(new_mask),
        "labels": torch.tensor(new_labels)
    }

def token_func(batch):

    batch_input_ids = []
    batch_attention_mask = []
    batch_labels = []

    for item in batch:

        tokenized_and_aligned = align_labels(text=item['text'], tags=item['tags'])
        batch_input_ids.append(tokenized_and_aligned['input_ids'])
        batch_attention_mask.append(tokenized_and_aligned['attention_mask'])
        batch_labels.append(tokenized_and_aligned['labels'])

    input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)
    labels = pad_sequence(batch_labels, batch_first=True, padding_value=0)


    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

batch_size = 16
data = NERDataset(data)

dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=token_func)

model = BertCRF('bert-base-uncased', num_tags=9).to("cuda:0")

ecohs = 20
optimizer = AdamW(model.parameters(), lr=5e-5)
for i in range(ecohs):
    train_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to("cuda:0")

        labels = batch['labels'].to("cuda:0")

        # tags = batch['tags'].to("cuda:0")
        attention_masks = batch['attention_mask'].to('cuda:0')

        loss = model(input_ids, attention_masks, labels=labels)

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    avg_epoch_loss = train_loss / len(dataloader)

    print("Epoch: {} Average loss: {:.4f}".format(i + 1, avg_epoch_loss))
