In [1]:
import json
import re

def tokenize(s):
    # Extract tokens: parentheses or sequences of non-whitespace, non-parenthesis characters.
    tokens = re.findall(r'\(|\)|[^\s()]+', s)
    return tokens

def parse_tokens(tokens):
    # Parse tokens into a nested list structure
    stack = []
    current_list = []
    for token in tokens:
        if token == '(':
            stack.append(current_list)
            current_list = []
        elif token == ')':
            finished = current_list
            current_list = stack.pop()
            current_list.append(finished)
        else:
            current_list.append(token)
    return current_list

def normalize_structure(tree):
    if not isinstance(tree, list):
        return None

    def is_key(token):
        return token in [
            "ORDER", "PIZZAORDER", "DRINKORDER", "NUMBER", "SIZE", "STYLE", "TOPPING",
            "COMPLEX_TOPPING", "QUANTITY", "VOLUME", "DRINKTYPE", "CONTAINERTYPE", "NOT"
        ]

    # Clean the list by keeping sublists and tokens as-is for further analysis
    cleaned = []
    for el in tree:
        cleaned.append(el)

    if len(cleaned) > 0 and isinstance(cleaned[0], str) and is_key(cleaned[0]):
        key = cleaned[0]
        if key == "ORDER":
            pizzaorders = []
            drinkorders = []
            for sub in cleaned[1:]:
                node = normalize_structure(sub)
                if isinstance(node, dict):
                    if "PIZZAORDER" in node:
                        if isinstance(node["PIZZAORDER"], list):
                            pizzaorders.extend(node["PIZZAORDER"])
                        else:
                            pizzaorders.append(node["PIZZAORDER"])
                    if "DRINKORDER" in node:
                        if isinstance(node["DRINKORDER"], list):
                            drinkorders.extend(node["DRINKORDER"])
                        else:
                            drinkorders.append(node["DRINKORDER"])
                    if node.get("TYPE") == "PIZZAORDER":
                        pizzaorders.append(node)
                    if node.get("TYPE") == "DRINKORDER":
                        drinkorders.append(node)
            result = {}
            if pizzaorders:
                result["PIZZAORDER"] = pizzaorders
            if drinkorders:
                result["DRINKORDER"] = drinkorders
            if result:
                return {"ORDER": result}
            else:
                return {}

        elif key == "PIZZAORDER":
            number = None
            size = None
            style = None
            toppings = []
            for sub in cleaned[1:]:
                node = normalize_structure(sub)
                if isinstance(node, dict):
                    t = node.get("TYPE")
                    if t == "NUMBER":
                        number = node["VALUE"]
                    elif t == "SIZE":
                        size = node["VALUE"]
                    elif t == "STYLE":
                        style = node["VALUE"]
                    elif t == "TOPPING":
                        toppings.append(node)
            result = {}
            if number is not None:
                result["NUMBER"] = number
            if size is not None:
                result["SIZE"] = size
            if style is not None:
                result["STYLE"] = style
            if toppings:
                result["AllTopping"] = toppings
            # Mark type internally, will remove later
            result["TYPE"] = "PIZZAORDER"
            return result

        elif key == "DRINKORDER":
            number = None
            volume = None
            drinktype = None
            containertype = None
            for sub in cleaned[1:]:
                node = normalize_structure(sub)
                if isinstance(node, dict):
                    t = node.get("TYPE")
                    if t == "NUMBER":
                        number = node["VALUE"]
                    elif t == "VOLUME":
                        volume = node["VALUE"]
                    elif t == "DRINKTYPE":
                        drinktype = node["VALUE"]
                    elif t == "CONTAINERTYPE":
                        containertype = node["VALUE"]
            result = {}
            if number is not None:
                result["NUMBER"] = number
            if volume is not None:
                result["VOLUME"] = volume
            if drinktype is not None:
                result["DRINKTYPE"] = drinktype
            if containertype is not None:
                result["CONTAINERTYPE"] = containertype
            result["TYPE"] = "DRINKORDER"
            return result

        elif key in ["NUMBER","SIZE","STYLE","VOLUME","DRINKTYPE","CONTAINERTYPE","QUANTITY"]:
            values = []
            for el in cleaned[1:]:
                if isinstance(el, str):
                    values.append(el)
            value_str = " ".join(values).strip()
            return {
                "TYPE": key,
                "VALUE": value_str
            }

        elif key == "TOPPING":
            values = []
            for el in cleaned[1:]:
                if isinstance(el, str):
                    values.append(el)
            topping_str = " ".join(values).strip()
            return {
                "TYPE": "TOPPING",
                "NOT": False,
                "Quantity": None,
                "Topping": topping_str
            }

        elif key == "COMPLEX_TOPPING":
            quantity = None
            topping = None
            for sub in cleaned[1:]:
                node = normalize_structure(sub)
                if isinstance(node, dict):
                    t = node.get("TYPE")
                    if t == "QUANTITY":
                        quantity = node["VALUE"]
                    elif t == "TOPPING":
                        topping = node["Topping"]
            return {
                "TYPE": "TOPPING",
                "NOT": False,
                "Quantity": quantity,
                "Topping": topping
            }

        elif key == "NOT":
            for sub in cleaned[1:]:
                node = normalize_structure(sub)
                if isinstance(node, dict) and node.get("TYPE") == "TOPPING":
                    node["NOT"] = True
                    if "Quantity" not in node:
                        node["Quantity"] = None
                    return node
            return None

    else:
        # Try to parse sublists and combine orders found
        combined_order = {"PIZZAORDER": [], "DRINKORDER": []}
        found_order = False

        for el in cleaned:
            node = normalize_structure(el)
            if isinstance(node, dict):
                if "ORDER" in node:
                    found_order = True
                    order_node = node["ORDER"]
                    if "PIZZAORDER" in order_node:
                        combined_order["PIZZAORDER"].extend(order_node["PIZZAORDER"])
                    if "DRINKORDER" in order_node:
                        combined_order["DRINKORDER"].extend(order_node["DRINKORDER"])
                elif node.get("TYPE") == "PIZZAORDER":
                    found_order = True
                    combined_order["PIZZAORDER"].append(node)
                elif node.get("TYPE") == "DRINKORDER":
                    found_order = True
                    combined_order["DRINKORDER"].append(node)

        if found_order:
            final = {}
            if combined_order["PIZZAORDER"]:
                final["PIZZAORDER"] = combined_order["PIZZAORDER"]
            if combined_order["DRINKORDER"]:
                final["DRINKORDER"] = combined_order["DRINKORDER"]
            return {"ORDER": final} if final else {}

        return None

