In [158]:
import torch
import os
import gc
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
from sklearn.model_selection import train_test_split
import torchtext.data
import spacy
from torchtext.data import Field, RawField, BucketIterator, Example, Dataset
import matplotlib.pyplot as plt

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x107cc9af0>

In [4]:
#filepaths
clean_img_path = 'clean_img'
clean_txt_path = 'clean_txt'
clean_entity_path = 'clean_entity'

In [277]:
#collect all records into a list
records = []
clean_imgs = torch.load(os.path.join(clean_img_path, 'clean.pt'))

for i, name in enumerate(os.listdir(clean_txt_path)):
    name, __ = os.path.splitext(name)
    img = clean_imgs[i]
    txt_filepath = os.path.join(clean_txt_path, name + '.txt')
    txt_file = open(txt_filepath, 'rb')
    text = txt_file.read().decode('latin1')
    txt_file.close()
    ent_filepath = os.path.join(clean_entity_path, name + '.pkl')
    #print(ent_filepath)
    #ent_file = open(ent_filepath, 'rb')
    try:
        with open(ent_filepath, 'rb') as my_file:
            unpickler = pickle.Unpickler(my_file)
            entities = unpickler.load()
    except EOFError:
        print(ent_filepath)
        print('An EOFError exception occurred. The file is empty')
    #entities = pickle.load(ent_file)
    #record = [img, text, entities]
    record = [text, entities]
    records.append(record)

In [278]:
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15
train_data, temp_data = train_test_split(records, test_size=(1 - train_ratio))
val_data, test_data = train_test_split(temp_data, test_size=test_ratio / (val_ratio + test_ratio))
print("Training set len:", len(train_data))
print("Validation set len:", len(val_data))
print("Test set len:", len(test_data))

Training set len: 681
Validation set len: 146
Test set len: 146


In [245]:
train_data[0]

["3180303 LIAN HING STATIONERY SDN BHD (162761-M) NO.32 & 33,JALAN SR 1/9. SEKSYEN 9. TAMAN SERDANG RAYA, 43300 SERI KEMBANGAN, SELANGOR DARUL EHSAN GST ID : 002139201536 TAX INVOICE 27/03/2018 NO.: CS-20242 QTY TAX RM DURSFILE H399(110 X 95MM) 100 SR 58.30 NAME BADGE (H) @ 0.5500 809 METAL NAME BADGE CLIP 1 SR 21.20 100'S @ 20.0000 TOTAL AMT INCL. GST @ 6%: 79.50 ROUNDING ADJUSTMENT: TOTAL AMT PAYABLE: 79.50 PAID AMOUNT: 100.00 CHANGE: 20.50 TOTAL QTY TENDER: 101 GST SUMMARY AMOUNT TAX (RM) (RM) SR @ A 75.00 4.50 TOTAL 75.00 4.50 THANK YOU FOR ANY ENQUIRY, PLEASE CONTACT US;",
 ['LIAN HING STATIONERY SDN BHD',
  '27/03/2018',
  'NO.32 & 33, JALAN SR 1/9, SEKSYEN 9, TAMAN SERDANG RAYA, 43300 SERI KEMBANGAN, SELANGOR DARUL EHSAN',
  '79.50']]

In [264]:
char_vocab = [' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
              'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
              '`', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '-', '=', '~', '!', '@', '#', '$', '%', '^', '&', '*', '(', ')', '_', '+',
              '[', ']', '\\', ';', '\'', ',', '.', '?', '{', '}', '|', ':', '\"', '<', '>', '?', '/', 'Â', '·']

def index_to_char(i):
    return char_vocab[i]

def char_to_index(c):
    return char_vocab.index(c)

In [275]:
#probabilities of [start name, end name, start date, end date, start address, end address, start total, end total]
#return tensor of ground truth probabilities; end index is inclusive
def get_ground_probs(text, ents, length):
    name_start = torch.zeros(length)
    name_end = torch.zeros(length)
    date_start = torch.zeros(length)
    date_end = torch.zeros(length)
    address_start = torch.zeros(length)
    address_end = torch.zeros(length)
    total_start = torch.zeros(length)
    total_end = torch.zeros(length)
    name_start_index = text.index(ents[0])
    name_end_index = name_start_index + len(ents[0]) - 1
    date_start_index = text.index(ents[1])
    date_end_index = date_start_index + len(ents[1]) - 1
    if text.find(ents[2]) == -1:
        print("text: ", text)
        print("addr: ", ents[2])
    address_start_index = text.index(ents[2])
    address_end_index = date_start_index + len(ents[2]) - 1
    total_start_index = text.index(ents[3])
    total_end_index = total_start_index + len(ents[3]) - 1
    name_start[name_start_index] = 1
    name_end[name_end_index] = 1
    date_start[date_start_index] = 1
    date_end[date_end_index] = 1
    address_start[address_start_index] = 1
    address_end[address_end_index] = 1
    total_start[total_start_index] = 1
    total_end[total_end_index] = 1
    return torch.stack([name_start, name_end, date_start, date_end, address_start, address_end, total_start, total_end])

