In [None]:
import os
os.chdir("../")

In [None]:
import json
from utils import process_config
from src import T5ModelForTableReasoning
import torch
import torch.nn as nn
from data import WikiTQReasoningDataset
from transformers import AutoTokenizer

In [None]:
from tqdm import tqdm

In [None]:
import pandas as pd

In [None]:
with open("configs/wiki_tq_reasoning/t5.json") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
from datasets import load_dataset

In [None]:
dataset = load_dataset("wikitablequestions")["train"]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-3b")

In [None]:
model = T5ModelForTableReasoning(config)

In [None]:
model.load_state_dict(torch.load("logs/table_question_reasoning_t5_3b_bootstrapping_baseline_loss_calc_change/checkpoints/epoch=80.pt"))

In [None]:
model.to("cuda:0")

In [None]:
model_pretrained = T5ModelForTableReasoning(config)
model_pretrained.to("cuda:1")

In [None]:
idx = 1122

In [None]:
question = dataset[idx]["question"]
answer = ", ".join(dataset[idx]["answers"]).lower()
table_column_names = dataset[idx]["table"]["header"]
table_content_values = dataset[idx]["table"]["rows"]

table_df = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

table =  "[HEADER] " + " | ".join(table_column_names)
for row_id, row in enumerate(table_content_values):
    table += f" [ROW] {row_id}: " + " | ".join(row) 

input_text = f"Question: {question} Answer: {answer}. "


In [None]:
print(f"Question: {question}")
print(f"Answer: {answer}")

display(table_df)

In [None]:
tokenized_input  = tokenizer(input_text, table, add_special_tokens = config.tokenizer.add_special_tokens,
                            padding = config.tokenizer.padding, truncation = config.tokenizer.truncation, 
                            max_length = config.tokenizer.max_length, return_tensors = config.tokenizer.return_tensors,
                            return_token_type_ids = config.tokenizer.return_token_type_ids,
                            return_attention_mask = config.tokenizer.return_attention_mask)

input_ids = tokenized_input["input_ids"]
attention_mask = tokenized_input["attention_mask"]

In [None]:
predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:0"), attention_mask = attention_mask.to("cuda:0"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True)

In [None]:
predicted_reason = tokenizer.decode(predicted_ids.squeeze(), skip_special_tokens = True)

In [None]:
predicted_reason

In [None]:
predicted_ids_pretrained = model_pretrained.model.generate(input_ids = input_ids.to("cuda:1"), attention_mask = attention_mask.to("cuda:1"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True)

In [None]:
predicted_reason_pretrained = tokenizer.decode(predicted_ids_pretrained.squeeze(), skip_special_tokens = True)

In [None]:
predicted_reason_pretrained

In [None]:
reasoning_dataset = pd.read_csv(config.data.data_path)
train_dataset = WikiTQReasoningDataset(dataset = reasoning_dataset, config= config)

In [None]:
reason_input_ids, reason_attention_mask, _, reason_output_ids, reason_labels = train_dataset.__getitem__(1)

In [None]:
x = tokenizer.decode(reason_labels[reason_labels != -100])

In [None]:
x_ids = tokenizer(x, add_special_tokens = config.tokenizer.add_special_tokens,
                            padding = config.tokenizer.padding, truncation = config.tokenizer.truncation, 
                            max_length = config.tokenizer.max_length, return_tensors = config.tokenizer.return_tensors,
                            return_token_type_ids = config.tokenizer.return_token_type_ids,
                            return_attention_mask = config.tokenizer.return_attention_mask)["input_ids"].squeeze()

In [None]:
if "bos_token" not in list(tokenizer.special_tokens_map.keys()):
    tokenizer.add_special_tokens({"bos_token": tokenizer.special_tokens_map["eos_token"]})

if "pad_token" not in list(tokenizer.special_tokens_map.keys()):
    tokenizer.add_special_tokens({"pad_token": tokenizer.special_tokens_map["eos_token"]})

if "sep_token" not in list(tokenizer.special_tokens_map.keys()):
    tokenizer.add_special_tokens({"sep_token": tokenizer.special_tokens_map["eos_token"]})

if "mask_token" not in list(tokenizer.special_tokens_map.keys()):
    tokenizer.add_special_tokens({"mask_token": tokenizer.special_tokens_map["eos_token"]})

In [None]:
x_ids[x_ids == tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map["pad_token"])] = -100
# x_ids[x_ids == tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map["sep_token"])] = -100
# x_ids[x_ids == tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map["bos_token"])] = -100

In [None]:
tokenizer.decode(x_ids[x_ids != -100])

# Generate Reasons on WikiTQ dataset using T5-3b

In [None]:
import json
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset

