In [1]:
print("hello")
!export CUDA_VISIBLE_DEVICES=0,1,2,3

hello


In [2]:
!pip3 install transformers
# helper for question to sql metadata data
import json
import os
import string
import unicodedata
import transformers

pretrained_weights = {
    ("bert", "base"): "bert-base-uncased",
    ("bert", "large"): "bert-large-uncased-whole-word-masking",
    ("roberta", "base"): "roberta-base",
    ("roberta", "large"): "roberta-large",
    ("albert", "xlarge"): "albert-xlarge-v2"
}

def create_base_model(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertModel.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaModel.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertModel.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

def create_tokenizer(config):
    weights_name = pretrained_weights[(config["base_class"], config["base_name"])]
    if config["base_class"] == "bert":
        return transformers.BertTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "roberta":
        return transformers.RobertaTokenizer.from_pretrained(weights_name)
    elif config["base_class"] == "albert":
        return transformers.AlbertTokenizer.from_pretrained(weights_name)
    else:
        raise Exception("base_class {0} not supported".format(config["base_class"]))

def read_conf(conf_path):
    config = {}
    for line in open(conf_path, encoding="utf8"):
        if line.strip() == "" or line[0] == "#":
             continue
        fields = line.strip().split("\t")
        config[fields[0]] = fields[1]
    config["train_data_path"] =  os.path.abspath(config["train_data_path"])
    config["dev_data_path"] =  os.path.abspath(config["dev_data_path"])

    return config

def read_jsonl(jsonl):
    for line in open(jsonl, encoding="utf8"):
        sample = json.loads(line.rstrip())
        yield sample

def is_whitespace(c):
    if c == " " or c == "\t" or c == "\n" or c == "\r":
        return True
    cat = unicodedata.category(c)
    if cat == "Zs":
        return True
    return False

def is_punctuation(c):
  """Checks whether `chars` is a punctuation character."""
  cp = ord(c)
  # We treat all non-letter/number ASCII as punctuation.
  # Characters such as "^", "$", and "`" are not in the Unicode
  # Punctuation class but we treat them as punctuation anyways, for
  # consistency.
  if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
      (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
    return True
  cat = unicodedata.category(c)
  if cat.startswith("P") or cat.startswith("S"):
    return True
  return False

def basic_tokenize(doc):
    doc_tokens = []
    char_to_word = []
    word_to_char_start = []
    prev_is_whitespace = True
    prev_is_punc = False
    prev_is_num = False
    for pos, c in enumerate(doc):
        if is_whitespace(c):
            prev_is_whitespace = True
            prev_is_punc = False
        else:
            if prev_is_whitespace or is_punctuation(c) or prev_is_punc or (prev_is_num and not str(c).isnumeric()):
                doc_tokens.append(c)
                word_to_char_start.append(pos)
            else:
                doc_tokens[-1] += c
            prev_is_whitespace = False
            prev_is_punc = is_punctuation(c)
            prev_is_num = str(c).isnumeric()
        char_to_word.append(len(doc_tokens) - 1)

    return doc_tokens, char_to_word, word_to_char_start

class SQLExample(object):
    def __init__(self,
                 qid,
                 question,
                 table_id,
                 column_meta,
                 agg=None,
                 select=None,
                 conditions=None,
                 tokens=None,
                 char_to_word=None,
                 word_to_char_start=None,
                 value_start_end=None,
                 valid=True):
        self.qid = qid
        self.question = question
        self.table_id = table_id
        self.column_meta = column_meta
        self.agg = agg
        self.select = select
        self.conditions = conditions
        self.valid = valid
        if tokens is None:
            self.tokens, self.char_to_word, self.word_to_char_start = basic_tokenize(question)
            self.value_start_end = {}
            if conditions is not None and len(conditions) > 0:
                cur_start = None
                for cond in conditions:
                    value = cond[-1]
                    value_tokens, _, _ = basic_tokenize(value)
                    val_len = len(value_tokens)
                    for i in range(len(self.tokens)):
                        if " ".join(self.tokens[i:i+val_len]).lower() != " ".join(value_tokens).lower():
                            continue
                        s = self.word_to_char_start[i]
                        e = len(question) if i + val_len >= len(self.word_to_char_start) else self.word_to_char_start[i + val_len]
                        recovered_answer_text = question[s:e].strip()
                        if value.lower() == recovered_answer_text.lower():
                            cur_start = i
                            break

                    if cur_start is None:
                        self.valid = False
                        print([value, value_tokens, question, self.tokens])
                        # for c in question:
                        #     print((c, ord(c), unicodedata.category(c)))
                        # raise Exception()
                    else:
                        self.value_start_end[value] = (cur_start, cur_start + val_len)
        else:
            self.tokens, self.char_to_word, self.word_to_char_start, self.value_start_end = tokens, char_to_word, word_to_char_start, value_start_end

    @staticmethod
    def load_from_json(s):
        d = json.loads(s)
        keys = ["qid", "question", "table_id", "column_meta", "agg", "select", "conditions", "tokens", "char_to_word", "word_to_char_start", "value_start_end", "valid"]

        return SQLExample(*[d[k] for k in keys])

    def dump_to_json(self):
        d = {}
        d["qid"] = self.qid
        d["question"] = self.question
        d["table_id"] = self.table_id
        d["column_meta"] = self.column_meta
        d["agg"] = self.agg
        d["select"] = self.select
        d["conditions"] = self.conditions
        d["tokens"] = self.tokens
        d["char_to_word"] = self.char_to_word
        d["word_to_char_start"] = self.word_to_char_start
        d["value_start_end"] = self.value_start_end
        d["valid"] = self.valid

        return json.dumps(d)

    def output_SQ(self, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        agg_text = agg_ops[self.agg]
        select_text = self.column_meta[self.select][0]
        cond_texts = []
        for wc, op, value_text in self.conditions:
            column_text = self.column_meta[wc][0]
            op_text = cond_ops[op]
            cond_texts.append(column_text + op_text + value_text)

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))
        return sq

def get_schema(tables):
    schema, headers, colTypes, naturalMap = {}, {}, {}, {}
    for table in tables:
        values = [set() for _ in range(len(table["header"]))]
        for row in table["rows"]:
            for i, value in enumerate(row):
                values[i].add(str(value).lower())
        columns = {column: values[i] for i, column in enumerate(table["header"])}

        trans = {"text": "string", "real": "real"}
        colTypes[table["id"]] = {col:trans[ty] for ty, col in zip(table["types"], table["header"])}
        schema[table["id"]] = columns
        naturalMap[table["id"]] = {col: col for col in columns}
        headers[table["id"]] = table["header"]

    return schema, headers, colTypes, naturalMap

You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.10/bin/python3.10 -m pip install --upgrade pip' command.[0m[33m
[0m

In [3]:
# Featurizer
import numpy as np
import json
import os
import torch.utils.data as torch_data
from collections import defaultdict
from typing import List

stats = defaultdict(int)

class InputFeature(object):
    def __init__(self,
                 question,
                 table_id,
                 tokens,
                 word_to_char_start,
                 word_to_subword,
                 subword_to_word,
                 input_ids,
                 input_mask,
                 segment_ids):
        self.question = question
        self.table_id = table_id
        self.tokens = tokens
        self.word_to_char_start = word_to_char_start
        self.word_to_subword = word_to_subword
        self.subword_to_word = subword_to_word
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids

        self.columns = None
        self.agg = None
        self.select = None
        self.where_num = None
        self.where = None
        self.op = None
        self.value_start = None
        self.value_end = None

    def output_SQ(self, agg = None, sel = None, conditions = None, return_str=True):
        agg_ops = ['NA', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
        cond_ops = ['=', '>', '<', 'OP']

        if agg is None and sel is None and conditions is None:
            sel = np.argmax(self.select)
            agg = self.agg[sel]
            conditions = []
            for i in range(len(self.where)):
                if self.where[i] == 0:
                    continue
                conditions.append((i, self.op[i], self.value_start[i], self.value_end[i]))

        agg_text = agg_ops[agg]
        select_text = self.columns[sel]
        cond_texts = []
        for wc, op, vs, ve in conditions:
            column_text = self.columns[wc]
            op_text = cond_ops[op]
            word_start, word_end = self.subword_to_word[wc][vs], self.subword_to_word[wc][ve]
            char_start = self.word_to_char_start[word_start]
            char_end = len(self.question) if word_end + 1 >= len(self.word_to_char_start) else self.word_to_char_start[word_end + 1]
            value_span_text = self.question[char_start:char_end]
            cond_texts.append(column_text + op_text + value_span_text.rstrip())

        if return_str:
            sq = agg_text + ", " + select_text + ", " + " AND ".join(cond_texts)
        else:
            sq = (agg_text, select_text, set(cond_texts))

        return sq

class HydraFeaturizer(object):
    def __init__(self, config):
        self.config = config
        self.tokenizer = create_tokenizer(config)
        self.colType2token = {
            "string": "[unused1]",
            "real": "[unused2]"}

    def get_input_feature(self, example: SQLExample, config):
        max_total_length = int(config["max_total_length"])

        input_feature = InputFeature(
            example.question,
            example.table_id,
            [],
            example.word_to_char_start,
            [],
            [],
            [],
            [],
            []
        )

        for column, col_type, _ in example.column_meta:
            # get query tokens
            tokens = []
            word_to_subword = []
            subword_to_word = []
            for i, query_token in enumerate(example.tokens):
                if self.config["base_class"] == "roberta":
                    sub_tokens = self.tokenizer.tokenize(query_token, add_prefix_space=True)
                else:
                    sub_tokens = self.tokenizer.tokenize(query_token)
                cur_pos = len(tokens)
                if len(sub_tokens) > 0:
                    word_to_subword += [(cur_pos, cur_pos + len(sub_tokens))]
                    tokens.extend(sub_tokens)
                    subword_to_word.extend([i] * len(sub_tokens))

            if self.config["base_class"] == "roberta":
                tokenize_result = self.tokenizer.encode_plus(
                    col_type + " " + column,
                    tokens,
                    padding="max_length",
                    max_length=max_total_length,
                    truncation=True,
                    add_prefix_space=True
                )
            else:
                tokenize_result = self.tokenizer.encode_plus(
                    col_type + " " + column,
                    tokens,
                    padding="max_length",
                    max_length=max_total_length,
                    truncation_strategy="longest_first",
                    truncation=True,
                )

            input_ids = tokenize_result["input_ids"]
            input_mask = tokenize_result["attention_mask"]

            tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
            column_token_length = 0
            if self.config["base_class"] == "roberta":
                for i, token_id in enumerate(input_ids):
                    if token_id == self.tokenizer.sep_token_id:
                        column_token_length = i + 2
                        break
                segment_ids = [0] * max_total_length
                for i in range(column_token_length, max_total_length):
                    if input_mask[i] == 0:
                        break
                    segment_ids[i] = 1
            else:
                for i, token_id in enumerate(input_ids):
                    if token_id == self.tokenizer.sep_token_id:
                        column_token_length = i + 1
                        break
                segment_ids = tokenize_result["token_type_ids"]

            subword_to_word = [0] * column_token_length + subword_to_word
            word_to_subword = [(pos[0]+column_token_length, pos[1]+column_token_length) for pos in word_to_subword]

            assert len(input_ids) == max_total_length
            assert len(input_mask) == max_total_length
            assert len(segment_ids) == max_total_length

            input_feature.tokens.append(tokens)
            input_feature.word_to_subword.append(word_to_subword)
            input_feature.subword_to_word.append(subword_to_word)
            input_feature.input_ids.append(input_ids)
            input_feature.input_mask.append(input_mask)
            input_feature.segment_ids.append(segment_ids)

        return input_feature

    def fill_label_feature(self, example: SQLExample, input_feature: InputFeature, config):
        max_total_length = int(config["max_total_length"])

        columns = [c[0] for c in example.column_meta]
        col_num = len(columns)
        input_feature.columns = columns

        input_feature.agg = [0] * col_num
        input_feature.agg[example.select] = example.agg
        input_feature.where_num = [len(example.conditions)] * col_num

        input_feature.select = [0] * len(columns)
        input_feature.select[example.select] = 1

        input_feature.where = [0] * len(columns)
        input_feature.op = [0] * len(columns)
        input_feature.value_start = [0] * len(columns)
        input_feature.value_end = [0] * len(columns)

        for colidx, op, _ in example.conditions:
            input_feature.where[colidx] = 1
            input_feature.op[colidx] = op
        for colidx, column_meta in enumerate(example.column_meta):
            if column_meta[-1] == None:
                continue
            se = example.value_start_end[column_meta[-1]]
            try:
                s = input_feature.word_to_subword[colidx][se[0]][0]
                input_feature.value_start[colidx] = s
                e = input_feature.word_to_subword[colidx][se[1]-1][1]-1
                input_feature.value_end[colidx] = e
                assert s < max_total_length and input_feature.input_mask[colidx][s] == 1
                assert e < max_total_length and input_feature.input_mask[colidx][e] == 1

            except:
                print("value span is out of range")
                return False

        # feature_sq = input_feature.output_SQ(return_str=False)
        # example_sq = example.output_SQ(return_str=False)
        # if feature_sq != example_sq:
        #     print(example.qid, feature_sq, example_sq)
        return True

    def load_data(self, data_paths, config, include_label=False):
        model_inputs = {k: [] for k in ["input_ids", "input_mask", "segment_ids"]}
        if include_label:
            for k in ["agg", "select", "where_num", "where", "op", "value_start", "value_end"]:
                model_inputs[k] = []

        pos = []
        input_features = []
        for data_path in data_paths.split("|"):
            cnt = 0
            for line in open(data_path, encoding="utf8"):
                example = SQLExample.load_from_json(line)
                if not example.valid and include_label == True:
                    continue

                input_feature = self.get_input_feature(example, config)
                if include_label:
                    success = self.fill_label_feature(example, input_feature, config)
                    if not success:
                        continue

                # sq = input_feature.output_SQ()
                input_features.append(input_feature)

                cur_start = len(model_inputs["input_ids"])
                cur_sample_num = len(input_feature.input_ids)
                pos.append((cur_start, cur_start + cur_sample_num))

                model_inputs["input_ids"].extend(input_feature.input_ids)
                model_inputs["input_mask"].extend(input_feature.input_mask)
                model_inputs["segment_ids"].extend(input_feature.segment_ids)
                if include_label:
                    model_inputs["agg"].extend(input_feature.agg)
                    model_inputs["select"].extend(input_feature.select)
                    model_inputs["where_num"].extend(input_feature.where_num)
                    model_inputs["where"].extend(input_feature.where)
                    model_inputs["op"].extend(input_feature.op)
                    model_inputs["value_start"].extend(input_feature.value_start)
                    model_inputs["value_end"].extend(input_feature.value_end)

                cnt += 1
                if cnt % 5000 == 0:
                    print(cnt)

                if "DEBUG" in config and cnt > 100:
                    break

        for k in model_inputs:
            model_inputs[k] = np.array(model_inputs[k], dtype=np.int64)

        return input_features, model_inputs, pos

class SQLDataset(torch_data.Dataset):
    def __init__(self, data_paths, config, featurizer, include_label=False):
        self.config = config
        self.featurizer = featurizer
        self.input_features, self.model_inputs, self.pos = self.featurizer.load_data(data_paths, config, include_label)

        print("{0} loaded. Data shapes:".format(data_paths))
        for k, v in self.model_inputs.items():
            print(k, v.shape)

    def __len__(self):
        return self.model_inputs["input_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.model_inputs.items()}


In [4]:
# Create training data (i/o)
import argparse
import os
import sys
import shutil
import datetime
import torch.utils.data as torch_data

config = read_conf("wikisql.conf")

featurizer = HydraFeaturizer(config)
train_data = SQLDataset(config["train_data_path"], config, featurizer, True)
train_data_loader = torch_data.DataLoader(train_data, batch_size=int(config["batch_size"]), shuffle=True, pin_memory=True)

num_samples = len(train_data)


Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

/Users/tulip/Downloads/cs626/Text2SQL/colab_data/data/wikitrain.jsonl loaded. Data shapes:
input_ids (14537, 96)
input_mask (14537, 96)
segment_ids (14537, 96)
agg (14537,)
select (14537,)
where_num (14537,)
where (14537,)
op (14537,)
value_start (14537,)
value_end (14537,)


In [5]:
# base model

import numpy as np
import time
from typing import List

class BaseModel(object):
    """Define common interfaces for HydraNet models"""
    def train_on_batch(self, batch):
        raise NotImplementedError()

    def save(self, model_path, epoch):
        raise NotImplementedError()

    def load(self, model_path, epoch):
        raise NotImplementedError()

    def model_inference(self, model_inputs):
        """model prediction on processed features"""
        raise NotImplementedError()

    def dataset_inference(self, dataset: SQLDataset):
        print("model prediction start")
        start_time = time.time()
        model_outputs = self.model_inference(dataset.model_inputs)

        final_outputs = []
        for pos in dataset.pos:
            final_output = {}
            for k in model_outputs:
                final_output[k] = model_outputs[k][pos[0]:pos[1], :]
            final_outputs.append(final_output)
        print("model prediction end, time elapse: {0}".format(time.time() - start_time))
        assert len(dataset.input_features) == len(final_outputs)

        return final_outputs

    def predict_SQL(self, dataset: SQLDataset, model_outputs=None):
        if model_outputs is None:
            model_outputs = self.dataset_inference(dataset)
        sqls = []
        for input_feature, model_output in zip(dataset.input_features, model_outputs):
            agg, select, where, conditions = self.parse_output(input_feature, model_output, [])

            conditions_with_value_texts = []
            for wc in where:
                _, op, vs, ve = conditions[wc]
                word_start, word_end = input_feature.subword_to_word[wc][vs], input_feature.subword_to_word[wc][ve]
                char_start = input_feature.word_to_char_start[word_start]
                char_end = len(input_feature.question)
                if word_end + 1 < len(input_feature.word_to_char_start):
                    char_end = input_feature.word_to_char_start[word_end + 1]
                value_span_text = input_feature.question[char_start:char_end].rstrip()
                conditions_with_value_texts.append((wc, op, value_span_text))

            sqls.append((agg, select, conditions_with_value_texts))

        return sqls

    def predict_SQL_with_EG(self, engine, dataset: SQLDataset, beam_size=5, model_outputs=None):
        if model_outputs is None:
            model_outputs = self.dataset_inference(dataset)
        sqls = []
        for input_feature, model_output in zip(dataset.input_features, model_outputs):
            agg, select, where_num, conditions = self.beam_parse_output(input_feature, model_output, beam_size)
            query = {"agg": agg, "sel": select, "conds": []}
            wcs = set()
            conditions_with_value_texts = []
            for condition in conditions:
                if len(wcs) >= where_num:
                    break
                _, wc, op, vs, ve = condition
                if wc in wcs:
                    continue

                word_start, word_end = input_feature.subword_to_word[wc][vs], input_feature.subword_to_word[wc][ve]
                char_start = input_feature.word_to_char_start[word_start]
                char_end = len(input_feature.question)
                if word_end + 1 < len(input_feature.word_to_char_start):
                    char_end = input_feature.word_to_char_start[word_end + 1]
                value_span_text = input_feature.question[char_start:char_end].rstrip()

                query["conds"] = [[int(wc), int(op), value_span_text]]
                result, sql = engine.execute_dict_query(input_feature.table_id, query)
                if not result or 'ERROR: ' in result:
                    continue

                conditions_with_value_texts.append((wc, op, value_span_text))
                wcs.add(wc)

            sqls.append((agg, select, conditions_with_value_texts))

        return sqls

    def _get_where_num(self, output):
        # wn = np.argmax(output["where_num"], -1)
        # max_num = 0
        # max_cnt = np.sum(wn == 0)
        # for num in range(1, 5):
        #     cur_cnt = np.sum(wn==num)
        #     if cur_cnt > max_cnt:
        #         max_cnt = cur_cnt
        #         max_num = num
        # def sigmoid(x):
        #     return 1/(1 + np.exp(-x))
        relevant_prob = 1 - np.exp(output["column_func"][:, 2])
        where_num_scores = np.average(output["where_num"], axis=0, weights=relevant_prob)
        where_num = int(np.argmax(where_num_scores))

        return where_num

    def parse_output(self, input_feature, model_output, where_label = []):
        def get_span(i):
            offset = 0
            segment_ids = np.array(input_feature.segment_ids[i])
            for j in range(len(segment_ids)):
                if segment_ids[j] == 1:
                    offset = j
                    break

            value_start, value_end = model_output["value_start"][i, segment_ids == 1], model_output["value_end"][i, segment_ids == 1]
            l = len(value_start)
            sum_mat = value_start.reshape((l, 1)) + value_end.reshape((1, l))
            span = (0, 0)
            for cur_span, _ in sorted(np.ndenumerate(sum_mat), key=lambda x:x[1], reverse=True):
                if cur_span[1] < cur_span[0] or cur_span[0] == l - 1 or cur_span[1] == l - 1:
                    continue
                span = cur_span
                break

            return (span[0]+offset, span[1]+offset)

        select_id_prob = sorted(enumerate(model_output["column_func"][:, 0]), key=lambda x:x[1], reverse=True)
        select = select_id_prob[0][0]
        agg = np.argmax(model_output["agg"][select, :])

        where_id_prob = sorted(enumerate(model_output["column_func"][:, 1]), key=lambda x:x[1], reverse=True)
        where_num = self._get_where_num(model_output)
        where = [i for i, _ in where_id_prob[:where_num]]
        conditions = {}
        for idx in set(where + where_label):
            span = get_span(idx)
            op = np.argmax(model_output["op"][idx, :])
            conditions[idx] = (idx, op, span[0], span[1])

        return agg, select, where, conditions

    def beam_parse_output(self, input_feature, model_output, beam_size=5):
        def get_span(i):
            offset = 0
            segment_ids = np.array(input_feature.segment_ids[i])
            for j in range(len(segment_ids)):
                if segment_ids[j] == 1:
                    offset = j
                    break

            value_start, value_end = model_output["value_start"][i, segment_ids == 1], model_output["value_end"][i, segment_ids == 1]
            l = len(value_start)
            sum_mat = value_start.reshape((l, 1)) + value_end.reshape((1, l))
            spans = []
            for cur_span, sum_logp in sorted(np.ndenumerate(sum_mat), key=lambda x:x[1], reverse=True):
                if cur_span[1] < cur_span[0] or cur_span[0] == l - 1 or cur_span[1] == l - 1:
                    continue
                spans.append((cur_span[0]+offset, cur_span[1]+offset, sum_logp))
                if len(spans) >= beam_size:
                    break

            return spans

        select_id_prob = sorted(enumerate(model_output["column_func"][:, 0]), key=lambda x:x[1], reverse=True)
        select = select_id_prob[0][0]
        agg = np.argmax(model_output["agg"][select, :])

        where_id_prob = sorted(enumerate(model_output["column_func"][:, 1]), key=lambda x:x[1], reverse=True)
        where_num = self._get_where_num(model_output)
        conditions = []
        for idx, wlogp in where_id_prob[:beam_size]:
            op = np.argmax(model_output["op"][idx, :])
            for span in get_span(idx):
                conditions.append((wlogp+span[2], idx, op, span[0], span[1]))
        conditions.sort(key=lambda x:x[0], reverse=True)
        return agg, select, where_num, conditions

In [6]:
# torch model

import torch
import os
import numpy as np
import transformers
from torch import nn

class HydraTorch(BaseModel):
    def __init__(self, config):
        self.config = config
        self.model = HydraNet(config)
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.optimizer, self.scheduler = None, None

    def train_on_batch(self, batch):
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": float(self.config["decay"]),
                },
                {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                 "weight_decay": 0.0},
            ]
            self.optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=float(self.config["learning_rate"]))
            self.scheduler = transformers.get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=int(self.config["num_warmup_steps"]),
                num_training_steps=int(self.config["num_train_steps"]))
            self.optimizer.zero_grad()

        self.model.train()
        for k, v in batch.items():
            batch[k] = v.to(self.device)
        batch_loss = torch.mean(self.model(**batch)["loss"])
        batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()

        return batch_loss.cpu().detach().numpy()

    def model_inference(self, model_inputs):
        self.model.eval()
        model_outputs = {}
        batch_size = 64 # 512
        for start_idx in range(0, model_inputs["input_ids"].shape[0], batch_size):
            input_tensor = {k: torch.from_numpy(model_inputs[k][start_idx:start_idx+batch_size]).to(self.device) for k in ["input_ids", "input_mask", "segment_ids"]}
            with torch.no_grad():
                model_output = self.model(**input_tensor)
            for k, out_tensor in model_output.items():
                if out_tensor is None:
                    continue
                if k not in model_outputs:
                    model_outputs[k] = []
                model_outputs[k].append(out_tensor.cpu().detach().numpy())

        for k in model_outputs:
            model_outputs[k] = np.concatenate(model_outputs[k], 0)

        return model_outputs

    def save(self, model_path, epoch):
        if "SAVE" in self.config and "DEBUG" not in self.config:
            save_path = os.path.join(model_path, "model_{0}.pt".format(epoch))
            if torch.cuda.device_count() > 1:
                torch.save(self.model.module.state_dict(), save_path)
            else:
                torch.save(self.model.state_dict(), save_path)
            print("Model saved in path: %s" % save_path)

    def load(self, model_path, epoch):
        pt_path = os.path.join(model_path, "model_{0}.pt".format(epoch))
        loaded_dict = torch.load(pt_path, map_location=torch.device(self.device))
        if torch.cuda.device_count() > 1:
            self.model.module.load_state_dict(loaded_dict)
        else:
            self.model.load_state_dict(loaded_dict)
        print("PyTorch model loaded from {0}".format(pt_path))

