In [None]:
!pip install datasets

In [10]:
# get the dataset
from dataset import NERDataset
from dataset import collate_batch

train_dataset = NERDataset(mode='train')
test_dataset = NERDataset(mode='test')
dev_dataset = NERDataset(mode='dev')

In [11]:
from tqdm import tqdm
import torch.nn as nn
import torch

# train the model for only one epoch
def train_epoch(model, train_dataloader, optimizer, epoch=None, clip=None, device='cpu'):
    loop = tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        desc=f'Training {epoch if epoch else ""}',
    )

    model.train()
    train_loss = 0
    for i, batch in loop:
        input_ids, token_type_ids, tags_ids = batch[:3]
        input_ids, token_type_ids, tags_ids = input_ids.to(device), token_type_ids.to(device), tags_ids.to(device)

        optimizer.zero_grad()

        # make the prediction
        loss = model(
            input_ids,
            token_type_ids=token_type_ids,
            attention_mask=(input_ids != 0).long().to(device),
            labels=tags_ids,
        )[0]

        loss.backward()
        if clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        train_loss += loss.item()
        loop.set_postfix(**{"loss": train_loss / (i + 1)})
    return train_loss / len(train_dataloader)

# evaluate the model for only one epoch
def eval_epoch(model, eval_dataloader, epoch=None, device='cpu'):
    loop = tqdm(
        enumerate(eval_dataloader),
        total=len(eval_dataloader),
        desc=f'Evaluating {epoch if epoch else ""}',
    )

    model.eval()
    eval_loss = 0
    with torch.no_grad():
        for i, batch in loop:
            input_ids, token_type_ids, tags_ids = batch[:3]
            input_ids, token_type_ids, tags_ids = input_ids.to(device), token_type_ids.to(device), tags_ids.to(device)

            # make the prediction
            loss = model(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=(input_ids != 0).long().to(device),
                labels=tags_ids,
            )[0]

            eval_loss += loss.item()
            loop.set_postfix(**{"loss": eval_loss / (i + 1)})
    return eval_loss / len(eval_dataloader)


def train(
    model=None,
    loaders=None,
    optimizer=None,
    epochs=10,
    device=None,
    clip_grad=None,
    ckpt_path='best.pt',
    best_loss=float('inf'),
    cur_epoch=1,
    return_model=False,
):
    epoch_cnt = 0
    for epoch in range(cur_epoch, epochs + cur_epoch):
        train_loss = train_epoch(model, loaders[0], optimizer, epoch, clip_grad, device)
        if len(loaders) > 1:
            val_loss = eval_epoch(model, loaders[1], epoch, device)
        else:
            val_loss = train_loss

        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model, ckpt_path)

    if return_model:
        return best_loss, model
    return best_loss

In [36]:
# get the dataloaders
from torch.utils.data import DataLoader

BATCH_SIZE = 16
NUM_WORKERS = 0

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_batch)
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_batch)

In [37]:
for batch in train_dataloader:
    input_ids, token_type_ids, tags_ids = batch
    print('input_ids.shape:', input_ids.shape)
    print('token_type_ids.shape:', token_type_ids.shape)
    print('tags_ids.shape:', tags_ids.shape)
    break

input_ids.shape: torch.Size([16, 128])
token_type_ids.shape: torch.Size([16, 128])
tags_ids.shape: torch.Size([16, 128])


In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Fine-tuning the model

In [50]:
# get the model
from transformers import BertForTokenClassification

model = BertForTokenClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels = 29 * 4 + 1, return_dict = False).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
# parameters
LEARNING_RATE = 1e-3
EPOCHS = 10

parameters = [p for n, p in model.named_parameters() if n.startswith('classifier')]
optimizer = torch.optim.Adam(parameters, lr=LEARNING_RATE)

In [52]:
train(model, loaders=(train_dataloader, test_dataloader), optimizer=optimizer, epochs=EPOCHS, device=device)

Training 1: 100%|██████████| 29/29 [00:37<00:00,  1.29s/it, loss=2.82]
Evaluating 1: 100%|██████████| 6/6 [00:01<00:00,  4.02it/s, loss=1.84]
Training 2: 100%|██████████| 29/29 [00:24<00:00,  1.17it/s, loss=1.5]
Evaluating 2: 100%|██████████| 6/6 [00:01<00:00,  4.00it/s, loss=1.42]
Training 3: 100%|██████████| 29/29 [00:25<00:00,  1.12it/s, loss=1.17]
Evaluating 3: 100%|██████████| 6/6 [00:01<00:00,  3.36it/s, loss=1.19]
Training 4: 100%|██████████| 29/29 [00:26<00:00,  1.09it/s, loss=0.981]
Evaluating 4: 100%|██████████| 6/6 [00:01<00:00,  4.00it/s, loss=1.05]
Training 5: 100%|██████████| 29/29 [00:24<00:00,  1.17it/s, loss=0.854]
Evaluating 5: 100%|██████████| 6/6 [00:01<00:00,  3.92it/s, loss=0.957]
Training 6: 100%|██████████| 29/29 [00:24<00:00,  1.18it/s, loss=0.77]
Evaluating 6: 100%|██████████| 6/6 [00:01<00:00,  4.11it/s, loss=0.891]
Training 7: 100%|██████████| 29/29 [00:24<00:00,  1.19it/s, loss=0.707]
Evaluating 7: 100%|██████████| 6/6 [00:01<00:00,  3.58it/s, loss=0.838]
T

0.7487468322118124

In [54]:
model = torch.load('best.pt').to(device)

In [55]:
def validate(text, model, dataset, addit=0):
    t = dataset.tokenizer(text)

    input_ids = torch.LongTensor(t['input_ids']).reshape(1, -1)
    token_type_ids = torch.LongTensor(t['token_type_ids']).reshape(1, -1)
    attention_mask = (input_ids != 0).long()

    input_ids = input_ids.to(device)
    token_type_ids = token_type_ids.to(device)
    attention_mask = attention_mask.to(device)

    model.eval()
    output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
    id2tag = dataset.id2tags
    typs = output[0].argmax(-1).squeeze()

    res = []
    start = 0
    ltag = []
    i = 0
    while i < len(typs):
        if typs[i] == 0:
            i += 1
            continue

        j = i + 1
        while j < len(typs):
            if id2tag[typs[j].item()][2:] != id2tag[typs[i].item()][2:]:
                break
            j += 1

        dataset.tokenizer.decode(t['input_ids'][i:j])
        target = dataset.tokenizer.decode(t['input_ids'][i:j])

        res.append([text.find(target, start) + addit, text.find(target, start) + len(target) + addit - 1, id2tag[typs[i].item()][2:]])
        #print(res[-1], target)
        i = j

    return res

In [56]:
validate("Привет Максим, я в Париже", model, train_dataset)

[[7, 12, 'PERSON'], [19, 24, 'CITY']]

In [57]:
import json
write = open("test.jsonl", "w")

with open("target_test.jsonl", "r") as f:
    for line in f.readlines():
        start = 0
        l = json.loads(line)
        sentences = l['senences']
        l['ners'] = []
        for sentence in sentences.split('\n'):
            l['ners'].extend(validate(sentence, model, train_dataset, addit=start))
            start += len(sentence) + 1
        write.write(json.dumps(l))
        write.write('\n')
write.close()