In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
cd '/content/gdrive/MyDrive/Colab Notebooks/Modified-IRNet'

/content/gdrive/MyDrive/Colab Notebooks/Modified-IRNet


In [None]:
pwd

'/content/gdrive/MyDrive/Colab Notebooks/Modified-IRNet'

In [None]:
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4') ##SS: colab

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [None]:
pip install pattern

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import time
import traceback

In [None]:
import os
import torch
import torch.optim as optim
import tqdm
import copy

In [None]:
from src import args as arg
from src import utils
from src.models.model import IRNet
from src.rule import semQL

In [None]:
def train(args):
    """
    :param args:
    :return:
    """

    grammar = semQL.Grammar()
    sql_data, table_data, val_sql_data,\
    val_table_data= utils.load_dataset(args.dataset, use_small=args.toy)

    model = IRNet(args, grammar)


    if args.cuda: model.cuda()

    # now get the optimizer
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)
    print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
    else:
        scheduler = None

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model:
        print('load pretrained model from %s'% (args.load_model))
        pretrained_model = torch.load(args.load_model,
                                         map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    model.word_emb = utils.load_word_emb(args.glove_embed_path)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    #utils.save_args(args, os.path.join(model_save_path, 'config.json')) #SS
    best_dev_acc = .0

    try:
        with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
            for epoch in tqdm.tqdm(range(args.epoch)):
                if args.lr_scheduler:
                    scheduler.step()
                epoch_begin = time.time()
                loss = utils.epoch_train(model, optimizer, args.batch_size,
                                   sql_data, table_data, args,
                                   loss_epoch_threshold=args.loss_epoch_threshold,
                                   sketch_loss_coefficient=args.sketch_loss_coefficient)
                epoch_end = time.time()
                #log_str = 'Epoch: %d, Loss: %f, Training Sketch Acc: %f, Training Acc: %f, time: %f\n' % (
                #    epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin)
                json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
                                             beam_size=args.beam_size)
                eval_acc, nest_acc = utils.eval_acc(json_datas, val_sql_data)

                if acc > best_dev_acc:
                #if nest_acc > best_dev_acc:
                    utils.save_checkpoint(model, os.path.join(model_save_path, 'best_model.model'))
                    #utils.save_checkpoint(model, os.path.join(model_save_path, 'best_nest_model.model'))
                    best_dev_acc = acc
                utils.save_checkpoint(model, os.path.join(model_save_path, '{%s}_{%s}.model') % (epoch, acc))

                log_str = 'Epoch: %d, Loss: %f, Val Sketch Acc: %f, Val Acc: %f : %f, Nested Acc: %f, time: %f\n' % (
                    epoch + 1, loss, sketch_acc, acc, eval_acc, nest_acc, epoch_end - epoch_begin)
                tqdm.tqdm.write(log_str)
                epoch_fd.write(log_str)
                epoch_fd.flush()
    except Exception as e:
        # Save model
        utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
        print(e)
        tb = traceback.format_exc()
        print(tb)
    else:
        utils.save_checkpoint(model, os.path.join(model_save_path, 'end_model.model'))
        json_datas, sketch_acc, acc = utils.epoch_acc(model, args.batch_size, val_sql_data, val_table_data,
                                     beam_size=args.beam_size)
        # acc = utils.eval_acc(json_datas, val_sql_data)

        print("Validation Sketch Acc: %f, Validation Acc: %f, Validation Beam Acc: %f" % (sketch_acc, acc, acc,))

In [None]:
## Arguments
class args:
    dataset = './data'
    glove_embed_path = './data/glove.42B.300d.txt'
    cuda = True
    epoch = 50 ##ss: 50
    loss_epoch_threshold = 50
    sketch_loss_coefficient = 0.85
    beam_size = 1
    seed = 90
    save = './'
    embed_size = 300
    sentence_features = True
    column_pointer = True
    hidden_size = 300
    lr_scheduler = True
    lr_scheduler_gammar = 0.5
    att_vec_size = 300
    toy=False
    model_name = 'modified_irnet'
    lstm = 'lstm'
    load_model = None
    batch_size = 64
    col_embed_size = 300
    action_embed_size = 128
    type_embed_size = 128
    dropout = 0.3
    word_dropout = 0.2
    no_query_vec_to_action_map = False
    readout = 'identity'
    query_vec_to_action_diff_map = False
    column_att = 'affine'
    decode_max_time_step = 40
    save_to = 'model'
    clip_grad = 5.
    max_epoch = -1
    optimizer = 'Adam'
    lr = 0.001

