# Custom NER

In [1]:
# !pip install spacy

In [2]:
import pandas as pd

In [3]:
import plac
import random
import warnings
from pathlib import Path
import spacy
from spacy.util import minibatch, compounding


In [4]:
data = pd.read_csv('./label_text_product_attrs.csv', dtype=str)

In [5]:
data.rename(columns={'sale_price': 'ppu', 'final_price': 'total_price'}, inplace=True)

In [6]:
row = data.iloc[0]
row

image_name                                        img_1_padded_0.jpg
text               04902430779746 NX DOWNY doahoa ngotngao\nVAT10...
sku                                                      4.90243E+12
product_name                               NX DOWNY doahoa ngotngao 
quantity                                                           1
ppu                                                        88,000.00
total_price                                                88,000.00
discounted_part                                                  NaN
original_price                                            103,000.00
Name: 0, dtype: object

In [7]:
print(row['text'])

04902430779746 NX DOWNY doahoa ngotngao\nVAT10%   1    88,000.00    88,000.00 \nGia goc:   103,000.00 


In [8]:
attributes = data.columns.difference(['image_name', 'text'])

In [9]:
attributes

Index(['discounted_part', 'original_price', 'ppu', 'product_name', 'quantity',
       'sku', 'total_price'],
      dtype='object')

In [10]:
a = 'ABC12345ABC12AAABC'
def find_2nd(string, substring):
   return string.find(substring, string.find(substring) + 1)
find_2nd(a, '1')

sub = 'ABC'
import re
[i for i in range(len(a)) if a.startswith(sub, i)]


[0, 8, 15]

In [11]:
a[14:]

'AABC'

In [12]:
import numpy as np

In [13]:
def get_row_entities(row, debug=False):
    row_entities = []
    cache = []
    text = row['text']
    for attr in ['sku', 'product_name', 'quantity', 'ppu', 'total_price', 'discounted_part', 'original_price']:
        if debug: print(attr)
        if not pd.isnull(row[attr]):
            value = str(row[attr])
            
            value = value.strip()
            
            if not value in text:
                print(attr, value, 'not in text')
                            
            else:
                try:
                    indices = [i for i in range(len(text)) if text.startswith(value, i)]
                    if debug: print(f'org indices = {indices}')
                    
                    indices_ = indices.copy()
                    for item in indices_:
                        if any(item in cache_ for cache_ in cache):
                            indices.remove(item)
                    if debug: print(f'indices = {indices}')
                    start_index = indices[0]                   
                    
                    end_index = start_index + len(value)
                    cache.append(range(start_index, end_index))
                    if debug: print(f'cache = {cache}')
                    if start_index < 0:
                        print(start_index)
                    if end_index < 0:
                        print(end_index)
                    row_entities.append((start_index, end_index, attr))
                except Exception as e:
                    print(row)
                    print(text)
                    print('Error', row_entities, value)
                    raise e
                    
    for i, item in enumerate(row_entities):
        if item[-1] == 'product_name':
            product_name_range = item[0:-1]
            product_name_id = i

    if debug: print('Product name ranges', product_name_range)

    splits = row.text[product_name_range[0]: product_name_range[1]].split(' ')

    ranges = []
    for item in splits:
        ranges.append((row.text.find(item), row.text.find(item) + len(item), 'product_name'))

    row_entities.pop(product_name_id)
    row_entities += ranges
    
    def is_overlapped(entities):
        ranges = [range(item[0], item[1]) for item in entities]
        if debug: print('Total ranges', ranges)
        return len(reduce(lambda x, y: set(x).intersection(y), ranges)) > 0
    
    from functools import reduce
    
    assert not is_overlapped(row_entities)
    return row_entities

In [14]:
row = data.iloc[56]

In [16]:
entities = get_row_entities(row, debug=True)
entities

sku
sku 4.90243E+12 not in text
product_name
org indices = [15]
indices = [15]
cache = [range(15, 39)]
quantity
org indices = [9, 44, 50, 55, 69, 93]
indices = [9, 44, 50, 55, 69, 93]
cache = [range(15, 39), range(9, 10)]
ppu
org indices = [55, 69]
indices = [55, 69]
cache = [range(15, 39), range(9, 10), range(55, 65)]
total_price
org indices = [55, 69]
indices = [69]
cache = [range(15, 39), range(9, 10), range(55, 65), range(69, 79)]
discounted_part
original_price
org indices = [93]
indices = [93]
cache = [range(15, 39), range(9, 10), range(55, 65), range(69, 79), range(93, 103)]
Product name ranges (15, 39)
Total ranges [range(9, 10), range(55, 65), range(69, 79), range(93, 103), range(15, 17), range(18, 23), range(24, 30), range(31, 39)]