from src import T5ModelForTableReasoning
from utils import process_config

from torch.utils.data import Dataset, DataLoader

In [None]:
class WikiTQReasoningDataset(Dataset):

    def __init__(self, dataset, config, data_type = "train"):
        super(WikiTQReasoningDataset, self).__init__()

        self.dataset = dataset
        self.config = config
        self.data_type = data_type

        self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer.tokenizer_path, local_files_only = self.config.tokenizer.local_files_only,
                                                       padding_side = self.config.tokenizer.padding_side)

        
        if "bos_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"bos_token": "<s>"})

        if "pad_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"pad_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "sep_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"sep_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "mask_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"mask_token": self.tokenizer.special_tokens_map["eos_token"]})

        self.text_input, self.table, self.text_output = self._process_dataset()


    def _tokenize(self, text_input, table = None, max_length = 512, text_output = None):

        if text_output is not None:
            if self.config.tokenizer.special_table_tok:
                raise NotImplementedError
            else:
                if table is not None:
                    table = table + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
                else:
                    text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
            # text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output

        if self.config.tokenizer.special_table_tok:
            if table is not None:
                return self.tokenizer(table, text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else: 
                return self.tokenizer(answer = text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
        else:
            if table is not None:
                return self.tokenizer(text_input, table, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else:
                return self.tokenizer(text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)


    def _process_one_sample(self, data, idx = None):

        question = data["question"]
        table_column_names = data["table"]["header"]
        table_content_values = data["table"]["rows"]

        answer = data["answers"]
        answer_list = answers = [str(a).lower() for a in data["answers"]]
        answer = f", ".join(answer).lower()

        output_text = ""
        input_text = f"Question: {question} Answer: {answer}. "


        if self.config.tokenizer.special_table_tok:

            table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

            if self.config.data.decompose_table:
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]
                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]
            
        else:
            
            if self.config.data.decompose_table:
                table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]

                    table_column_names = table.columns.tolist()
                    table_content_values = table.values.tolist()

                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]


            table = "[HEADER] " + " | ".join(table_column_names)
            for row_id, row in enumerate(table_content_values):
                table += f" [ROW] {row_id}: " + " | ".join(row) 

            if self.config.training.training_type == "table_decomposition":
                table_column_names_output = table_output.columns.tolist()
                table_content_values_output = table_output.values.tolist()

                table_output = "[HEADER] " + " | ".join(table_column_names_output)
                for row_id, row in enumerate(table_content_values_output):
                    table_output += f" [ROW] {row_id}: " + " | ".join(row)

        if self.config.training.training_type == "table_decomposition":
            return question, table, table_output
        else:
            return input_text, table, output_text

    

    def _process_dataset(self):

        processed_data = []
        for i, data in tqdm(enumerate(self.dataset[self.data_type]), position = 0, leave = True, total = len(self.dataset[self.data_type])):
            processed_data.append(self._process_one_sample(data, i))


        text_input = [x[0] for x in processed_data]
        table = [x[1] for x in processed_data]
        text_output = [x[2] for x in processed_data]

        return text_input, table, text_output

    def __len__(self):
        return len(self.text_input)


    def __getitem__(self, index):

        tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length)
        return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze()


In [None]:
dataset = load_dataset("wikitablequestions")

In [None]:
with open("configs/wiki_tq_reasoning/t5.json", "r") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
train_dataset = WikiTQReasoningDataset(dataset = dataset, config = config, data_type = "test")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = False, num_workers = config.system.num_workers)

In [None]:
reason_generations = []

In [None]:
model = T5ModelForTableReasoning(config)
model.load_state_dict(torch.load("logs/table_question_reasoning_t5_3b_bootstrapping_baseline_loss_calc_change/checkpoints/epoch=80.pt"))

In [None]:
model.to("cuda:6")

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), position = 0, leave = True, total = len(train_dataloader)):

    input_ids, attention_mask = batch
    predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:6"), attention_mask = attention_mask.to("cuda:6"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True).detach().cpu()

    batch_predicted_reason = train_dataset.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)

    reason_generations.extend(batch_predicted_reason)

In [None]:
len(reason_generations)

In [None]:
reason_generations[0]

In [None]:
reason_generations[5]