In [None]:
args.dataset

'./data'

In [None]:
import random
import numpy as np

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
np.random.seed(int(args.seed * 13 / 7))
random.seed(int(args.seed))

In [None]:
'''
file_name = args.glove_embed_path
print(file_name)
with open(file_name) as inf: #SS encoding="utf-8"
    for idx, line in enumerate(inf):
        print(idx, line)
'''

'\nfile_name = args.glove_embed_path\nprint(file_name)\nwith open(file_name) as inf: #SS encoding="utf-8"\n    for idx, line in enumerate(inf):\n        print(idx, line)\n'

In [None]:
train(args)

Loading from datasets...
Loading data from ./data/tables.json
Loading data from ./data/train.json
Loading data from ./data/dev.json
Use Column Pointer:  True
Enable Learning Rate Scheduler:  True
Loss epoch threshold: 50
Sketch loss coefficient: 0.850000
Loading word embedding from ./data/glove.42B.300d.txt
./data/glove.42B.300d.txt


  2%|▏         | 1/50 [02:04<1:41:19, 124.08s/it]

0.20485436893203884 0.08333333333333333
Epoch: 1, Loss: 10.644575, Val Sketch Acc: 0.619417, Val Acc: 0.274757 : 0.204854, Nested Acc: 0.083333, time: 85.850069



  4%|▍         | 2/50 [04:17<1:43:40, 129.60s/it]

0.28640776699029125 0.1794871794871795
Epoch: 2, Loss: 4.295348, Val Sketch Acc: 0.703883, Val Acc: 0.376699 : 0.286408, Nested Acc: 0.179487, time: 94.328762



  6%|▌         | 3/50 [06:41<1:46:35, 136.07s/it]

0.32233009708737864 0.1987179487179487
Epoch: 3, Loss: 2.902473, Val Sketch Acc: 0.742718, Val Acc: 0.427184 : 0.322330, Nested Acc: 0.198718, time: 104.398635



  8%|▊         | 4/50 [09:15<1:49:53, 143.33s/it]

0.329126213592233 0.22435897435897437
Epoch: 4, Loss: 2.173023, Val Sketch Acc: 0.755340, Val Acc: 0.445631 : 0.329126, Nested Acc: 0.224359, time: 112.589409



 10%|█         | 5/50 [12:00<1:53:22, 151.16s/it]

0.33689320388349514 0.20512820512820512
Epoch: 5, Loss: 1.741524, Val Sketch Acc: 0.755340, Val Acc: 0.439806 : 0.336893, Nested Acc: 0.205128, time: 123.435495



 12%|█▏        | 6/50 [14:54<1:56:26, 158.78s/it]

0.3504854368932039 0.1987179487179487
Epoch: 6, Loss: 1.414367, Val Sketch Acc: 0.785437, Val Acc: 0.468932 : 0.350485, Nested Acc: 0.198718, time: 133.121490



 14%|█▍        | 7/50 [17:59<1:59:56, 167.35s/it]

0.3466019417475728 0.23076923076923078
Epoch: 7, Loss: 1.179195, Val Sketch Acc: 0.771845, Val Acc: 0.454369 : 0.346602, Nested Acc: 0.230769, time: 143.381117



 16%|█▌        | 8/50 [21:14<2:03:16, 176.12s/it]

0.36019417475728155 0.25
Epoch: 8, Loss: 0.990479, Val Sketch Acc: 0.794175, Val Acc: 0.479612 : 0.360194, Nested Acc: 0.250000, time: 153.013752



 18%|█▊        | 9/50 [24:40<2:06:42, 185.42s/it]

0.3553398058252427 0.23717948717948717
Epoch: 9, Loss: 0.863121, Val Sketch Acc: 0.764078, Val Acc: 0.462136 : 0.355340, Nested Acc: 0.237179, time: 163.167266



 20%|██        | 10/50 [28:17<2:10:10, 195.27s/it]

0.36310679611650487 0.23717948717948717
Epoch: 10, Loss: 0.759071, Val Sketch Acc: 0.790291, Val Acc: 0.478641 : 0.363107, Nested Acc: 0.237179, time: 173.952283



 22%|██▏       | 11/50 [32:05<2:13:26, 205.29s/it]

