In [1]:
from tqdm import tqdm
import json
import torch
from torch.nn import Softmax

from torch.utils.data import DataLoader
from llmtosql.model import WikiSQLModel
from llmtosql.trainer import Trainer
from llmtosql.dataloader import WikiSQLDataset
from llmtosql.utils.utils import plot_history, plot_history_base, load_model, load_history

In [2]:
path = 'model_output'

In [3]:
model = WikiSQLModel(base_model_type='bert-base-uncased', attention_type='cross')
model = load_model(model, 'model_output/model.pth')

2023-05-08 16:39:30 [info     ] Using cross attention mechanism
2023-05-08 16:39:30 [info     ] 3 heads model -- ['SELECT', 'AGG', 'CONDS']


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_set = WikiSQLDataset(type='test', model=model)
test_loader = DataLoader(test_set, batch_size=32)

2023-05-08 16:39:34 [info     ] Tokenizing dataset.


100%|██████████| 15878/15878 [00:16<00:00, 963.72it/s] 


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
count = 0
for test in test_set:
    for gt, pred in zip(test['CHECK'][0], test['CHECK'][1]):
        if str(gt).lower() != str(pred).lower():
            print(f'{gt} -- {pred} -- {test["input"][0]}')
            count += 1
count

Can I Love? -- can i love -- Which examples ask the existential question "Can I Love?"
"Hell" -- hell -- what's the original air date with title  "hell"
"Poppin' Tags" -- poppin' tags -- How many episodes in season 6 titles "Poppin' Tags"?
"Lockdown" -- lockdown -- How many series are named "Lockdown"?
"House of Cards" -- house of cards -- What was the latest episode of "House of Cards"?
"The Birthday Present" -- the birthday present -- What was the # for the episode "the birthday present"?
"The Wedding" -- the wedding -- What was the # for the episode "the wedding"?
"Bad Influence" -- bad influence -- What is the series number for the title "Bad Influence"
"Na Jane Aise Ho Gaya Kaise" -- na jane aise ho gaya kaise -- What was the lyricst of "na jane aise ho gaya kaise"?
"Jekhanete Jaai Ami 1" -- jekhanete jaai ami 1 -- who is the music director for the song "jekhanete jaai ami 1"?
"Detalles" -- detalles -- What is the order number for the song choice "detalles"? 
"The Great Mayoralty 

254

In [7]:
count/len(test_set)*100

1.5996976949237938

In [8]:
question_1 = test_set[0]['input'][0]
question_2 = test_set[1]['input'][0]

In [9]:
question_3 = 'If % lunsford is 51.82% what is the % mcconnell in Letcher?'

In [10]:
where_answer_1 = "Terrence Ross"
where_answer_2 = "1995-96"
where_answer_3 = "51.82%"

In [11]:
question_1.split()

['What', 'is', 'terrence', "ross'", 'nationality']

In [12]:
question_2.split()

['What', 'clu', 'was', 'in', 'toronto', '1995-96']

In [13]:
question_1

"What is terrence ross' nationality"

In [14]:
import re

In [15]:
pattern_list = [r'(?i)\b\w*terrence\w*\b', r'(?i)\b\w*ross\w*\b']

In [16]:
# pattern_list = [r'(?i)\b\w*sustainabl\w*\b', r'(?i)\b\w*suppl\w*\b', r'(?i)\b\w*fashion\w*\b']
index_list = []
for pattern in pattern_list:
    for idx, token in enumerate(question_1.split()):
        if re.findall(pattern, token):
            index_list.append(idx)
index_list

[2, 3]

In [17]:
map_list = [index_list[0], index_list[-1]-index_list[0]]

In [18]:
map_list

[2, 1]

In [19]:
def try_generate_mapping(question_token_list, pattern_list, gt):
    index_list = []
    for pattern, key in zip(pattern_list, gt):
        for idx, token in enumerate(question_token_list):
            if (re.findall(pattern, token)) or (key==token):
                index_list.append(idx)
    return [index_list[0], index_list[-1]-index_list[0]]

In [20]:
try_generate_mapping(question_1.split(), pattern_list, where_answer_1.split())

[2, 1]

In [21]:
try_generate_mapping(question_2.split(), [r'(?i)\b\w*1995-96\w*\b'], where_answer_2.split())

[5, 0]

In [22]:
where_answer_1.split()

['Terrence', 'Ross']

In [23]:
[fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()]

['(?i)\\b\\w*terrence\\w*\\b', '(?i)\\b\\w*ross\\w*\\b']

In [24]:
try_generate_mapping(question_1.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()], where_answer_1.split())

[2, 1]

In [25]:
try_generate_mapping(question_2.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_2.split()], where_answer_2.split())

[5, 0]

In [26]:
try_generate_mapping(question_3.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_3.split()], where_answer_3.split())

[4, 0]

In [27]:
test_set[0]

