In [1]:
import torch
import numpy as np
import os
import gc
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
from sklearn.model_selection import train_test_split
import torchtext.data
import spacy
from torchtext.data import Field, RawField, BucketIterator, Example, Dataset
import matplotlib.pyplot as plt
from torchmetrics.text import CharErrorRate

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x1094d1ad0>

In [3]:
#filepaths
clean_img_path = 'clean_img'
clean_txt_path = 'clean_txt'
clean_entity_path = 'clean_entity'

In [4]:
#collect all records into a list
records = []
clean_imgs = torch.load(os.path.join(clean_img_path, 'clean.pt'))

for i, name in enumerate(os.listdir(clean_txt_path)):
    name, __ = os.path.splitext(name)
    img = clean_imgs[i]
    txt_filepath = os.path.join(clean_txt_path, name + '.txt')
    txt_file = open(txt_filepath, 'rb')
    text = txt_file.read().decode('latin1')
    txt_file.close()
    ent_filepath = os.path.join(clean_entity_path, name + '.pkl')
    #print(ent_filepath)
    #ent_file = open(ent_filepath, 'rb')
    try:
        with open(ent_filepath, 'rb') as my_file:
            unpickler = pickle.Unpickler(my_file)
            entities = unpickler.load()
    except EOFError:
        print(ent_filepath)
        print('An EOFError exception occurred. The file is empty')
    #entities = pickle.load(ent_file)
    #record = [img, text, entities]
    record = [text, entities]
    records.append(record)

In [5]:
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15
train_data, temp_data = train_test_split(records, test_size=(1 - train_ratio))
val_data, test_data = train_test_split(temp_data, test_size=test_ratio / (val_ratio + test_ratio))
print("Training set len:", len(train_data))
print("Validation set len:", len(val_data))
print("Test set len:", len(test_data))

Training set len: 681
Validation set len: 146
Test set len: 146


In [6]:
train_data[0]

['POTTERS GARDEN SDN BHD (1153774-D) BATU 11 . SG BULOH , 47000 SELANGOR . TEL : 016-667 0982 , 016-333 3812 GST REG NO : 000392024064 TAX INVOICE TABLE : 01 POS ID : G INV NO : G/00002268 CASHIER: ADMINISTRATOR INV DT: 04/01/2018 12:29:18 PM RM ITEM AMOUNT AA00007 PLASTIC PLATE AA00022 CHINA POT 2 @7.90 AA00016 PLANT AA00016 PLANT 5 SUB TOTAL 1.00 15.80 16.00 7.00 39.80 SR SR SR SR NET TOTAL CASH CHANGE 39.80 40.00 0.20 TAX SUMMARY SR INCLUSIVE GST 6% AMOUNT 37.54 TAX 2.26 *** THANK YOU PLEASE COME AGAIN *** ITEM SOLD ARE NOT REFUNDABLE & EXCHANGEABLE',
 ['POTTERS GARDEN SDN BHD',
  '04/01/2018',
  'BATU 11 . SG BULOH , 47000 SELANGOR .',
  '39.80']]

In [7]:
char_vocab = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
              'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
              '`', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '-', '=', '~', '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '_', '+',
              '[', ']', '\\', ';', '\'', ',', '.', '?', '{', '}', '|', ':', '\"', '<', '>', '?', '/', 'Â', '·', 'Ã', '\x83', '\x82']

def index_to_char(i):
    return char_vocab[i]

def char_to_index(c):
    return char_vocab.index(c)

In [8]:
def initial_space_count(s):
    ans = 0
    for c in s:
        if c == ' ':
            ans += 1
        else:
            break
    return ans

def get_index_ignoring_spaces(main_string, substring):
    # Remove spaces from both main string and substring for comparison
    main_string_no_spaces = main_string.replace(" ", "")
    substring_no_spaces = substring.replace(" ", "")

    # Get the index where the substring (ignoring spaces) occurs in the modified main string
    index = main_string_no_spaces.find(substring_no_spaces)
    end_index = index + len(substring_no_spaces)

    #print("index: ", index)
    #print("end index: ", end_index)

    if index != -1:
        # Calculate the adjusted index considering spaces in the original string
        non_space_count = 0
        adjusted_index = 0
        adjusted_end_index = 0
        for i in range(len(main_string)):
            if non_space_count < index:
                adjusted_index += 1;
            else:
                break
            if main_string[i] != ' ':
                non_space_count += 1;
        non_space_count = 0
        for i in range(len(main_string)):
            if non_space_count < end_index:
                adjusted_end_index += 1;
            else:
                break
            if main_string[i] != ' ':
                non_space_count += 1;
        adjusted_index += initial_space_count(main_string[adjusted_index:])
        return adjusted_index, adjusted_end_index
    else:
        return -1, -1