0.341747572815534 0.23076923076923078
Epoch: 11, Loss: 0.693861, Val Sketch Acc: 0.790291, Val Acc: 0.452427 : 0.341748, Nested Acc: 0.230769, time: 179.842327



 24%|██▍       | 12/50 [36:06<2:16:55, 216.20s/it]

0.33980582524271846 0.23717948717948717
Epoch: 12, Loss: 0.598008, Val Sketch Acc: 0.800971, Val Acc: 0.447573 : 0.339806, Nested Acc: 0.237179, time: 191.507258



 26%|██▌       | 13/50 [40:17<2:19:45, 226.64s/it]

0.36699029126213595 0.22435897435897437
Epoch: 13, Loss: 0.546979, Val Sketch Acc: 0.790291, Val Acc: 0.485437 : 0.366990, Nested Acc: 0.224359, time: 200.316622



 28%|██▊       | 14/50 [44:40<2:22:34, 237.62s/it]

0.3485436893203884 0.24358974358974358
Epoch: 14, Loss: 0.514157, Val Sketch Acc: 0.768932, Val Acc: 0.461165 : 0.348544, Nested Acc: 0.243590, time: 215.217470



 30%|███       | 15/50 [49:24<2:26:45, 251.60s/it]

0.36699029126213595 0.2564102564102564
Epoch: 15, Loss: 0.447801, Val Sketch Acc: 0.800971, Val Acc: 0.479612 : 0.366990, Nested Acc: 0.256410, time: 229.283982



 32%|███▏      | 16/50 [54:22<2:30:30, 265.60s/it]

0.34951456310679613 0.1858974358974359
Epoch: 16, Loss: 0.433486, Val Sketch Acc: 0.775728, Val Acc: 0.461165 : 0.349515, Nested Acc: 0.185897, time: 242.160713



 34%|███▍      | 17/50 [59:25<2:32:17, 276.89s/it]

0.3553398058252427 0.23717948717948717
Epoch: 17, Loss: 0.387324, Val Sketch Acc: 0.786408, Val Acc: 0.479612 : 0.355340, Nested Acc: 0.237179, time: 254.490884



 36%|███▌      | 18/50 [1:04:34<2:32:46, 286.44s/it]

0.3592233009708738 0.21794871794871795
Epoch: 18, Loss: 0.368840, Val Sketch Acc: 0.766990, Val Acc: 0.475728 : 0.359223, Nested Acc: 0.217949, time: 252.401067



 38%|███▊      | 19/50 [1:09:50<2:32:36, 295.38s/it]

0.3446601941747573 0.22435897435897437
Epoch: 19, Loss: 0.336943, Val Sketch Acc: 0.788350, Val Acc: 0.458252 : 0.344660, Nested Acc: 0.224359, time: 259.595070



 40%|████      | 20/50 [1:15:09<2:31:10, 302.36s/it]

0.3592233009708738 0.21794871794871795
Epoch: 20, Loss: 0.315121, Val Sketch Acc: 0.781553, Val Acc: 0.477670 : 0.359223, Nested Acc: 0.217949, time: 261.378751



 42%|████▏     | 21/50 [1:20:46<2:31:15, 312.96s/it]

0.3679611650485437 0.23717948717948717
Epoch: 21, Loss: 0.179916, Val Sketch Acc: 0.789320, Val Acc: 0.486408 : 0.367961, Nested Acc: 0.237179, time: 279.651202



 44%|████▍     | 22/50 [1:26:26<2:29:47, 320.99s/it]

0.3640776699029126 0.22435897435897437
Epoch: 22, Loss: 0.122797, Val Sketch Acc: 0.790291, Val Acc: 0.481553 : 0.364078, Nested Acc: 0.224359, time: 290.438341



 46%|████▌     | 23/50 [1:32:26<2:29:43, 332.71s/it]

0.3679611650485437 0.24358974358974358
Epoch: 23, Loss: 0.107881, Val Sketch Acc: 0.793204, Val Acc: 0.494175 : 0.367961, Nested Acc: 0.243590, time: 299.698275



 48%|████▊     | 24/50 [1:38:28<2:28:00, 341.57s/it]

