### Cell for Setup:

In [None]:
import os, sys
import pandas as pd
import torch
import numpy as np
import faiss
from tqdm.notebook import tqdm # For notebook-friendly progress bars
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from pathlib import Path
from pathlib import Path
import sys
os.chdir("/mnt/c/GitHub/HEARTS") 
# Use current working directory (for notebooks)
PROJECT_DIR = Path().resolve().parent

# Add the correct parent folder of `sdd`
sdd_path = PROJECT_DIR / "scripts/starmie"

sys.path.append(str(sdd_path))

# Import modules from sdd
from sdd.pretrain import load_checkpoint, inference_on_tables
from sdd.dataset import PretrainTableDataset

# Configuration
MODEL_PATH = "checkpoints/starmie/santos/model_drop_col_tfidf_entity_column_0.pt"
DATA_DIR = "data/santos/datalake/"
EMBEDDING_BATCH_SIZE = 32


def get_all_tables(data_dir_path):
    all_tables = {}
    for filename in os.listdir(data_dir_path):
        if filename.endswith(".csv"):
            file_path = os.path.join(data_dir_path, filename)
            try:
                df = pd.read_csv(file_path, lineterminator='\n')
                all_tables[filename] = df
            except Exception as e:
                print(f"Warning: Skipping problematic file {filename}: {str(e)}")
    return all_tables

def extract_all_column_embeddings(model, unlabeled_dataset, tables_dict, batch_size_for_inference=32):
    all_column_data = []
    # Ensure model is on the correct device (e.g., GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model.to(device)
    model.eval()

    print("Extracting column embeddings...")
    for table_name, df in tqdm(tables_dict.items(), desc="Processing tables for embeddings"):
        if df.empty or len(df.columns) == 0:
            print(f"Skipping empty or column-less table: {table_name}")
            continue
        
        column_embeddings_list_of_lists = inference_on_tables(
            tables=[df], 
            model=model, 
            unlabeled=unlabeled_dataset, 
            batch_size=batch_size_for_inference # Use the passed batch size for inference_on_tables
        )

        if not column_embeddings_list_of_lists or not column_embeddings_list_of_lists[0]:
            print(f"Warning: No embeddings returned for table {table_name}")
            continue
            
        column_embeddings_for_table = column_embeddings_list_of_lists[0]

        if len(df.columns) != len(column_embeddings_for_table):
            print(f"Warning: Mismatch columns ({len(df.columns)}) vs embeddings ({len(column_embeddings_for_table)}) for {table_name}. Skipping.")
            continue

        for i, col_name in enumerate(df.columns):
            all_column_data.append({
                "table_name": table_name,
                "column_name": col_name,
                "embedding": np.array(column_embeddings_for_table[i])
            })
    return all_column_data

def build_faiss_index(all_column_data_list):
    if not all_column_data_list:
        return None, None
    
    embeddings = np.array([col_data["embedding"] for col_data in all_column_data_list]).astype('float32')
    
    if embeddings.ndim == 1: # Should ideally not happen if there's more than one column globally
        if embeddings.shape[0] == 0: # No embeddings at all
             return None, None
        embeddings = embeddings.reshape(1, -1) # Reshape if it's a single embedding vector
    elif embeddings.shape[0] == 0: # No rows
        return None, None

    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # For cosine similarity (Inner Product on normalized vectors)
    
    # Normalize embeddings before adding to IndexFlatIP for cosine similarity
    faiss.normalize_L2(embeddings)
    index.add(embeddings)
    
    # Create a mapping from FAISS index to original column info
    idx_to_info = {i: {"table_name": col_data["table_name"], "column_name": col_data["column_name"]} 
                   for i, col_data in enumerate(all_column_data_list)}
    return index, idx_to_info


# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = torch.load(MODEL_PATH, map_location=device, weights_only=False)
model, unlabeled_dataset = load_checkpoint(ckpt)
model.to(device)
model.eval()

# Load tables and extract all embeddings
all_tables_dict = get_all_tables(DATA_DIR)
all_column_data = extract_all_column_embeddings(model, unlabeled_dataset, all_tables_dict, EMBEDDING_BATCH_SIZE)
faiss_index, idx_to_info_map = build_faiss_index(all_column_data)

