In [647]:
import os
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertConfig, BertForTokenClassification

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

cpu


In [648]:
dataframe = pd.read_csv('data.csv',encoding='unicode-escape')
dataframe.tag.value_counts()

OTHER        1314
B-POKEMON     139
B-MOVE         72
I-MOVE         46
I-POKEMON      20
B-ABILITY      14
B-ITEM         11
I-ITEM         11
B-STATS         9
B-TYPE          8
B-NATURE        3
B-STATUS        3
I-ABILITY       1
B-TIER          1
I-TIER          1
Name: tag, dtype: int64

In [649]:
labels2ids = {k:v for v,k in enumerate(dataframe.tag.unique())}
ids2labels = {v:k for v,k in enumerate(dataframe.tag.unique())} 

In [650]:
dataframe['sentence_copy'] = dataframe[['sentence no','word','tag']].groupby(['sentence no'])['word'].transform(lambda x: ' '.join(x))
dataframe['word_labels'] = dataframe[['sentence no','word','tag']].groupby(['sentence no'])['tag'].transform(lambda x: ','.join(x))
dataframe[['sentence_copy','word_labels']].drop_duplicates().reset_index(drop=True)

Unnamed: 0,sentence_copy,word_labels
0,Overview,OTHER
1,Dragonite is a devastating sweeper with Supers...,"B-POKEMON,OTHER,OTHER,OTHER,OTHER,OTHER,B-MOVE..."
2,"It can select from great priority, additional ...","OTHER,OTHER,OTHER,OTHER,OTHER,OTHER,OTHER,OTHE..."
3,Dragonite's ability Multiscale gives it an eas...,"B-POKEMON,OTHER,B-ABILITY,OTHER,OTHER,OTHER,OT..."
4,A subpar Speed tier leaves Dragonite easily re...,"OTHER,OTHER,OTHER,OTHER,OTHER,B-POKEMON,OTHER,..."
...,...,...
62,"Additionally, many sets can set up Stealth Roc...","OTHER,OTHER,OTHER,OTHER,OTHER,OTHER,B-MOVE,I-M..."
63,Celesteela and Skarmory: Unless Dragonite is r...,"B-POKEMON,OTHER,B-POKEMON,OTHER,B-POKEMON,OTHE..."
64,Dragonite can attempt to take advantage of Cel...,"B-POKEMON,OTHER,OTHER,OTHER,OTHER,OTHER,OTHER,..."
65,Unaware: Most Unaware walls are capable of inf...,"B-ABILITY,OTHER,B-ABILITY,OTHER,OTHER,OTHER,OT..."


In [651]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_NORM_GRAD = 10
CHECKPOINT = 'bert-base-uncased'
TOKENIZER = BertTokenizerFast.from_pretrained(CHECKPOINT)