0.36893203883495146 0.2692307692307692
Epoch: 24, Loss: 0.103655, Val Sketch Acc: 0.782524, Val Acc: 0.492233 : 0.368932, Nested Acc: 0.269231, time: 311.219132



 50%|█████     | 25/50 [1:44:57<2:28:16, 355.86s/it]

0.3679611650485437 0.25
Epoch: 25, Loss: 0.106214, Val Sketch Acc: 0.792233, Val Acc: 0.491262 : 0.367961, Nested Acc: 0.250000, time: 326.558103



 52%|█████▏    | 26/50 [1:51:21<2:25:41, 364.22s/it]

0.3650485436893204 0.23076923076923078
Epoch: 26, Loss: 0.091922, Val Sketch Acc: 0.787379, Val Acc: 0.488350 : 0.365049, Nested Acc: 0.230769, time: 332.690935



 54%|█████▍    | 27/50 [1:58:07<2:24:21, 376.58s/it]

0.3533980582524272 0.24358974358974358
Epoch: 27, Loss: 0.090453, Val Sketch Acc: 0.786408, Val Acc: 0.474757 : 0.353398, Nested Acc: 0.243590, time: 341.818742



 56%|█████▌    | 28/50 [2:04:50<2:21:03, 384.71s/it]

0.3699029126213592 0.24358974358974358
Epoch: 28, Loss: 0.085798, Val Sketch Acc: 0.781553, Val Acc: 0.492233 : 0.369903, Nested Acc: 0.243590, time: 351.691832



 58%|█████▊    | 29/50 [2:11:58<2:19:07, 397.50s/it]

0.37281553398058254 0.22435897435897437
Epoch: 29, Loss: 0.079520, Val Sketch Acc: 0.794175, Val Acc: 0.492233 : 0.372816, Nested Acc: 0.224359, time: 361.756247



 60%|██████    | 30/50 [2:19:16<2:16:34, 409.70s/it]

0.3611650485436893 0.23076923076923078
Epoch: 30, Loss: 0.077089, Val Sketch Acc: 0.794175, Val Acc: 0.481553 : 0.361165, Nested Acc: 0.230769, time: 370.849388



 62%|██████▏   | 31/50 [2:26:32<2:12:16, 417.72s/it]

0.38058252427184464 0.24358974358974358
Epoch: 31, Loss: 0.084243, Val Sketch Acc: 0.795146, Val Acc: 0.499029 : 0.380583, Nested Acc: 0.243590, time: 368.352139



 64%|██████▍   | 32/50 [2:34:00<2:08:03, 426.88s/it]

0.3640776699029126 0.22435897435897437
Epoch: 32, Loss: 0.087169, Val Sketch Acc: 0.771845, Val Acc: 0.476699 : 0.364078, Nested Acc: 0.224359, time: 393.483541



 66%|██████▌   | 33/50 [2:41:51<2:04:41, 440.08s/it]

0.36310679611650487 0.24358974358974358
Epoch: 33, Loss: 0.100489, Val Sketch Acc: 0.798058, Val Acc: 0.486408 : 0.363107, Nested Acc: 0.243590, time: 401.420691



 68%|██████▊   | 34/50 [2:49:56<2:00:54, 453.44s/it]

0.3640776699029126 0.22435897435897437
Epoch: 34, Loss: 0.088830, Val Sketch Acc: 0.788350, Val Acc: 0.484466 : 0.364078, Nested Acc: 0.224359, time: 413.289236



 70%|███████   | 35/50 [2:57:55<1:55:15, 461.06s/it]

0.3524271844660194 0.21153846153846154
Epoch: 35, Loss: 0.086066, Val Sketch Acc: 0.800000, Val Acc: 0.477670 : 0.352427, Nested Acc: 0.211538, time: 407.369637



 72%|███████▏  | 36/50 [3:06:06<1:49:41, 470.07s/it]

0.35436893203883496 0.22435897435897437
Epoch: 36, Loss: 0.092066, Val Sketch Acc: 0.791262, Val Acc: 0.474757 : 0.354369, Nested Acc: 0.224359, time: 434.949959



 74%|███████▍  | 37/50 [3:14:44<1:44:57, 484.45s/it]

0.3572815533980582 0.25
Epoch: 37, Loss: 0.081007, Val Sketch Acc: 0.806796, Val Acc: 0.480583 : 0.357282, Nested Acc: 0.250000, time: 444.559890



 76%|███████▌  | 38/50 [3:23:18<1:38:40, 493.37s/it]

