In [1]:
from tqdm import tqdm
import json
import torch
from torch.nn import Softmax

from torch.utils.data import DataLoader
from llmtosql.model import WikiSQLModel
from llmtosql.trainer import Trainer
from llmtosql.dataloader import WikiSQLDataset
from llmtosql.utils.utils import plot_history, plot_history_base, load_model, load_history

In [2]:
path = 'model_output'

In [3]:
model = WikiSQLModel(base_model_type='bert-base-uncased', attention_type='cross')
model = load_model(model, 'model_output/model.pth')

2023-05-10 16:39:04 [info     ] Using cross attention mechanism
2023-05-10 16:39:04 [info     ] 3 heads model -- ['SELECT', 'AGG', 'CONDS']


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [28]:
test_set = WikiSQLDataset(type='dev', model=model)
test_loader = DataLoader(test_set, batch_size=32)

2023-05-10 17:14:49 [info     ] Tokenizing dataset.


100%|██████████| 8421/8421 [00:07<00:00, 1171.23it/s]


In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [30]:
model = model.to(device)

In [39]:
sel = []
agg = []
conds = []
with tqdm(test_loader, unit='batch') as tepoch:
    for data in tepoch:
        questions = (data['input'][0])
        inputs, _ = model.unpack(data, device)
        outputs = model(inputs)
        predictions = model.predict(outputs)
        sel.extend(predictions[0].tolist())
        agg.extend(predictions[1].tolist())
        for idx, cond in enumerate(predictions[2]):
            if len(cond.shape) == 1:
                cond = cond.unsqueeze(1)
            if idx == 0:
                max_num_conditions = torch.max(cond).item()
                # print(max_num_conditions)
                if max_num_conditions == 0:
                    cond_1 = cond_2, cond_3 = [[None]], [[None]], [[None]]
                    break
            elif idx == 1:
                cond_1 = cond.T.tolist()
                cond_1 = cond_1[:max_num_conditions]
            elif idx == 2:
                cond = cond - 1
                cond_2 = cond.T.tolist()
                cond_2 = cond_2[:max_num_conditions]
            elif idx == 3:
                outer_list = []
                # print(torch.transpose(predictions[2][3].T, 1, 2).tolist())
                for condition in torch.transpose(predictions[2][3].T, 1, 2).tolist():
                    batch_list = [WikiSQLDataset._digitize(' '.join((q.split())[cond_range[0]:cond_range[0] + cond_range[1]])) for cond_range, q in zip(condition, questions)]
                    # for batch in condition:
                    #     word_list = model.tokenizer.convert_ids_to_tokens(batch, skip_special_tokens=True)
                    #     batch_list.append(' '.join(word_list))
                    outer_list.append(batch_list)
                cond_3 = outer_list
                cond_3 = cond_3[:max_num_conditions]
        print(cond_1, cond_2, cond_3)
        all_conds = []
        for c1, c2, c3 in zip(cond_1, cond_2, cond_3):
            inner_all_conds = []
            for b1, b2, b3 in zip(c1, c2, c3):
                if b2 == -1:
                    b1, b2, b3 = None, None, None
                inner_all_conds.append((b1, b2, b3))
            all_conds.append(inner_all_conds)
        conds.extend([list(x) for x in zip(*all_conds)])
        break

  0%|          | 0/264 [00:37<?, ?batch/s]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, -1, -1, -1, -1]] [['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', 'What', '', '', '', '', '', '', ''], ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', 'What', '', '', '', '', '', '', '']]





In [40]:
final = []
for s, a, c in zip(sel, agg, conds):
    solution = {
        "query": {
            "sel":s,
            "agg":a
        }
    }
    if all([all([x is None for x in cond]) for cond in c]):
        c = None
    if c is not None:
        solution["query"]["conds"] = [list(x) for x in c]
    final.append(solution)

In [41]:
final

[{'query': {'sel': 3, 'agg': 0}},
 {'query': {'sel': 1, 'agg': 3}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 5, 'agg': 0}},
 {'query': {'sel': 1, 'agg': 3}},
 {'query': {'sel': 3, 'agg': 3}},
 {'query': {'sel': 2, 'agg': 0}},
 {'query': {'sel': 5, 'agg': 0}},
 {'query': {'sel': 2, 'agg': 0}},
 {'query': {'sel': 4, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 3, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0}},
 {'query': {'sel': 3, 'agg': 0}},
 {'query': {'sel': 3, 'agg': 0}},
 {'query': {'sel': 1, 'agg': 3}},
 {'query': {'sel': 5, 'agg': 0}},
 {'query': {'sel': 7, 'agg': 0}},
 {'query': {'sel': 5, 'agg': 2}},
 {'query': {'sel': 4, 'agg': 0}},
 {'query': {'sel': 4, 'agg': 0}},
 {'query': {'sel': 0, 'agg': 0, 'conds': [[0, 2, 'What'], [0, 2, 'What']]}},
 {'query': {'sel': 4, 'agg': 0}},
 {'query': {'sel': 4, 'agg': 3}},
 {'query': {'sel': 0, 'agg': 0}},
 {'qu

In [None]:
test_file = 'model_output/test_results.jsonl'

In [None]:
with open(test_file, 'w+') as f:
    for line in final:
        json.dump(line, f)
        f.write('\n')