In [None]:
import pickle
with open("datasets/test_wiki_tq_reason.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

In [None]:
import pickle
with open("datasets/wiki_tq_reason.pkl", "rb") as f:
    reason_generations = pickle.load(f)

In [None]:
import pandas as pd
gold_df = pd.read_csv("datasets/WikiTQReasoningData.csv")

In [None]:
gold_df

In [None]:
for i in range(len(gold_df)):
    idx = gold_df["id"][i]
    reason_generations[idx] = gold_df["reason"][i]

In [None]:
import pickle
with open("datasets/wiki_tq_reason.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

In [None]:
import pickle
with open("datasets/wiki_tq_reason.pkl", "rb") as f:
    reason_generations = pickle.load(f)

In [None]:
from datasets import load_dataset

In [None]:
train_dataset = load_dataset("wikitablequestions")["test"]

In [None]:
import pandas as pd

In [None]:
idx = 4224

In [None]:
question = train_dataset[idx]["question"]
answer = ", ".join(train_dataset[idx]["answers"]).lower()
table_column_names = train_dataset[idx]["table"]["header"]
table_content_values = train_dataset[idx]["table"]["rows"]

table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

reason = reason_generations[idx]

print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"reason: {reason}")

display(table)


# Generate Reasons on WikiTQ dataset using T5-3b trained without answer

In [None]:
import json
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset

from src import T5ModelForTableReasoning
from utils import process_config

from torch.utils.data import Dataset, DataLoader

In [None]:
class WikiTQReasoningWithoutAnswerDataset(Dataset):

    def __init__(self, dataset, config, data_type = "train"):
        super(WikiTQReasoningWithoutAnswerDataset, self).__init__()

        self.dataset = dataset
        self.config = config
        self.data_type = data_type

        self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer.tokenizer_path, local_files_only = self.config.tokenizer.local_files_only,
                                                       padding_side = self.config.tokenizer.padding_side)

        
        if "bos_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"bos_token": "<s>"})

        if "pad_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"pad_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "sep_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"sep_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "mask_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"mask_token": self.tokenizer.special_tokens_map["eos_token"]})

        self.text_input, self.table, self.text_output = self._process_dataset()


    def _tokenize(self, text_input, table = None, max_length = 512, text_output = None):

        if text_output is not None:
            if self.config.tokenizer.special_table_tok:
                raise NotImplementedError
            else:
                if table is not None:
                    table = table + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
                else:
                    text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
            # text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output

        if self.config.tokenizer.special_table_tok:
            if table is not None:
                return self.tokenizer(table, text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else: 
                return self.tokenizer(answer = text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
        else:
            if table is not None:
                return self.tokenizer(text_input, table, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else:
                return self.tokenizer(text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)


    def _process_one_sample(self, data, idx = None):

        question = data["question"]
        table_column_names = data["table"]["header"]
        table_content_values = data["table"]["rows"]

        answer = data["answers"]
        answer_list = answers = [str(a).lower() for a in data["answers"]]
        answer = f", ".join(answer).lower()

        output_text = ""
        input_text = f"Question: {question} "


        if self.config.tokenizer.special_table_tok:

            table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

            if self.config.data.decompose_table:
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]
                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]
            
        else:
            
            if self.config.data.decompose_table:
                table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]

                    table_column_names = table.columns.tolist()
                    table_content_values = table.values.tolist()

                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]


            table = "[HEADER] " + " | ".join(table_column_names)
            for row_id, row in enumerate(table_content_values):
                table += f" [ROW] {row_id}: " + " | ".join(row) 

            if self.config.training.training_type == "table_decomposition":
                table_column_names_output = table_output.columns.tolist()
                table_content_values_output = table_output.values.tolist()

                table_output = "[HEADER] " + " | ".join(table_column_names_output)
                for row_id, row in enumerate(table_content_values_output):
                    table_output += f" [ROW] {row_id}: " + " | ".join(row)

        if self.config.training.training_type == "table_decomposition":
            return question, table, table_output
        else:
            return input_text, table, output_text

    

    def _process_dataset(self):

        processed_data = []
        for i, data in tqdm(enumerate(self.dataset[self.data_type]), position = 0, leave = True, total = len(self.dataset[self.data_type])):
            processed_data.append(self._process_one_sample(data, i))


        text_input = [x[0] for x in processed_data]
        table = [x[1] for x in processed_data]
        text_output = [x[2] for x in processed_data]

        return text_input, table, text_output

    def __len__(self):
        return len(self.text_input)


    def __getitem__(self, index):

        tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length)
        return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze()


In [None]:
dataset = load_dataset("wikitablequestions")

In [None]:
with open("configs/wiki_tq_reasoning/t5.json", "r") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
train_dataset = WikiTQReasoningWithoutAnswerDataset(dataset = dataset, config = config, data_type = "test")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle = False, num_workers = config.system.num_workers)

In [None]:
reason_generations = []

In [None]:
model = T5ModelForTableReasoning(config)
model.load_state_dict(torch.load("logs/table_question_reasoning_flan_t5_xl_reason_without_answer_new/checkpoints/epoch=50.pt", map_location="cpu"))

