In [None]:
import importlib
import torch
from torch.utils.tensorboard import SummaryWriter
from transformers import BertModel

from dataLoader.DataLoader import get_data_loader
from utils.Constants import tokenizer

from models.QABert import QABertTrainer
from train.Trainer import train_epoch

In [None]:
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    print('Current device:', torch.cuda.get_device_name(device))
    torch.cuda.empty_cache()
else:
    print('Failed to find GPU. Will use CPU.')
    device = 'cpu'


In [None]:
import dataLoader.DataLoader
import dataLoader.DataLoaderUtils


importlib.reload(dataLoader.DataLoader)
importlib.reload(dataLoader.DataLoaderUtils)

from dataLoader.DataLoader import get_data_loader
from dataLoader.DataLoaderUtils import get_question_answers_for_where_value_def_length, get_question_answers_def_length

dev_data_loader = get_data_loader(data_type = 'dev', tokenizer = tokenizer, batch_size = 1)
test_data_loader = get_data_loader(data_type = 'test', tokenizer = tokenizer, batch_size = 1)


In [None]:
import models.SelectRanker
import models.WhereRanker
import models.AggregationClassifier
import models.WhereConditionClassifier
import models.WhereNumberClassifier
import utils.Constants
from utils.Constants import tokenizer, PRE_TRAINED_MODEL_NAME

importlib.reload(models.QABert)
importlib.reload(models.SelectRanker)
importlib.reload(models.WhereRanker)
importlib.reload(models.WhereConditionClassifier)
importlib.reload(models.WhereNumberClassifier)
importlib.reload(models.AggregationClassifier)

importlib.reload(utils.Constants)

from models.QABert import QABertTrainer
from models.SelectRanker import SelectRankerTrainer
from models.WhereRanker import WhereRankerTrainer
from models.WhereConditionClassifier import WhereConditionClassifierTrainer
from models.WhereNumberClassifier import WhereNumberClassifierTrainer
from models.AggregationClassifier import AggregationClassifierTrainer


bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
for param in bert.parameters():
    param.requires_grad = False

models = dict(
    selection_trainer = SelectRankerTrainer(device, dev_data_loader, bert),
    agg_class_trainer = AggregationClassifierTrainer(device, dev_data_loader, bert, use_pretrained=False),
    where_ranker_trainer = WhereRankerTrainer(device, dev_data_loader, bert),
    where_cond_class_trainer = WhereConditionClassifierTrainer(device, dev_data_loader, bert, use_pretrained=False),
    where_numb_class_trainer = WhereNumberClassifierTrainer(device, dev_data_loader, bert),
    qa_trainer = QABertTrainer(device, dev_data_loader, bert, use_pretrained=False),
)

#save_model(models,"./checkpoint")
#models = load_model("./checkpoint/16_Mar_2021_19_21", dev_data_loader, device)


import train.Trainer
importlib.reload(train.Trainer)
from train.Trainer import train_epoch, save_model, load_model

writer = SummaryWriter(log_dir = "runs/")

train_epoch(
    models = models,
    train_data_loader = dev_data_loader,
    eval_data_loader = test_data_loader,
    device = device,
    batch_size = 16, report_size = 4, eval_size = 64,
    writer = writer
)

In [None]:
from torch.utils.tensorboard import SummaryWriter
d = next(iter(dev_data_loader))
for key, model in models.items():
    writer = SummaryWriter(f'runs/{key}/')
    print(key)
    writer.add_graph(model.get_model(), model.parse_input(d))
writer.close()

In [None]:
train_epoch(models, dev_data_loader, device)
train_epoch(models, dev_data_loader, device)
train_epoch(models, dev_data_loader, device)
train_epoch(models, dev_data_loader, device)

In [None]:
import utils.Constants
import models.WhereRanker

importlib.reload(models.WhereRanker)
importlib.reload(utils.Constants)

from models.WhereRanker import WhereRankerTrainer

where_ranker = WhereRankerTrainer(device, dev_data_loader)

iterator = iter(dev_data_loader)
for i in range(5):
    d = next(iterator)

input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
token_type_ids = d["token_type_ids"].to(device)

agg_target = d["target"]['WHERE'].to(device)

agg_output = where_ranker.predict(
    input_ids,
    attention_mask,
    token_type_ids
)