def get_batch_ground_probs(batch_text, batch_ents): #[batch, 8*char]
    max_len = len(batch_text[-1])
    ans = []
    for text, ents in zip(batch_text, batch_ents):
        assert(len(text) > 0)
        ans.append(get_ground_probs(text, ents, max_len))
    return torch.stack(ans).view(batch_text.size(0), -1)

In [270]:
def text_to_onehot(text, length):
    ident = torch.eye(len(char_vocab))
    ans = []
    for char in text:
        ans.append(ident[char_to_index(char)])
    for i in range(length - len(text)):
        ans.append(ident[char_to_index(' ')])
    assert(len(ans) == length)
    return torch.stack(ans)

def get_batch_numerical(batch_data):
    #for data in batch_data:
    #    assert(len(data[0]) > 0)
    batch_text = [data[0] for data in batch_data]
    batch_ents = [data[1] for data in batch_data]
    max_len = len(batch_text[-1])
    batch_enc_text = torch.stack([text_to_onehot(text, max_len) for text in batch_text])
    batch_probs = get_batch_ground_probs(batch_text, batch_ents)
    return batch_enc_text, batch_probs

def get_data_loader(data, batch_size):
    #for d in data:
    #    assert(len(d[0]) > 0)
    sorted_data = sorted(data, key=lambda x: len(x[0]))
    ans = []
    for i in range(0, len(sorted_data), batch_size):
        batch_data = sorted_data[i:i+batch_size]
        ans.append(get_batch_numerical(batch_data))
    return ans

In [249]:
class EXT(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EXT, self).__init__()
        self.emb = nn.Embedding(input_size, input_size)
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 8)

    def forward(self, x):
        x = self.emb(x)
        h0 = torch.zeros(1, len(x), self.hidden_size)  # Initial hidden state
        out, _ = self.rnn(x, h0) #[batch index, char index, vocab index]
        out = self.fc(out) #[batch index, char index, 8]
        out = torch.transpose(out, 1, 2) #[batch index, 8, char index]
        out = F.softmax(out, dim=2).view(out.size(0), -1) #[batch, 8*char]
        return out #The output_probs tensor contains probabilities for each class at each position in the sequence

In [250]:
class TransformerEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(TransformerEncoder, self).__init__()
        self.linear_q = nn.Linear(input_size, hidden_size)
        self.linear_k = nn.Linear(input_size, hidden_size)
        self.linear_v = nn.Linear(input_size, hidden_size)
        self.linear_x = nn.Linear(input_size, hidden_size)
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=4, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.norm = nn.LayerNorm(hidden_size)
    
    def forward(self, x):
        q, k, v = self.linear_q(x), self.linear_k(x), self.linear_v(x)
        x = self.norm(self.linear_x(x) + self.attention(q, k, v))
        x = self.norm(x + self.fc(x))
        return x

class EXT2(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EXT2, self).__init__()
        self.emb = nn.Embedding(input_size, input_size)
        self.hidden_size = hidden_size
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        self.transformer = TransformerEncoder(input_size, hidden_size)  # Include TransformerEncoder here
        self.fc = nn.Linear(hidden_size, 8)

    def forward(self, x):
        x = self.emb(x)
        h0 = torch.zeros(1, len(x), self.hidden_size)  # Initial hidden state
        out_rnn, _ = self.rnn(x, h0)  # GRU output
        out_transformer = self.transformer(x)  # Transformer output
        out_combined = out_rnn + out_transformer  # Combine outputs
        out = self.fc(out_combined)  # Apply linear layer
        out = torch.transpose(out, 1, 2) #[batch index, 8, char index]
        out = F.softmax(out, dim=2).view(out.size(0), -1) #[batch, 8*char]
        return out

In [251]:
def cer(reference, hypothesis):
    # Convert the sentences into character lists
    ref = list(reference)
    hyp = list(hypothesis)

    # Create a matrix of size (len(ref)+1) x (len(hyp)+1)
    d = np.zeros((len(ref) + 1) * (len(hyp) + 1), dtype=np.uint8)
    d = d.reshape((len(ref) + 1, len(hyp) + 1))

    # Initialize the first row and column to be the distance from the empty string
    for i in range(len(ref) + 1):
        d[i][0] = i
    for j in range(len(hyp) + 1):
        d[0][j] = j

    # Populate the rest of the matrix
    for i in range(1, len(ref) + 1):
        for j in range(1, len(hyp) + 1):
            if ref[i - 1] == hyp[j - 1]:
                cost = 0
            else:
                cost = 1
            d[i][j] = min(d[i - 1][j] + 1,      # deletion
                          d[i][j - 1] + 1,      # insertion
                          d[i - 1][j - 1] + cost)  # substitution

    # The CER is the cost of transforming hypothesis into reference divided by the number of characters in the reference
    cer_value = float(d[len(ref)][len(hyp)]) / len(ref)

    return cer_value