def remove_type_keys(obj):
    # Recursively remove "TYPE" keys from all dictionaries
    if isinstance(obj, dict):
        obj.pop("TYPE", None)
        for k, v in obj.items():
            remove_type_keys(v)
    elif isinstance(obj, list):
        for item in obj:
            remove_type_keys(item)


def preprocess(text):
    tokens = tokenize(text)
    parsed = parse_tokens(tokens)
    result = normalize_structure(parsed)
    remove_type_keys(result)
    return result


In [2]:
import json
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hid_dim, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hid_dim*2, hid_dim)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, src):
        # src: [batch, src_len]
        embedded = self.dropout(self.embedding(src))
        # outputs: [batch, src_len, hid_dim*2]
        outputs, (hidden, cell) = self.lstm(embedded)
        
        # hidden and cell are from both directions
        # We can combine them:
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))
        
        # cell: just take last layer's cell states (similar to hidden)
        # If you want, you can do similar combination for cell:
        cell = torch.zeros_like(hidden)  # for simplicity
        
        return outputs, (hidden.unsqueeze(0), cell.unsqueeze(0))
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hid_dim, batch_first=True)
        self.out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, input_step, hidden, cell):
        # input_step: [batch]
        # hidden, cell: [1, batch, hid_dim]
        input_step = input_step.unsqueeze(1) # [batch, 1]
        embedded = self.dropout(self.embedding(input_step)) # [batch, 1, emb_dim]
        outputs, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        # outputs: [batch, 1, hid_dim]
        pred = self.out(outputs.squeeze(1)) # [batch, output_dim]
        return pred, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, tgt_vocab):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tgt_vocab = tgt_vocab

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # tgt: [batch, tgt_len]
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        outputs = torch.zeros(batch_size, tgt_len, len(self.tgt_vocab)).to(src.device)

        enc_outputs, (hidden, cell) = self.encoder(src)

        input_step = tgt[:,0] # <SOS>

        for t in range(1, tgt_len):
            pred, hidden, cell = self.decoder(input_step, hidden, cell)
            outputs[:, t, :] = pred
            # Teacher forcing
            teacher_force = True if torch.rand(1).item() < teacher_forcing_ratio else False
            top1 = pred.argmax(1)
            input_step = tgt[:, t] if teacher_force else top1

        return outputs

In [3]:
def tokenize_input(text):
    return text.strip().split()