print(f"Setup complete. Loaded {len(all_tables_dict)} tables and {len(all_column_data)} columns.")
print(f"FAISS index built with {faiss_index.ntotal if faiss_index else 0} embeddings.")

cuda


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda
Extracting column embeddings...


Processing tables for embeddings:   0%|          | 0/550 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:07<00:00,  7.33s/it]
100%|██████████| 1/1 [00:06<00:00,  6.97s/it]
100%|██████████| 1/1 [00:05<00:00,  5.70s/it]
100%|██████████| 1/1 [00:05<00:00,  5.96s/it]
100%|██████████| 1/1 [00:06<00:00,  6.07s/it]
100%|██████████| 1/1 [00:08<00:00,  8.25s/it]
100%|██████████| 1/1 [00:09<00:00,  9.93s/it]
100%|██████████| 1/1 [00:06<00:00,  6.67s/it]
100%|██████████| 1/1 [00:06<00:00,  6.81s/it]
100%|██████████| 1/1 [00:08<00:00,  8.05s/it]
100%|██████████| 1/1 [00:05<00:00,  5.34s/it]
100%|██████████| 1/1 [00:12<00:00, 12.64s/it]
100%|██████████| 1/1 [00:07<00:00,  7.84s/it]
100%|██████████| 1/1 [00:04<00:00,  4.78s/it]
100%|██████████| 1/1 [00:03<00:00,  3.85s/it]
100%|██████████| 1/1 [00:02<00:00,  2.46s/it]
100%|██████████| 1/1 [00:02<00:00,  2.64s/it]
100%|██████████| 1/1 [00:02<00:00,  2.24s/it]
100%|██████████| 1/1 [00:02<00:00,  2.54s/it]
100%|██████████| 1/1 [00:18<00:00, 18.29s/it]
100%|██████████| 1/1 [00:17<00:00, 17.14s/it]
100%|██████████| 1/1 [00:06<00:00,

### Cell for Query and Search

In [None]:
# --- User Input ---
target_query_column_name = "ticket_created_date_time"
target_query_table_name = "311_calls_historic_data_0.csv" # Set to None if not specific
top_n_to_visualize = 10

# --- Find Query Embedding ---
query_embedding_vector = None
query_column_info = None
for col_data_item in all_column_data:
    if col_data_item["column_name"] == target_query_column_name:
        if target_query_table_name is None or col_data_item["table_name"] == target_query_table_name:
            query_embedding_vector = col_data_item["embedding"].astype('float32').reshape(1, -1)
            query_column_info = col_data_item
            break
if query_embedding_vector is None:
    raise ValueError("Query column not found.")

normalized_query_embedding = query_embedding_vector.copy()
faiss.normalize_L2(normalized_query_embedding)

# --- Perform FAISS Search ---
num_to_retrieve = top_n_to_visualize + 1 # +1 to handle self-match
distances, indices = faiss_index.search(normalized_query_embedding, num_to_retrieve)

# --- Process Search Results ---
search_results_list = []
for i in range(indices.shape[1]):
    faiss_idx = indices[0][i]
    if faiss_idx == -1: continue
    retrieved_col_info = idx_to_info_map[faiss_idx]
    similarity_score = distances[0][i]
    is_query_self = (retrieved_col_info["table_name"] == query_column_info["table_name"] and
                     retrieved_col_info["column_name"] == query_column_info["column_name"])
    if is_query_self: continue
    search_results_list.append({**retrieved_col_info, "similarity": similarity_score})
    if len(search_results_list) >= top_n_to_visualize: break

print(f"Query: {query_column_info['table_name']}.{query_column_info['column_name']}")
for res in search_results_list:
    print(f"  Similar: {res['table_name']}.{res['column_name']} (Score: {res['similarity']:.4f})")

### Cell for Bar Chart Visualization:

In [None]:
plt.figure(figsize=(10, 8))
labels = [f"{res['table_name']}\n.{res['column_name']}" for res in search_results_list]
scores = [res['similarity'] for res in search_results_list]

sns.barplot(x=scores, y=labels, palette="viridis")
plt.title(f"Top {len(search_results_list)} columns similar to '{query_column_info['table_name']}.{query_column_info['column_name']}'", fontsize=14)
plt.xlabel("Similarity Score (Cosine Similarity)", fontsize=12)
plt.ylabel("Table.Column", fontsize=12)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
plt.show()

### Cell for t-SNE Visualization:

In [None]:
embeddings_for_tsne = [query_embedding_vector.flatten()] # Query embedding
tsne_labels = [f"QUERY:\n{query_column_info['table_name']}\n.{query_column_info['column_name']}"]
tsne_colors = ['red'] # Color for query

# Add top N similar columns (e.g., top 5 for clarity)
for res in search_results_list[:5]:
    for col_data_item in all_column_data:
        if col_data_item['table_name'] == res['table_name'] and col_data_item['column_name'] == res['column_name']:
            embeddings_for_tsne.append(col_data_item['embedding'])
            tsne_labels.append(f"Similar:\n{res['table_name']}\n.{res['column_name']}\n(Score: {res['similarity']:.2f})")
            tsne_colors.append('green') # Color for similar items
            break

# Optional: Add a few random "other" columns for contrast
# Ensure these are not the query or already in top similar
# For example:
# count = 0
# for col_data_item in np.random.permutation(all_column_data): # Shuffle to get random items
#     is_query = col_data_item['table_name'] == query_column_info['table_name'] and col_data_item['column_name'] == query_column_info['column_name']
#     is_in_results = any(r['table_name'] == col_data_item['table_name'] and r['column_name'] == col_data_item['column_name'] for r in search_results_list[:5])
#     if not is_query and not is_in_results and count < 3:
#         embeddings_for_tsne.append(col_data_item['embedding'])
#         tsne_labels.append(f"Other:\n{col_data_item['table_name']}\n.{col_data_item['column_name']}")
#         tsne_colors.append('blue') # Color for other items
#         count += 1


embeddings_array_for_tsne = np.array(embeddings_for_tsne).astype('float32')

if embeddings_array_for_tsne.shape[0] > 1: # t-SNE needs at least 2 samples
    # Adjust perplexity: must be less than the number of samples.
    perplexity_value = min(30, embeddings_array_for_tsne.shape[0] - 1)
    if perplexity_value < 5 and embeddings_array_for_tsne.shape[0] > 1 : perplexity_value = embeddings_array_for_tsne.shape[0] -1


    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity_value, n_iter=1000, learning_rate='auto' if hasattr(TSNE, '_EXPLORATION_EARLY_EXAGGERATION_DECAY') else 200.0)
    embeddings_2d = tsne.fit_transform(embeddings_array_for_tsne)

    plt.figure(figsize=(14, 10))
    for i in range(embeddings_2d.shape[0]):
        plt.scatter(embeddings_2d[i, 0], embeddings_2d[i, 1], color=tsne_colors[i], s=150, label=tsne_labels[i] if i < 3 else None, alpha=0.7) # Only label first few for legend clarity
        plt.text(embeddings_2d[i, 0] + 0.03, embeddings_2d[i, 1] + 0.03, tsne_labels[i], fontsize=9)
    
    plt.title(f"t-SNE visualization of embeddings related to '{query_column_info['table_name']}.{query_column_info['column_name']}'", fontsize=14)
    plt.xlabel("t-SNE Component 1", fontsize=12)
    plt.ylabel("t-SNE Component 2", fontsize=12)
    # Create a custom legend
    handles = [plt.Line2D([0], [0], marker='o', color='w', label='Query', markersize=10, markerfacecolor='red'),
               plt.Line2D([0], [0], marker='o', color='w', label='Similar', markersize=10, markerfacecolor='green'),
               plt.Line2D([0], [0], marker='o', color='w', label='Other', markersize=10, markerfacecolor='blue')]
    if len(tsne_colors) > len(search_results_list[:5]) +1 : # if 'other' samples were added
        plt.legend(handles=handles, title="Column Types", fontsize=10)
    else:
         plt.legend(handles=handles[:2], title="Column Types", fontsize=10)


    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()
else:
    print("Not enough unique embeddings to plot with t-SNE (need at least 2).")
