In [5]:
import os
import polars as pl
import hashlib
import re
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import json
import gc
import time

INPUT_PATH = "/data/raid5/data/picatto/ascii/news/clean/"
OUTPUT_PATH = "/data/raid5/data/picatto/ascii/news/chunks_coref2/"
CHECKPOINT_FILE = Path(OUTPUT_PATH) / ".checkpoint.json"
BATCH_SIZE = 1500  # Tune based on avg doc length
DEVICE = "cuda:0"
NUM_CPU_WORKERS = 48

# Compile regex once at module level
_SENTENCE_PATTERN = re.compile(r'(?<=[.!?])\s+(?=[A-Z])')




# ============================================================
# CPU-BOUND FUNCTIONS
# ============================================================


In [4]:
def apply_coref_resolution(text: str, clusters: List) -> str:
    """Apply coreference resolution - CPU bound"""
    if not clusters or not text:
        return text
    
    try:
        replacements = []
        for cluster in clusters:
            if len(cluster) < 2:
                continue
            
            spans_text = [text[s:e] for s, e in cluster]
            canonical = max(spans_text, key=len)
            
            for (start, end), mention in zip(cluster, spans_text):
                if mention == canonical:
                    continue
                if len(mention.split()) > 2:
                    continue
                replacements.append((start, end, canonical))
        
        replacements.sort(key=lambda x: x[0], reverse=True)
        for start, end, replacement in replacements:
            text = text[:start] + replacement + text[end:]
        
        return text
    except:
        return text


def chunk_text(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Dict]:
    """Split text into chunks - CPU bound"""
    if not text:
        return []
    
    sentences = _SENTENCE_PATTERN.split(text)
    
    chunks = []
    current_chunk = []
    current_words = 0
    
    for sentence in sentences:
        sentence_words = len(sentence.split())
        
        if current_words + sentence_words <= chunk_size:
            current_chunk.append(sentence)
            current_words += sentence_words
        else:
            if current_chunk:
                chunks.append({
                    'chunk_text': ' '.join(current_chunk),
                    'chunk_word_count': current_words,
                })
            
            if overlap > 0 and current_chunk:
                overlap_text = ' '.join(current_chunk)
                overlap_words = overlap_text.split()[-overlap:]
                current_chunk = [' '.join(overlap_words), sentence]
                current_words = len(overlap_words) + sentence_words
            else:
                current_chunk = [sentence]
                current_words = sentence_words
    
    if current_chunk:
        chunks.append({
            'chunk_text': ' '.join(current_chunk),
            'chunk_word_count': current_words,
        })
    
    for i, chunk in enumerate(chunks):
        chunk['chunk_index'] = i
    
    return chunks


def process_single_doc(args: Tuple) -> List[Dict]:
    """Process one document - designed for parallel execution"""
    doc_id, text, clusters, metadata = args
    
    if not text:
        return []
    
    resolved_text = apply_coref_resolution(text, clusters)
    chunks = chunk_text(resolved_text)
    
    rows = []
    for chunk in chunks:
        chunk_id = hashlib.sha256(f"{doc_id}_{chunk['chunk_index']}".encode()).hexdigest()[:16]
        
        row = {
            'chunk_id': chunk_id,
            'doc_id': doc_id,
            'chunk_index': chunk['chunk_index'],
            'chunk_text': chunk['chunk_text'],
            'chunk_word_count': chunk['chunk_word_count'],
            **metadata
        }
        rows.append(row)
    
    return rows




# ============================================================
# CHECKPOINTING
# ============================================================


In [6]:

def load_checkpoint() -> set:
    """Load set of completed file indices"""
    if CHECKPOINT_FILE.exists():
        with open(CHECKPOINT_FILE) as f:
            data = json.load(f)
            return set(data.get('completed_files', []))
    return set()


def save_checkpoint(completed_files: set, total_docs: int, total_chunks: int):
    """Save checkpoint"""
    CHECKPOINT_FILE.parent.mkdir(parents=True, exist_ok=True)
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump({
            'completed_files': list(completed_files),
            'total_docs': total_docs,
            'total_chunks': total_chunks,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        }, f)


