In [1]:
import torch
from torch import nn
from utils_2 import get_vocab, get_tags

In [2]:
def load_model_state(model, path):
    """
    Loads the model's parameters into a pre-defined architecture
    """
    model.load_state_dict(torch.load(path))
    model.eval()  # Set to evaluation mode
    device = torch.device("cuda:0")
    model.to(device)
    return model


class NER(nn.Module):
  def __init__(self, vocab_size=5, embedding_dim=50, hidden_size=50, n_classes=5):
    """
    The constructor of our NER model
    Inputs:
    - vacab_size: the number of unique words
    - embedding_dim: the embedding dimension
    - n_classes: the number of final classes (tags)
    """
    super(NER, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True)
    self.dropout = nn.Dropout(p=0.25)
    self.linear = nn.Linear(hidden_size * 2, n_classes) # x2 cuz bi-directional
    
  def forward(self, sentences):
    """
    This function does the forward pass of our model
    Inputs:
    - sentences: tensor of shape (batch_size, max_length)

    Returns:
    - final_output: tensor of shape (batch_size, max_length, n_classes)
    """
    embedding = self.embedding(sentences)
    lstm, _ = self.lstm(embedding)
    dropout = self.dropout(lstm)
    final_output = self.linear(dropout)
    return final_output

In [3]:
vocab    = get_vocab('processed_input/train_vocab.txt')

tags_ct  = get_tags('processed_input/train_complex_topping_tags.txt')  # Complex Topping tags
tags_pz  = get_tags('processed_input/train_pizza_orders_tags.txt')     # Pizza orders tags
tags_dr  = get_tags('processed_input/train_drink_orders_tags.txt')     # Drink orders tags
tags_ob  = get_tags('processed_input/train_orders_tags.txt')           # Order boundary tags


In [4]:
print(tags_ct)
print(tags_pz)
print(tags_dr)
print(tags_ob)

{'QUANTITY': 0, 'TOPPING': 1, 'TOPPING_S': 2, 'QUANTITY_S': 3, 'NONE': 4}
{'TOPPING_S': 0, 'COMPLEX_TOPPING_S': 1, 'NOT_STYLE_S': 2, 'STYLE': 3, 'NUMBER': 4, 'TOPPING': 5, 'SIZE': 6, 'NONE': 7, 'NOT_TOPPING': 8, 'NOT_TOPPING_S': 9, 'NUMBER_S': 10, 'NOT_COMPLEX_TOPPING_S': 11, 'STYLE_S': 12, 'NOT_COMPLEX_TOPPING': 13, 'COMPLEX_TOPPING': 14, 'SIZE_S': 15, 'NOT_STYLE': 16}
{'SIZE': 0, 'CONTAINERTYPE': 1, 'DRINKTYPE': 2, 'DRINKTYPE_S': 3, 'NONE': 4, 'CONTAINERTYPE_S': 5, 'VOLUME': 6, 'NUMBER': 7, 'VOLUME_S': 8, 'SIZE_S': 9, 'NUMBER_S': 10}
{'DRINKORDER': 0, 'DRINKORDER_S': 1, 'PIZZAORDER': 2, 'PIZZAORDER_S': 3, 'NONE': 4}


In [26]:
model_boundary    = NER(embedding_dim=70, hidden_size=200, n_classes=len(tags_ob), vocab_size=len(vocab))
model_pizza_order = NER(embedding_dim=70, hidden_size=500, n_classes=len(tags_pz), vocab_size=len(vocab))
model_drink_order = NER(embedding_dim=70, hidden_size=500, n_classes=len(tags_dr), vocab_size=len(vocab))
model_complex     = NER(embedding_dim=70, hidden_size=500, n_classes=len(tags_ct), vocab_size=len(vocab))

load_model_state(model_boundary, "models/order_boundary_x95.3.pth")
load_model_state(model_pizza_order, "models/pizza_order_x94.4.pth")
load_model_state(model_drink_order, "models/drink_order_x100.0.pth")
load_model_state(model_complex, "models/complex_x100.0.pth")

device = torch.device("cuda:0")
model_boundary.to(device)
model_pizza_order.to(device)
model_drink_order.to(device)
model_complex.to(device)

  model.load_state_dict(torch.load(path))


