In [None]:
import os
import json
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split

run_name = "220905_250k_train_knn"
base_path = "/Users/jawaugh/labs" if os.path.exists("/Users/jawaugh/labs") else "/home/jawaugh"

groundtruth_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_groundtruth.parquet')
knn_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_knn.parquet')
sketchpad_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_sketchpad.parquet')
metrics_path = os.path.join(base_path, f'sketch/sketch/examples/Text2SQL_Iterations/{run_name}_metrics.parquet')

In [None]:
# print(torch.backends.mps.is_available())
# print(torch.backends.mps.is_built())

# device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cpu")
print(device)

In [None]:

class BinaryKNNSketch(Dataset):
    def __init__(self, metrics_path, groundtruth_path):
        self.metrics = pd.read_parquet(metrics_path)
        self.metrics = self.metrics.astype({c: 'float32' for c in self.metrics.columns})
        self.metrics = self.metrics.fillna(-1)
        self.metrics = torch.tensor(self.metrics.values, dtype=torch.float32, device=device)
        self.groundtruth = pd.read_parquet(groundtruth_path).drop(columns=['left', 'right', 'left_string', 'right_string', 'path', 'union'])
        self.groundtruth = self.groundtruth.astype({c: 'float32' for c in self.groundtruth.columns})
        self.groundtruth = self.groundtruth.fillna(-1)
        self.groundtruth = torch.tensor(self.groundtruth.values, dtype=torch.float32, device=device)
        assert len(self.metrics) == len(self.groundtruth)

    def __len__(self):
        return len(self.metrics)

    def __getitem__(self, idx):
        # return self.metrics[idx], self.groundtruth[idx]
        return self.groundtruth[idx], self.groundtruth[idx]

In [None]:
dataset = BinaryKNNSketch(metrics_path, groundtruth_path)

test_data, train_data = random_split(dataset, [10000, len(dataset) - 10000])

In [None]:
# for i, batch in enumerate(train_dataloader):
#     print(i, batch)

# Basic NN model with 2 hidden layers
class BinaryNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BinaryNN, self).__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size, device=device)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size, device=device)
        self.linear3 = torch.nn.Linear(hidden_size, output_size, device=device)


    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        x = torch.relu(self.linear3(x))
        return x

In [None]:
train_dataloader = DataLoader(train_data, batch_size=50000, shuffle=True)

In [None]:
# Create model
model = BinaryNN(len(dataset[0][0]), 200, len(dataset[0][1]))

# Create loss function
loss_fn = torch.nn.MSELoss()

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Train model
for t in range(100):
    for i, batch in enumerate(train_dataloader):
        x, y = batch
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        if i % 1 == 0:
            print(f"{t:<4} {i:<5} {loss.item():24.1f}")
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
x, y = test_data[22]
y_pred = model(x)
print([f"{x.item():2.3f}" for x in y])
print([f"{x.item():2.3f}" for x in y_pred])

In [None]:
pd.read_parquet(groundtruth_path).head().T

In [None]:
print([f"{x.item():2.3f}" for x in x])

In [None]:
# # fake data to prove it works
# class FakeData(Dataset):
#     def __init__(self):
#         self.x = torch.tensor([[1], [2], [3], [4], [5], [6]], dtype=torch.float32)
#         self.y = torch.tensor([[1, 5], [2, 4], [3, 3], [4, 2], [5, 1], [6, 0]], dtype=torch.float32)
    
#     def __len__(self):
#         return len(self.x)
    
#     def __getitem__(self, idx):
#         return self.x[idx], self.y[idx]

# loader = DataLoader(FakeData(), batch_size=2, shuffle=True)

# model = BinaryNN(1, 10, 2)
# criterion = torch.nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)

# for epoch in range(100):
#     for i, batch in enumerate(loader):
#         x, y = batch
#         y_pred = model(x)
#         loss = criterion(y_pred, y)
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#         if i % 100 == 0:
#             print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')