In [None]:
model.to("cuda:0")

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), position = 0, leave = True, total = len(train_dataloader)):

    input_ids, attention_mask = batch
    predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:0"), attention_mask = attention_mask.to("cuda:0"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True).detach().cpu()

    batch_predicted_reason = train_dataset.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)

    reason_generations.extend(batch_predicted_reason)

In [None]:
reason_generations

In [None]:
len(reason_generations)

In [None]:
import pickle
with open("datasets/test_wiki_tq_no_answer_in_reason_flant5.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

In [None]:
len(reason_generations)

In [None]:
import pickle
with open("datasets/wiki_tq_reason.pkl", "rb") as f:
    reason_generations_with_answer = pickle.load(f)

In [None]:
count = 0
total = 0
for i in range(len(reason_generations)):
    if reason_generations[i] != reason_generations_with_answer[i]:
        count += 1
    total += 1


In [None]:
count

In [None]:
total

In [None]:
from datasets import load_dataset
import pandas as pd

train_dataset = load_dataset("wikitablequestions")["train"]

In [None]:
idx = 2011

In [None]:
question = train_dataset[idx]["question"]
answer = ", ".join(train_dataset[idx]["answers"]).lower()
table_column_names = train_dataset[idx]["table"]["header"]
table_content_values = train_dataset[idx]["table"]["rows"]

table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

reason = reason_generations[idx]
reason_with_answer = reason_generations_with_answer[idx]

print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"reason by model trained without answer: {reason}")
print(f"reason by model trained with answer: {reason_with_answer}")

display(table)

