# Importing libraries and setting seeds for reprodicibility

In [1]:
import random
import numpy as np
import functools
import torch
import torchtext.functional as F
from torch.utils.data import DataLoader
from torchtext.datasets import UDPOS
from torchtext.vocab import build_vocab_from_iterator
from transformers import BertModel,BertTokenizer


from preprocessing import *
from model import *

In [2]:
SEED = 7
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#torch.cuda.manual_seed_all(SEED)
device='cpu'
torch.use_deterministic_algorithms(True)

# Creating the dataloaders

In [3]:
# Hyperparameters
BATCH_SIZE = 32
TRANSFORMER = "bert-base-uncased"
# Setting up dataloaders for training.
tokenizer = BertTokenizer.from_pretrained(TRANSFORMER) # tokenizer for BERT
init_token = tokenizer.cls_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token
sep_token = tokenizer.sep_token
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)
sep_token_idx = tokenizer.convert_tokens_to_ids(sep_token)
max_input_length = tokenizer.max_model_input_sizes[TRANSFORMER]



train_datapipe = UDPOS(split="train")
valid_datapipe = UDPOS(split="valid")
pos_vocab = build_vocab_from_iterator(
    [i[1] for i in list(train_datapipe)],
    specials=[init_token, pad_token, sep_token],
)
T_CAL = torch.tensor([i for i in range(pos_vocab.__len__())])

text_preprocessor = functools.partial(
    prepare_words,
    tokenizer=tokenizer,
    max_input_length=max_input_length,
    init_token=init_token,
    sep_token=sep_token,
)

tag_preprocessor = functools.partial(
    prepare_tags,
    max_input_length=max_input_length,
    init_token=init_token,
    sep_token=sep_token,
    pos_vocab=pos_vocab
)


def apply_transform(x):
    return text_preprocessor(x[0]), tag_preprocessor(x[1])


train_datapipe = (
    train_datapipe.map(apply_transform)
    .batch(BATCH_SIZE)
    .rows2columnar(["words", "pos"])
)
train_dataloader = DataLoader(train_datapipe, batch_size=None, shuffle=False)

valid_datapipe = (
    valid_datapipe.map(apply_transform)
    .batch(BATCH_SIZE)
    .rows2columnar(["words", "pos"])
)
valid_dataloader = DataLoader(valid_datapipe, batch_size=None, shuffle=False)

Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.


# Training

In [6]:
EPOCHS = 3
LR = 2e-5

best_acc = 0.0
betas = [0.0,0.1,1]
for beta in betas:
    print('beta : {} \n'.format(beta))
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    bert = BertModel.from_pretrained(TRANSFORMER)
    crf = NeuralCRF(
        pad_idx_word=pad_token_idx,
        pad_idx_pos=pos_vocab[pad_token],
        bos_idx=init_token_idx,
        eos_idx=sep_token_idx,
        bot_idx=pos_vocab[init_token],
        eot_idx=pos_vocab[sep_token],
        t_cal=T_CAL,
        transformer=bert,
        beta=beta
    )
    if device!= 'cpu' : torch.cuda.empty_cache()
    crf.to(device)
    accuracy = train_model_report_accuracy(
        crf,
        LR,
        EPOCHS,
        train_dataloader,
        valid_dataloader,
        pad_token_idx,
        pos_vocab[pad_token],
        device
    )
    if accuracy > best_acc:
        best_acc = accuracy
        best_model = crf
        best_beta = beta

beta : 0.0 

Epoch: 1 / 3


100%|██████████| 392/392 [26:13<00:00,  4.01s/it]


-------------------------
Development set accuracy: 0.8823398351669312
-------------------------
Epoch: 2 / 3


100%|██████████| 392/392 [25:07<00:00,  3.85s/it]


-------------------------
Development set accuracy: 0.9094575643539429
-------------------------
Epoch: 3 / 3


100%|██████████| 392/392 [23:58<00:00,  3.67s/it]


-------------------------
Development set accuracy: 0.9199662208557129
-------------------------
beta : 0.1 

Epoch: 1 / 3


100%|██████████| 392/392 [33:50<00:00,  5.18s/it]


-------------------------
Development set accuracy: 0.8871060609817505
-------------------------
Epoch: 2 / 3


100%|██████████| 392/392 [33:51<00:00,  5.18s/it]


-------------------------
Development set accuracy: 0.9016318321228027
-------------------------
Epoch: 3 / 3


100%|██████████| 392/392 [33:51<00:00,  5.18s/it]


-------------------------
Development set accuracy: 0.9221358299255371
-------------------------
beta : 1 

Epoch: 1 / 3


100%|██████████| 392/392 [33:54<00:00,  5.19s/it]


-------------------------
Development set accuracy: 0.8996726274490356
-------------------------
Epoch: 2 / 3


100%|██████████| 392/392 [5:30:14<00:00, 50.55s/it]     


-------------------------
Development set accuracy: 0.9156255125999451
-------------------------
Epoch: 3 / 3


100%|██████████| 392/392 [34:36<00:00,  5.30s/it]


-------------------------
Development set accuracy: 0.9252203702926636
-------------------------


In [7]:
PATH="pos_model.pt"
torch.save(best_model.state_dict(), PATH)