In [36]:
from transformers import AutoTokenizer, RagRetriever, RagModel
import transformers
import torch
import numpy as np
import pandas as pd
from scipy.io.arff import loadarff
from datasets import load_dataset, concatenate_datasets
from langchain import HuggingFacePipeline
import os, os.path as osp
device = "cuda:0" if torch.cuda.is_available() else "cpu"


In [5]:
root_dir = os.getcwd()
data_dir = osp.join(root_dir, "data")
raw_data_dir = osp.join(data_dir, "raw")
dataset_path = osp.join(raw_data_dir, "bank")

In [3]:
data = load_dataset("wikitablequestions", trust_remote_code=True)

In [13]:
data = concatenate_datasets([data["train"], data["validation"], data["test"]])

In [43]:
def stat_tables(dataset):
    num_tabs = len(dataset)
    num_cols = [len(d["table"]["header"]) for d in data]
    num_rows = [len(d["table"]["rows"]) for d in data]
    
    avg_cols, avg_rows = np.mean(num_cols), np.mean(num_rows)
    min_cols, min_rows = np.min(num_cols), np.min(num_rows)
    max_cols, max_rows = np.max(num_cols), np.max(num_rows)
    
    print(f"Number of tables: {num_tabs}")
    print(f"Average number of cols/rows: {avg_cols:.2f}/{avg_rows:.2f}")
    print(f"Max number of cols/rows: {max_cols:.2f}/{max_rows:.2f}")
    print(f"Min number of cols/rows: {min_cols:.2f}/{min_rows:.2f}")
    
    

In [45]:
stat_tables(data)

Number of tables: 18496
Average number of cols/rows: 6.35/25.38
Max number of cols/rows: 25.00/753.00
Min number of cols/rows: 3.00/4.00


In [49]:
for d in data:
    table = d["table"]
    if len(table["header"]) > 20:
        print(d)
        break
    
    # elif len(table["rows"]) > 300:
    #     print(table["header"])
    #     print(table["rows"])

{'id': 'nt-223', 'question': 'when was the benetton b198 chassis used?', 'answers': ['1998'], 'table': {'header': ['Year', 'Entrant', 'Chassis', 'Engine', 'Tyres', 'Drivers', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', 'Points', 'WCC'], 'rows': [['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', '', 'AUS', 'BRA', 'ARG', 'SMR', 'ESP', 'MON', 'CAN', 'FRA', 'GBR', 'AUT', 'GER', 'HUN', 'BEL', 'ITA', 'LUX', 'JPN', '', '33', '5th'], ['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', 'Giancarlo Fisichella', 'Ret', '6', '7', 'Ret', 'Ret', '2', '2', '9', '5', 'Ret', '7', '8', 'Ret', '8', '6', '8', '', '33', '5th'], ['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', 'Alexander Wurz', '7', '4', '4', 'Ret', '4', 'Ret', '4', '5', '4', '9', '11', '16', 'Ret', 'Ret', '7', '9', '', '33', '5th'], ['1999', 'Mild Seven Benetton', 'Benetton B199', 'FB01 V10', 'B', '', 'AUS', 'BRA', 'SMR', 'MON', 'ESP', '

In [51]:
for d in data:
    table = d["table"]
    if "EUR" in table["header"]:
        print(d)
        break
    for row in table["rows"]:
        if "EUR" in row:
            print(d)
            break

{'id': 'nt-223', 'question': 'when was the benetton b198 chassis used?', 'answers': ['1998'], 'table': {'header': ['Year', 'Entrant', 'Chassis', 'Engine', 'Tyres', 'Drivers', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', 'Points', 'WCC'], 'rows': [['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', '', 'AUS', 'BRA', 'ARG', 'SMR', 'ESP', 'MON', 'CAN', 'FRA', 'GBR', 'AUT', 'GER', 'HUN', 'BEL', 'ITA', 'LUX', 'JPN', '', '33', '5th'], ['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', 'Giancarlo Fisichella', 'Ret', '6', '7', 'Ret', 'Ret', '2', '2', '9', '5', 'Ret', '7', '8', 'Ret', '8', '6', '8', '', '33', '5th'], ['1998', 'Mild Seven Benetton', 'Benetton B198', 'GC37-01 V10', 'B', 'Alexander Wurz', '7', '4', '4', 'Ret', '4', 'Ret', '4', '5', '4', '9', '11', '16', 'Ret', 'Ret', '7', '9', '', '33', '5th'], ['1999', 'Mild Seven Benetton', 'Benetton B199', 'FB01 V10', 'B', '', 'AUS', 'BRA', 'SMR', 'MON', 'ESP', '

In [57]:
for row, col in data[0]["table"]["rows"].iterrow():
    print(row, col)
    break

AttributeError: 'list' object has no attribute 'iterrow'

In [67]:
unique_tables = []
unique_tnames = set()
for row in data:
    tname = row["table"]["name"]
    if tname not in unique_tnames:
        unique_tnames.add(tname)
        unique_tables.append(row["table"])

In [70]:
num_nodes,num_hyperedges = 0, 0
for table in unique_tables:
    num_nodes += len(table["header"]) * len(table["rows"])
    num_hyperedges += len(table["header"])  +len(table["rows"])

In [71]:
num_nodes, num_hyperedges

(363509, 71000)

In [75]:
row.keys()

dict_keys(['id', 'question', 'answers', 'table'])

In [5]:
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
class BipartiteData(Data):
    pass
x_s = torch.randn(2, 16)  # 2 nodes.
x_t = torch.randn(3, 16)  # 3 nodes.
edge_index = torch.tensor([
    [0, 0, 1, 1],
    [0, 1, 1, 2],
])

data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)

data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))



In [7]:
batch.edge_index

tensor([[0, 0, 1, 1, 3, 3, 4, 4],
        [0, 1, 1, 2, 3, 4, 4, 5]])