In [1]:
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/all_Flaubert.txt
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/french.txt
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/l.txt

In [2]:
import re
import unicodedata
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Sequence

CONFIG = {
    "max_pair_dist": 280,
    "max_depth": 2,
    "n_paragraphs": 3000,
    "lowercase": False,
    "fold_diacritics": True,
    "digit_map": "none",
    "collapse_whitespace": True,
    "seed": 10,
    "TRAIN_FILES": ["all_Flaubert.txt", "french.txt", "l.txt"]
}

# --- Data Structures and Helper Functions ---

@dataclass(frozen=True)
class ParagraphInfo:
    text: str
    has_pair: bool
    balanced: bool
    max_pair_dist: int
    max_depth: int

@dataclass(frozen=True)
class TextNormConfig:
    lowercase: bool
    fold_diacritics: bool
    digit_map: str
    collapse_whitespace: bool

def _analyze_parentheses(par: str) -> Tuple[bool, bool, int, int]:
    """
    Exact reproduction of the analysis logic used in the training script.
    """
    has_pair = "(" in par and ")" in par
    stack, max_dist, depth, max_depth, balanced = [], 0, 0, 0, True
    for i, ch in enumerate(par):
        if ch == "(":
            stack.append(i)
            depth += 1
            max_depth = max(max_depth, depth)
        elif ch == ")":
            if not stack:
                balanced = False
                break
            max_dist = max(max_dist, i - stack.pop())
            depth = max(0, depth - 1)
    if stack:
        balanced = False
    return has_pair, balanced, max_dist, max_depth

def load_paragraphs(files: Sequence[Path]) -> List[str]:
    """
    Reads files and splits by blank lines.
    Crucial: uses 'errors="ignore"' as per original code.
    """
    paragraphs = []
    for fp in files:
        if not fp.exists():
            print(f"Warning: File {fp} not found. Skipping.")
            continue
        txt = fp.read_text(encoding="utf-8", errors="ignore")
        # Split by blank lines (\n\n)
        paragraphs.extend([p.strip() for p in re.split(r"\n\s*\n", txt) if p.strip()])
    return paragraphs

def build_paragraph_infos(files: Sequence[Path]) -> List[ParagraphInfo]:
    return [ParagraphInfo(p, *_analyze_parentheses(p)) for p in load_paragraphs(files)]

def normalize_text(s: str, cfg: TextNormConfig) -> str:
    """
    Applies text normalization exactly as the LSTM preprocessing did.
    """
    if cfg.lowercase:
        s = s.lower()
    if cfg.fold_diacritics:
        s = unicodedata.normalize("NFKD", s)
        s = "".join(ch for ch in s if not unicodedata.combining(ch))
    if cfg.digit_map != "none":
        s = re.sub(r"\d", "#" if cfg.digit_map == "hash" else "0", s)
    if cfg.collapse_whitespace:
        s = re.sub(r"[ \t]+", " ", s)
        s = re.sub(r"\n{3,}", "\n\n", s)
    return s

def generate_corpus_content(infos: List[ParagraphInfo], config: dict) -> str:
    """
    Filters, samples, and normalizes the text.
    """
    # 1. Filter candidates
    candidates = [
        pi.text for pi in infos
        if pi.has_pair
        and pi.balanced
        and pi.max_pair_dist <= config["max_pair_dist"]
        and pi.max_depth <= config["max_depth"]
    ]

    print(f"Total paragraphs found: {len(infos)}")
    print(f"Candidates passing filter: {len(candidates)}")

    if not candidates:
        return ""

    # 2. Random Sampling
    # Note: The original code creates a local generator inside the function
    rng = np.random.default_rng(config["seed"])

    # Selection logic: rng.choice with replace=False
    n_select = min(config["n_paragraphs"], len(candidates))
    selected_indices = rng.choice(len(candidates), size=n_select, replace=False)
    chosen = [candidates[i] for i in selected_indices]

    # 3. Normalization
    norm_cfg = TextNormConfig(
        config["lowercase"],
        config["fold_diacritics"],
        config["digit_map"],
        config["collapse_whitespace"]
    )

    return "\n\n".join([normalize_text(p, norm_cfg) for p in chosen])

# --- Main Execution ---

def main():
    input_files = [Path(f) for f in CONFIG["TRAIN_FILES"]]
    output_filename = "reconstructed_lstm_corpus.txt"

    print("Analyzing input files...")
    infos = build_paragraph_infos(input_files)

    print("Generating corpus text...")
    final_text = generate_corpus_content(infos, CONFIG)

    print(f"Writing {len(final_text):,} characters to {output_filename}...")
    with open(output_filename, "w", encoding="utf-8") as f:
        f.write(final_text)

    print("Done.")

if __name__ == "__main__":
    main()

Analyzing input files...
Generating corpus text...
Total paragraphs found: 23644
Candidates passing filter: 3249
Writing 3,640,437 characters to reconstructed_lstm_corpus.txt...
Done.