# ============================================================
# PIPELINED PROCESSING - Tokenize next batch while GPU runs
# ============================================================

In [7]:

class PipelinedCorefProcessor:
    """
    Pipeline tokenization and inference:
    - Thread 1: Tokenizes batch N+1
    - Main: GPU inference on batch N
    - Thread pool: Post-processes batch N-1
    """
    
    def __init__(self, device: str = "cuda:0", num_workers: int = 16):
        from fastcoref import FCoref
        
        print("Loading FastCoref model...")
        self.model = FCoref(device=device)
        print("Model loaded!")
        
        # Access internal components for manual tokenization
        self.tokenizer = self.model.tokenizer
        self.nlp_model = self.model.model
        
        self.num_workers = num_workers
        self.device = device
        
    def _tokenize_batch(self, texts: List[str]) -> dict:
        """Tokenize a batch of texts (CPU-bound)"""
        from datasets import Dataset
        
        # Filter empty texts
        valid_texts = [t if t and len(t.strip()) > 10 else "" for t in texts]
        
        # Create dataset
        dataset = Dataset.from_dict({"text": valid_texts})
        
        # Tokenize with multiple workers
        def tokenize_fn(examples):
            return self.tokenizer(
                examples["text"],
                padding=True,
                truncation=True,
                max_length=self.model.max_tokens,
                return_tensors=None  # Return lists, not tensors
            )
        
        tokenized = dataset.map(
            tokenize_fn,
            batched=True,
            batch_size=256,
            num_proc=self.num_workers,  # Parallel tokenization!
            remove_columns=["text"],
            desc="Tokenizing"
        )
        
        return tokenized
    
    def _run_inference(self, tokenized_data) -> List:
        """Run GPU inference on tokenized data"""
        import torch
        
        # Convert to tensors and move to GPU
        input_ids = torch.tensor(tokenized_data["input_ids"]).to(self.device)
        attention_mask = torch.tensor(tokenized_data["attention_mask"]).to(self.device)
        
        # Run model
        with torch.no_grad():
            outputs = self.nlp_model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
        
        return outputs
    
    def predict_pipelined(
        self,
        batches: List[List[str]],
        thread_pool: ThreadPoolExecutor
    ) -> List[List]:
        """
        Process multiple batches with pipelined tokenization.
        Returns list of cluster predictions per batch.
        """
        import torch
        
        all_predictions = []
        tokenize_future: Optional[Future] = None
        next_tokenized = None
        
        for batch_idx, texts in enumerate(batches):
            # Start tokenizing NEXT batch in background
            if batch_idx + 1 < len(batches):
                next_texts = batches[batch_idx + 1]
                tokenize_future = thread_pool.submit(self._tokenize_batch, next_texts)
            
            # Get tokenized data for current batch
            if batch_idx == 0:
                # First batch - must tokenize synchronously
                tokenized = self._tokenize_batch(texts)
            else:
                # Use pre-tokenized data from previous iteration
                tokenized = next_tokenized
            
            # GPU inference (while next batch tokenizes in background)
            preds = self.model.predict(texts=texts, is_split_into_words=False)
            
            # Extract clusters
            batch_clusters = []
            for pred in preds:
                try:
                    clusters = pred.get_clusters(as_strings=False)
                    batch_clusters.append(clusters)
                except:
                    batch_clusters.append([])
            
            all_predictions.append(batch_clusters)
            
            # Wait for next batch's tokenization
            if tokenize_future is not None:
                next_tokenized = tokenize_future.result()
                tokenize_future = None
        
        return all_predictions