class HydraNet(nn.Module):
    def __init__(self, config):
        super(HydraNet, self).__init__()
        self.config = config
        self.base_model = create_base_model(config)
        

        # #=====Hack for RoBERTa model====
        # self.base_model.config.type_vocab_size = 2
        # single_emb = self.base_model.embeddings.token_type_embeddings
        # self.base_model.embeddings.token_type_embeddings = torch.nn.Embedding(2, single_emb.embedding_dim)
        # self.base_model.embeddings.token_type_embeddings.weight = torch.nn.Parameter(single_emb.weight.repeat([2, 1]), requires_grad=True)
        # #====================================

        drop_rate = float(config["drop_rate"]) if "drop_rate" in config else 0.0
        self.dropout = nn.Dropout(drop_rate)

        bert_hid_size = self.base_model.config.hidden_size
        projected_size = 64
        
        self.projection = nn.Linear(bert_hid_size, projected_size)
        self.column_func = nn.Linear(projected_size, 3)
        self.agg = nn.Linear(projected_size, int(config["agg_num"]))
        self.op = nn.Linear(projected_size, int(config["op_num"]))
        self.where_num = nn.Linear(projected_size, int(config["where_column_num"]) + 1)
        self.start_end = nn.Linear(projected_size, 2)

        # self.column_func = nn.Linear(bert_hid_size, 3)
        # self.agg = nn.Linear(bert_hid_size, int(config["agg_num"]))
        # self.op = nn.Linear(bert_hid_size, int(config["op_num"]))
        # self.where_num = nn.Linear(bert_hid_size, int(config["where_column_num"]) + 1)
        # self.start_end = nn.Linear(bert_hid_size, 2)

    def forward(self, input_ids, input_mask, segment_ids, agg=None, select=None, where=None, where_num=None, op=None, value_start=None, value_end=None):
        # print("[inner] input_ids size:", input_ids.size())
        if self.config["base_class"] == "roberta":
            bert_output, pooled_output = self.base_model(
                input_ids=input_ids,
                attention_mask=input_mask,
                token_type_ids=None,
                return_dict=False)
        else:
            bert_output, pooled_output = self.base_model(
                input_ids=input_ids,
                attention_mask=input_mask,
                token_type_ids=segment_ids,
                return_dict=False)
        print("pooled output before: ")
        print(pooled_output.shape)

        print("bert output before: ")
        print(bert_output.shape)

        bert_output = self.projection(bert_output)
        pooled_output = self.projection(pooled_output)
        
        print("pooled output: ")
        print(pooled_output.shape)

        print("bert output: ")
        print(bert_output.shape)

        bert_output = self.dropout(bert_output)
        pooled_output = self.dropout(pooled_output)

        column_func_logit = self.column_func(pooled_output)
        agg_logit = self.agg(pooled_output)
        op_logit = self.op(pooled_output)
        where_num_logit = self.where_num(pooled_output)
        start_end_logit = self.start_end(bert_output)
        value_span_mask = input_mask.to(dtype=bert_output.dtype)
        # value_span_mask[:, 0] = 1
        start_logit = start_end_logit[:, :, 0] * value_span_mask - 1000000.0 * (1 - value_span_mask)
        end_logit = start_end_logit[:, :, 1] * value_span_mask - 1000000.0 * (1 - value_span_mask)

        loss = None
        if select is not None:
            bceloss = nn.BCEWithLogitsLoss(reduction="none")
            cross_entropy = nn.CrossEntropyLoss(reduction="none")

            loss = cross_entropy(agg_logit, agg) * select.float()
            loss += bceloss(column_func_logit[:, 0], select.float())
            loss += bceloss(column_func_logit[:, 1], where.float())
            loss += bceloss(column_func_logit[:, 2], (1-select.float()) * (1-where.float()))
            loss += cross_entropy(where_num_logit, where_num)
            loss += cross_entropy(op_logit, op) * where.float()
            loss += cross_entropy(start_logit, value_start)
            loss += cross_entropy(end_logit, value_end)


        # return loss, column_func_logit, agg_logit, op_logit, where_num_logit, start_logit, end_logit
        log_sigmoid = nn.LogSigmoid()

        return {"column_func": log_sigmoid(column_func_logit),
                "agg": agg_logit.log_softmax(1),
                "op": op_logit.log_softmax(1),
                "where_num": where_num_logit.log_softmax(1),
                "value_start": start_logit.log_softmax(1),
                "value_end": end_logit.log_softmax(1),
                "loss": loss}