In [652]:
class dataset(Dataset):
    def __init__(self,dataframe,tokenizer,max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        # get sentences and world label
        sentence = self.data.sentence_copy[index].strip().split()
        # word_labels = self.data.word_labels[index].strip().split(',')
        word_labels = self.data.word_labels[index].split(',')

        encoding = self.tokenizer(sentence,
                                #   is_pretokenized=True,
                                  return_offsets_mapping=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)
        
        # create token labels only for first wordpieces of each tokenized word
        labels = [labels2ids[label] for label in word_labels]

        # create an empty array of -100 of length max_length
        encoded_labels = np.ones(len(encoding['offset_mapping']),dtype=int)*-100

        # overwrite only those labels whose zeroth offset position is equal to zero but not first 
        i = 0
        for idx,mapping in enumerate(encoding['offset_mapping']):
            if mapping[0] == 0 and mapping[1] != 0:
                # overwrite label
                encoded_labels[idx] = labels[i]
                i+=1

        # convert to pytorch tensors
        item = {key:torch.as_tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.as_tensor(encoded_labels)
        
        return item

    def __len__(self):
        return self.len

In [653]:
train_size = 0.8
train_dataset = dataframe.sample(frac=train_size,random_state=200)
test_dataset = dataframe.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

train_set = dataset(train_dataset,TOKENIZER,MAX_LEN)
test_set = dataset(test_dataset,TOKENIZER,MAX_LEN)

In [654]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle':True,
                'num_workers':0}
test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle':True,
                'num_workers':0}

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    # for item in batch:
    #     print(item)
    batch_input_ids = [item['input_ids'] for item in batch]
    padded_input_ids = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
    print("Input IDs ---> ", padded_input_ids.shape)

    batch_token_type_ids = [item['token_type_ids'] for item in batch]
    padded_token_type_ids = pad_sequence(batch_token_type_ids, batch_first=True, padding_value=0)
    print("Token Type IDs ---> ", padded_token_type_ids.shape)

    batch_attention_mask = [item['attention_mask'] for item in batch]
    padded_attention_mask = pad_sequence(batch_attention_mask, batch_first=True, padding_value=0)
    print("Attention Mask ---> ", padded_attention_mask.shape)

    batch_offset_mapping = [item['offset_mapping'] for item in batch]
    padded_offset_mapping = pad_sequence(batch_offset_mapping, batch_first=True, padding_value=0)
    print('Offset Mapping ---> ',padded_offset_mapping.shape)

    batch_labels = [item['labels'] for item in batch]
    padded_labels = pad_sequence(batch_labels, batch_first=True, padding_value=0)
    print("Labels ---> ",padded_labels.shape)

    # batch_token_type_ids = [item['token_type_ids'] for item in batch]
    # batch_attention_mask = [item['attention_mask'] for item in batch]
    # batch_offset_mapping = [item['offset_mapping'] for item in batch]
    # batch_labels = [item['labels'] for item in batch]

    # return {'input_ids': padded_input_ids}
    return {'input_ids':padded_input_ids,
            'token_type_ids':padded_token_type_ids,
            'attention_mask':padded_attention_mask,
            'offset_mapping':padded_offset_mapping,
            'labels':padded_labels
            }

train_loader = DataLoader(train_set,**train_params,collate_fn=collate_fn)
test_loader = DataLoader(test_set,**test_params,collate_fn=collate_fn)

In [655]:
# for k, v in train_loader.dataset[132].items():
#     print(f"{k} ---> {len(v)}")
#     print(f"{k} ---> {v}")


In [656]:
model = BertForTokenClassification.from_pretrained(CHECKPOINT,num_labels=len(labels2ids))
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [657]:
# train_set[0]['input_ids'].unsqueeze(0).shape,train_set[0]['input_ids'].shape
OPTIMIZER = torch.optim.AdamW(params=model.parameters(),lr=LEARNING_RATE)

In [658]:
def train(epoch):
    train_loss, train_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    train_preds, train_labels = [], []
    # put model into training mode
    model.train()

    for idx, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].to(device,dtype=torch.long)
        # print(batch)
        attention_mask = batch['attention_mask'].to(device,dtype=torch.long)
        labels = batch['labels'].to(device,dtype=torch.long)
        print(model(input_ids=input_ids,attention_mask=attention_mask,labels=labels))
        loss, train_logits = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels)
        train_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)

        if idx%100==0:
            loss_step = train_loss/nb_tr_steps
            print(f"Training loss per 100 training steps: {loss_step}")

        #compute training accuracy
        flattened_targets = labels.view(-1) # shape(batch_size*seq_len)
        active_logits = train_logits.view(-1,model.num_labels) # shape(batch_size*seq_len, num_labels)
        flattened_preds = torch.argmax(active_logits,axis=1) # shape(batch_size*seq_len)

        # only compute accuracy at active labels
        active_accuracy = labels.view(-1) != -100 #  # shape(batch_size*seq_len)
        
        labels = torch.masked_select(flattened_targets,active_accuracy)
        preds = torch.masked_select(flattened_preds,active_accuracy)

        train_labels.extend(labels)
        train_preds.extend(preds)

        tmp_tr_acc += accuracy_score(labels.cpu().numpy(),preds.cpu().numpy())
        train_accuracy += tmp_tr_acc

        # gradient clipping
        torch.nn.utils.clip_grad_norm(
            parameters=model.parameters(),max_norm=MAX_NORM_GRAD
            )
        
        # backward pass
        OPTIMIZER.zero_grad()
        loss.backward()
        OPTIMIZER.step()

    epoch_loss = train_loss/nb_tr_steps
    train_accuracy = train_accuracy/nb_tr_steps

    print(f"Training loss epoch: {epoch_loss}")
    print(f"Training accuracy epoch: {train_accuracy}")

