In [1]:
import os
import shutil
import warnings
import polars as pl
import lance

from datasets import Dataset, load_dataset, IterableDataset

In [None]:
Q_SEP = "Q:\n\n"
A_SEP = "\n\nA:\n\n"

def process_qna(series: pl.Series):
    return series.str.strip_prefix(Q_SEP).str.splitn(A_SEP, 3).struct.rename_fields(["question", "answer", "_"]).struct.unnest().drop("_")



STREAM_BATCH_BYTES = 2 ** 27 # ~128MB

def process_batch(path: str, batch_iter, i: int):
    print(f"Processing Batch: {i}")
    batch = next(batch_iter)
    pl_series = pl.Series("text", batch["text"])
    pl_qna = process_qna(pl_series)
    lance.write_dataset(pl_qna.to_arrow(), "".join([path, f"_batch_{i}"]) if i > 0 else path)

def get_lance_dataset(path: str, batch_limit: int | None=None, force_reload=False):
    "path: str - make sure to make it an absolute path, or else it will redownload"
    if os.path.exists(path) and not force_reload:
        base_dataset = lance.dataset(path)

    else:
        ds = load_dataset("bigscience-data/roots_code_stackexchange", streaming=True)["train"]
        train_info = ds.info.splits["train"]
        stream_batch_size = int(STREAM_BATCH_BYTES * (train_info.num_examples / train_info.num_bytes))
        batch_iter = iter(ds.select_columns("text").batch(stream_batch_size))

        if batch_limit is not None:
            c = None
            for i in range(batch_limit):
                try:
                    process_batch(path, batch_iter, i)
                except Exception as e:
                    c = i
                    if e is not StopIteration:
                        warnings.warn(f"Caught exception in iterator: {e}")
                    break
            
            c = c or batch_limit

        else:
            c = 0
            while True:
                try:
                    process_batch(path, batch_iter, i)
                    c += 1
                except Exception as e:
                    if e is not StopIteration:
                        warnings.warn(f"Caught exception in iterator: {e}")
                    break

        print(f"Processed Batches: {c}")
        assert c > 0, f"Empty or error-ridden dataset: {path}"

        print(f"\nSaving Dataset...")
        base_dataset = lance.dataset(path)
        if c > 1:
            for i in range(1, c):
                print(f"Appending Batch: {i}")
                aux_path = "".join([path, f"_batch_{i}"])
                aux_dataset = lance.dataset(aux_path)
                base_dataset.insert(aux_dataset)
                shutil.rmtree(aux_path)
    
        print("Saved Dataset")

    return base_dataset

db_name = "stackexchange_base_db_lance"
lance_dataset = get_lance_dataset(os.path.sep.join([os.getcwd(), "cache", db_name]), 5)

Resolving data files:   0%|          | 0/54 [00:00<?, ?it/s]

Processing Batch: 0
Processing Batch: 1
Processing Batch: 2
Processing Batch: 3
Processing Batch: 4
Processed Batches: 5

Saving Dataset...
Appending Batch: 1


TypeError: str.join() takes exactly one argument (2 given)

In [5]:
LOAD_BATCH_SIZE = 2 ** 16

lance_batches = lance_dataset.to_batches(batch_size=LOAD_BATCH_SIZE)

In [None]:
import torch
from sentence_transformers import SentenceTransformer


In [None]:
SentenceTransformer("").encode()