[(9, 10, 'quantity'),
 (55, 65, 'ppu'),
 (69, 79, 'total_price'),
 (93, 103, 'original_price'),
 (15, 17, 'product_name'),
 (18, 23, 'product_name'),
 (24, 30, 'product_name'),
 (31, 39, 'product_name')]

In [17]:
print(row)
print()
for i, j, name in entities:
    print(f"{name}: {row['text'][i:j]}")

image_name                                       img_16_padded_1.jpg
text               04902430418287 NX DOWNY doahoa ngotngao\nVAT10...
sku                                                      4.90243E+12
product_name                               NX DOWNY doahoa ngotngao 
quantity                                                           1
ppu                                                       129,000.00
total_price                                               129,000.00
discounted_part                                                  NaN
original_price                                            159,000.00
Name: 56, dtype: object

quantity: 1
ppu: 129,000.00
total_price: 129,000.00
original_price: 159,000.00
product_name: NX
product_name: DOWNY
product_name: doahoa
product_name: ngotngao


In [18]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, test_size=0.2, random_state=42)


In [19]:
TRAIN_DATA = []
for index, row in train.iterrows():
    TRAIN_DATA.append((row['text'], {"entities": get_row_entities(row)}))

sku 4.90243E+12 not in text
sku 8.9385E+12 not in text
sku 8.9351E+12 not in text
sku 8.93471E+12 not in text
sku 8.93851E+12 not in text
sku 4.90243E+12 not in text
sku 8.93605E+12 not in text
sku 8.9361E+12 not in text
sku 8.93473E+12 not in text
sku 8.93601E+12 not in text
sku 8.93487E+12 not in text
sku 8.93601E+12 not in text
sku 4.90243E+12 not in text
sku 2.00013E+12 not in text
sku 4.90243E+12 not in text
sku 8.85113E+12 not in text
sku 8.85205E+12 not in text
sku 4.90243E+12 not in text
sku 8.69122E+12 not in text
sku 8.9348E+12 not in text
sku 8.93482E+12 not in text
sku 8.93602E+12 not in text
sku 4.90243E+12 not in text
sku 8.93511E+12 not in text
sku 8.93604E+12 not in text
sku 8.93507E+12 not in text
sku 8.93468E+12 not in text
sku 8.93487E+12 not in text
sku 8.93522E+12 not in text
sku 8.93607E+12 not in text
sku 8.93501E+12 not in text
sku 8.93603E+12 not in text
sku 8.9385E+12 not in text
sku 4.90243E+12 not in text
sku 8.85113E+12 not in text
sku 8.93506E+12 not in te

In [19]:
def trim_entity_spans(data: list) -> list:

    """Removes leading and trailing white spaces from entity spans.

    Args:
    data (list): The data to be cleaned in spaCy JSON format.
|
    Returns:
    list: The cleaned data.
    """
    invalid_span_tokens = re.compile(r'\s')

    cleaned_data = []
    for text, annotations in data:
        entities = annotations['entities']
        valid_entities = []
        for start, end, label in entities:
            valid_start = start
            valid_end = end
            # if there's preceding spaces, move the start position to nearest character
            while valid_start < len(text) and invalid_span_tokens.match(
                    text[valid_start]):
                valid_start += 1
            while valid_end > 1 and invalid_span_tokens.match(
                    text[valid_end - 1]):
                valid_end -= 1
            valid_entities.append([valid_start, valid_end, label])
        cleaned_data.append([text, {'entities': valid_entities}])
    return cleaned_data


In [20]:
trim_entity_spans(TRAIN_DATA[:1])

[['04902430779746  NX DOWNY doahoa ngotngao\\nVAT10%   1    38,053.21    38,053.21 \\nSo tien giam tuong ung:   49,946.79 \\nGia goc:   103,000.00 ',
  {'entities': [[0, 14, 'sku'],
    [45, 46, 'quantity'],
    [56, 65, 'ppu'],
    [69, 78, 'total_price'],
    [107, 116, 'discounted_part'],
    [130, 140, 'original_price'],
    [16, 18, 'product_name'],
    [19, 24, 'product_name'],
    [25, 31, 'product_name'],
    [32, 40, 'product_name']]}]]

