In [1]:
import sys
import os
from dotenv import load_dotenv

root_dir = os.path.abspath("..")
sys.path.append(root_dir)
dotenv_path = os.path.join(root_dir, ".env")
load_dotenv(dotenv_path)

False

In [None]:
from pathlib import Path
import math
import polars as pl
import numpy as np
from sentence_transformers import SentenceTransformer

In [None]:
PARQUET_PATH = os.path.join(root_dir, "data", "processed_flight_features_test.parquet")
OUTPUT_DIR = os.path.join(root_dir, "data", "embedded_flight_feature_lite_test")

PARQUET_OUT_DIR = os.path.join(root_dir, "data", "embedded_flight_feature_lite_parquet_test")

COL_NAME = "flight_text"
MODEL_NAME = "all-MiniLM-L6-v2"

In [None]:
scan = pl.scan_parquet(PARQUET_PATH).select([COL_NAME]).with_row_index("row_id")
row_count = scan.select(pl.len()).collect(engine="streaming")[0, 0]
print(f"[INFO] Total rows: {row_count}")

In [None]:
model = SentenceTransformer(MODEL_NAME)

In [None]:
start = 0
chunk_idx = 0
BATCH_SIZE = 128
ROW_COUNT = row_count

READ_CHUNK_SIZE = 8192  # tune this based on memory

while start < ROW_COUNT:
    end = min(start + READ_CHUNK_SIZE, ROW_COUNT)
    print(f"[INFO] Loading rows {start} to {end} of {ROW_COUNT}")

    df_block = (
        pl.scan_parquet(PARQUET_PATH)
          .select([COL_NAME])
          .with_row_index("row_id")
          .filter((pl.col("row_id") >= start) & (pl.col("row_id") < end))
          .collect(engine="streaming")
    )

    texts = df_block[COL_NAME].to_list()
    row_ids = df_block["row_id"].to_list()

    for i in range(0, len(texts), BATCH_SIZE):
        subtexts = texts[i:i+BATCH_SIZE]
        subids = row_ids[i:i+BATCH_SIZE]
        emb = model.encode(subtexts, batch_size=BATCH_SIZE,
                           show_progress_bar=False,
                           convert_to_numpy=True,
                           normalize_embeddings=True)
        out_file = os.path.join(OUTPUT_DIR, f"embeddings_part{chunk_idx:05d}.npz")
        np.savez_compressed(out_file, row_ids=np.array(subids), embeddings=emb)
        print(f"[INFO] Saved {len(subtexts)} embeddings to {out_file}")
        chunk_idx += 1

    start = end

In [None]:
# find all npz files
NUM_OUTPUTS = 15

npz_files = sorted(Path(OUTPUT_DIR).glob("embeddings_part*.npz"))
total_files = len(npz_files)
print(f"[INFO] Found {total_files} chunk files")

# how many files per parquet group (ceil)
files_per_split = math.ceil(total_files / NUM_OUTPUTS)

for split_idx in range(NUM_OUTPUTS):
    start = split_idx * files_per_split
    end = min((split_idx + 1) * files_per_split, total_files)
    split_files = npz_files[start:end]

    if not split_files:  # no files left
        break

    print(f"[INFO] Processing split {split_idx+1}/{NUM_OUTPUTS}: files {start} to {end-1} ({len(split_files)} files)")

    all_tables = []

    for f in split_files:
        data = np.load(f)
        row_ids = data["row_ids"]
        embeddings = data["embeddings"]
        n_samples, dim = embeddings.shape

        embed_cols = {f"emb_{i}": embeddings[:, i] for i in range(dim)}

        df = pl.DataFrame({
            "row_id": row_ids,
            **embed_cols
        })
        all_tables.append(df)

    merged_df = pl.concat(all_tables, how="vertical")
    merged_df = merged_df.sort("row_id")

    out_path = os.path.join(PARQUET_OUT_DIR, f"merged_part{split_idx:02d}.parquet")
    merged_df.write_parquet(out_path)
    print(f"[INFO] Saved split {split_idx+1} -> {out_path}")

In [None]:
# Collect all parquet files (adjust pattern if needed)
parquet_files = sorted(Path(PARQUET_OUT_DIR).glob("merged_part*.parquet"))
print(f"[INFO] Found {len(parquet_files)} parquet files")

all_tables = []

for f in parquet_files:
    print(f"[INFO] Loading {f}")
    df = pl.read_parquet(f)
    all_tables.append(df)

# Merge them into one DataFrame
merged_df = pl.concat(all_tables, how="vertical")

# (Optional) sort by row_id if needed
merged_df = merged_df.sort("row_id")

print(f"[INFO] Final merged shape: {merged_df.shape}")

In [None]:
out_path = os.path.join(root_dir, "data", "embed_flight_feature_test.parquet")
merged_df.write_parquet(out_path)