def build_vocab(dataset, min_freq=1):
    # Collect frequencies
    freq = {}
    for src_tokens, tgt_tokens in dataset:
        for t in src_tokens:
            freq[t] = freq.get(t, 0) + 1
        for t in tgt_tokens:
            freq[t] = freq.get(t, 0) + 1
    
    # Create vocab
    vocab = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
    idx = len(vocab)
    for token, count in freq.items():
        if count >= min_freq and token not in vocab:
            vocab[token] = idx
            idx += 1
    return vocab

def numericalize(tokens, vocab):
    return [vocab.get(t, vocab["<UNK>"]) for t in tokens]

def pad_batch(sequences, pad_idx):
    # Pad a list of sequences to the max length in the batch
    max_len = max(len(seq) for seq in sequences)
    padded = []
    for seq in sequences:
        padded.append(seq + [pad_idx]*(max_len - len(seq)))
    return padded

def collate_fn(batch, src_vocab, tgt_vocab):
    # batch: list of tuples (src_tokens, tgt_tokens)
    src_seqs, tgt_seqs = zip(*batch)
    # Numericalize
    src_seqs_num = [numericalize(s, src_vocab) for s in src_seqs]
    tgt_seqs_num = [[tgt_vocab["<SOS>"]]+numericalize(t, tgt_vocab)+[tgt_vocab["<EOS>"]] for t in tgt_seqs]

    src_padded = pad_batch(src_seqs_num, src_vocab["<PAD>"])
    tgt_padded = pad_batch(tgt_seqs_num, tgt_vocab["<PAD>"])

    src_tensor = torch.tensor(src_padded, dtype=torch.long)
    tgt_tensor = torch.tensor(tgt_padded, dtype=torch.long)
    return src_tensor, tgt_tensor

In [4]:
def train_model(model, dataloader, optimizer, criterion, clip=1):
    model.train()
    epoch_loss = 0
    total_batches = len(dataloader)
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch_idx, batch in enumerate(progress_bar):
        src, tgt = batch
        src = src.to(model.device)
        tgt = tgt.to(model.device)
        
        optimizer.zero_grad()
        output = model(src, tgt)  # output: [batch, tgt_len, output_dim]
        
        # Reshape for loss calculation
        output_dim = output.shape[-1]
        # Exclude the <SOS> token at the start
        output = output[:, 1:].reshape(-1, output_dim)  
        tgt = tgt[:, 1:].reshape(-1)
        
        loss = criterion(output, tgt)
        loss.backward()

        # torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
        avg_loss = epoch_loss / (batch_idx + 1)
        progress_bar.set_description(f"Training Progress: Batch {batch_idx + 1}/{total_batches}, Avg Loss: {avg_loss:.8f}")
        
    return epoch_loss / len(dataloader)


def greedy_decode(model, src_tokens, src_vocab, tgt_vocab, max_len=50):
    model.eval()
    src_ids = [src_vocab.get(t, src_vocab["<UNK>"]) for t in src_tokens]
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0) # [1, src_len]
    src_tensor = src_tensor.to(model.device)  # Move input to the same device as model

    with torch.no_grad():
        enc_outputs, (hidden, cell) = model.encoder(src_tensor)

    input_step = torch.tensor([tgt_vocab["<SOS>"]], dtype=torch.long).to(model.device)
    decoded_tokens = []

    for _ in range(max_len):
        with torch.no_grad():
            pred, hidden, cell = model.decoder(input_step, hidden, cell)
        top1 = pred.argmax(1)
        token_id = top1.item()
        if token_id == tgt_vocab["<EOS>"]:
            break
        # Convert token_id back to token
        decoded_tokens.append([k for k,v in tgt_vocab.items() if v==token_id][0])
        input_step = top1

    return " ".join(decoded_tokens)


In [5]:
class PizzaDataset(Dataset):
    def __init__(self, data, src_field, tgt_field):

        self.data = data
        self.src_field = src_field
        self.tgt_field = tgt_field

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_text = self.data[idx][self.src_field]
        tgt_text = self.data[idx][self.tgt_field]
        src_tokens = tokenize(src_text)
        tgt_tokens = tokenize(tgt_text)
        return src_tokens, tgt_tokens

In [6]:
import pandas as pd
train_json_path = "../dataset/PIZZA_train.json"  
dev_jsonl_path = "../dataset/PIZZA_dev.json" 
train_data = pd.read_json(train_json_path, lines=True)


In [7]:
data_list = train_data.to_dict(orient='records')

In [8]:
dataset = PizzaDataset(data_list, "train.SRC", "train.TOP-DECOUPLED")
dataset.__getitem__(1)