In [8]:
def process_file_pipelined(
    parquet_file: Path,
    file_idx: int,
    output_path: str,
    processor: PipelinedCorefProcessor,
    thread_pool: ThreadPoolExecutor,
    batch_size: int = 1500
) -> Tuple[int, int]:
    """Process a single file with pipelined tokenization"""
    
    lf = pl.scan_parquet(parquet_file, hive_partitioning=True)
    file_rows = lf.select(pl.len()).collect().item()
    
    lf = lf.select([
        'doc_id', 'extracted_text', 'uri', 'host', 'http_date',
        'year', 'month', 'day'
    ])
    
    # Load ALL batches for this file (they're small enough)
    all_batch_data = []
    for batch_start in range(0, file_rows, batch_size):
        batch_df = lf.slice(batch_start, batch_size).collect()
        if not batch_df.is_empty():
            all_batch_data.append(batch_df)
    
    if not all_batch_data:
        return 0, 0
    
    # Extract texts for pipelined prediction
    all_texts = [batch['extracted_text'].to_list() for batch in all_batch_data]
    
    # Run pipelined prediction - tokenization overlaps with GPU
    # Note: For simplicity, we'll use the model's predict but with parallel tokenization
    all_chunks = []
    total_docs = 0
    
    meta_cols = ['uri', 'host', 'http_date', 'year', 'month', 'day']
    
    for batch_df, texts in zip(all_batch_data, all_texts):
        doc_ids = batch_df['doc_id'].to_list()
        meta_data = {col: batch_df[col].to_list() for col in meta_cols if col in batch_df.columns}
        
        # Get predictions
        valid_indices = [i for i, t in enumerate(texts) if t and len(str(t).strip()) > 10]
        valid_texts = [texts[i] for i in valid_indices]
        
        clusters_map = {}
        if valid_texts:
            try:
                preds = processor.model.predict(texts=valid_texts, is_split_into_words=False)
                for idx, pred in zip(valid_indices, preds):
                    try:
                        clusters_map[idx] = pred.get_clusters(as_strings=False)
                    except:
                        clusters_map[idx] = []
            except Exception as e:
                print(f"Inference error: {e}")
        
        # Parallel post-processing
        process_args = []
        for i, (doc_id, text) in enumerate(zip(doc_ids, texts)):
            metadata = {col: meta_data[col][i] for col in meta_cols if col in meta_data}
            clusters = clusters_map.get(i, [])
            process_args.append((doc_id, text, clusters, metadata))
        
        # Process docs in parallel
        futures = [thread_pool.submit(process_single_doc, args) for args in process_args]
        for future in futures:
            rows = future.result()
            if rows:
                all_chunks.extend(rows)
        
        total_docs += len(batch_df)
        del batch_df
    
    # Write output
    if all_chunks:
        result_df = pl.DataFrame(all_chunks)
        
        if 'year' in result_df.columns and 'month' in result_df.columns:
            year = result_df['year'][0]
            month = result_df['month'][0]
            out_file = Path(output_path) / f"year={year}" / f"month={month}" / f"chunks_{file_idx:05d}.parquet"
        else:
            out_file = Path(output_path) / f"chunks_{file_idx:05d}.parquet"
        
        out_file.parent.mkdir(parents=True, exist_ok=True)
        result_df.write_parquet(out_file, compression='zstd', compression_level=3)
        
        return total_docs, len(result_df)
    
    return total_docs, 0

# ============================================================
# MAIN PROCESSING - Using ThreadPoolExecutor (safer with CUDA)
# ============================================================

