In [1]:
from tqdm import tqdm
import json
import torch
from torch.nn import Softmax

from torch.utils.data import DataLoader
from llmtosql.model import WikiSQLModel
from llmtosql.trainer import Trainer
from llmtosql.dataloader import WikiSQLDataset
from llmtosql.utils.utils import plot_history, plot_history_base, load_model, load_history

In [2]:
path = 'model_output'

In [3]:
model = WikiSQLModel(base_model_type='bert-base-uncased', attention_type='cross')
model = load_model(model, 'model_output/model.pth')

2023-05-05 21:37:14 [info     ] Using cross attention mechanism
2023-05-05 21:37:14 [info     ] 3 heads model -- ['SELECT', 'AGG', 'CONDS']


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_set = WikiSQLDataset(type='test', model=model)
test_loader = DataLoader(test_set, batch_size=32)

2023-05-05 21:37:18 [info     ] Tokenizing dataset.


100%|██████████| 15878/15878 [00:13<00:00, 1176.05it/s]


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
question_1 = test_set[0]['input'][0]
question_2 = test_set[1]['input'][0]

In [7]:
question_3 = 'If % lunsford is 51.82% what is the % mcconnell in Letcher?'

In [8]:
where_answer_1 = "Terrence Ross"
where_answer_2 = "1995-96"
where_answer_3 = "51.82%"

In [9]:
question_1.split()

['What', 'is', 'terrence', "ross'", 'nationality']

In [10]:
question_2.split()

['What', 'clu', 'was', 'in', 'toronto', '1995-96']

In [11]:
question_1

"What is terrence ross' nationality"

In [12]:
import re

In [13]:
pattern_list = [r'(?i)\b\w*terrence\w*\b', r'(?i)\b\w*ross\w*\b']

In [14]:
# pattern_list = [r'(?i)\b\w*sustainabl\w*\b', r'(?i)\b\w*suppl\w*\b', r'(?i)\b\w*fashion\w*\b']
index_list = []
for pattern in pattern_list:
    for idx, token in enumerate(question_1.split()):
        if re.findall(pattern, token):
            index_list.append(idx)
index_list

[2, 3]

In [15]:
map_list = [index_list[0], index_list[-1]-index_list[0]]

In [16]:
map_list

[2, 1]

In [17]:
def try_generate_mapping(question_token_list, pattern_list, gt):
    index_list = []
    for pattern, key in zip(pattern_list, gt):
        for idx, token in enumerate(question_token_list):
            if (re.findall(pattern, token)) or (key==token):
                index_list.append(idx)
    return [index_list[0], index_list[-1]-index_list[0]]

In [19]:
try_generate_mapping(question_1.split(), pattern_list, where_answer_1.split())

[2, 1]

In [21]:
try_generate_mapping(question_2.split(), [r'(?i)\b\w*1995-96\w*\b'], where_answer_2.split())

[5, 0]

In [22]:
where_answer_1.split()

['Terrence', 'Ross']

In [23]:
[fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()]

['(?i)\\b\\w*terrence\\w*\\b', '(?i)\\b\\w*ross\\w*\\b']

In [25]:
try_generate_mapping(question_1.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_1.split()], where_answer_1.split())

[2, 1]

In [26]:
try_generate_mapping(question_2.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_2.split()], where_answer_2.split())

[5, 0]

In [27]:
try_generate_mapping(question_3.split(), [fr'(?i)\b\w*{token.lower()}\w*\b' for token in where_answer_3.split()], where_answer_3.split())

[4, 0]