In [9]:
#probabilities of [start name, end name, start date, end date, start address, end address, start total, end total]
#return tensor of ground truth probabilities; end index is inclusive
def get_ground_probs(text, ents, length):
    name_start = torch.zeros(length)
    name_end = torch.zeros(length)
    date_start = torch.zeros(length)
    date_end = torch.zeros(length)
    address_start = torch.zeros(length)
    address_end = torch.zeros(length)
    total_start = torch.zeros(length)
    total_end = torch.zeros(length)
    name_start_index, name_end_index = get_index_ignoring_spaces(text, ents[0])
    name_end_index -= 1
    date_start_index, date_end_index = get_index_ignoring_spaces(text, ents[1])
    date_end_index -= 1
    address_start_index, address_end_index = get_index_ignoring_spaces(text, ents[2])
    address_end_index -= 1
    total_start_index, total_end_index = get_index_ignoring_spaces(text, ents[3])
    total_end_index -= 1
    name_start[name_start_index] = 1
    name_end[name_end_index] = 1
    date_start[date_start_index] = 1
    date_end[date_end_index] = 1
    address_start[address_start_index] = 1
    if text.find(ents[2]) == -1:
        print("text: ", text)
        print("addr: ", ents[2])
    address_end[address_end_index] = 1
    total_start[total_start_index] = 1
    total_end[total_end_index] = 1
    return torch.stack([name_start, name_end, date_start, date_end, address_start, address_end, total_start, total_end])

def get_batch_ground_probs(batch_text, batch_ents): #[batch, 8*char]
    max_len = len(batch_text[-1])
    ans = []
    for text, ents in zip(batch_text, batch_ents):
        assert(len(text) > 0)
        ans.append(get_ground_probs(text, ents, max_len))
    return torch.stack(ans).view(len(batch_text), -1)

In [10]:
def text_to_onehot(text, length):
    ident = torch.eye(len(char_vocab))
    ans = []
    for char in text:
        ans.append(ident[char_to_index(char)])
    for i in range(length - len(text)):
        ans.append(ident[char_to_index(' ')])
    assert(len(ans) == length)
    return torch.stack(ans)

def get_batch_numerical(batch_data):
    #for data in batch_data:
    #    assert(len(data[0]) > 0)
    batch_text = [data[0] for data in batch_data]
    batch_ents = [data[1] for data in batch_data]
    max_len = len(batch_text[-1])
    batch_enc_text = torch.stack([text_to_onehot(text, max_len) for text in batch_text])
    batch_probs = get_batch_ground_probs(batch_text, batch_ents)
    return batch_enc_text, batch_probs

def get_data_loader(data, batch_size):
    #for d in data:
    #    assert(len(d[0]) > 0)
    sorted_data = sorted(data, key=lambda x: len(x[0]))
    ans = []
    for i in range(0, len(sorted_data), batch_size):
        batch_data = sorted_data[i:i+batch_size]
        batch_text = [d[0] for d in batch_data]
        batch_ents = [d[1] for d in batch_data]
        batch_enc_text, ground_probs = get_batch_numerical(batch_data)
        ans.append((batch_enc_text, ground_probs, batch_text, batch_ents))
    return ans

In [27]:
class EXT(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EXT, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 8)

    def forward(self, x):
        #print(x.shape)
        assert(torch.is_tensor(x))
        h0 = torch.zeros(1, len(x), self.hidden_size)  # Initial hidden state
        out, _ = self.rnn(x, h0) #[batch index, char index, vocab index]
        out = self.fc(out) #[batch index, char index, 8]
        out = torch.transpose(out, 1, 2) #[batch index, 8, char index]
        out = F.softmax(out, dim=2).view(out.size(0), -1) #[batch, 8*char]
        return out #The output_probs tensor contains probabilities for each class at each position in the sequence

