In [1]:
import os
import pandas as pd 
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader
import transformers
import torch.nn as nn
from tqdm import tqdm
from prepare_invoice_ner_dataset import label_idx_dict
from prepare_invoice_ner_dataset import split_tokenize_label_dataset, split_tokenize_label_file, form_input
from print_ner_tag import print_ner_labels, print_ner_labels_detokenized


In [2]:
test_split_path = './split/test.txt'

from transformers import AutoTokenizer, BertTokenizer
bert_path = '../kaggle_ner/huggingface-bert/bert-base-uncased/'
#model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(bert_path, do_lower_case=True)

In [3]:
config = {'MAX_LEN':128,
          'tokenizer': tokenizer,
          'batch_size':32,
          'Epoch': 1,
          'device': 'cuda' if torch.cuda.is_available() else 'cpu',
          'model_name':'model1_bert_base_uncased_3_epochs.bin'
         }

In [4]:
model = transformers.BertForTokenClassification.from_pretrained('bert-base-uncased',  num_labels = len(label_idx_dict))
model = nn.DataParallel(model)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

In [5]:
#model.load_state_dict(torch.load(config['model_name']))
#model = model.to(config['device'])


In [6]:
mode = model.eval()

In [7]:
final_test_split_and_tokenized_file_list, final_test_split_and_tokenized_labels, final_test_split_word_id_list, final_test_split_token_ids_list = split_tokenize_label_dataset(test_split_path, tokenizer, overlap=0)

test_prod_input = form_input(final_test_split_and_tokenized_file_list, final_test_split_and_tokenized_labels, 
                              final_test_split_token_ids_list, final_test_split_word_id_list, config, data_type='test')


In [8]:
invoice_number = 1

if_estimation = True
estimation_list = []
gt_list = []
print("Sample Invoice {}".format(invoice_number))
print()
for sample_idx in range(len(test_prod_input)):
    #print("Estimation")
    
    
    #print("sample_idx: {}".format(sample_idx))
    dataset = test_prod_input[sample_idx]

    batch_input_ids = torch.unsqueeze(dataset['ids'], 0).to(config['device'], dtype = torch.long)
    batch_att_mask = torch.unsqueeze(dataset['att_mask'], 0).to(config['device'], dtype = torch.long)
    batch_tok_type_id = torch.unsqueeze(dataset['tok_type_id'], 0).to(config['device'], dtype = torch.long)
    batch_target = torch.unsqueeze(dataset['target'], 0).to(config['device'], dtype = torch.long)
    
    bert_tokens = dataset['bert_tokens']
    bert_token_word_ids = dataset['bert_token_word_ids']
    
    output = model(batch_input_ids, 
                       token_type_ids=None,
                       attention_mask=batch_att_mask,
                       #labels=batch_target
              )
    sample_ner_result = output['logits'].detach().numpy()[0]
    sample_estimated_tags = np.argmax(sample_ner_result, -1)
    sample_tag_labels = test_prod_input[sample_idx]['target'].numpy()   
    
    sample_estimated_tag_labels = np.array([list(label_idx_dict.keys())[label_idx] for label_idx in list(sample_estimated_tags)])
    sample_ground_truth_tag_labels = np.array([list(label_idx_dict.keys())[label_idx] for label_idx in list(sample_tag_labels)])
        
    estimation_list.append((bert_tokens, sample_estimated_tag_labels[:len(bert_tokens)], bert_token_word_ids))
    gt_list.append((bert_tokens, sample_ground_truth_tag_labels[:len(bert_tokens)], bert_token_word_ids))
    
    if(len(bert_tokens)!=128):     
        print("Estimation")
        for estimation_tuple in estimation_list:
            print_ner_labels_detokenized(estimation_tuple[0], estimation_tuple[1], estimation_tuple[2])
        print()
        print()
        print("Ground truth")
        for gt_tuple in gt_list:
            print_ner_labels_detokenized(gt_tuple[0], gt_tuple[1], gt_tuple[2])
        print()
        print()
        print()
        
        gt_list = []
        estimation_list = []
        invoice_number += 1
        

        if(invoice_number==6):
            break
            
        print("Sample Invoice {}".format(invoice_number))

Sample Invoice 1