In [7]:
def create_model(config, is_train = False) -> BaseModel:
    if config["model_type"] == "pytorch":
        return HydraTorch(config)
    # elif config["model_type"] == "tf":
    #     return HydraTensorFlow(config, is_train, num_gpu)
    else:
        raise NotImplementedError("model type {0} is not supported".format(config["model_type"]))

In [8]:
!pwd
import os
import numpy as np

class HydraEvaluator():
    def __init__(self, output_path, config, hydra_featurizer: HydraFeaturizer, model:BaseModel, note=""):
        self.config = config
        self.model = model
        self.eval_history_file = os.path.join(output_path, "eval.log")
        self.bad_case_dir = os.path.join(output_path, "bad_cases")
        os.mkdir(output_path)
        if "DEBUG" not in config:
            os.mkdir(self.bad_case_dir)
            with open(self.eval_history_file, "w", encoding="utf8") as f:
                f.write(note.rstrip() + "\n")

        self.eval_data = {}
        for eval_path in config["dev_data_path"].split("|") + config["test_data_path"].split("|"):
            eval_data = SQLDataset(eval_path, config, hydra_featurizer, True)
            self.eval_data[os.path.basename(eval_path)] = eval_data

            print("Eval Data file {0} loaded, sample num = {1}".format(eval_path, len(eval_data)))

    def _eval_imp(self, eval_data: SQLDataset, get_sq=True):
        items = ["overall", "agg", "sel", "wn", "wc", "op", "val"]
        acc = {k:0.0 for k in items}
        sq = []
        cnt = 0
        print("before_inf")
        model_outputs = self.model.dataset_inference(eval_data)
        print("after_inf")
        for input_feature, model_output in zip(eval_data.input_features, model_outputs):
            cur_acc = {k:1 for k in acc if k != "overall"}

            select_label = np.argmax(input_feature.select)
            agg_label = input_feature.agg[select_label]
            wn_label = input_feature.where_num[0]
            wc_label = [i for i, w in enumerate(input_feature.where) if w == 1]

            agg, select, where, conditions = self.model.parse_output(input_feature, model_output, wc_label)
            if agg != agg_label:
                cur_acc["agg"] = 0
            if select != select_label:
                cur_acc["sel"] = 0
            if len(where) != wn_label:
                cur_acc["wn"] = 0
            if set(where) != set(wc_label):
                cur_acc["wc"] = 0

            for w in wc_label:
                _, op, vs, ve = conditions[w]
                if op != input_feature.op[w]:
                    cur_acc["op"] = 0

                if vs != input_feature.value_start[w] or ve != input_feature.value_end[w]:
                    cur_acc["val"] = 0

            for k in cur_acc:
                acc[k] += cur_acc[k]

            all_correct = 0 if 0 in cur_acc.values() else 1
            acc["overall"] += all_correct

            if ("DEBUG" in self.config or get_sq) and not all_correct:
                try:
                    true_sq = input_feature.output_SQ()
                    pred_sq = input_feature.output_SQ(agg=agg, sel=select, conditions=[conditions[w] for w in where])
                    task_cor_text = "".join([str(cur_acc[k]) for k in items if k in cur_acc])
                    sq.append([str(cnt), input_feature.question, "|".join([task_cor_text, pred_sq, true_sq])])
                except:
                    pass
            cnt += 1

        result_str = []
        for item in items:
            result_str.append(item + ":{0:.1f}".format(acc[item] * 100.0 / cnt))

        result_str = ", ".join(result_str)

        return result_str, sq

    def eval(self, epochs):
        print(self.bad_case_dir)
        for eval_file in self.eval_data:
            result_str, sq = self._eval_imp(self.eval_data[eval_file])
            print(eval_file + ": " + result_str)

            if "DEBUG" in self.config:
                for text in sq:
                    print(text[0] + ":" + text[1] + "\t" + text[2])
            else:
                with open(self.eval_history_file, "a+", encoding="utf8") as f:
                    f.write("[{0}, epoch {1}] ".format(eval_file, epochs) + result_str + "\n")

                bad_case_file = os.path.join(self.bad_case_dir,
                                           "{0}_epoch_{1}.log".format(eval_file, epochs))
                with open(bad_case_file, "w", encoding="utf8") as f:
                    for text in sq:
                        f.write(text[0] + ":" + text[1] + "\t" + text[2] + "\n")

