# Model Prediction

In [1]:
from transformers import AutoTokenizer, BertForTokenClassification # Import AutoTokenizer and BertForTokenClassification from the transformers library for NLP tasks.
import torch # Import the PyTorch library for tensor computations and deep learning.
import numpy as np # Import NumPy for numerical operations and array manipulations.
import argparse # Import argparse for parsing command-line arguments.
from typing import List # Import List from the typing module for type annotations.
from config import Config # Import Config class from the config module, used for loading and accessing configuration settings.
# Import utility functions: read_labels (to read label data), get_label_map and get_inv_label_map (for mapping labels to indices and vice versa).
from utils import read_labels, get_label_map, get_inv_label_map
import argparse # Re-import argparse (duplicate import, not necessary).
import sys # Import sys for interacting with the Python interpreter (e.g., command-line arguments, system exit).
import dill

  from pandas.core import (


In [13]:
sentences = []
labels = []

curr_sentence = []
curr_labels = []

with open("TestingData.txt", "r") as file:
    for line in file:
        if line != "\n":
            label = line.split()[0]
            word = line.split()[1]
            
            curr_sentence.append(word)
            curr_labels.append(label)
        else:
            sentences.append(curr_sentence)
            labels.append(curr_labels)
            curr_sentence = []
            curr_labels = []
            
print("DONE!")           

DONE!


In [14]:
class NERPredictor:
    def __init__(self, model_path: str):
        self.cfg = Config() # Initialize and load configuration settings from the Config class.
        
        # Read the label list from the specified file path.
        self.label_list = read_labels('NewEntities.txt')
        # Create mappings from labels to indices and vice versa.
        self.label_map = get_label_map(self.label_list)
        self.inv_label_map = get_inv_label_map(self.label_list)

        # Load the pre-trained BERT model for token classification.
        self.model = BertForTokenClassification.from_pretrained(
            self.cfg.MODEL_NAME,
            return_dict=True,
            num_labels=len(self.label_map),
            output_attentions=False,
            output_hidden_states=False
        )

        # Load the saved model weights.
        self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        # Load the tokenizer associated with the pre-trained BERT model.
        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.MODEL_NAME)

    def predict(self, sentences: str) -> List[str]:
        
        foundDecimal = False
        
        # Tokenize the input sentence to get input IDs.
        input_ids = self.tokenizer.encode(sentences, return_tensors='pt')
        # print(len(input_ids[0]), input_ids)
        with torch.no_grad(): # Disable gradient calculations for inference.
            self.model.to('cpu') # Ensure the model is on CPU for inference.
            # Get model predictions for the input IDs.
            output = self.model(input_ids)

        # Convert model logits to label indices.
        label_indices = np.argmax(output.logits.to('cpu').numpy(), axis=2)
        # Convert input IDs back to tokens.
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])

        new_tokens, new_labels = [], []

        label_indices_0 = label_indices[0]
        for i in range(len(tokens)):
            token = tokens[i]
            label_idx = label_indices_0[i]
            # Merge subword tokens that start with "##".
            # print(token, label_idx, "DEBUG", i)
            
            if foundDecimal:
                # print("__________foundDecimal__________")
                token = "##" + tokens[i]
                foundDecimal = False
            if token.startswith(".") and tokens[i-1][-1].isdigit() and i != len(tokens)-2:
                # print("__________foundDecimal__________")
                foundDecimal = True
                token = "##" + tokens[i]
                
            if token == "%" or token == "٪" and i != 1:
                # print("________foundPercentage_________")
                token = "##" + tokens[i]
                
            if (token == "٬" or token == "٫") and (any(char.isdigit() for char in prev_token) and any(char.isdigit() for char in tokens[i+1])):
                # print("___________foundComma___________")
                token = "##" + tokens[i]
                foundDecimal = True
                            
            if token.startswith("##") :
                # print("_____________MERGE_______________")
                new_tokens[-1] = new_tokens[-1] + token[2:]
            else:
                if input_ids[0][i] == 2 or input_ids[0][i] == 3:
                    continue
                # Append the label for the token to new_labels.
                new_labels.append(self.inv_label_map[label_idx])
                # Append the token to new_tokens.
                new_tokens.append(token)
                
            prev_token = token
            
        # Return the list of labels corresponding to each token in the input.
        return new_labels
    
if __name__ == '__main__':
    predictor = NERPredictor(model_path='JuneModel.pt')

#     index = 320
    
#     predicted_labels = predictor.predict(' '.join(testSentences[index]))
#     print()
#     print(predicted_labels)
#     for token, label in zip(testSentences[index], predicted_labels):
#         print(label, token)
        
        
#     for i in range(0, len(testSentences[index])):
#         print(testSentences[index][i], predicted_labels[i])
        
#     for i in range(len(testSentences[index]), len(predicted_labels)):
#         print(predicted_labels[i])
        
#     print(len(predicted_labels))

    labelsPredictedArray = []
    for i in range(0, len(sentences)):
        predicted_labels = predictor.predict(' '.join(sentences[i]))
        labelsPredictedArray.append(predicted_labels)
        if i % 100 == 0:
            print(round((i/len(sentences)) * 100, 2), "%")
        
    print("100 %")


Some weights of BertForTokenClassification were not initialized from the model checkpoint at aubmindlab/bert-base-arabertv02 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


0.0 %
11.68 %
23.36 %
35.05 %
46.73 %
58.41 %
70.09 %
81.78 %
93.46 %
100 %


In [48]:
dill.dump_session('fixedBugs_env.db')

In [15]:
index = 676
print("{:<20} {:<20} {:<20}".format("Predicted (" + str(len(labelsPredictedArray[index])) + ")", "Word", "True (" + str(len(labels[index])) + ")"))
print("--------------------------------------------------------")


for i in range(0, len(labels[index])):
    print("{:<20} {:<20} {:<20}".format(labelsPredictedArray[index][i], sentences[index][i], labels[index][i]))
for i in range(len(labels[index]), len(labelsPredictedArray[index])):
    print(labelsPredictedArray[index][i])


Predicted (58)       Word                 True (58)           
--------------------------------------------------------
OUTSIDE              وتقدر                OUTSIDE             
OUTSIDE              إستهلاكات            OUTSIDE             
OUTSIDE              الطاقة               OUTSIDE             
OUTSIDE              المتجددة             OUTSIDE             
OUTSIDE              في                   OUTSIDE             
OUTSIDE              عام                  B-Date              
B-Date               2015                 I-Date              
OUTSIDE              حسب                  OUTSIDE             
OUTSIDE              ترتيب                OUTSIDE             
OUTSIDE              المساهمة             OUTSIDE             
OUTSIDE              حيث                  OUTSIDE             
OUTSIDE              تمثل                 OUTSIDE             
OUTSIDE              الكتلة               OUTSIDE             
OUTSIDE              الحيوية              OUTSIDE            

In [16]:
# Cell to check how many invalid true/predicted entries we have
z = 0
for i in range(0, len(labels)):
    if len(labels[i]) != len(labelsPredictedArray[i]):
        # print(len(trueLabels[i]), len(labelsPredictedArray[i]))
        z+=1
        print(i)
        # for j in range(0, len(trueLabels[i])):
            # print(trueLabels[i][j], labelsPredictedArray[i][j])
        # print(trueLabels[i], labelsPredictedArray[i])
        print()
    #     break
    # break

print("Z:", z)

Z: 0


In [1]:
import dill
dill.load_session('fixedBugs_env.db')