In [21]:
def train_spacy(TRAIN_DATA):
    nlp = spacy.blank('en')  # create blank Language class
    # create the built-in pipeline components and add them to the pipeline
    # nlp.create_pipe works for built-ins that are registered with spaCy
    if 'ner' not in nlp.pipe_names:
        ner = nlp.create_pipe('ner')
        nlp.add_pipe(ner, last=True)


    # add labels
    for _, annotations in TRAIN_DATA:
         for ent in annotations.get('entities'):
            ner.add_label(ent[2])

    # get names of other pipes to disable them during training
    other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'ner']
    with nlp.disable_pipes(*other_pipes):  # only train NER
        optimizer = nlp.begin_training()
        for itn in range(10):
            print("Statring iteration " + str(itn))
            random.shuffle(TRAIN_DATA)
            losses = {}
            for text, annotations in TRAIN_DATA:
                nlp.update(
                [text],  # batch of texts
                [annotations],  # batch of annotations
                drop=0.2,  # dropout - make it harder to memorise data
                sgd=optimizer,  # callable to update weights
                losses=losses)
            print(losses)
    return nlp


In [22]:
def main(model=None, output_dir='.', n_iter=100):
    """Load the model, set up the pipeline and train the entity recognizer."""
    if model is not None:
        nlp = spacy.load(model)  # load existing spaCy model
        print("Loaded model '%s'" % model)
    else:
        nlp = spacy.blank("en")  # create blank Language class
        print("Created blank 'en' model")

    # create the built-in pipeline components and add them to the pipeline
    # nlp.create_pipe works for built-ins that are registered with spaCy
    if "ner" not in nlp.pipe_names:
        ner = nlp.create_pipe("ner")
        nlp.add_pipe(ner, last=True)
    # otherwise, get it so we can add labels
    else:
        ner = nlp.get_pipe("ner")

    # add labels
    for _, annotations in TRAIN_DATA:
        for ent in annotations.get("entities"):
            ner.add_label(ent[2])

    # get names of other pipes to disable them during training
    pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"]
    other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
    # only train NER
    with nlp.disable_pipes(*other_pipes), warnings.catch_warnings():
        try:
            # show warnings for misaligned entity spans once
            warnings.filterwarnings("once", category=UserWarning, module='spacy')

            # reset and initialize the weights randomly – but only if we're
            # training a new model
            if model is None:
                nlp.begin_training()
            for itn in range(n_iter):
                random.shuffle(TRAIN_DATA)
                losses = {}
                # batch up the examples using spaCy's minibatch
                batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
                for batch in batches:
                    texts, annotations = zip(*batch)
                    nlp.update(
                        texts,  # batch of texts
                        annotations,  # batch of annotations
                        drop=0.5,  # dropout - make it harder to memorise data
                        losses=losses,
                    )
                print("Losses", losses)
        except Exception as e:
            print(texts, annotations)
            raise e
#     # test the trained model
#     for text, _ in TRAIN_DATA:
#         doc = nlp(text)
#         print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
#         print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])

    # save model to output directory
    if output_dir is not None:
        output_dir = Path(output_dir)
        if not output_dir.exists():
            output_dir.mkdir()
        nlp.to_disk(output_dir)
        print("Saved model to", output_dir)



In [23]:
nlp = spacy.load('en_core_web_sm')

docs = []
for text, annot in TRAIN_DATA:

    doc = nlp(text)

    tags = spacy.gold.biluo_tags_from_offsets(doc, annot['entities'])
    
    print(np.array(doc))
    print(np.array(tags))
    break

OSError: [E050] Can't find model 'en_core_web_sm'. It doesn't seem to be a shortcut link, a Python package or a valid path to a data directory.

In [48]:
print(TRAIN_DATA[0][0][46:47])
TRAIN_DATA[0]

0


('04902430779746  NX DOWNY doahoa ngotngao\\nVAT10%   1    38,053.21    38,053.21 \\nSo tien giam tuong ung:   49,946.79 \\nGia goc:   103,000.00 ',
 {'entities': [(0, 14, 'sku'),
   (45, 46, 'quantity'),
   (56, 65, 'ppu'),
   (69, 78, 'total_price'),
   (107, 116, 'discounted_part'),
   (130, 140, 'original_price'),
   (16, 18, 'product_name'),
   (19, 24, 'product_name'),
   (25, 31, 'product_name'),
   (32, 40, 'product_name')]})

In [22]:
# nlp = train_spacy(trim_entity_spans(TRAIN_DATA))

In [23]:
for i, row in enumerate(TRAIN_DATA):
    if row[0].startswith('08936034200116'):
        print(i)

239


In [24]:
main()