In [659]:
# train_loader.dataset[0].keys()

In [660]:
for epoch in range(EPOCHS):
    print(f"Training epoch: {epoch+1}")
    train(epoch)

Training epoch: 1
Input IDs --->  torch.Size([4, 34, 128])
Token Type IDs --->  torch.Size([4, 34, 128])
Attention Mask --->  torch.Size([4, 34, 128])
Offset Mapping --->  torch.Size([4, 34, 128, 2])
Labels --->  torch.Size([4, 34])


ValueError: too many values to unpack (expected 2)

In [None]:
def valid(model, test_loader):
    model.eval()

    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []

    with torch.no_grad():
        for idx, batch in enumerate(test_loader):
            input_ids = batch['input_ids'].to(device,dtype=torch.long)
            attention_mask = batch['attention_mask'].to(device,dtype=torch.long)
            labels = batch['labels'].to(device,dtype=torch.long)

            loss, eval_logits = model(input_ids=input_ids,attention_mask=attention_mask,labels=labels)
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)

            if idx % 100 == 0:
                loss_step = eval_loss/nb_eval_steps
                print(f"Validation loss per 100 evaliuation steps {loss_step}")

            # compute evalutaion accuracy
            flattened_targets = labels.view(-1) # shape(batch_size*seq_len)
            active_logits = eval_logits.view(-2, model.num_labels) # shape(batch_size*seq_len, num_labels)
            flattened_preds = torch.argmax(active_logits,axis=1) # shape(batch_size*seq_len)

            # only compute accuracy at active labels
            active_accuracy = labels.view(-1) != -100 # shape(batch_size, seq_len)

            labels = torch.masked_select(flattened_targets,active_accuracy)
            preds = torch.masked_select(flattened_preds,active_accuracy)

            eval_labels.extend(labels)
            eval_preds.extend(preds)

            tmp_eval_accuracy = accuracy_score(labels.cpu().numpy(),preds.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy

        labels = [ids2labels [id.item()] for id in eval_labels]
        preds = [ids2labels [id.item()] for id in eval_preds]

        eval_loss = eval_loss/nb_eval_steps
        eval_accuracy = eval_accuracy/nb_eval_steps

        print(f"Validation loss: {eval_loss}")
        print(f"Validation accuracy: {eval_accuracy}")

        return labels, preds

In [None]:
labels, preds = valid(model,test_loader)

In [None]:
from seqeval.metrics import classification_report
print(classification_report(labels, preds))

In [None]:
sentence = """
Tornadus has several other options over Taunt in the third moveslot. Rain Dance and Sunny Day can help support Pokemon like Palafin and Flutter Mane, respectively. Icy Wind, meanwhile, can double down on speed control while dealing some chip damage and is very helpful against opposing Tailwind teams.
"""

In [None]:
inputs = TOKENIZER(sentence.split(),
                   is_pretokinzed=True,
                   return_offsets_mapping=True,
                   padding='max_length',
                   truncation=True,
                   max_length=MAX_LEN,
                   return_tensors='pt')

# move to gpu
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
labels = inputs['labels'].to(device)

# forward pass
outputs = model(input_ids,attention_mask)
logits = outputs[0]

active_logits = logits.view(-1, model.num_labels) # shape(batch_size*seq_len, num_labels)
flattened_preds = torch.argmax(active_logits,axis=1) # shape(batch_size*seq_len)

tokens = TOKENIZER.convert_ids_to_tokens(input_ids.squeeze().tolist())
token_preds = [ids2labels[i] for i in flattened_preds.cpu().numpy()]

wp_preds = list(zip(tokens, token_preds)) # list of tuples. Each tuple = (wordpiece, prediction)

preds = []
for token_pred, mapping in zip(wp_preds,inputs['offset_mapping'].squeeze().tolist()):
    # only predictions on first word pieces are important
    if mapping[0] == 0 and mapping[1] != 0:
        preds.append(token_pred[1])
    else:
        continue

print(sentence.split())
print(preds)