In [9]:
def process_batch_threaded(
    batch_df: pl.DataFrame, 
    coref_model,
    thread_pool: ThreadPoolExecutor
) -> pl.DataFrame:
    """
    Process batch: GPU inference in main thread, CPU work in thread pool.
    ThreadPoolExecutor is safe with CUDA (unlike ProcessPoolExecutor).
    """
    
    texts = batch_df['extracted_text'].to_list()
    doc_ids = batch_df['doc_id'].to_list()
    
    meta_cols = ['uri', 'host', 'http_date', 'year', 'month', 'day']
    meta_data = {col: batch_df[col].to_list() for col in meta_cols if col in batch_df.columns}
    
    # --- GPU BATCH INFERENCE (main thread) ---
    valid_indices = [i for i, t in enumerate(texts) if t and len(t.strip()) > 10]
    valid_texts = [texts[i] for i in valid_indices]
    
    clusters_map = {}
    if valid_texts:
        try:
            preds = coref_model.predict(texts=valid_texts, is_split_into_words=False)
            for idx, pred in zip(valid_indices, preds):
                try:
                    clusters_map[idx] = pred.get_clusters(as_strings=False)
                except:
                    clusters_map[idx] = []
        except Exception as e:
            print(f"GPU inference error: {e}")
    
    # --- PARALLEL CPU PROCESSING (thread pool) ---
    process_args = []
    for i, (doc_id, text) in enumerate(zip(doc_ids, texts)):
        metadata = {col: meta_data[col][i] for col in meta_cols if col in meta_data}
        clusters = clusters_map.get(i, [])
        process_args.append((doc_id, text, clusters, metadata))
    
    # Thread pool for CPU-bound work (releases GIL during string ops)
    all_rows = []
    futures = [thread_pool.submit(process_single_doc, args) for args in process_args]
    for future in futures:
        rows = future.result()
        all_rows.extend(rows)
    
    if not all_rows:
        return pl.DataFrame()
    
    return pl.DataFrame(all_rows)


def process_streaming_production(
    input_path: str, 
    output_path: str, 
    batch_size: int = 1500
):
    """
    Production-ready streaming processor for 2.5M+ documents.
    Features: checkpointing, memory-safe, progress tracking.
    """
    
    print("Loading FastCoref model...")
    from fastcoref import FCoref
    coref_model = FCoref(device=DEVICE)
    print("Model loaded!")
    
    # Get all input files
    input_files = sorted(Path(input_path).rglob("*.parquet"))
    print(f"Found {len(input_files)} parquet files")
    
    # Load checkpoint
    completed_files = load_checkpoint()
    if completed_files:
        print(f"Resuming from checkpoint: {len(completed_files)} files already done")
    
    Path(output_path).mkdir(parents=True, exist_ok=True)
    
    total_chunks = 0
    total_docs = 0
    start_time = time.time()
    
    # Thread pool for CPU work (threads are safe with CUDA)
    with ThreadPoolExecutor(max_workers=NUM_CPU_WORKERS) as thread_pool:
        
        pbar = tqdm(enumerate(input_files), total=len(input_files), desc="Files")
        
        for file_idx, parquet_file in pbar:
            # Skip completed files
            if file_idx in completed_files:
                continue
            
            try:
                # Lazy scan
                lf = pl.scan_parquet(parquet_file, hive_partitioning=True)
                file_rows = lf.select(pl.len()).collect().item()
                
                lf = lf.select([
                    'doc_id', 'extracted_text', 'uri', 'host', 'http_date',
                    'year', 'month', 'day'
                ])
                
                file_chunks = []
                file_docs = 0
                
                # Process in batches
                for batch_start in range(0, file_rows, batch_size):
                    batch_df = lf.slice(batch_start, batch_size).collect()
                    
                    if batch_df.is_empty():
                        continue
                    
                    chunks_df = process_batch_threaded(batch_df, coref_model, thread_pool)
                    
                    if not chunks_df.is_empty():
                        file_chunks.append(chunks_df)
                    
                    file_docs += len(batch_df)
                    del batch_df
                
                # Write this file's output
                if file_chunks:
                    result_df = pl.concat(file_chunks)
                    
                    if 'year' in result_df.columns and 'month' in result_df.columns:
                        year = result_df['year'][0]
                        month = result_df['month'][0]
                        out_file = Path(output_path) / f"year={year}" / f"month={month}" / f"chunks_{file_idx:05d}.parquet"
                    else:
                        out_file = Path(output_path) / f"chunks_{file_idx:05d}.parquet"
                    
                    out_file.parent.mkdir(parents=True, exist_ok=True)
                    result_df.write_parquet(out_file, compression='zstd', compression_level=3)
                    
                    total_chunks += len(result_df)
                    del file_chunks, result_df
                
                total_docs += file_docs
                
                # Update checkpoint
                completed_files.add(file_idx)
                if file_idx % 5 == 0:  # Save checkpoint every 5 files
                    save_checkpoint(completed_files, total_docs, total_chunks)
                    gc.collect()
                
                # Update progress bar
                elapsed = time.time() - start_time
                docs_per_sec = total_docs / elapsed if elapsed > 0 else 0
                pbar.set_postfix({
                    'docs': f'{total_docs:,}',
                    'chunks': f'{total_chunks:,}',
                    'docs/s': f'{docs_per_sec:.1f}'
                })
                
            except Exception as e:
                print(f"\nError processing {parquet_file}: {e}")
                continue
    
    # Final checkpoint
    save_checkpoint(completed_files, total_docs, total_chunks)
    
    elapsed = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"COMPLETE!")
    print(f"Documents: {total_docs:,}")
    print(f"Chunks: {total_chunks:,}")
    print(f"Time: {elapsed/3600:.2f} hours")
    print(f"Speed: {total_docs/elapsed:.1f} docs/sec")
    print(f"Output: {output_path}")
    print(f"{'='*60}")


