In [1]:
# load data
import os
import pandas
import json
import random
github_path = "https://raw.githubusercontent.com/allenai/Break/refs/heads/master/break_dataset/QDMR"
data = {}
for name in ['train', 'dev', 'test']:
    url = os.path.join(github_path, '{}.csv'.format(name))
    filepath = '{}.csv'.format(name)
    if not os.path.isfile(filepath):
      os.system('wget {}'.format(url))
    assert os.path.isfile(filepath)
    lines = pandas.read_csv(open(filepath, 'r'))
    data[name] = json.loads(lines.to_json(orient='records'))
    
def jumble(decomp, ii):
    clause_separator = ' ;'
    word_separator = ' '
    if ii%3==0:
        clauses = decomp.split(clause_separator)
        random.shuffle(clauses)
        decomp = clause_separator.join(clauses)
    elif ii%3==1:
        words = decomp.split(word_separator)
        indices = (0,0)
        while indices[0] == indices[1]:
            indices = (random.randrange(len(words)), random.randrange(len(words)))
        tmp = words[indices[0]]
        words[indices[0]] = words[indices[1]]
        words[indices[1]] = tmp
        decomp = ' '.join(words)
    elif ii%3==2:
        clauses = decomp.split(clause_separator)
        clauses = [y.split(word_separator) for y in clauses]
        index = random.randrange(len(clauses))
        random.shuffle(clauses[index])
        decomp = clause_separator.join(word_separator.join(c) for c in clauses)
    return decomp

def corrupt(data):
    # make a negative example for the grammar
    # three types of corruption:
    # 1. change order of clauses
    # 2. switch 2 words
    # 3. shuffle word order inside a clause
    random.seed(0)
    output = []
    for z in data:
        assert isinstance(z, dict)
        output.append(dict(z))
    assert output == data
    for ii, x in enumerate(output):
        assert data[ii] == output[ii]
        decomp = x['decomposition']
        new_decomp = decomp
        while decomp == new_decomp:
            new_decomp = jumble(decomp, ii)
        assert output[ii]['decomposition'] != new_decomp
        output[ii]['decomposition'] = str(new_decomp)
        assert data[ii] != output[ii]
    return output
    
# reduce size of dev; provide smaller canonical sets
N = 100
step = len(data['dev'])//N
data['dev'] = data['dev'][0:len(data['dev']):step]
data['dev'] = data['dev'][:N]
data['antidev'] = corrupt(data['dev'])
assert data['dev'] != data['antidev']
for name in ['train', 'dev', 'antidev']:
    print('=== {} ({} items)==='.format(name.upper(), len(data[name])))
    print('\n'.join(json.dumps(x, indent=2, sort_keys=True) for x in data[name][:3]))
max_length_train = max(len(x['decomposition'].split()) for x in data['train'])
avg_length_train = sum(len(x['decomposition'].split()) for x in data['train'])/len(data['train'])
var_length_train = sum((len(x['decomposition'].split())-avg_length_train)**2 for x in data['train'])/len(data['train'])