{'table_id': '1-10015132-16',
 'columns': 'Player, No., Nationality, Position, Years in Toronto, School/Club Team',
 'input': ("What is terrence ross' nationality",
  'Player, No., Nationality, Position, Years in Toronto, School/Club Team'),
 'tokenized_inputs': {'question': {'input_ids': tensor([  101,  2054,  2003, 25170,  5897,  5811,  1005, 10662,  2447,  1010,
           2053,  1012,  1010, 10662,  1010,  2597,  1010,  2086,  1999,  4361,
           1010,  2082,  1013,  2252,  2136,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,

In [28]:
cond_range = test_set[0]['labels']['conds'][3][0]

In [29]:
cleaned_q = list(WikiSQLDataset._generate_cond3(test_set[0]['input'][0].split()))

In [30]:
test_set[0]['input'][0]

"What is terrence ross' nationality"

In [31]:
cleaned_q

['what', 'is', 'terrence', "ross'", 'nationality']

In [32]:
cleaned_q[cond_range[0]:cond_range[0]+cond_range[1]]

['terrence', "ross'"]

In [33]:
test_set[1]

{'table_id': '1-10015132-16',
 'columns': 'Player, No., Nationality, Position, Years in Toronto, School/Club Team',
 'input': ('What clu was in toronto 1995-96',
  'Player, No., Nationality, Position, Years in Toronto, School/Club Team'),
 'tokenized_inputs': {'question': {'input_ids': tensor([  101,  2054, 18856,  2226,  2001,  1999,  4361,  2786,  1011,  5986,
           2447,  1010,  2053,  1012,  1010, 10662,  1010,  2597,  1010,  2086,
           1999,  4361,  1010,  2082,  1013,  2252,  2136,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,   

In [34]:
count = 0
for test in test_set:
    for gt, pred in zip(test['CHECK'][0], test['CHECK'][1]):
        if str(gt).lower() != str(pred).lower():
            print(f'{gt} -- {pred} -- {test["input"][0]}')
            count += 1
count

Can I Love? -- can i love -- Which examples ask the existential question "Can I Love?"
"Hell" -- hell -- what's the original air date with title  "hell"
"Poppin' Tags" -- poppin' tags -- How many episodes in season 6 titles "Poppin' Tags"?
"Lockdown" -- lockdown -- How many series are named "Lockdown"?
"House of Cards" -- house of cards -- What was the latest episode of "House of Cards"?
"The Birthday Present" -- the birthday present -- What was the # for the episode "the birthday present"?
"The Wedding" -- the wedding -- What was the # for the episode "the wedding"?
"Bad Influence" -- bad influence -- What is the series number for the title "Bad Influence"
"Na Jane Aise Ho Gaya Kaise" -- na jane aise ho gaya kaise -- What was the lyricst of "na jane aise ho gaya kaise"?
"Jekhanete Jaai Ami 1" -- jekhanete jaai ami 1 -- who is the music director for the song "jekhanete jaai ami 1"?
"Detalles" -- detalles -- What is the order number for the song choice "detalles"? 
"The Great Mayoralty 

254

In [35]:
count/len(test_set)*100

1.5996976949237938

In [36]:
q = 'What number game had a high assist of lebron james (7) and high point of lebron james (21)?'
cond = 'LeBron James (7)'

In [37]:
def _clean_text(text):
    char_list = '?"()+,$[]{};*'
    for char in char_list:
        text = text.replace(char, '')
    text = text.replace("'s", '')
    text = text.replace("'", '')
    return text.lower()

In [38]:
from collections import defaultdict

In [39]:
def _generate_mapping(question_token_list, pattern_list, gt):
    token_dict = defaultdict(list)
    for pattern, key in zip(pattern_list, gt):
        for idx, token in enumerate(question_token_list):
            if len(key) == 1:
                if key.lower() == _clean_text(token):
                    token_dict[key].append(idx)
            else:
                if (re.findall(pattern, _clean_text(token))) or \
                            (key.lower() == _clean_text(token)) or \
                            ((re.findall(r'^[-+]?(?:[0-9]+,)*[0-9]+(?:\.[0-9]+)?$', key)) and
                            (re.findall(r'^[-+]?(?:[0-9]+,)*[0-9]+(?:\.[0-9]+)?$', _clean_text(token))) and
                            (float(_clean_text(token)) == float(_clean_text(key)))):
                    token_dict[key].append(idx)
    first_tokens = set(token_dict[gt[0]])
    end_tokens = set(token_dict[gt[-1]])
    for end in end_tokens:
        for start in first_tokens:
            if (end -start + 1) == len(gt):
                index_list = [start, end]
    return [index_list[0], index_list[-1] - index_list[0] + 1]

In [40]:
pattern = [fr'(?i)\b\w*{token.lower()}\w*\b' for token in _clean_text(str(cond)).split()]

In [41]:
_generate_mapping(q.split(), pattern, _clean_text(str(cond)).split())

[8, 3]

In [42]:
txt = "Washington Capital's"

In [43]:
txt.endswith("'s")

True

In [44]:
re.sub(r"'s", '', txt)

'Washington Capital'

In [45]:
list(WikiSQLDataset._generate_cond3(txt.split()))

['washington', "capital's"]

In [46]:
WikiSQLDataset._digitize(txt.strip(","))

'Washington Capital'

In [47]:
' '.join(txt.split())

"Washington Capital's"