/Users/tulip/Downloads/cs626/Text2SQL/colab_data


In [9]:
note = ""
model_path = "model--"+datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
config["num_train_steps"] = int(num_samples * int(config["epochs"]) / int(config["batch_size"]))
step_per_epoch = num_samples / int(config["batch_size"])
print("total_steps: {0}, warm_up_steps: {1}".format(config["num_train_steps"], config["num_warmup_steps"]))

model = create_model(config, is_train=True)
evaluator = HydraEvaluator(model_path, config, featurizer, model, note)


total_steps: 1135, warm_up_steps: 400


Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


/Users/tulip/Downloads/cs626/Text2SQL/colab_data/data/wikidev.jsonl loaded. Data shapes:
input_ids (14537, 96)
input_mask (14537, 96)
segment_ids (14537, 96)
agg (14537,)
select (14537,)
where_num (14537,)
where (14537,)
op (14537,)
value_start (14537,)
value_end (14537,)
Eval Data file /Users/tulip/Downloads/cs626/Text2SQL/colab_data/data/wikidev.jsonl loaded, sample num = 14537
data/wikitest.jsonl loaded. Data shapes:
input_ids (14431, 96)
input_mask (14431, 96)
segment_ids (14431, 96)
agg (14431,)
select (14431,)
where_num (14431,)
where (14431,)
op (14431,)
value_start (14431,)
value_end (14431,)
Eval Data file data/wikitest.jsonl loaded, sample num = 14431