=== TRAIN (44321 items)===
{
  "decomposition": "return homepages ;return #1 of  PVLDB",
  "operators": "['select', 'filter']",
  "question_id": "ACADEMIC_train_0",
  "question_text": "return me the homepage of PVLDB . ",
  "split": "train"
}
{
  "decomposition": "return homepages ;return #1 of  H. V. Jagadish",
  "operators": "['select', 'filter']",
  "question_id": "ACADEMIC_train_1",
  "question_text": "return me the homepage of \" H. V. Jagadish \" . ",
  "split": "train"
}
{
  "decomposition": "return references ;return #1 of  Making database systems usable ;return number of  #2",
  "operators": "['select', 'filter', 'aggregate']",
  "question_id": "ACADEMIC_train_10",
  "question_text": "return me the number of references of \" Making database systems usable \" . ",
  "split": "train"
}
=== DEV (100 items)===
{
  "decomposition": "return flights ;return #1 from  denver ;return #2 to philadelphia ;return #3 if  available",
  "operators": "['select', 'filter', 'filter', 'filter']",

In [2]:
# Grammar
import lark
import tqdm
def get_grammar_parse_rate(grammar, data, hard=False):
    p = lark.Lark(grammar)
    positives, negatives = 0, 0
    for x in tqdm.tqdm(data):
        parse=None
        decomp = x['decomposition'] + '\n' # add final newline as end of generation token
        if hard:
            print(decomp)
            parse = p.parse(decomp)
        else:
            try:
                parse = p.parse(decomp)
            except:
                pass
        #print(coverage)
        #assert False
        parsed = int(parse is not None)
        positives += parsed
        negatives += (1-parsed)
    assert positives + negatives == len(data)
    return positives, negatives

def get_grammar_metrics(grammar):
    output = {}
    tp, fn = get_grammar_parse_rate(grammar, data['dev'])
    fp, tn = get_grammar_parse_rate(grammar, data['antidev'])
    output['precision'] = tp/(tp+fp)
    output['recall'] = tp/(tp+fn)
    output['f1'] = 2*tp/(2*tp+fp+fn)
    return output

In [3]:
# Grammar
def make_ngram_grammar(data, vocab_fraction, n=3):
    all_ngrams = {}
    for x in data:
        words = x['decomposition'].split()
        assert '' not in words, words
        for nn in range(1,n+1):
            all_ngrams.setdefault(nn, {})
            for ii,_ in enumerate(words):
                ngram = ' '.join(words[ii:ii+nn])
                all_ngrams[nn].setdefault(ngram,0)
                all_ngrams[nn][ngram]+=1
    ngrams = [None for _ in range(n)]
    for nn in range(1,n+1):
        num_items = max(int(vocab_fraction/n*len(all_ngrams[nn])), 1)
        ngram_list = filter(lambda x: x[0] not in ['return', ';return', ' ', ';'], all_ngrams[nn].items())
        ngram_list = sorted(ngram_list, key=lambda x: (x[1], x[0]), reverse=True)[:num_items] # only keep top words
        ngram_list = sorted(map(lambda x: x[0], ngram_list))
        ngram_list = ' | '.join('"{}"'.format(x) for x in ngram_list)
        ngrams[nn-1] = ngram_list
    assert all(x is not None for x in ngrams)
    ngrams = ' | '.join('{}'.format(x) for x in ngrams)
    # Be careful to write \\n and not \n when declaring a grammar directly as a string
    # sometimes there are double whitespaces so I'm accounting for that
    grammar = """
    NEWLINE: "\\n"
    
    SPACE: " " | "  " | "   "
    
    ?start: sentence NEWLINE

    RETURN: "return"

    SEPARATOR: " ;"

    ?sentence: RETURN SPACE expression (SEPARATOR RETURN SPACE expression)~0..9
    
    ?expression: word (SPACE word)~0..9
    
    ?word: """+ngrams
    return grammar

In [4]:
# Few-shot retriever
from rank_bm25 import BM25Okapi
class FewShotRetriever:
    def __init__(self, data_set):
        # data_set contains items from the training set
        assert isinstance(data_set, list)
        assert all(isinstance(x, dict) for x in data_set)
        self.data = [(x['question_text'], x['decomposition']) for x in data_set]
        self.bm25 = None
    def build_index(self):
        self.corpus = [x[0] for x in self.data]
        tokenized_corpus = [doc.split(" ") for doc in self.corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)
    def get_samples(self, query, n=4):
        tokenized_query = query.split(" ")
        top_n = self.bm25.get_top_n(tokenized_query, self.data, n=n)
        return top_n
        
retriever = FewShotRetriever(data['train'])
retriever.build_index()

In [5]:
# Model
import outlines
from outlines.generate.api import GenerationParameters
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
MAX_TOKENS = int(avg_length_train+3*(var_length_train**.5))
class Model:
    def __init__(self, model_name, retriever, grammar, n_few_shot=8, max_tokens=MAX_TOKENS):
        assert isinstance(retriever, FewShotRetriever)
        self.retriever=retriever
        self.grammar=grammar
        if self.grammar is not None:
            assert isinstance(grammar, str)
        assert isinstance(max_tokens, int)
        self.max_tokens = max_tokens
        assert isinstance(model_name, str)
        if self.grammar is None:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=self.tokenizer.eos_token_id)
        else:
            self.model = outlines.models.transformers(model_name)
            self.generator=outlines.generate.cfg(self.model, self.grammar)
        assert isinstance(n_few_shot, int)
        self.n_few_shot=n_few_shot
    def generate(self, input_strings):
        if self.grammar is None:
            return self.generate_unconstrained(input_strings)
        else:
            return self.generate_constrained(input_strings)
    def generate_unconstrained(self, input_strings):
        assert isinstance(input_strings, list)
        assert all(isinstance(ii, str) for ii in input_strings)
        prompts = self.make_prompts(input_strings)
        pipe = pipeline(
            "text-generation", 
            model=self.model,
            tokenizer=self.tokenizer
        )
        generation_args = {
            "max_new_tokens": self.max_tokens,
            "return_full_text": False,
            "do_sample": False,
            "stop_strings": ['\n'],
            "tokenizer": self.tokenizer,
            "pad_token_id": self.tokenizer.eos_token_id

        } 
        output = pipe(prompts, **generation_args)
        output = [oo[0] for oo in output] # remove final whitespace and end-of-line
        assert len(output) == len(prompts)
        output = [{'prompt': prompt, 'lm_output': x['generated_text'].strip()} for x, prompt in zip(output, prompts)]
        return output
    def generate_constrained(self, input_strings):
        assert isinstance(input_strings, list)
        assert all(isinstance(ii, str) for ii in input_strings)
        prompts = self.make_prompts(input_strings)
        lm_outputs = self.generator(prompts, max_tokens=self.max_tokens, stop_at=['\n'])
        output = [{'prompt': prompt, 'max_tokens': self.max_tokens, 'lm_output': lmo.strip()}
                  for lmo,prompt in zip(lm_outputs, prompts)]
        return output
    def make_prompt(self, input_string):
        samples = self.retriever.get_samples(input_string, n=self.n_few_shot)
        #assert len(samples) == self.n_few_shot
        prompt = ["Question: {}\nAnswer: {}".format(x, y) for x, y in samples]
        prompt.append("Question: {}\nAnswer: ".format(input_string))
        prompt = '\n'.join(prompt)
        return prompt
    def make_prompts(self, input_strings):
        prompts = [self.make_prompt(input_string) for input_string in input_strings]
        return prompts

