In [1]:
from tqdm import tqdm
import json
import torch
from torch.nn import Softmax

from torch.utils.data import DataLoader
from llmtosql.model import WikiSQLModel
from llmtosql.trainer import Trainer
from llmtosql.dataloader import WikiSQLDataset
from llmtosql.utils.utils import plot_history, plot_history_base, load_model, load_history

In [2]:
path = 'model_output'

In [3]:
model = WikiSQLModel(base_model_type='bert-base-uncased', attention_type='cross', inference=True)
model = load_model(model, 'model_output/model.pth')

2023-04-18 22:08:00 [info     ] Using cross attention mechanism
2023-04-18 22:08:00 [info     ] 3 heads model -- ['SELECT', 'AGG', 'CONDS']


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
test_set = WikiSQLDataset(type='test', model=model)
test_loader = DataLoader(test_set, batch_size=32)

2023-04-18 22:08:04 [info     ] Tokenizing dataset.


100%|██████████| 15878/15878 [00:14<00:00, 1072.75it/s]


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [38]:
sel = []
agg = []
conds = []
with tqdm(test_loader, unit='batch') as tepoch:
    for data in tepoch:
        inputs, _ = model.unpack(data, device)
        outputs = model(inputs)
        predictions = model.predict(outputs)
        sel.extend(predictions[0].tolist())
        agg.extend(predictions[1].tolist())
        if predictions[2] is None:
            conds.extend(None)
        else:
            for idx, cond in enumerate(predictions[2]):
                if len(cond.shape) == 1:
                    cond = cond.unsqueeze(1)
                if idx == 0:
                    continue
                elif idx == 1:
                    cond_1 = cond.T.tolist()
                elif idx == 2:
                    cond = cond - 1
                    cond_2 = cond.T.tolist()
                elif idx == 3:
                    outer_list = []
                    for condition in torch.transpose(predictions[2][3].T, 1, 2).tolist():
                        batch_list = []
                        for batch in condition:
                            word_list = model.tokenizer.convert_ids_to_tokens(batch, skip_special_tokens=True)
                            batch_list.append(' '.join(word_list))
                        outer_list.append(batch_list)
                    cond_3 = outer_list
            all_conds = []
            for c1, c2, c3 in zip(cond_1, cond_2, cond_3):
                inner_all_conds = []
                for b1, b2, b3 in zip(c1, c2, c3):
                    if b2 == -1:
                        b1, b2, b3 = None, None, None
                    inner_all_conds.append((b1, b2, b3))
                all_conds.append(inner_all_conds)
            conds.extend([list(x) for x in zip(*all_conds)])
        break

  0%|          | 0/497 [00:40<?, ?batch/s]


In [39]:
final = []
for s, a, c in zip(sel, agg, conds):
    solution = {
        "query": {
            "sel":s,
            "agg":a
        }
    }
    if c is not None:
        solution["query"]["conds"] = [list(x) for x in c]
    final.append(solution)

In [40]:
final

[{'query': {'sel': 2, 'agg': 0, 'conds': [[None, None, None]]}},
 {'query': {'sel': 1, 'agg': 0, 'conds': [[None, None, None]]}},
 {'query': {'sel': 0, 'agg': 0, 'conds': [[None, None, None]]}},
 {'query': {'sel': 5, 'agg': 3, 'conds': [[None, None, None]]}},
 {'query': {'sel': 1, 'agg': 0, 'conds': [[None, None, None]]}},
 {'query': {'sel': 2, 'agg': 3, 'conds': [[None, None, None]]}},
 {'query': {'sel': 1, 'agg': 0, 'conds': [[None, None, None]]}},
 {'query': {'sel': 2,
   'agg': 3,
   'conds': [[0,
     0,
     'united accountants mayoral buffet oval grossing generalized ##漢 apartment parody terrorist parody inherently 0 vh1 nhs habits ) windy nhs ##漢 ) windy luggage distinction − vicky drilling ##rooms defensive drilling ##漢 み windy ##歌 grossing copa grossing varieties parody sick windy parody orgasm ) ( copa tissue [unused646] ( elsewhere parody [unused646] windy papacy ##cased % ##oof orgasm ##漢 guards 1763 侍 6 windy vh1 ᴰ rising emperor campaigned ease parody orioles – windy par

In [None]:
test_file = 'model_output/test_results.jsonl'

In [None]:
with open(test_file, 'w+') as f:
    for line in final:
        json.dump(line, f)
        f.write('\n')