Estimation
[31mtax[0m [35minvoice[0m [35minvoice[0m [44mno[0m [44m.[0m [44msm/18[0m [33m-[0m [30m19/01159[0m [44mdelivery[0m [44mnote[0m [44msm/18[0m [44m-[0m [30m19/01159[0m [42msupplier[0m [44m's[0m [43mref[0m [44m.[0m [42m([0m [44mk[0m [44m)[0m [44mdated[0m [30m17-may-18[0m [44mmode[0m [44m/[0m [35mterms[0m [31mof[0m [44mpayment[0m [44mother[0m [44mreference(s[0m [44m)[0m [44mshall[0m [31mmarketing[0m [44mshop[0m [44mno.6[0m [30mllyods[0m [44mchamber[0m [44mbarne[0m [44mroad[0m [30m409[0m [44mmangalwar[0m [44mpeth[0m [44mpune[0m [44mgstin[0m [44m/[0m [44muin[0m [44m:[0m [30m27adxps3921j1zt[0m [30mstate[0m [31mname[0m [34m:[0m [30mmaharashtra[0m [32mcode[0m [34m:[0m [42m27[0m [44mcontact[0m [41m:[0m [44m20[0m [33m-[0m [32m-[0m [30m26052747[0m [31me[0m [41m-[0m [41mmail[0m [45m:[0m [31mshahmarketing4@gmail.com[0m [44mblyer[0m [35msiddha

Estimation
[31mgst[0m [35minvoice[0m [42m([0m [44moriginal[0m [36mfor[0m [44mrecipient[0m [41m)[0m [44mm[0m [35minvoice[0m [44mno[0m [41m.[0m [41me[0m [41m-[0m [41mway[0m [44mbill[0m [44mno[0m [41m.[0m [44msm/20[0m [30m-[0m [30m21/01746[0m [30m2.5122e+11[0m [44mdelivery[0m [44mnote[0m [44mdated[0m [30m07-sep-20[0m [44mmode[0m [44m/[0m [31mterms[0m [44mof[0m [44mpayment[0m [31mshah[0m [31mmarketing[0m [45m([0m [31mmangalwar[0m [34mpeth[0m [41m)[0m [41m([0m [42mbath[0m [41mdivision[0m [41m)[0m [44mllyods[0m [44mchambers[0m [44mshop[0m [31mno.5[0m [30m&[0m [30m6[0m [30m409[0m [44mmangalwar[0m [43mpeth[0m [35mbarne[0m [44mroad[0m [45mpune[0m [44mgstin[0m [44m/[0m [44muin[0m [44m:[0m [30m27adxps3921j1zt[0m [31mstate[0m [31mname[0m [34m:[0m [34mmaharashtra[0m [44mcode[0m [34m:[0m [42m27[0m [31mcontact[0m [44m:[0m [44m20[0m [44m-[0m [42m26052647[0m [44m/[

Estimation
[31mgst[0m [35minvoice[0m [42m([0m [44moriginal[0m [44mfor[0m [44mrecipient[0m [41m)[0m [35minvoice[0m [31mno[0m [41m.[0m [44msm/19[0m [30m-[0m [30m20/00500[0m [44mdelivery[0m [44mnote[0m [44me[0m [41m-[0m [41mway[0m [44mbill[0m [44mno[0m [44m.[0m [44mdated[0m [30m22-apr-19[0m [44mmode[0m [44m/[0m [31mterms[0m [31mof[0m [44mpayment[0m [44mother[0m [44mreference(s[0m [44m)[0m [31mshah[0m [44mmarketing[0m [42m([0m [31mmangalwar[0m [34mpeth[0m [44m)[0m [44mshop[0m [44mno.6[0m [44mllyods[0m [44mchamber[0m [44mbarne[0m [44mroad[0m [30m409[0m [44mmangalwar[0m [44mpeth[0m [45mpune[0m [44mgstin[0m [44m/[0m [44muin[0m [44m:[0m [30m27adxps3921j1zt[0m [30mstate[0m [31mname[0m [34m:[0m [30mmaharashtra[0m [31mcode[0m [34m:[0m [42m27[0m [44mcontact[0m [41m:[0m [30m20[0m [33m-[0m [30m-[0m [30m26052747[0m [31me[0m [41m-[0m [42mmail[0m [45m:[0m [31mshahma

Estimation
[30maviva[0m [31mdoors[0m [31mfactory[0m [44maddress:-[0m [43mgat[0m [31mno.255b[0m [44mplot[0m [44mno.26/27[0m [31mjyotiba[0m [44mnagar[0m [31mtalawade[0m [31mpune-411062[0m [31mregistered[0m [44moffice[0m [34m:[0m [41m-[0m [33me-83[0m [44madinath[0m [44msociety[0m [35mpune[0m [45m-[0m [35msatara[0m [45mroad[0m [35mpune[0m [44m-[0m [35m411037[0m [44mcontact[0m [41m:[0m [41m-[0m [36m20[0m [41m-[0m [30m24269505[0m [44mmob:-[0m [41m91[0m [30m8446380632[0m [30m-[0m [44memail[0m [44m:[0m [41m-[0m [30mavivadoor@gmail.com[0m [42mtax[0m [35minvoice[0m [44moriginal[0m [35minvoice[0m [44mno[0m [44m.[0m [44m:[0m [30mti-356/19[0m [30m-[0m [36m20[0m [44mbuyer[0m [44mparty[0m [44mm[0m [44m/[0m [36ms[0m [44m.[0m [42m:[0m [31msiddharth[0m [31mproperties[0m [36m.[0m [44meden[0m [44mcourt[0m [44mmodel[0m [31mcolony[0m [31mpune[0m [42mdate[0m [32m:[0m [33m13-0

Estimation
[30maviva[0m [31mdoors[0m [31mfactory[0m [44maddress:-[0m [43mgat[0m [31mno.255b[0m [44mplot[0m [44mno.26/27[0m [31mjyotiba[0m [44mnagar[0m [31mtalawade[0m [31mpune-411062[0m [43mistered[0m [44moffice[0m [41m:[0m [41m-[0m [33me-83[0m [44madinath[0m [44msociety[0m [42mpune[0m [45m-[0m [45msatara[0m [44mroad[0m [35mpune[0m [44m-[0m [35m411037[0m [44mcontact[0m [41m:[0m [41m-[0m [36m20[0m [41m-[0m [30m24269505[0m [41mmob:-[0m [41m91[0m [30m8446380632[0m [30m-[0m [44memail[0m [44m:[0m [41m-[0m [30mavivadoor@gmail.com[0m [31mtax[0m [44minvoice[0m [44moriginal[0m [44mbuye[0m [44minvoice[0m [31mno[0m [44m.[0m [30m:[0m [30mti-500/18[0m [41m-[0m [31m19[0m [31mm[0m [41m/[0m [31ms[0m [44m.[0m [44mdate[0m [30m:[0m [33m25-12-2018[0m [41mdharth[0m [44mreal[0m [43mventures[0m [44mllp[0m [44m.[0m [44mce[0m [44mno-60[0m [31m1[0m [31m&[0m [30m602[0m [36mmccl