In [6]:
# run on validation
import tqdm
def inference(model, data, batch_size=1):
    output = []
    for ii in tqdm.tqdm(range(0, len(data), batch_size)):
        batch = data[ii:ii+batch_size]
        lm_output=model.generate([xx['question_text'] for xx in batch])
        output.extend([{'decomposition': yy['lm_output'], 'question': xx['question_text'], 'gold': xx['decomposition'], 'id': xx['question_id']}
                   for xx, yy in zip(batch, lm_output)])
    return output

In [7]:
# Evaluator for Break dataset
from scripts.evaluate_predictions import evaluate
from evaluation.decomposition import Decomposition
def run_eval(model, batch_size=10):
    output = inference(model, data['dev'], batch_size=batch_size)
    ids = [oo['id'] for oo in output]
    questions = [oo['question'] for oo in output]
    decompositions = [Decomposition.from_str(oo['decomposition']) for oo in output]
    golds = [Decomposition.from_str(oo['gold']) for oo in output]
    metadata=None
    output_path_base = '/home/nils/Desktop/code-generation-lab'
    metrics = evaluate(ids, questions, decompositions, golds, metadata, output_path_base)
    return metrics

In [8]:
# Model without constraints
unconstrained_model = Model("HuggingFaceTB/SmolLM-135M", retriever, None, n_few_shot=1)
m = run_eval(unconstrained_model)
del unconstrained_model

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [10:07<00:00, 60.72s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 39.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 70350.62it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 36.14it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 35.16it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5306.22it/s]

evaluating example #0
	id: ATIS_dev_0
	question: what flights are available tomorrow from denver to philadelphia 
	gold: return flights ;return #1 from denver ;return #2 to philadelphia ;return #3 if available
	prediction: 1000000000000000000000000000000000000000000000000000000000
	exact_match: 0
	match: 0.0
	structural_match: 1.0
	sari: 0.431
	ged: 1.0
	normalized_exact_match: 0
	normalized_match: 0.0
	normalized_structural_match: 0.0
	normalized_sari: 0.542