In [None]:
import pickle
with open("datasets/test_wiki_tq_reason_without_answer.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

# Generate Reasons on WikiTQ dataset using TAPEX trained with answer

In [None]:
import json
import torch
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from transformers import AutoTokenizer
from datasets import load_dataset

from src import T5ModelForTableReasoning
from utils import process_config

from torch.utils.data import Dataset, DataLoader

In [None]:
class WikiTQReasoningDataset(Dataset):

    def __init__(self, dataset, config):
        super(WikiTQReasoningDataset, self).__init__()

        self.dataset = dataset
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer.tokenizer_path, local_files_only = self.config.tokenizer.local_files_only,
                                                       padding_side = self.config.tokenizer.padding_side)
        
        if "bos_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"bos_token": "<s>"})

        if "pad_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"pad_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "sep_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"sep_token": self.tokenizer.special_tokens_map["eos_token"]})

        if "mask_token" not in list(self.tokenizer.special_tokens_map.keys()):
            self.tokenizer.add_special_tokens({"mask_token": self.tokenizer.special_tokens_map["eos_token"]})

        
        with open("datasets/wiki_tq_reason.pkl", "rb") as f:
            self.reasons = pickle.load(f)



        self.text_input, self.table, self.text_output = self._process_dataset()
        

    def _tokenize(self, text_input, table = None, max_length = 512, text_output = None):

        if text_output is not None:
            if self.config.tokenizer.special_table_tok:
                raise NotImplementedError
            else:
                if table is not None:
                    table = table + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
                else:
                    text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output
            # text_input = text_input + f" {self.tokenizer.special_tokens_map['sep_token']} " + text_output

        if self.config.tokenizer.special_table_tok:
            if table is not None:
                return self.tokenizer(table, text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else: 
                return self.tokenizer(answer = text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
        else:
            if table is not None:
                return self.tokenizer(text_input, table, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)
            else:
                return self.tokenizer(text_input, add_special_tokens = self.config.tokenizer.add_special_tokens,
                            padding = self.config.tokenizer.padding, truncation = self.config.tokenizer.truncation, 
                            max_length = max_length, return_tensors = self.config.tokenizer.return_tensors,
                            return_token_type_ids = self.config.tokenizer.return_token_type_ids,
                            return_attention_mask = self.config.tokenizer.return_attention_mask)


    def _process_one_sample(self, data, idx = None):

        question = data["question"]
        table_column_names = data["table"]["header"]
        table_content_values = data["table"]["rows"]

        answer = data["answers"]
        answer_list = answers = [str(a).lower() for a in data["answers"]]
        answer = f", ".join(answer).lower()


        # question = self.dataset["question"][idx]
        # table_dict = eval(self.dataset["table"][idx])
        # table_column_names = table_dict["header"]
        # table_content_values = table_dict["rows"]

        # answer = eval(self.dataset["answers"][idx])
        # answer_list = answers = [str(a).lower() for a in self.dataset["answers"]]
        # answer = f", ".join(answer).lower()

        output_text = self.reasons[idx]
        input_text = f"Question: {question} Answer: {answer}. "


        if self.config.tokenizer.special_table_tok:
            
            # table_content_values = [self.expand_numbers(table_content_values[i]) for i in range(len(table_content_values))]

            # table_content_values = [[self.expand_numbers(table_content_values[i][j]) for j in range(len(table_content_values[i]))] for i in range(len(table_content_values))]

            # for i in range(table_content_values):
            #     for j in range(table_content_values[i]):
            #         table_content_values[i][j] = self.expand_numbers(table_content_values[i][j])

            table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

            if self.config.data.decompose_table:
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]
                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]
            
        else:
            
            if self.config.data.decompose_table:
                table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})
                relevant_rows, relevant_columns = self._decompose_table(question, answer_list, table)
                
                if self.config.training.training_type != "table_decomposition":
                    if len(relevant_rows) > 0:
                        table = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table = table[relevant_columns]

                    table_column_names = table.columns.tolist()
                    table_content_values = table.values.tolist()

                else:
                    if len(relevant_rows) > 0:
                        table_output = table.iloc[relevant_rows]
                    
                    elif len(relevant_columns) > 0:
                        table_output = table[relevant_columns]


            table = "[HEADER] " + " | ".join(table_column_names)
            for row_id, row in enumerate(table_content_values):
                table += f" [ROW] {row_id}: " + " | ".join(row) 

            if self.config.training.training_type == "table_decomposition":
                table_column_names_output = table_output.columns.tolist()
                table_content_values_output = table_output.values.tolist()

                table_output = "[HEADER] " + " | ".join(table_column_names_output)
                for row_id, row in enumerate(table_content_values_output):
                    table_output += f" [ROW] {row_id}: " + " | ".join(row)

        if self.config.training.training_type == "table_decomposition":
            return question, table, table_output
        else:
            return input_text, table, output_text

    
    def _process_dataset(self):

        # processed_data = Parallel(n_jobs = 1)(
        #     delayed(self._process_one_sample)(data, i) for i, data in tqdm(enumerate(self.dataset[self.data_type]), position = 0, leave = True, total = len(self.dataset[self.data_type])) if i < 1000
        # )

        processed_data = []
        for i, data in tqdm(enumerate(self.dataset), position = 0, leave = True, total = len(self.dataset)):
            processed_data.append(self._process_one_sample(data, i))


        text_input = [x[0] for x in processed_data]
        table = [x[1] for x in processed_data]
        text_output = [x[2] for x in processed_data]

        return text_input, table, text_output

    def __len__(self):
        return len(self.text_input)


    def __getitem__(self, index):

        
        # NOTE: Currently the implementation of row embeddings, column embeddings and segment embeddings is available for encode-decoder models

        # NOTE: Permute the rows and columns randomly
        # self.table[index] = self.table[index].sample(frac = 1, axis = 1)

        if self.config.model.type == "encoder-decoder":
            if self.config.model.use_table:
                tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length)
            else:
                tokenized_input = self._tokenize(self.text_input[index], max_length = self.config.tokenizer.input_max_length)

            if self.config.training.training_type == "description_generation" or self.config.training.training_type == "column_reasoning" \
                  or self.config.training.training_type == "table_question_answering" or self.config.training.training_type == "table_decomposition" \
                    or self.config.training.training_type == "table_reasoning":
                tokenized_output = self._tokenize(self.text_output[index], max_length = self.config.tokenizer.output_max_length)
                labels = tokenized_output["input_ids"][0].clone()
                
                if labels[0] == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"]):
                    labels[:-1] = labels[1:].clone()
                else:
                    tokenized_output["input_ids"][0][1:] = tokenized_output["input_ids"][0][:-1].clone()
                    tokenized_output["input_ids"][0][0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])

                labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["pad_token"])] = -100
                # labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["sep_token"])] = -100
                # labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])] = -100
                # labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])] = -100

            elif self.config.training.training_type == "masked_language_modelling":

                mask_labels, desc_idx = self._whole_word_mask(self.tokenized_text[index])
                mask_labels = torch.nonzero(mask_labels, as_tuple = True)[0] + 2

                # Select the elements from the original tensor based on the random indices
                mask_labels = mask_labels[mask_labels < self.config.tokenizer.input_max_length]
                if mask_labels.size()[0] >= self.config.data.masked_gen_length // 2:
                    mask_labels = mask_labels[:self.config.data.masked_gen_length // 2]

                tokenized_output['input_ids'] = torch.ones(1, self.config.data.masked_gen_length, dtype = torch.long) * self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["sep_token"])
                tokenized_output["input_ids"][0][1:2*mask_labels.size()[0]:2] = tokenized_input["input_ids"][0][mask_labels]
                tokenized_output["input_ids"][0][2*mask_labels.size()[0] + 1:] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["pad_token"])
                tokenized_output["input_ids"][0][0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])

                tokenized_input["input_ids"][0][mask_labels] = self.tokenizer.mask_token_id

                labels = tokenized_output["input_ids"][0].clone()
                labels[:-1] = labels[1:].clone()
                labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["pad_token"])] = -100
                labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["sep_token"])] = -100

            if self.config.tokenizer.use_row_col_ids:
                tokenized_text = self.tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"].squeeze(0))
                tokenized_input["row_ids"] = self._get_row_ids(tokenized_text = tokenized_text)
                tokenized_input["col_ids"] = self._get_col_ids(tokenized_text = tokenized_text)


        # Tokenizers of decoder only models do not add start token, add them explicitly
        elif self.config.model.type == "decoder-only":
            tokenized_output = {}
            if self.config.training.training_type == "description_generation" or self.config.training.training_type == "column_reasoning" \
                  or self.config.training.training_type == "table_question_answering" or self.config.training.training_type == "table_decomposition" \
                    or self.config.training.training_type == "table_reasoning":
                
                if self.config.model.use_table:
                    tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length, text_output = self.text_output[index])
                    inference_tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length)
                else:
                    tokenized_input = self._tokenize(self.text_input[index], max_length = self.config.tokenizer.input_max_length, text_output = self.text_output[index])
                    inference_tokenized_input = self._tokenize(self.text_input[index], max_length = self.config.tokenizer.input_max_length)

                idx = (inference_tokenized_input["input_ids"][0] == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])).nonzero(as_tuple = True)[0]
                if len(idx) != 0:
                    idx = idx[0]
                    inference_tokenized_input["input_ids"][0] = inference_tokenized_input["input_ids"][0]
                    inference_tokenized_input["attention_mask"][0][:idx] = 0
                    inference_tokenized_input["attention_mask"][0][idx:] = 1

                padded_input = torch.ones(self.config.tokenizer.input_max_length, dtype = torch.long) * self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])
                padded_input[self.config.tokenizer.input_max_length - inference_tokenized_input["input_ids"][0].shape[0]:] = inference_tokenized_input["input_ids"][0]
                inference_tokenized_input["input_ids"][0] = padded_input

                labels = tokenized_input["input_ids"][0].clone()
                actual_output_ids = self._tokenize(self.text_output[index], max_length = self.config.tokenizer.output_max_length)["input_ids"].squeeze()

                indices = (labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["sep_token"])).nonzero(as_tuple = True)[0]
                if len(indices) >= 2:
                    out_start, out_end = indices[0] + 1, indices[1]
                    labels[:out_start], labels[out_end:] = -100, -100
                elif len(indices) == 1:
                    out_start = indices[0] + 1
                    labels[:out_start] = -100
                else:
                    labels[:] = -100
                    labels[0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])


                tokenized_input["input_ids"][0][1:] = tokenized_input["input_ids"][0].clone()[:-1]
                tokenized_input["input_ids"][0][0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])

            elif self.config.training.training_type == "masked_language_modelling":

                if self.config.model.use_table:
                    tokenized_input = self._tokenize(self.text_input[index], self.table[index], max_length = self.config.tokenizer.input_max_length)
                else:
                    tokenized_input = self._tokenize(self.text_input[index], max_length = self.config.tokenizer.input_max_length)

                tokenized_input["input_ids"][0][1:] = tokenized_input["input_ids"][0].clone()[:-1]
                tokenized_input["input_ids"][0][0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])

                mask_labels, desc_idx = self._whole_word_mask(self.tokenized_text[index])
                mask_labels = torch.nonzero(mask_labels, as_tuple = True)[0] + 2

                # Select the elements from the original tensor based on the random indices
                mask_labels = mask_labels[mask_labels < self.config.tokenizer.input_max_length]
                if mask_labels.size()[0] >= self.config.data.masked_gen_length // 2:
                    mask_labels = mask_labels[:self.config.data.masked_gen_length // 2]

                eos_indices = (tokenized_input["input_ids"][0] == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])).nonzero(as_tuple = True)[0]
                if len(eos_indices) < 4:
                    # NOTE: No masking possible for this                    
                    labels = torch.ones(tokenized_input["input_ids"].shape[1], dtype = torch.long) * (-100)
                    labels[0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])
                else:

                    out_start = eos_indices[3]
                    mask_labels = mask_labels[:(self.config.tokenizer.input_max_length - out_start) // 2]

                    labels = torch.ones(tokenized_input["input_ids"].shape[1], dtype = torch.long) * (-100)
                    labels[out_start:out_start + 2*mask_labels.size()[0]:2] = tokenized_input["input_ids"][0][mask_labels]
                    labels[:-1] = labels[1:].clone()

                    tokenized_input["input_ids"][0][mask_labels] = self.tokenizer.mask_token_id

                    labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["pad_token"])] = -100
                    labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["sep_token"])] = -100
                    labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["bos_token"])] = -100
                    labels[labels == self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])] = -100

                    # NOTE: Discuss whether this is correct
                    labels[0] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.special_tokens_map["eos_token"])

        # NOTE: Row and column ids is implemented only for encoder-decoder models
        if self.config.model.type == "encoder-decoder":
            if self.config.tokenizer.use_row_col_ids:
                position_ids = torch.tensor([i for i in range(tokenized_input["input_ids"].shape[1])], dtype = torch.long)
                return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                        tokenized_input["token_type_ids"].squeeze(), tokenized_output["input_ids"].squeeze(), tokenized_input["row_ids"].squeeze(), tokenized_input["col_ids"].squeeze(), labels

            else:
                return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                        tokenized_input["token_type_ids"].squeeze(), tokenized_output["input_ids"].squeeze(), labels
        
        elif self.config.model.type == "decoder-only":

            if self.config.training.training_type == "description_generation" or self.config.training.training_type == "column_reasoning" \
                  or self.config.training.training_type == "table_question_answering" or self.config.training.training_type == "table_decomposition" \
                    or self.config.training.training_type == "table_reasoning":
                if self.config.model.use_position_ids:
                    position_ids = torch.tensor([i for i in range(tokenized_input["input_ids"].shape[1])], dtype = torch.long)
                    return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                            tokenized_input["token_type_ids"].squeeze(), position_ids, inference_tokenized_input["input_ids"].squeeze(), inference_tokenized_input["attention_mask"].squeeze(), actual_output_ids, labels

                else:
                    return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                            tokenized_input["token_type_ids"].squeeze(), inference_tokenized_input["input_ids"].squeeze(), inference_tokenized_input["attention_mask"].squeeze(), actual_output_ids, labels
            
            else:
                if self.config.model.use_position_ids:
                    position_ids = torch.tensor([i for i in range(tokenized_input["input_ids"].shape[1])], dtype = torch.long)
                    return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                            tokenized_input["token_type_ids"].squeeze(), position_ids, labels

                else:
                    return tokenized_input["input_ids"].squeeze(), tokenized_input["attention_mask"].squeeze(), \
                            tokenized_input["token_type_ids"].squeeze(), labels

    def collate_fn(self, items):
        pass