def car(reference, hypothesis):
    return 1 - cer(reference, hypothesis)

In [252]:
def get_model_name(name, batch_size, learning_rate, epoch):
    return "model_{}_bs{}_lr{}_epoch{}".format(name, batch_size, learning_rate, epoch)

#based on CER
def get_accuracy(model, data_itr):
    total = 0
    total_car = 0 #sum of char acc rate over all samples
    total_ear = 0 #sum of entity acc rate over all samples
    total_corr = 0 #sum of sample acc rate over all samples
    for batch_text, batch_ents in data_itr:
        batch_output = model(batch_text).view(batch_output.size(0), 8, -1) #[batch, 8, char]
        batch_output = torch.argmax(batch_output, dim=2) #[batch, 8]

        pred_names = [text[output[0]:output[1]+1] for text, output in zip(batch_text, batch_output)] #+1 since string slicing is exclusive
        pred_dates = [text[output[2]:output[3]+1] for text, output in zip(batch_text, batch_output)]
        pred_addresses = [text[output[4]:output[5]+1] for text, output in zip(batch_text, batch_output)]
        pred_totals = [text[output[6]:output[7]+1] for text, output in zip(batch_text, batch_output)]

        total += len(batch_text)
        total_car += sum([car(ents[0]+ents[1]+ents[2]+ents[3], pred_name+pred_date+pred_address+pred_total)
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)])
        total_ear += sum([(int(ents[0]==pred_name) + int(ents[1]==pred_date) + int(ents[2]==pred_address) + int(ents[3]==pred_total))/4.0
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)
        ])
        total_corr += sum([int(ents[0]+ents[1]+ents[2]+ents[3] == pred_name+pred_date+pred_address+pred_total)
                          for ents, pred_name, pred_date, pred_address, pred_total in
                          zip(batch_ents, pred_names, pred_dates, pred_addresses, pred_totals)
        ])
    char_acc = total_car / total
    ent_acc = total_ear / total
    sample_acc = total_corr / total
    return char_acc, ent_acc, sample_acc

In [253]:
def get_loss(model, data_itr, criterion):
    loss = 0
    total = 0
    with torch.no_grad():
        for batch_text, batch_ents in data_itr:
            batch_output = model(batch_text)
            ground_probs = get_batch_ground_probs(batch_text, batch_ents) #[batch, 8*char]
            loss += criterion(batch_output, ground_probs)
            total += batch_output.size(0)
    return loss / total

In [254]:
def train_model(model, train_set, val_set, batch_size=32, num_epochs=5, learning_rate=1e-5):
    # DataLoaders for train and validation sets
    # Create BucketIterator
    train_itr = get_data_loader(train_set, batch_size)
    val_itr = get_data_loader(val_set, batch_size)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    train_loss, val_loss, train_acc, valid_acc = [], [], [], []
    epochs = []

    for epoch in range(num_epochs):
        for i, (batch_text, batch_ents) in enumerate(train_itr):
            optimizer.zero_grad()
            pred = model(batch_text)
            loss = criterion(pred, get_batch_ground_probs(batch_text, batch_ents))
            loss.backward()
            optimizer.step()

        train_loss.append(get_loss(model, train_itr, criterion))
        val_loss.append(get_loss(model, val_itr, criterion))

        epochs.append(epoch)
        train_acc.append(get_accuracy(model, train_itr))
        valid_acc.append(get_accuracy(model, val_itr))
        print("Epoch %d; Train Loss %f; Val Loss %f Train Acc %f; Val Acc %f" % (
            epoch+1, train_loss[-1], val_loss[-1], train_acc[-1], valid_acc[-1]))

    # plotting
    plt.title("Training Curve")
    plt.plot(epochs, train_loss, label="Train")
    plt.plot(epochs, val_loss, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    plt.title("Training Curve")
    plt.plot(epochs, train_acc, label="Train")
    plt.plot(epochs, valid_acc, label="Validation")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show()

In [279]:
george = EXT(len(char_vocab), 8*len(char_vocab))

train_model(george, train_data, val_data)

text:  RESTAURANT SIN DU K3-113,JL IBRAHIM SULTAN 80300 JOHOR BAHRU JOHOR H/P: 019-7521215 016-7867868 09/03/2018 21:28 0001 000000#7259 CASHIER01 DPT.05 RM 149.00 DPT.04 RM 21.00 CASH RM 170.00
addr:  K3-113, JL IBRAHIM SULTAN 80300 JOHOR BAHRU JOHOR


ValueError: substring not found