Created blank 'en' model


  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)
  gold = GoldParse(doc, **gold)


('00000002954821 S-ca basa nc bo dau tam u\\nVAT10%    70,000.00    34,440.00 \\n0.492 KG', '04902430415637 NX DOWNY doahoa thom mat\\nVAT10%   2    129,900.00    259,800.00 \\nGia goc:   159,000.00 ', '04902430647250 BG Tide huong Downy 3.8kg\\nVAT10%   1    127,000.00    127,000.00 \\nGia goc:   153,000.00 ', '00000002954486  Pate heo kg-BMQ\\nVAT10%    187,900.00    37,580.00 \\n0.200 KG') ({'entities': [(0, 14, 'sku'), (77, 85, 'quantity'), (52, 61, 'ppu'), (65, 74, 'total_price'), (15, 19, 'product_name'), (20, 24, 'product_name'), (25, 27, 'product_name'), (28, 30, 'product_name'), (31, 34, 'product_name'), (35, 38, 'product_name'), (33, 34, 'product_name')]}, {'entities': [(0, 14, 'sku'), (50, 51, 'quantity'), (55, 65, 'ppu'), (69, 79, 'total_price'), (93, 103, 'original_price'), (15, 17, 'product_name'), (18, 23, 'product_name'), (24, 30, 'product_name'), (31, 35, 'product_name'), (36, 39, 'product_name')]}, {'entities': [(0, 14, 'sku'), (45, 46, 'quantity'), (56, 66, 'ppu'), (

ValueError: [E103] Trying to set conflicting doc.ents: '(31, 34, 'product_name')' and '(33, 34, 'product_name')'. A token can only be part of one entity, so make sure the entities you're setting don't overlap.

In [20]:
TEST_DATA = []
for index, row in test.iterrows():
    TEST_DATA.append((row['text'], {"entities": get_row_entities(row)}))

sku 4.90243E+12 not in text
sku 8.93467E+12 not in text
sku 4.90243E+12 not in text
sku 4.90243E+12 not in text
sku 8.93604E+12 not in text
sku 8.9351E+12 not in text
sku 8.93607E+12 not in text
sku 4.90243E+12 not in text
sku 8.93614E+12 not in text
sku 8.9348E+12 not in text
sku 8.93601E+12 not in text
sku 8.93466E+12 not in text
sku 8.93851E+12 not in text
sku 8.93604E+12 not in text
sku 4.90243E+12 not in text
sku 8.93511E+12 not in text
sku 4.90243E+12 not in text
sku 8.93511E+12 not in text
sku 8.9385E+12 not in text
sku 4.90243E+12 not in text
sku 8.93604E+12 not in text
sku 4.90243E+12 not in text
sku 8.93603E+12 not in text
sku 4.90243E+12 not in text
sku 8.93487E+12 not in text
sku 4.90243E+12 not in text
sku 8.93458E+12 not in text
sku 8.93487E+12 not in text
sku 8.93505E+12 not in text
sku 4.90243E+12 not in text
sku 8.93606E+12 not in text
sku 8.9348E+12 not in text
sku 4.90243E+12 not in text
sku 8.93601E+12 not in text
sku 4.90243E+12 not in text
sku 8.93502E+12 not in t

In [24]:
text = test.iloc[0].text

In [26]:
text

'04902430418287 NX DOWNY doahoa ngotngao\\nVAT10%   2    129,900.00    259,800.00 \\nGia goc:   159,000.00 '

In [25]:
output_dir = '../pretrained_models/NER'

print("Loading from", output_dir)
nlp2 = spacy.load(output_dir)
# for text, _ in TEST_DATA:
doc = nlp2(text)
print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])


Loading from ../pretrained_models/NER
Entities [('04902430418287', 'sku'), ('NX', 'product_name'), ('DOWNY', 'product_name'), ('doahoa', 'product_name'), ('2', 'quantity'), ('129,900.00', 'ppu'), ('259,800.00', 'total_price'), ('159,000.00', 'original_price')]
Tokens [('04902430418287', 'sku', 3), ('NX', 'product_name', 3), ('DOWNY', 'product_name', 3), ('doahoa', 'product_name', 3), ('ngotngao\\nVAT10', '', 2), ('%', '', 2), ('  ', '', 2), ('2', 'quantity', 3), ('   ', '', 2), ('129,900.00', 'ppu', 3), ('   ', '', 2), ('259,800.00', 'total_price', 3), ('\\nGia', '', 2), ('goc', '', 2), (':', '', 2), ('  ', '', 2), ('159,000.00', 'original_price', 3)]