In [None]:
import pickle

In [None]:
dataset = load_dataset("wikitablequestions")

In [None]:
with open("configs/wiki_tq_reasoning/tapex.json", "r") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
train_dataset = WikiTQReasoningDataset(dataset = dataset["train"], config = config)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = False, num_workers = config.system.num_workers)

In [None]:
from src import BartModelForTableReasoning

In [None]:
model = BartModelForTableReasoning(config)
model.load_state_dict(torch.load("logs/table_question_reasoning_tapex_bootstrapping_baseline_loss_calc_change/checkpoints/epoch=30.pt"))

In [None]:
model.to("cuda:0")

In [None]:
reason_generations = []

In [None]:
actual_reasons = []

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), position = 0, leave = True, total = len(train_dataloader)):

    input_ids, attention_mask, _, _, labels = batch
    predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:0"), attention_mask = attention_mask.to("cuda:0"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True).detach().cpu()

    batch_predicted_reason = train_dataset.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)

    reason_generations.extend(batch_predicted_reason)
    actual_reasons.extend(train_dataset.tokenizer.batch_decode(labels[labels != -100], skip_special_tokens=True))

In [None]:
len(reason_generations)

In [None]:
len(train_dataset.text_output)

In [None]:
reason_generations = reason_generations[8:]