In [28]:
class TransformerEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(TransformerEncoder, self).__init__()
        self.linear_q = nn.Linear(input_size, hidden_size)
        self.linear_k = nn.Linear(input_size, hidden_size)
        self.linear_v = nn.Linear(input_size, hidden_size)
        self.linear_x = nn.Linear(input_size, hidden_size)
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=4, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        q, k, v = self.linear_q(x), self.linear_k(x), self.linear_v(x)
        temp = self.attention(q, k, v)
        x = self.norm(self.linear_x(x) + temp[0])
        x = self.norm(x + self.fc(x))
        return x

class EXT2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EXT2, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.transformer = TransformerEncoder(input_size, hidden_size)  # Include TransformerEncoder here
        self.fc = nn.Linear(hidden_size, 8)

    def forward(self, x):
        h0 = torch.zeros(1, len(x), self.hidden_size, device=x.device)  # Initial hidden state
        out_rnn, _ = self.rnn(x, h0)  # GRU output
        out_transformer = self.transformer(x)  # Transformer output
        out_combined = out_rnn + out_transformer  # Combine outputs
        out = self.fc(out_combined)  # Apply linear layer
        out = torch.transpose(out, 1, 2) #[batch index, 8, char index]
        out = F.softmax(out, dim=2).view(out.size(0), -1) #[batch, 8*char]
        return out

In [33]:
cer = CharErrorRate()

def car(reference, hypothesis):
    return 1 - cer(reference, hypothesis).item()

In [14]:
def get_model_name(name, batch_size, learning_rate, epoch):
    return "model_{}_bs{}_lr{}_epoch{}".format(name, batch_size, learning_rate, epoch)

#based on CER
def get_accuracy(model, data_itr):
    total = 0
    total_car = 0 #sum of char acc rate over all samples
    total_ear = 0 #sum of entity acc rate over all samples
    total_corr = 0 #sum of sample acc rate over all samples
    for batch_enc_text, ground_probs, batch_text, batch_ents in data_itr:
        batch_output = model(batch_enc_text).view(len(batch_enc_text), 8, -1) #[batch, 8, char]
        batch_output = torch.argmax(batch_output, dim=2) #[batch, 8]

        pred_names = [text[output[0]:output[1]+1] for text, output in zip(batch_text, batch_output)] #+1 since string slicing is exclusive
        pred_dates = [text[output[2]:output[3]+1] for text, output in zip(batch_text, batch_output)]
        pred_addresses = [text[output[4]:output[5]+1] for text, output in zip(batch_text, batch_output)]
        pred_totals = [text[output[6]:output[7]+1] for text, output in zip(batch_text, batch_output)]

        total += len(batch_text)
        total_car += sum([car(ents[0]+ents[1]+ents[2]+ents[3], pred_name+pred_date+pred_address+pred_total)
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)])
        total_ear += sum([(int(ents[0]==pred_name) + int(ents[1]==pred_date) + int(ents[2]==pred_address) + int(ents[3]==pred_total))/4.0
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)
        ])
        total_corr += sum([int(ents[0]+ents[1]+ents[2]+ents[3] == pred_name+pred_date+pred_address+pred_total)
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)
        ])
    char_acc = total_car / total
    ent_acc = total_ear / total
    sample_acc = total_corr / total
    return char_acc, ent_acc, sample_acc

In [35]:
def get_pred(model, text):
    enc_text = text_to_onehot(text, len(text)).unsqueeze(0)
    output = model(enc_text).view(8, -1)
    output = torch.argmax(output, dim=1).squeeze(0) #[8]

    pred_name = text[output[0]:output[1]+1]
    pred_date = text[output[2]:output[3]+1]
    pred_address = text[output[4]:output[5]+1]
    pred_total = text[output[6]:output[7]+1]
    return pred_name, pred_date, pred_address, pred_total

In [16]:
def get_loss(model, data_itr, criterion):
    loss = 0
    total = 0
    with torch.no_grad():
        for batch_enc_text, ground_probs, batch_text, batch_ents in data_itr:
            batch_output = model(batch_enc_text)
            loss += criterion(batch_output, ground_probs)
            total += batch_output.size(0)
    return loss / total