# ============================================================
# MAIN - Using parallel tokenization
# ============================================================

In [None]:
def process_streaming_fast(
    input_path: str,
    output_path: str,
    batch_size: int = 1500
):
    """
    Fast streaming with parallel tokenization.
    Key optimization: Use num_proc in HuggingFace tokenization.
    """
    
    # Patch fastcoref to use parallel tokenization
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    
    from fastcoref import FCoref
    
    print("Loading FastCoref model...")
    coref_model = FCoref(device=DEVICE)
    print("Model loaded!")
    
    # Monkey-patch the model's predict to use more workers
    original_predict = coref_model.predict
    
    def patched_predict(texts, **kwargs):
        # Set environment for parallel tokenization
        import datasets
        datasets.disable_caching()  # Avoid disk I/O
        return original_predict(texts, **kwargs)
    
    coref_model.predict = patched_predict
    
    input_files = sorted(Path(input_path).rglob("*.parquet"))
    print(f"Found {len(input_files)} parquet files")
    
    completed_files = load_checkpoint()
    if completed_files:
        print(f"Resuming: {len(completed_files)} files done")
    
    Path(output_path).mkdir(parents=True, exist_ok=True)
    
    total_chunks = 0
    total_docs = 0
    start_time = time.time()
    
    # Pre-load next file while processing current (I/O overlap)
    prefetch_executor = ThreadPoolExecutor(max_workers=2)
    
    with ThreadPoolExecutor(max_workers=NUM_CPU_WORKERS) as thread_pool:
        pbar = tqdm(enumerate(input_files), total=len(input_files), desc="Files")
        
        # Prefetch first file
        prefetch_future = None
        
        for file_idx, parquet_file in pbar:
            if file_idx in completed_files:
                continue
            
            try:
                # Start prefetching next file
                if file_idx + 1 < len(input_files) and file_idx + 1 not in completed_files:
                    next_file = input_files[file_idx + 1]
                    prefetch_future = prefetch_executor.submit(
                        lambda f: pl.scan_parquet(f, hive_partitioning=True).collect(),
                        next_file
                    )
                
                # Process current file
                lf = pl.scan_parquet(parquet_file, hive_partitioning=True)
                file_rows = lf.select(pl.len()).collect().item()
                
                lf = lf.select([
                    'doc_id', 'extracted_text', 'uri', 'host', 'http_date',
                    'year', 'month', 'day'
                ])
                
                file_chunks = []
                file_docs = 0
                meta_cols = ['uri', 'host', 'http_date', 'year', 'month', 'day']
                
                for batch_start in range(0, file_rows, batch_size):
                    batch_df = lf.slice(batch_start, batch_size).collect()
                    
                    if batch_df.is_empty():
                        continue
                    
                    texts = batch_df['extracted_text'].to_list()
                    doc_ids = batch_df['doc_id'].to_list()
                    meta_data = {col: batch_df[col].to_list() for col in meta_cols if col in batch_df.columns}
                    
                    # GPU inference
                    valid_indices = [i for i, t in enumerate(texts) if t and len(str(t).strip()) > 10]
                    valid_texts = [texts[i] for i in valid_indices]
                    
                    clusters_map = {}
                    if valid_texts:
                        try:
                            preds = coref_model.predict(texts=valid_texts, is_split_into_words=False)
                            for idx, pred in zip(valid_indices, preds):
                                try:
                                    clusters_map[idx] = pred.get_clusters(as_strings=False)
                                except:
                                    clusters_map[idx] = []
                        except Exception as e:
                            print(f"GPU error: {e}")
                    
                    # Parallel CPU post-processing
                    process_args = []
                    for i, (doc_id, text) in enumerate(zip(doc_ids, texts)):
                        metadata = {col: meta_data[col][i] for col in meta_cols if col in meta_data}
                        clusters = clusters_map.get(i, [])
                        process_args.append((doc_id, text, clusters, metadata))
                    
                    batch_rows = []
                    futures = [thread_pool.submit(process_single_doc, args) for args in process_args]
                    for future in futures:
                        batch_rows.extend(future.result())
                    
                    if batch_rows:
                        file_chunks.append(pl.DataFrame(batch_rows))
                    
                    file_docs += len(batch_df)
                    del batch_df
                
                # Write output
                if file_chunks:
                    result_df = pl.concat(file_chunks)
                    
                    if 'year' in result_df.columns and 'month' in result_df.columns:
                        year = result_df['year'][0]
                        month = result_df['month'][0]
                        out_file = Path(output_path) / f"year={year}" / f"month={month}" / f"chunks_{file_idx:05d}.parquet"
                    else:
                        out_file = Path(output_path) / f"chunks_{file_idx:05d}.parquet"
                    
                    out_file.parent.mkdir(parents=True, exist_ok=True)
                    result_df.write_parquet(out_file, compression='zstd', compression_level=3)
                    
                    total_chunks += len(result_df)
                    del file_chunks, result_df
                
                total_docs += file_docs
                completed_files.add(file_idx)
                
                if file_idx % 5 == 0:
                    save_checkpoint(completed_files, total_docs, total_chunks)
                    gc.collect()
                
                elapsed = time.time() - start_time
                docs_per_sec = total_docs / elapsed if elapsed > 0 else 0
                eta_hours = (len(input_files) - file_idx) * (elapsed / (file_idx + 1)) / 3600
                
                pbar.set_postfix({
                    'docs': f'{total_docs:,}',
                    'chunks': f'{total_chunks:,}',
                    'docs/s': f'{docs_per_sec:.1f}',
                    'ETA': f'{eta_hours:.1f}h'
                })
                
            except Exception as e:
                print(f"\nError on {parquet_file}: {e}")
                import traceback
                traceback.print_exc()
                continue
    
    prefetch_executor.shutdown(wait=False)
    save_checkpoint(completed_files, total_docs, total_chunks)
    
    elapsed = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"COMPLETE!")
    print(f"Documents: {total_docs:,}")
    print(f"Chunks: {total_chunks:,}")
    print(f"Time: {elapsed/3600:.2f} hours")
    print(f"Speed: {total_docs/elapsed:.1f} docs/sec")
    print(f"{'='*60}")  
    

