In [1]:
from tqdm import tqdm
import json
import torch
from torch.nn import Softmax

from torch.utils.data import DataLoader
from llmtosql.model import WikiSQLModel
from llmtosql.trainer import Trainer
from llmtosql.dataloader import WikiSQLDataset
from llmtosql.utils.utils import plot_history, plot_history_base, load_model, load_history

In [2]:
path = 'model_output'

In [3]:
model = WikiSQLModel(base_model_type='bert-base-uncased', attention_type='cross')
model = load_model(model, 'model_output/model.pth')

2023-05-07 18:24:08 [info     ] Using cross attention mechanism
2023-05-07 18:24:08 [info     ] 3 heads model -- ['SELECT', 'AGG', 'CONDS']


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_set = WikiSQLDataset(type='test', model=model)
test_loader = DataLoader(test_set, batch_size=32)

2023-05-07 18:24:13 [info     ] Tokenizing dataset.


100%|██████████| 15878/15878 [00:15<00:00, 998.55it/s] 


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
question_1 = test_set[0]['input'][0]
question_2 = test_set[1]['input'][0]

In [7]:
question_3 = 'If % lunsford is 51.82% what is the % mcconnell in Letcher?'

In [8]:
where_answer_1 = "Terrence Ross"
where_answer_2 = "1995-96"
where_answer_3 = "51.82%"

In [9]:
question_1.split()

['What', 'is', 'terrence', "ross'", 'nationality']

In [10]:
question_2.split()

['What', 'clu', 'was', 'in', 'toronto', '1995-96']

In [11]:
question_1

"What is terrence ross' nationality"

In [12]:
import re

In [13]:
pattern_list = [r'(?i)\b\w*terrence\w*\b', r'(?i)\b\w*ross\w*\b']

In [14]:
# pattern_list = [r'(?i)\b\w*sustainabl\w*\b', r'(?i)\b\w*suppl\w*\b', r'(?i)\b\w*fashion\w*\b']
index_list = []
for pattern in pattern_list:
    for idx, token in enumerate(question_1.split()):
        if re.findall(pattern, token):
            index_list.append(idx)
index_list

[2, 3]

In [15]:
map_list = [index_list[0], index_list[-1]-index_list[0]]

In [16]:
map_list

[2, 1]

In [17]:
def try_generate_mapping(question_token_list, pattern_list, gt):
    index_list = []
    for pattern, key in zip(pattern_list, gt):
        for idx, token in enumerate(question_token_list):
            if (re.findall(pattern, token)) or (key==token):
                index_list.append(idx)
    return [index_list[0], index_list[-1]-index_list[0]]

In [18]:
try_generate_mapping(question_1.split(), pattern_list, where_answer_1.split())

[2, 1]

In [19]:
try_generate_mapping(question_2.split(), [r'(?i)\b\w*1995-96\w*\b'], where_answer_2.split())

[5, 0]

In [20]:
where_answer_1.split()

['Terrence', 'Ross']

In [21]:
[fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()]

['(?i)\\b\\w*terrence\\w*\\b', '(?i)\\b\\w*ross\\w*\\b']

In [22]:
try_generate_mapping(question_1.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()], where_answer_1.split())

[2, 1]

In [23]:
try_generate_mapping(question_2.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_2.split()], where_answer_2.split())

[5, 0]

In [24]:
try_generate_mapping(question_3.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_3.split()], where_answer_3.split())

[4, 0]

In [25]:
test_set[0]