(['large',
  'pie',
  'with',
  'green',
  'pepper',
  'and',
  'with',
  'extra',
  'peperonni'],
 ['(',
  'ORDER',
  '(',
  'PIZZAORDER',
  '(',
  'SIZE',
  'large',
  ')',
  '(',
  'TOPPING',
  'green',
  'pepper',
  ')',
  '(',
  'COMPLEX_TOPPING',
  '(',
  'QUANTITY',
  'extra',
  ')',
  '(',
  'TOPPING',
  'peperonni',
  ')',
  ')',
  ')',
  ')'])

In [9]:
del train_data

In [10]:

all_pairs = [dataset[i] for i in range(len(dataset))]
src_vocab = build_vocab(all_pairs, min_freq=1)
tgt_vocab = src_vocab  #  use one vocab; we may separate if needed


In [11]:
batch_size = 128
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=lambda b: collate_fn(b, src_vocab, tgt_vocab))

In [12]:
tgt_vocab

{'<PAD>': 0,
 '<SOS>': 1,
 '<EOS>': 2,
 '<UNK>': 3,
 'can': 4,
 'i': 5,
 'have': 6,
 'a': 7,
 'large': 8,
 'bbq': 9,
 'pulled': 10,
 'pork': 11,
 '(': 12,
 'ORDER': 13,
 'PIZZAORDER': 14,
 'NUMBER': 15,
 ')': 16,
 'SIZE': 17,
 'TOPPING': 18,
 'pie': 19,
 'with': 20,
 'green': 21,
 'pepper': 22,
 'and': 23,
 'extra': 24,
 'peperonni': 25,
 'COMPLEX_TOPPING': 26,
 'QUANTITY': 27,
 "i'd": 28,
 'like': 29,
 'vegetarian': 30,
 'pizza': 31,
 'STYLE': 32,
 'party': 33,
 'size': 34,
 'stuffed': 35,
 'crust': 36,
 'american': 37,
 'cheese': 38,
 'mushroom': 39,
 'one': 40,
 'personal': 41,
 'sized': 42,
 'artichoke': 43,
 'banana': 44,
 'peppperonis': 45,
 'low': 46,
 'fat': 47,
 'want': 48,
 'regular': 49,
 'without': 50,
 'any': 51,
 'fried': 52,
 'onions': 53,
 'NOT': 54,
 'little': 55,
 'bit': 56,
 'of': 57,
 'high': 58,
 'rise': 59,
 'dough': 60,
 'lot': 61,
 'olive': 62,
 'pesto': 63,
 'sauce': 64,
 'peperonis': 65,
 'yellow': 66,
 'meatball': 67,
 '-': 68,
 'bean': 69,
 'big': 70,
 'meat

In [13]:
# Model hyperparameters
# Define model parameters
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(tgt_vocab)
EMB_DIM = 128
HID_DIM = 128
PAD_IDX = src_vocab["<PAD>"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(INPUT_DIM, EMB_DIM, HID_DIM, PAD_IDX)
decoder = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM, PAD_IDX)
model = Seq2Seq(encoder, decoder, tgt_vocab)
model = model.to(device)
model.device = device

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [14]:
dev_json_path = "../dataset/PIZZA_dev.json" 
dev_data = pd.read_json(dev_json_path, lines=True)
dev_data = dev_data.to_dict(orient='records')

In [15]:
## fix maxlen
n_epochs = 5
for epoch in range(n_epochs):
    train_loss = train_model(model, train_loader, optimizer, criterion)


                                                                                                                   

KeyboardInterrupt: 

In [16]:
import re

def clean_parentheses(s):
    # Tokenize parentheses and other tokens
    tokens = re.findall(r'\(|\)|[^\s()]+', s)

    stack = []
    cleaned = []
    open_count = 0
    
    for token in tokens:
        if token == '(':
            open_count += 1
            cleaned.append(token)
        elif token == ')':
            # Only add a closing parenthesis if there is a corresponding open
            if open_count > 0:
                cleaned.append(token)
                open_count -= 1
            # If no open, skip this token
        else:
            # It's a word/token, just add it
            cleaned.append(token)

    return " ".join(cleaned)

In [23]:
def compare_json(pred, gold):
    def compare_nested_dicts(pred_dict, gold_dict, match_count, total_count):
        """
        Recursively compare two dictionaries and count matches and total attributes.
        """
        # Ensure both inputs are dictionaries
        if not isinstance(pred_dict, dict) or not isinstance(gold_dict, dict):
            return match_count, total_count

        for key in gold_dict:
            if key in pred_dict:
                if isinstance(gold_dict[key], dict) and isinstance(pred_dict[key], dict):
                    # Recursively compare nested dictionaries
                    match_count, total_count = compare_nested_dicts(
                        pred_dict[key], gold_dict[key], match_count, total_count
                    )
                elif isinstance(gold_dict[key], list) and isinstance(pred_dict[key], list):
                    # Compare lists element-wise
                    match_count, total_count = compare_nested_lists(
                        pred_dict[key], gold_dict[key], match_count, total_count
                    )
                else:
                    # Compare values directly
                    total_count += 1
                    if gold_dict[key] == pred_dict[key]:
                        match_count += 1
            else:
                # Key missing in prediction
                total_count += 1

        return match_count, total_count

    def compare_nested_lists(pred_list, gold_list, match_count, total_count):
        """
        Compare two lists element-wise or by subset (assuming gold is ground truth).
        """
        if not isinstance(pred_list, list) or not isinstance(gold_list, list):
            return match_count, total_count

        for gold_item in gold_list:
            matched = False
            for pred_item in pred_list:
                # Compare each item in the lists (assuming dictionaries)
                if isinstance(gold_item, dict) and isinstance(pred_item, dict):
                    sub_match_count, sub_total_count = compare_nested_dicts(
                        pred_item, gold_item, 0, 0
                    )
                    if sub_match_count == sub_total_count:
                        matched = True
                        break
                elif gold_item == pred_item:
                    matched = True
                    break
            total_count += 1
            if matched:
                match_count += 1

        return match_count, total_count

    # Initialize match and total counts
    match_count = 0
    total_count = 0

    # Start comparison
    match_count, total_count = compare_nested_dicts(pred, gold, match_count, total_count)

    # Return percentage of correct matches
    accuracy = (match_count / total_count) * 100 if total_count > 0 else 0
    return accuracy, match_count, total_count


In [24]:
total_sequences = len(dev_data)
correct_sequences = 0
pred_list = []
tgt_list = []
accuracy= match_count = total_count = 0
for sample in dev_data:
    pred = greedy_decode(model, tokenize_input(sample["dev.SRC"]), src_vocab, tgt_vocab)
    print("SRC:", sample["dev.SRC"])
    print("PRED:", pred)
    print("GOLD:", sample["dev.TOP"])
    pred = preprocess(clean_parentheses(pred))
    tgt = preprocess(sample["dev.TOP"])
    print("PRED:", pred)
    print("GOLD:", tgt)
    accuracy_1, match_count_1, total_count_1 = compare_json(pred, tgt)
    accuracy += accuracy_1
    match_count += match_count_1
    total_count += total_count_1
    if pred == tgt:
        # print(pred)
        # print(tgt)
        correct_sequences += 1
    # else:
    #     print(pred)
    #     print(tgt)
print(f"Correct {correct_sequences}, Total {total_sequences}")
sequence_accuracy = correct_sequences / total_sequences if total_sequences > 0 else 0
sequence_accuracy * 100

SRC: i want to order two medium pizzas with sausage and black olives and two medium pizzas with pepperoni and extra cheese and three large pizzas with pepperoni and sausage
PRED: ( ORDER ( PIZZAORDER ( NUMBER two ) ( SIZE medium ) ( TOPPING pecorino cheese ) ) ( PIZZAORDER ( NUMBER three ) ( TOPPING pecorino cheese ) ( TOPPING vegan pepperoni ) ) ( DRINKORDER ( NUMBER a ) ( DRINKTYPE sprite ) ) )
GOLD: (ORDER i want to order (PIZZAORDER (NUMBER two ) (SIZE medium ) pizzas with (TOPPING sausage ) and (TOPPING black olives ) ) and (PIZZAORDER (NUMBER two ) (SIZE medium ) pizzas with (TOPPING pepperoni ) and (COMPLEX_TOPPING (QUANTITY extra ) (TOPPING cheese ) ) ) and (PIZZAORDER (NUMBER three ) (SIZE large ) pizzas with (TOPPING pepperoni ) and (TOPPING sausage ) ) )
PRED: {'ORDER': {'PIZZAORDER': [{'NUMBER': 'two', 'SIZE': 'medium', 'AllTopping': [{'NOT': False, 'Quantity': None, 'Topping': 'pecorino cheese'}]}, {'NUMBER': 'three', 'AllTopping': [{'NOT': False, 'Quantity': None, 'Toppin

0.0

In [29]:
accuracy/total_count * 100, match_count, total_count

(48.44961240310077, 5, 430)