0.3553398058252427 0.23076923076923078
Epoch: 38, Loss: 0.090724, Val Sketch Acc: 0.777670, Val Acc: 0.473786 : 0.355340, Nested Acc: 0.230769, time: 456.359318



 78%|███████▊  | 39/50 [3:32:21<1:33:09, 508.18s/it]

0.3572815533980582 0.24358974358974358
Epoch: 39, Loss: 0.076097, Val Sketch Acc: 0.789320, Val Acc: 0.476699 : 0.357282, Nested Acc: 0.243590, time: 466.809330



 80%|████████  | 40/50 [3:41:18<1:26:09, 516.91s/it]

0.3650485436893204 0.25
Epoch: 40, Loss: 0.074394, Val Sketch Acc: 0.793204, Val Acc: 0.490291 : 0.365049, Nested Acc: 0.250000, time: 478.726075



 82%|████████▏ | 41/50 [3:50:47<1:19:52, 532.47s/it]

0.37087378640776697 0.23717948717948717
Epoch: 41, Loss: 0.046891, Val Sketch Acc: 0.800000, Val Acc: 0.490291 : 0.370874, Nested Acc: 0.237179, time: 490.835462



 84%|████████▍ | 42/50 [4:00:04<1:12:00, 540.02s/it]

0.3650485436893204 0.23717948717948717
Epoch: 42, Loss: 0.035396, Val Sketch Acc: 0.804854, Val Acc: 0.485437 : 0.365049, Nested Acc: 0.237179, time: 498.514395



 86%|████████▌ | 43/50 [4:09:51<1:04:38, 554.05s/it]

0.3572815533980582 0.22435897435897437
Epoch: 43, Loss: 0.033311, Val Sketch Acc: 0.795146, Val Acc: 0.473786 : 0.357282, Nested Acc: 0.224359, time: 507.231909



 88%|████████▊ | 44/50 [4:19:32<56:11, 561.94s/it]  

0.3679611650485437 0.23717948717948717
Epoch: 44, Loss: 0.034684, Val Sketch Acc: 0.809709, Val Acc: 0.491262 : 0.367961, Nested Acc: 0.237179, time: 519.839754



 90%|█████████ | 45/50 [4:29:41<48:00, 576.14s/it]

0.36213592233009706 0.23717948717948717
Epoch: 45, Loss: 0.033036, Val Sketch Acc: 0.804854, Val Acc: 0.486408 : 0.362136, Nested Acc: 0.237179, time: 527.188716



 92%|█████████▏| 46/50 [4:39:42<38:53, 583.49s/it]

0.35436893203883496 0.23076923076923078
Epoch: 46, Loss: 0.032588, Val Sketch Acc: 0.798058, Val Acc: 0.476699 : 0.354369, Nested Acc: 0.230769, time: 539.439084



 94%|█████████▍| 47/50 [4:50:16<29:56, 598.78s/it]

0.36310679611650487 0.22435897435897437
Epoch: 47, Loss: 0.032712, Val Sketch Acc: 0.794175, Val Acc: 0.479612 : 0.363107, Nested Acc: 0.224359, time: 549.518964



 96%|█████████▌| 48/50 [5:00:39<20:12, 606.03s/it]

0.36310679611650487 0.21153846153846154
Epoch: 48, Loss: 0.028573, Val Sketch Acc: 0.798058, Val Acc: 0.480583 : 0.363107, Nested Acc: 0.211538, time: 559.989837



 98%|█████████▊| 49/50 [5:11:42<10:23, 623.07s/it]

0.36699029126213595 0.24358974358974358
Epoch: 49, Loss: 0.025830, Val Sketch Acc: 0.803883, Val Acc: 0.492233 : 0.366990, Nested Acc: 0.243590, time: 576.300810



100%|██████████| 50/50 [5:22:27<00:00, 386.95s/it]

0.3640776699029126 0.24358974358974358
Epoch: 50, Loss: 0.026702, Val Sketch Acc: 0.790291, Val Acc: 0.487379 : 0.364078, Nested Acc: 0.243590, time: 559.147318






Validation Sketch Acc: 0.790291, Validation Acc: 0.487379, Validation Beam Acc: 0.487379