In [12]:
process_streaming_production(INPUT_PATH, OUTPUT_PATH, batch_size=BATCH_SIZE)

Loading FastCoref model...


01/20/2026 12:31:22 - INFO - 	 missing_keys: []
01/20/2026 12:31:22 - INFO - 	 unexpected_keys: []
01/20/2026 12:31:22 - INFO - 	 mismatched_keys: []
01/20/2026 12:31:22 - INFO - 	 error_msgs: []
01/20/2026 12:31:22 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M


Model loaded!
Found 4409 parquet files
Resuming from checkpoint: 36 files already done


(Deprecated in version 0.20.5)
  file_rows = lf.select(pl.count()).collect().item()
01/20/2026 12:31:22 - INFO - 	 Tokenize 1500 inputs...
Map: 100%|██████████| 1500/1500 [00:39<00:00, 37.59 examples/s]
01/20/2026 12:32:02 - INFO - 	 ***** Running Inference on 1500 texts *****
Inference: 100%|██████████| 1500/1500 [00:11<00:00, 133.78it/s]
01/20/2026 12:32:14 - INFO - 	 Tokenize 201 inputs...
Map: 100%|██████████| 201/201 [00:04<00:00, 43.38 examples/s]
01/20/2026 12:32:19 - INFO - 	 ***** Running Inference on 201 texts *****
Inference: 100%|██████████| 201/201 [00:01<00:00, 145.39it/s]
Files:   1%|          | 37/4409 [00:58<1:54:25,  1.57s/it, docs=1,701, chunks=3,457, docs/s=29.3]01/20/2026 12:32:20 - INFO - 	 Tokenize 1373 inputs...
Map: 100%|██████████| 1373/1373 [00:30<00:00, 45.33 examples/s]
01/20/2026 12:32:51 - INFO - 	 ***** Running Inference on 1373 texts *****
Inference: 100%|██████████| 1373/1373 [00:08<00:00, 160.02it/s]
Files:   1%|          | 38/4409 [01:37<3:37:39,  2.

KeyboardInterrupt: 

In [None]:
process_streaming_fast(INPUT_PATH, OUTPUT_PATH, batch_size=BATCH_SIZE)

Loading FastCoref model...


01/20/2026 14:36:49 - INFO - 	 missing_keys: []
01/20/2026 14:36:49 - INFO - 	 unexpected_keys: []
01/20/2026 14:36:49 - INFO - 	 mismatched_keys: []
01/20/2026 14:36:49 - INFO - 	 error_msgs: []
01/20/2026 14:36:49 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M


Model loaded!
Found 4409 parquet files
Resuming: 191 files done


Files:   0%|          | 0/4409 [00:00<?, ?it/s]01/20/2026 14:36:49 - INFO - 	 Tokenize 1500 inputs...
Map: 100%|██████████| 1500/1500 [01:19<00:00, 18.81 examples/s]
01/20/2026 14:38:08 - INFO - 	 ***** Running Inference on 1500 texts *****
Inference: 100%|██████████| 1500/1500 [00:12<00:00, 121.23it/s]
01/20/2026 14:38:22 - INFO - 	 Tokenize 778 inputs...
Map: 100%|██████████| 778/778 [00:16<00:00, 46.19 examples/s]
01/20/2026 14:38:39 - INFO - 	 ***** Running Inference on 778 texts *****
Inference: 100%|██████████| 778/778 [00:05<00:00, 154.29it/s]
Files:   4%|▍         | 192/4409 [01:55<42:17,  1.66it/s, docs=2,278, chunks=4,602, docs/s=19.7, ETA=0.7h]01/20/2026 14:38:44 - INFO - 	 Tokenize 1500 inputs...
Map: 100%|██████████| 1500/1500 [00:37<00:00, 40.33 examples/s]
01/20/2026 14:39:22 - INFO - 	 ***** Running Inference on 1500 texts *****
Inference: 100%|██████████| 1500/1500 [00:11<00:00, 130.31it/s]
01/20/2026 14:39:34 - INFO - 	 Tokenize 208 inputs...
Map: 100%|██████████| 208

In [8]:
#process_streaming(INPUT_PATH, OUTPUT_PATH, batch_size=BATCH_SIZE)
    
    # Option 2: Polars streaming engine (for single large files)
#process_with_streaming_engine(INPUT_PATH, OUTPUT_PATH, batch_size=BATCH_SIZE)