{'table_id': '1-10015132-16',
 'columns': 'Player, No., Nationality, Position, Years in Toronto, School/Club Team',
 'input': ("What is terrence ross' nationality",
  'Player, No., Nationality, Position, Years in Toronto, School/Club Team'),
 'tokenized_inputs': {'question': {'input_ids': tensor([  101,  2054,  2003, 25170,  5897,  5811,  1005, 10662,  2447,  1010,
           2053,  1012,  1010, 10662,  1010,  2597,  1010,  2086,  1999,  4361,
           1010,  2082,  1013,  2252,  2136,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,

In [26]:
cond_range = test_set[0]['labels']['conds'][3][0]

In [27]:
cleaned_q = list(WikiSQLDataset._generate_cond3(test_set[0]['input'][0].split()))

In [28]:
test_set[0]['input'][0]

"What is terrence ross' nationality"

In [29]:
cleaned_q

['what', 'is', 'terrence', "ross'", 'nationality']

In [30]:
cleaned_q[cond_range[0]:cond_range[0]+cond_range[1]]

['terrence', "ross'"]

In [31]:
test_set[1]

{'table_id': '1-10015132-16',
 'columns': 'Player, No., Nationality, Position, Years in Toronto, School/Club Team',
 'input': ('What clu was in toronto 1995-96',
  'Player, No., Nationality, Position, Years in Toronto, School/Club Team'),
 'tokenized_inputs': {'question': {'input_ids': tensor([  101,  2054, 18856,  2226,  2001,  1999,  4361,  2786,  1011,  5986,
           2447,  1010,  2053,  1012,  1010, 10662,  1010,  2597,  1010,  2086,
           1999,  4361,  1010,  2082,  1013,  2252,  2136,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,   

In [32]:
count = 0
for test in test_set:
    for gt, pred in zip(test['CHECK'][0], test['CHECK'][1]):
        if str(gt).lower().strip('"') != str(pred).lower():
            print(f'{gt} -- {pred} -- {test["input"][0]}')
            count += 1
count

Can I Love? -- can i love -- Which examples ask the existential question "Can I Love?"
"Ambush" (Part 1) -- ambush (part 1) -- Who is the writer of the episode called "Ambush" (part 1)?
bumped St. Catharine's -- bumped st. catharine --  how many 1st day with 3rd day being bumped st. catharine's
" Save the Last Dance for Me " -- save the last dance for me --  how many artbeingt with song title being " save the last dance for me "
0.1 -- 2012 -- During the quarter of 2012 Q2, how many millions of Blackberry OS phones where shipped when 0.1 million others were shipped?
Shasta H.S. -- shasta h.s -- How many entries are there for class when the prior experience is shasta h.s.
M.C.G. -- m.c.g -- What is the home team that played on M.C.G. grounds?
M. Van der Goten ( BEL ) --  -- Name the 3 where weightlifter is m. van der goten ( bel )
"Breakthrough" "Burēku surū" (ブレーク·スルー) -- breakthrough burēku surū (ブレーク·スルー) -- What is the number of theWhat is the number of the chapter that is called "b

235

In [33]:
q = 'What number game had a high assist of lebron james (7) and high point of lebron james (21)?'
cond = 'LeBron James (7)'

In [34]:
def _clean_text(text):
    char_list = '?"()+,$[]{};*'
    for char in char_list:
        text = text.replace(char, '')
    text = text.replace("'s", '')
    text = text.replace("'", '')
    return text.lower()

In [35]:
from collections import defaultdict

In [36]:
def _generate_mapping(question_token_list, pattern_list, gt):
    token_dict = defaultdict(list)
    for pattern, key in zip(pattern_list, gt):
        for idx, token in enumerate(question_token_list):
            if len(key) == 1:
                if key.lower() == _clean_text(token):
                    token_dict[key].append(idx)
            else:
                if (re.findall(pattern, _clean_text(token))) or \
                            (key.lower() == _clean_text(token)) or \
                            ((re.findall(r'^[-+]?(?:[0-9]+,)*[0-9]+(?:\.[0-9]+)?$', key)) and
                            (re.findall(r'^[-+]?(?:[0-9]+,)*[0-9]+(?:\.[0-9]+)?$', _clean_text(token))) and
                            (float(_clean_text(token)) == float(_clean_text(key)))):
                    token_dict[key].append(idx)
    first_tokens = set(token_dict[gt[0]])
    end_tokens = set(token_dict[gt[-1]])
    for end in end_tokens:
        for start in first_tokens:
            if (end -start + 1) == len(gt):
                index_list = [start, end]
    return [index_list[0], index_list[-1] - index_list[0] + 1]

In [37]:
pattern = [fr'(?i)\b\w*{token.lower()}\w*\b' for token in _clean_text(str(cond)).split()]

In [38]:
_generate_mapping(q.split(), pattern, _clean_text(str(cond)).split())

[8, 3]

In [39]:
txt = "Washington Capital's"

In [40]:
txt.endswith("'s")

True

In [41]:
re.sub(r"'s", '', txt)

'Washington Capital'

In [42]:
list(WikiSQLDataset._generate_cond3(txt.split()))

['washington', "capital's"]

In [43]:
WikiSQLDataset._digitize(txt.strip(","))

'Washington Capital'

In [44]:
' '.join(txt.split())

"Washington Capital's"