In [None]:
for i, (reason, actual_reason) in enumerate(zip(reason_generations, train_dataset.text_output)):

    if i > 3000 and i < 3010:
        print("Generated Reason: ", reason)
        print("Actual Reason: ", actual_reason)
        print("\n")

# Remove answer from reason

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv("datasets/WikiTQReasoningData.csv")

In [None]:
reasons_without_answer = []

In [None]:
count = 0

In [None]:
for i in range(len(df)):
    reason = df["reason"][i]
    if "from table" in reason:
        reason = reason.split("from table")
        if reason[0] == "":
            reason = reason[1]
        else:
            reason = reason[0]

    elif "from the table" in reason:
        reason = reason.split("from the table")[0]
        if reason[0] == "":
            reason = reason[1]
        else:
            reason = reason[0]

    reasons_without_answer.append(reason.strip())

In [None]:
from copy import deepcopy

In [None]:
new_df = deepcopy(df)

In [None]:
new_df["reason"] = reasons_without_answer

In [None]:
new_df.to_csv("datasets/WikiTQReasoningDataWithoutAnswer.csv", index = False)

In [None]:
import pandas as pd

new_df = pd.read_csv("datasets/WikiTQReasoningDataWithoutAnswer.csv")

In [None]:
from data import WikiTQReasoningWithoutAnswerDataset