where_ranker.calc_loss(
    agg_output, agg_target
)

In [None]:
import models.WhereConditionClassifier

importlib.reload(models.WhereConditionClassifier)

from models.WhereConditionClassifier import WhereConditionClassifierTrainer

where_ranker = WhereConditionClassifierTrainer(device, dev_data_loader)

iterator = iter(dev_data_loader)
for i in range(5):
    d = next(iterator)

where_cond_targets = d["target"]['WHERE_CONDITIONS'].to(device)
where_columns = d["target"]['WHERE'].to(device)
num_where_columns = torch.count_nonzero(where_columns).item()
target_idx = torch.topk(where_columns, k=num_where_columns, dim=1)[1].to(device)

for where_column, where_cond_target in zip(target_idx.view(-1), where_cond_targets.view(-1)):
    where_outputs = where_ranker.predict(
        input_ids = input_ids,
        attention_mask = attention_mask,
        token_type_ids = token_type_ids,
        where_column = where_column
    )

    where_ranker.calc_loss(where_outputs, where_cond_target)


In [None]:
import models.QABert

importlib.reload(models.QABert)

from models.QABert import QABertTrainer

qa_ranker = QABertTrainer(device, dev_data_loader)

iterator = iter(dev_data_loader)
for i in range(5):
    d = next(iterator)

where_input_ids = d["qa_input_ids"].to(device)
where_attention_mask = d["qa_attention_mask"].to(device)
where_token_type_ids = d["qa_token_type_ids"].to(device)
for cond_num, where_cond_target in enumerate(d["target"]['WHERE_VALUE']):
    start_softmax, end_softmax = qa_ranker.predict(
        input_ids = where_input_ids.squeeze(0)[cond_num].view(-1),
        attention_mask = where_attention_mask.squeeze(0)[cond_num].view(-1),
        token_type_ids = where_token_type_ids.squeeze(0)[cond_num].view(-1),
    )

    qa_ranker.calc_loss(start_softmax, end_softmax, where_cond_target)

In [None]:

import importlib
import torch
from transformers import BertModel

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    print('Current device:', torch.cuda.get_device_name(device))
    torch.cuda.empty_cache()
else:
    print('Failed to find GPU. Will use CPU.')
    device = 'cpu'

import dataLoader.DataLoader
import models.SelectRanker
import models.WhereRanker
import models.AggregationClassifier
import models.WhereConditionClassifier
import models.WhereNumberClassifier
import utils.Constants
from utils.Constants import tokenizer, PRE_TRAINED_MODEL_NAME

importlib.reload(models.SelectRanker)
importlib.reload(models.WhereRanker)
importlib.reload(models.WhereConditionClassifier)
importlib.reload(models.WhereNumberClassifier)
importlib.reload(models.AggregationClassifier)

importlib.reload(utils.Constants)

importlib.reload(dataLoader.DataLoader)
importlib.reload(dataLoader.DataLoaderUtils)

from dataLoader.DataLoader import get_data_loader, get_input_data
from train.Trainer import train_epoch, save_model, load_model, get_request

from models.QABert import QABertTrainer
from models.SelectRanker import SelectRankerTrainer
from models.WhereRanker import WhereRankerTrainer
from models.WhereConditionClassifier import WhereConditionClassifierTrainer
from models.WhereNumberClassifier import WhereNumberClassifierTrainer
from models.AggregationClassifier import AggregationClassifierTrainer


bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
for param in bert.parameters():
    param.requires_grad = False

models = dict(
    selection_trainer = SelectRankerTrainer(device, [], bert),
    agg_class_trainer = AggregationClassifierTrainer(device, [], bert, use_pretrained=False),
    where_ranker_trainer = WhereRankerTrainer(device, [], bert),
    where_cond_class_trainer = WhereConditionClassifierTrainer(device, [], bert, use_pretrained=False),
    where_numb_class_trainer = WhereNumberClassifierTrainer(device, [], bert),
    qa_trainer = QABertTrainer(device, [], bert, use_pretrained=False),
)

table_name, columns, types, question = dataLoader.DataLoader.get_input_data(data_type = 'dev', tokenizer = tokenizer, batch_size = 1, idx=0)

get_request(models, table_name, columns, types, question, tokenizer, device)