NER(
  (embedding): Embedding(307, 70)
  (lstm): LSTM(70, 500, batch_first=True, bidirectional=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (linear): Linear(in_features=1000, out_features=5, bias=True)
)

In [27]:
def tags_inverse(tags):
    inv_tags = {}
    for tag, value in tags.items():
        inv_tags[value] = tag
    return inv_tags

inv_tags_ct = tags_inverse(tags_ct)
inv_tags_pz = tags_inverse(tags_pz)
inv_tags_dr = tags_inverse(tags_dr)
inv_tags_ob = tags_inverse(tags_ob)

print(inv_tags_ct)
print(inv_tags_pz)
print(inv_tags_dr)
print(inv_tags_ob)

{0: 'QUANTITY', 1: 'TOPPING', 2: 'TOPPING_S', 3: 'QUANTITY_S', 4: 'NONE'}
{0: 'TOPPING_S', 1: 'COMPLEX_TOPPING_S', 2: 'NOT_STYLE_S', 3: 'STYLE', 4: 'NUMBER', 5: 'TOPPING', 6: 'SIZE', 7: 'NONE', 8: 'NOT_TOPPING', 9: 'NOT_TOPPING_S', 10: 'NUMBER_S', 11: 'NOT_COMPLEX_TOPPING_S', 12: 'STYLE_S', 13: 'NOT_COMPLEX_TOPPING', 14: 'COMPLEX_TOPPING', 15: 'SIZE_S', 16: 'NOT_STYLE'}
{0: 'SIZE', 1: 'CONTAINERTYPE', 2: 'DRINKTYPE', 3: 'DRINKTYPE_S', 4: 'NONE', 5: 'CONTAINERTYPE_S', 6: 'VOLUME', 7: 'NUMBER', 8: 'VOLUME_S', 9: 'SIZE_S', 10: 'NUMBER_S'}
{0: 'DRINKORDER', 1: 'DRINKORDER_S', 2: 'PIZZAORDER', 3: 'PIZZAORDER_S', 4: 'NONE'}


In [28]:

def feed_model(model, query, inv_tags):
    s = [vocab[token] if token in vocab
                 else vocab['<UNK>']
                 for token in query.split(' ') if token != '']
    x_tensor = torch.tensor(s)
    output = model.forward(x_tensor.to(device))
    output = torch.argmax(output, dim=-1).to("cpu")
    return [inv_tags[x.item()] for x in output]

feed_model(model_boundary, "can i have one pizza", inv_tags_ob)    

['NONE', 'NONE', 'NONE', 'PIZZAORDER_S', 'PIZZAORDER']

In [55]:
def run_complex(order):
    words = [token for token in order.split(' ') if token != '']
    order_result = feed_model(model_complex, order, inv_tags_ct)
    result = ""
    index = 0
    
    TAGS_STARTERS = ["TOPPING_S", "QUANTITY_S"]
    TAGS_CONT     = ["TOPPING"  , "QUANTITY"  ]
    while index < len(order_result):  # len(order_result) == len(words)
        found = False
        for tag_s, tag in zip(TAGS_STARTERS, TAGS_CONT):
            if order_result[index] == tag_s:
                found = True
                content = [words[index]]
                index = index + 1
                while index < len(order_result) and order_result[index] == tag:
                    content = content + [words[index]]
                    index = index + 1
                result += f"({tag} {' '.join(content)}) "
                break
        if not found:
            result += words[index] + " "
            index = index + 1
    
    return result

def run_pizza_order(order):
    words = [token for token in order.split(' ') if token != '']
    order_result = feed_model(model_pizza_order, order, inv_tags_pz)
    result = ""
    # print(order_result)
    index = 0
    
    NORMAL_TAGS_STARTERS = ["TOPPING_S", "STYLE_S", "SIZE_S", "NUMBER_S"]
    NORMAL_TAGS_CONT     = ["TOPPING"  , "STYLE"  , "SIZE"  , "NUMBER"]
    
    NOT_TAGS_STARTERS    = ["NOT_TOPPING_S", "NOT_STYLE_S", "NOT_SIZE_S", "NOT_NUMBER_S"]   # last two doesn't exist but aahhh whatever xD
    NOT_TAGS_CONT        = ["NOT_TOPPING"  , "NOT_STYLE"  , "NOT_SIZE"  , "NOT_NUMBER"]     # I'll keep it just in case the model is tripping or something :)
    while index < len(order_result):  # len(order_result) == len(words)
        found = False
        for tag_s, tag in zip(NORMAL_TAGS_STARTERS, NORMAL_TAGS_CONT):
            if order_result[index] == tag_s:
                found = True
                content = [words[index]]
                index = index + 1
                while index < len(order_result) and order_result[index] == tag:
                    content = content + [words[index]]
                    index = index + 1
                result += f"({tag} {' '.join(content)}) "
                break
        if found:
           continue
        
        for tag_s, tag in zip(NOT_TAGS_STARTERS, NOT_TAGS_CONT):
            if order_result[index] == tag_s:
                found = True
                content = [words[index]]
                index = index + 1
                while index < len(order_result) and order_result[index] == tag:
                    content = content + [words[index]]
                    index = index + 1
                result += f"(NOT ({tag[4:]} {' '.join(content)}) ) "
                break
                
        if found:
           continue
            
        # special case: COMPLEX_TOPPING_S & NOT_COMPLEX_TOPPING_S
        if "COMPLEX_TOPPING_S" in order_result[index]:
            found = True
            negated = "NOT" in order_result[index]
            content = [words[index]]
            index = index + 1
            while index < len(order_result) and "COMPLEX_TOPPING" in order_result[index]:
                content = content + [words[index]]
                index = index + 1
            val = run_complex(' '.join(content))
            if negated:
                result += f"(NOT (COMPLEX {val}) ) "
            else:
                result += f"(COMPLEX {val}) "
        if found:
           continue
        result += words[index] + " "
        index = index + 1
    
    return result

def run_drink_order(order):
    words = [token for token in order.split(' ') if token != '']
    order_result = feed_model(model_drink_order, order, inv_tags_dr)
    result = ""
    index = 0
    
    TAGS_STARTERS = ["SIZE_S", "VOLUME_S", "NUMBER_S", "DRINKTYPE_S", "CONTAINERTYPE_S"]
    TAGS_CONT     = ["SIZE"  , "VOLUME"  , "NUMBER"  , "DRINKTYPE"  , "CONTAINERTYPE"]
    while index < len(order_result):  # len(order_result) == len(words)
        found = False
        for tag_s, tag in zip(TAGS_STARTERS, TAGS_CONT):
            if order_result[index] == tag_s:
                found = True
                content = [words[index]]
                index = index + 1
                while index < len(order_result) and order_result[index] == tag:
                    content = content + [words[index]]
                    index = index + 1
                result += f"({tag} {' '.join(content)}) "
                break
        if not found:
            result += words[index] + " "
            index = index + 1
    
    return result

def run_order(order):
    words = [token for token in order.split(' ') if token != '']
    order_result = feed_model(model_boundary, order, inv_tags_ob)
    result = ""
    index = 0
    while index < len(order_result):  # len(order_result) == len(words)
        if order_result[index] in 'PIZZAORDER_S':  # read a pizza order
            order = [words[index]]
            index = index + 1
            while index < len(order_result) and order_result[index] == 'PIZZAORDER':
                order = order + [words[index]]
                index = index + 1
            
            result += f"(PIZZAORDER {run_pizza_order(' '.join(order))}) "
        elif order_result[index] in 'DRINKORDER_S':  # read a drink order
            order = [words[index]]
            index = index + 1
            while index < len(order_result) and order_result[index] == 'DRINKORDER':
                order = order + [words[index]]
                index = index + 1
                
            result += f"(DRINKORDER {run_drink_order(' '.join(order))}) "
        else:
            result += words[index] + " "
            index = index + 1
    return result
    
def run_query(query):
    query = query.lower()
    return f"(ORDER {run_order(query)})"


print(run_query("i want to order one pepperoni pizza with ham and without onions and one coke"))

(ORDER i want to order (PIZZAORDER (NUMBER one) (TOPPING pepperoni) pizza with (TOPPING ham) and without (NOT (TOPPING onions) ) ) and (DRINKORDER (NUMBER one) (DRINKTYPE coke) ) )