In [None]:
import json
from utils import process_config

In [None]:
with open("configs/wiki_tq_reasoning/t5.json", "rb") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
dataset = pd.read_csv(config.data.data_path)
train_dataset = WikiTQReasoningWithoutAnswerDataset(dataset, config)
tokenizer = train_dataset.tokenizer

In [None]:
train_dataset.text_output

In [None]:
import numpy as np
for i, text in enumerate(train_dataset.text_output):
    if not isinstance(text, str):
        print(i, text)

In [None]:
df = pd.read_csv("datasets/WikiTQReasoningData.csv")

In [None]:
df["reason"][46]

# Generate reason on SequentialQA using Flan T5-xl

In [None]:
from src import T5ModelForTableReasoning
from data import SequentialQADataset
from utils import process_config
import json
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [None]:
from datasets import load_dataset

In [None]:
with open("configs/wiki_tq_reasoning/t5.json", "rb") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
config.training.training_type = "table_question_answering"

In [None]:
dataset = load_dataset("msr_sqa")

In [None]:
train_dataset = SequentialQADataset(dataset = dataset, config = config, data_type = "test")

In [None]:
model = T5ModelForTableReasoning(config)
model.load_state_dict(torch.load("logs/table_question_reasoning_flan_t5_xl_reason_with_answer_rerun/checkpoints/epoch=50.pt", map_location = "cpu"))

In [None]:
model.to("cuda:0")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle = False, num_workers = config.system.num_workers)

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from tqdm import tqdm

In [None]:
reason_generations = []

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), position = 0, leave = True, total = len(train_dataloader)):

    input_ids, attention_mask, _, _, labels = batch
    predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:0"), attention_mask = attention_mask.to("cuda:0"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True).detach().cpu()

    batch_predicted_reason = train_dataset.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)

    reason_generations.extend(batch_predicted_reason)

In [None]:
import pickle
with open("datasets/test_seq_qa_reason_without_answer_flant5.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

In [None]:
reason_generations[2229]

In [None]:
len(reason_generations)

In [None]:
question = train_dataset.text_input[13]
table = train_dataset.table[13]
reason = reason_generations[13]

print(question, end = "\n\n")
print(reason, end = "\n\n")
print(table)

# Reason generation on FetaQA using Flan T5 xl

In [None]:
from src import T5ModelForTableReasoning
from data import FetaQADataset
from utils import process_config
import json
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [None]:
from datasets import load_dataset

In [None]:
with open("configs/wiki_tq_reasoning/t5.json", "rb") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
config.training.training_type = "descriptive_table_question_answering"

In [None]:
dataset = load_dataset("DongfuTingle/FeTaQA")

In [None]:
train_dataset = FetaQADataset(dataset = dataset, config = config, data_type = "test")

In [None]:
model = T5ModelForTableReasoning(config)
model.load_state_dict(torch.load("/datadrive/tabllm/logs/table_question_reasoning_flan_t5_xl_reason_with_answer_rerun/checkpoints/epoch=50.pt", map_location = "cpu"))

In [None]:
model.to("cuda:0")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle = False, num_workers = config.system.num_workers)

In [None]:
import os
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
reason_generations = []

In [None]:
for i, batch in tqdm(enumerate(train_dataloader), position = 0, leave = True, total = len(train_dataloader)):

    input_ids, attention_mask, _, _, labels = batch
    predicted_ids = model.model.generate(input_ids = input_ids.to("cuda:0"), attention_mask = attention_mask.to("cuda:0"), 
                                     max_new_tokens = config.tokenizer.output_max_length, num_beams = 3, early_stopping = True).detach().cpu()

    batch_predicted_reason = train_dataset.tokenizer.batch_decode(predicted_ids, skip_special_tokens = True)

    reason_generations.extend(batch_predicted_reason)

In [None]:
import pickle
with open("datasets/test_feta_qa_reason_without_answer_flant5.pkl", "wb") as f:
    pickle.dump(reason_generations, f)

In [None]:
len(reason_generations)

In [None]:
text_input = train_dataset.text_input[1000]
table = train_dataset.table[1000]
answer = train_dataset.text_output[1000]
reason = reason_generations[1000]

print(text_input, end = "\n\n")
print(answer, end = "\n\n")
print(reason, end = "\n\n")
print(table, end = "\n\n")