In [10]:
print("start training")
print(model_path)
loss_avg, step, epoch = 0.0, 0, 0
while True:
    for batch_id, batch in enumerate(train_data_loader):
        print(batch_id)
        cur_loss = model.train_on_batch(batch)
        loss_avg = (loss_avg * step + cur_loss) / (step + 1)
        step += 1
        print("here1")
        if batch_id % 1 == 0:
            print("here2")
            currentDT = datetime.datetime.now()
            print("[{3}] epoch {0}, batch {1}, batch_loss={2:.4f}".format(epoch, batch_id, cur_loss,
                                                                            currentDT.strftime("%m-%d %H:%M:%S")))
    model.save(model_path, epoch)
    evaluator.eval(epoch)
    epoch += 1
    if epoch >= int(config["epochs"]):
        break


start training
model--20221123_123520
0




pooled output before: 
torch.Size([64, 768])
bert output before: 
torch.Size([64, 96, 768])
pooled output: 
torch.Size([64, 64])
bert output: 
torch.Size([64, 96, 64])
here1
here2
[11-23 12:36:43] epoch 0, batch 0, batch_loss=10.6696
1
pooled output before: 
torch.Size([64, 768])
bert output before: 
torch.Size([64, 96, 768])
pooled output: 
torch.Size([64, 64])
bert output: 
torch.Size([64, 96, 64])
here1
here2
[11-23 12:36:56] epoch 0, batch 1, batch_loss=10.8315
2
pooled output before: 
torch.Size([64, 768])
bert output before: 
torch.Size([64, 96, 768])
pooled output: 
torch.Size([64, 64])
bert output: 
torch.Size([64, 96, 64])
here1
here2
[11-23 12:37:09] epoch 0, batch 2, batch_loss=10.7581
3
pooled output before: 
torch.Size([64, 768])
bert output before: 
torch.Size([64, 96, 768])
pooled output: 
torch.Size([64, 64])
bert output: 
torch.Size([64, 96, 64])
here1
here2
[11-23 12:37:22] epoch 0, batch 3, batch_loss=10.7151
4
pooled output before: 
torch.Size([64, 768])
bert output

In [None]:
model.save(model_path, epoch)

Model saved in path: model--20221106_184457/model_0.pt


In [None]:
evaluator.eval(epoch)

model--20221107_005826/bad_cases
before_inf
model prediction start


KeyboardInterrupt: 