In [17]:
def train_model(model, model_name, train_set, val_set, batch_size=32, num_epochs=5, learning_rate=1e-5):
    # DataLoaders for train and validation sets
    # Create BucketIterator
    train_itr = get_data_loader(train_set, batch_size)
    val_itr = get_data_loader(val_set, batch_size)
    print("Created data loaders")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    train_loss, val_loss, train_char_acc, val_char_acc, train_ent_acc, val_ent_acc, train_sample_acc, val_sample_acc = [], [], [], [], [], [], [], []
    epochs = []
    iterations = 0

    for epoch in range(num_epochs):
        for i, (batch_enc_text, ground_probs, __, __) in enumerate(train_itr):
            optimizer.zero_grad()
            pred = model(batch_enc_text.to(torch.float32))
            loss = criterion(pred, ground_probs)
            loss.backward()
            optimizer.step()

            iterations += 1
            if iterations % 20 == 0:
                print("iterations: ", iterations)

        train_loss.append(get_loss(model, train_itr, criterion))
        val_loss.append(get_loss(model, val_itr, criterion))

        epochs.append(epoch)
        tca, tea, tsa = get_accuracy(model, train_itr)
        vca, vea, vsa = get_accuracy(model, val_itr)
        train_char_acc.append(tca)
        train_ent_acc.append(tea)
        train_sample_acc.append(tsa)
        val_char_acc.append(vca)
        val_ent_acc.append(vea)
        val_sample_acc.append(vsa)
        torch.save(model.state_dict(), get_model_name(model_name, batch_size, learning_rate, epoch))
        print("Epoch %d; Train Loss %f; Val Loss %f Train Char Acc %f; Val Char Acc %f Train Ent Acc %f; Val Ent Acc %f Train Sample Acc %f; Val Sample Acc %f" % (
            epoch+1, train_loss[-1], val_loss[-1], train_char_acc[-1], val_char_acc[-1], train_ent_acc[-1], val_ent_acc[-1], train_sample_acc[-1], val_sample_acc[-1]))

    #save losses/accs
    stats = [train_loss, val_loss, train_char_acc, val_char_acc, train_ent_acc, val_ent_acc, train_sample_acc, val_sample_acc]
    stats_file = open(model_name + str(epoch) + '.pkl', 'wb')
    pickle.dump(stats, stats_file)
    stats_file.close()

    # plotting
    plt.title("Training Loss Curve")
    plt.plot(epochs, train_loss, label="Train")
    plt.plot(epochs, val_loss, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    plt.title("Training Char Acc Curve")
    plt.plot(epochs, train_char_acc, label="Train")
    plt.plot(epochs, val_char_acc, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()

    plt.title("Training Ent Acc Curve")
    plt.plot(epochs, train_ent_acc, label="Train")
    plt.plot(epochs, val_ent_acc, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()

    plt.title("Training Sample Acc Curve")
    plt.plot(epochs, train_sample_acc, label="Train")
    plt.plot(epochs, val_sample_acc, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()

In [18]:
#hilbert = EXT(len(char_vocab), 2*len(char_vocab))
#train_model(hilbert, 'hilbert', train_data, val_data, 64, 20, 1e-2)

In [19]:
thomas = EXT2(len(char_vocab), 8*len(char_vocab))

train_model(thomas, 'thomas', train_data, val_data, 64, 40, 2e-4)

KeyboardInterrupt: 

In [29]:
hannah = EXT(len(char_vocab), 4*len(char_vocab))
hannah.load_state_dict(torch.load('model_hannah_bs64_lr0.015_epoch39'))

<All keys matched successfully>

In [30]:
test_data[0][0]

'AMANO MALAYSIA SDN BHD (682288-V) 12 JALAN PENGACARA U1/48 TEMASYA INDUSTRIAL PARK 40150 SHAH ALAM SELANGOR TEL: 03-55695002/5003 (GST ID: 001137704960) TAX INVOICE P/S #02 RM3.00 A INV-NO.0002300417000138 T/D #11 TICKET NO.029190 ENTRY TIME PAID TIME PARKING TIME 30/04/2017 (SUN) 19:44 30/04/2017 (SUN) 23:14 RM3.00 RM0.00 THANK YOU INCLUSIVE 6% GST 3:30 TYPE RATE A RM3.00 PARKING FEE GST(INCLUDED) 6.00 % RM2.83 RM0.17 TOTAL PAID CHANGE'

In [52]:
pred_name, pred_date, pred_address, pred_total = get_pred(hannah, test_data[0][0])
print("predicted entities: ")
print("name: ", pred_name)
print("date: ", pred_date)
#print("address: ", pred_address)
print("total: ", pred_total)
print("actual entities: ")
print("name: ", test_data[0][1][0])
print("date: ", test_data[0][1][1])
#print("address: ", test_data[0][1][2])
print("total: ", test_data[0][1][3])

predicted entities: 
name:  AMANO MALAYSIA SDN BHD
date:  001137704960) TAX INVOICE P/S #02 RM3.00 A INV-NO.0002300417000138 T/D #11 TICKET NO.029190 ENTRY TIME PAID TIME PARKING TIME 30/04/2017
total:  
actual entities: 
name:  AMANO MALAYSIA SDN BHD
date:  30/04/2017
total:  RM3.00


In [54]:
s = '3-1707067 F&P PHARMACY (002309592-P) NO.20, GROUNDFLOOR, SELANGOR DARUL EHSAN GST Reg NO 001880666112 JALAN BS 10/6 TAMAN BUKIT SERDANG, SEKSYEN 10, 43300 SERI KEMBANGAN, TEL 03-89599823 TAX INVOICE Doci No Cashier Salesperson 955789210525F CS00110840 F&P Date 02/03/2018 Time. 16.46.00 Ref (GST) 6.00 600 430 380 6.50 530 (GST) Amount Tax 6.00 SR 600 ZRL 430 ZRL 380 SR 650 SR 530 3190 30.68 0.00 122 0.00 31.90 50.00 18.10 Item 1486 Qty 1 1 1 1 1 S/Price S/Price 5.66 600 430 3.58 613 500 HOMECARE GASCOAL 50MG P.P NAPROXEN NA 275 MG YELLOWLOTION 30 MI. PANADOL SOLUBLE TABLET PMS GAUZE BANDAGE 5CM X 4M 9557837400035 1014 1155 95506104 DETTOL 50 ML Total Qty SR 6 Total Sales (Excluding GST) Discount Totai GST Rounding Total Sales (Inclusive of GST) : CASH : Change : GST SUMMARY Tax Code SR ZRL % 6 0 Total: : Amt (RM) 2038 1030 30.68 Tax (RM) 1.22 0.00 1.22 GOODS SOLD ARE NOT RETURNABLE & EXCHANGABLE, THANK YOU.'
entities = ["F&P PHARMACY", "02/03/2018", "NO.20, GROUND FLOOR, JALAN BS 10/6 TAMAN BUKIT SERDANG, SEKSYEN 10, 43300 SERI KEMBANGAN. SELANGOR DARUL EHSAN", "31.90"]

In [55]:
dylan = EXT2(len(char_vocab), 4*len(char_vocab))
dylan.load_state_dict(torch.load('models/model_dylan_bs10_lr0.0015_epoch39', map_location=torch.device('cpu')))

<All keys matched successfully>

In [56]:
pred_name, pred_date, pred_address, pred_total = get_pred(dylan, s)
print("predicted entities: ")
print("name: ", pred_name)
print("date: ", pred_date)
#print("address: ", pred_address)
print("total: ", pred_total)
print("actual entities: ")
print("name: ", entities[0])
print("date: ", entities[1])
#print("address: ", entities[2])
print("total: ", entities[3])

predicted entities: 
name:  3-1707067 F&P PHARMACY
date:  02/03/2018
total:  95578374
actual entities: 
name:  F&P PHARMACY
date:  02/03/2018
total:  31.90


In [53]:
pred_name, pred_date, pred_address, pred_total = get_pred(dylan, test_data[0][0])
print("predicted entities: ")
print("name: ", pred_name)
print("date: ", pred_date)
#print("address: ", pred_address)
print("total: ", pred_total)
print("actual entities: ")
print("name: ", test_data[0][1][0])
print("date: ", test_data[0][1][1])
#print("address: ", test_data[0][1][2])
print("total: ", test_data[0][1][3])

predicted entities: 
name:  AMANO MALAYSIA SDN BHD
date:  30/04/2017
total:  RM3.00
actual entities: 
name:  AMANO MALAYSIA SDN BHD
date:  30/04/2017
total:  RM3.00


In [58]:
test_itr = get_data_loader(test_data, 32)
get_accuracy(dylan, test_itr)

(0.8307773194607501, 0.797945205479452, 0.410958904109589)