evaluating example #1
	id: ATIS_dev_170
	question: what nonstop flights are available from oakland to philadelphia arriving between 5 and 6pm 
	gold: return flights ;return #1 that are nonstop ;return #2 from oakland ;return #3 to philadelphia ;return #4 arriving between 5 and 6pm
	prediction: 5pm to 6pm ; 6pm to 7pm ; 7pm to 8pm ; 8pm to 9pm ; 9pm to 10pm ; 10pm to 11pm ; 11pm to 12
	exact_match: 0
	match: 0.113
	structural_match: 1.0
	sari: 0.224
	ged: 0.887
	normalized_exact_match: 0
	normalized_match: 0.217
	normalized_structu




In [9]:
# Model with unigram grammar
unigram_grammar = make_ngram_grammar(data['train'], .1, n=1)
unigram_grammar_metrics = get_grammar_metrics(unigram_grammar)
print(unigram_grammar_metrics)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:16<00:00,  1.30it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:16<00:00,  1.31it/s]

{'precision': 0.625, 'recall': 0.35, 'f1': 0.44871794871794873}





In [10]:
# eval constrained LLM
ngram_model = Model("HuggingFaceTB/SmolLM-135M", retriever, unigram_grammar, n_few_shot=1)
run_eval(ngram_model)
del ngram_model

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [22:28<00:00, 134.85s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 40.08it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 70021.77it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 36.73it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 35.11it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3618.99it/s]

evaluating example #0
	id: ATIS_dev_0
	question: what flights are available tomorrow from denver to philadelphia 
	gold: return flights ;return #1 from denver ;return #2 to philadelphia ;return #3 if available
	prediction: return flights ;return #1 from denver ;return #2 to philadelphia ;return #3 that are on monday
	exact_match: 0
	match: 0.857
	structural_match: 1.0
	sari: 0.881
	ged: 0.143
	normalized_exact_match: 0
	normalized_match: 0.833
	normalized_structural_match: 0.818
	normalized_sari: 0.884
evaluating example #1
	id: ATIS_dev_170
	question: what nonstop flights are available from oakland to philadelphia arriving between 5 and 6pm 
	gold: return flights ;return #1 that are nonstop ;return #2 from oakland ;return #3 to philadelphia ;return #4 arriving between 5 and 6pm
	prediction: return flights ;return #1 from oakland ;return #2 from philadelphia ;return #3 from philadelphia ;return #4 from philadelphia ;return #5 from philadelphia ;return #6 from philadelphia
	exact_match:




In [14]:
# Baseline grammar
baseline_grammar = ''.join(open('grammar.lark', 'r'))
baseline_grammar_metrics = get_grammar_metrics(baseline_grammar)
print(baseline_grammar_metrics)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 33.88it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 42.00it/s]

{'precision': 0.6190476190476191, 'recall': 0.13, 'f1': 0.21487603305785125}





In [15]:
# eval constrained LLM
baseline_model = Model("HuggingFaceTB/SmolLM-135M", retriever, baseline_grammar, n_few_shot=1)
run_eval(baseline_model)
del baseline_model

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [18:15<00:00, 109.54s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 46.87it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 80582.21it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 36.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 41.39it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5470.24it/s]

evaluating example #0
	id: ATIS_dev_0
	question: what flights are available tomorrow from denver to philadelphia 
	gold: return flights ;return #1 from denver ;return #2 to philadelphia ;return #3 if available
	prediction: return flights from denver ;return #1 from philadelphia ;return #2 to philadelphia ;return #3 that are
	exact_match: 0
	match: 0.81
	structural_match: 1.0
	sari: 0.819
	ged: 0.19
	normalized_exact_match: 0
	normalized_match: 0.862
	normalized_structural_match: 0.857
	normalized_sari: 0.912
evaluating example #1
	id: ATIS_dev_170
	question: what nonstop flights are available from oakland to philadelphia arriving between 5 and 6pm 
	gold: return flights ;return #1 that are nonstop ;return #2 from oakland ;return #3 to philadelphia ;return #4 arriving between 5 and 6pm
	prediction: return flights from oakland ;return #1 from philadelphia ;return #2 from oakland ;return #3 from philadelphia
	exact_match: 0
	match: 0.642
	structural_match: 1.0
	sari: 0.426
	ged: 0.358
	no


