In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from sqlalchemy.engine import URL
import os

HOST_IP = os.environ["DATABASE_IP"]
DATABASE_USER = os.environ["DATABASE_USER"]
DATABASE_PASSWORD = os.environ["DATABASE_PASSWORD"]
DATABASE_PORT = os.environ["DATABASE_PORT"]

connection_url = URL.create(
    "postgresql+psycopg2",
    username=DATABASE_USER,
    password=DATABASE_PASSWORD,
    host=HOST_IP,
    port=DATABASE_PORT,
    database="mimicllm",
)

In [None]:
from sqlalchemy import create_engine, Column, Text, BigInteger
from sqlalchemy.orm import sessionmaker, declarative_base

Base = declarative_base()


class TokenizedData(Base):
    __tablename__ = "tokenized_data"
    __table_args__ = {"schema": "mimicllm"}

    token_id = Column(BigInteger, primary_key=True)
    attention_mask = Column(Text)
    input_ids = Column(Text)


# Replace with your database connection details
engine = create_engine(connection_url)
Session = sessionmaker(bind=engine)

In [None]:
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import pickle

chunk_size = 1000
test_size = 0.2

parquet_dir = "data"

os.makedirs(parquet_dir, exist_ok=True)

train_parquet_file = os.path.join(parquet_dir, "train.parquet")
test_parquet_file = os.path.join(parquet_dir, "test.parquet")

last_id = 0

# Initialize Parquet writers
train_writer = None
test_writer = None

# Usage
session = Session()

total_rows = session.query(TokenizedData).count()

with tqdm(total=total_rows, desc="Processing") as pbar:
    while True:
        # Query the database in chunks ordered by a unique column
        query = (
            session.query(TokenizedData)
            .order_by(TokenizedData.token_id)
            .filter(TokenizedData.token_id > last_id)
            .limit(chunk_size)
        )
        chunk = pd.read_sql(query.statement, session.bind)

        if chunk.empty:
            break

        train_chunk, test_chunk = train_test_split(chunk, test_size=test_size)

        # Convert DataFrame to PyArrow Table
        train_table = pa.Table.from_pandas(train_chunk, preserve_index=False)
        test_table = pa.Table.from_pandas(test_chunk, preserve_index=False)

        # Write train chunk
        if train_writer is None:
            train_writer = pq.ParquetWriter(
                train_parquet_file, train_table.schema, compression="snappy"
            )
        train_writer.write_table(train_table)

        # Write test chunk
        if test_writer is None:
            test_writer = pq.ParquetWriter(
                test_parquet_file, test_table.schema, compression="snappy"
            )

        test_writer.write_table(test_table)

        last_id = int(chunk["token_id"].iloc[-1])

        # Update progress bar
        pbar.update(len(chunk))

        break

# Close the Parquet writers
if train_writer:
    train_writer.close()
if test_writer:
    test_writer.close()