# 01_build_dataset

In [1]:
from pathlib import Path
import json, re, hashlib
import pandas as pd
import numpy as np

HERE = Path.cwd().resolve()

def find_config(start: Path) -> Path:
    for p in [start, *start.parents]:
        cand = p / "project_config.json"
        if cand.exists():
            return cand
    raise FileNotFoundError("project_config.json not found. Run 00_config_and_checks.ipynb first.")

CONFIG_PATH = find_config(HERE)
cfg = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
print("Config:", CONFIG_PATH)



INPUT_DIR = Path(cfg["INPUT_DIR"])
PROCESSED_DIR = Path(cfg["PROCESSED_DIR"])

ENRON_CSV = Path(cfg["ENRON_CSV"])
INVOICE1_CSV = Path(cfg["INVOICE1_CSV"])
INVOICE2_CSV = Path(cfg["INVOICE2_CSV"])
ARXIV_PATH = Path(cfg["ARXIV_PATH"])

RNG = np.random.default_rng(42)

Config: C:\Users\viach\Downloads\document-classifier-portfolio-v2\project_config.json


In [2]:
def basic_clean(t: str) -> str:
    if not isinstance(t, str):
        return ""
    t = t.replace("\x00", " ")
    t = re.sub(r"[\r\t]", " ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

def filter_by_length(df: pd.DataFrame, col: str = "text", min_chars: int = 80, max_chars: int = 12000) -> pd.DataFrame:
    n = df[col].astype(str).str.len()
    return df[(n >= min_chars) & (n <= max_chars)].copy()

def stable_id(text: str) -> str:
    h = hashlib.sha256(text.encode("utf-8", errors="ignore")).hexdigest()
    return h[:16]

In [3]:
def extract_email_body(raw_msg: str) -> str:
    if not isinstance(raw_msg, str):
        return ""
    parts = re.split(r"\n\s*\n", raw_msg, maxsplit=1)
    body = parts[1] if len(parts) > 1 else raw_msg
    lines = []
    for line in body.splitlines():
        s = line.strip()
        if s.startswith(">"):
            continue
        if re.match(r"^\s*On .+ wrote:\s*$", line):
            continue
        lines.append(line)
    return "\n".join(lines)

def load_enron(path: Path, max_rows: int | None = None) -> pd.DataFrame:
    df = pd.read_csv(path)
    text_col = None
    for c in ["message", "text", "content", "body"]:
        if c in df.columns:
            text_col = c
            break
    if text_col is None:
        obj_cols = [c for c in df.columns if df[c].dtype == "object"]
        if not obj_cols:
            raise ValueError("No text column found in Enron CSV.")
        text_col = obj_cols[0]

    if max_rows and len(df) > max_rows:
        df = df.sample(max_rows, random_state=42)

    out = pd.DataFrame()
    out["text"] = df[text_col].astype(str).map(extract_email_body).map(basic_clean)
    out["doc_type"] = "EMAIL"
    out["source"] = "ENRON"
    return out

In [4]:
INVOICE1_REQUIRED = [
    "id_invoice", "issuedDate", "country", "service", "total",
    "discount", "tax", "invoiceStatus", "balance", "dueDate", "client"
]

def _safe(row: dict, key: str, default: str = "") -> str:
    v = row.get(key, default)
    if pd.isna(v):
        return default
    return str(v).strip()

def invoice1_row_to_text(row: dict) -> str:
    inv_id = _safe(row, "id_invoice")
    issued = _safe(row, "issuedDate")
    country = _safe(row, "country")
    service = _safe(row, "service")
    total = _safe(row, "total")
    discount = _safe(row, "discount")
    tax = _safe(row, "tax")
    status = _safe(row, "invoiceStatus")
    balance = _safe(row, "balance")
    due = _safe(row, "dueDate")
    client = _safe(row, "client")

    lines = [
        f"INVOICE {inv_id}".strip(),
        f"Issued: {issued}".strip(),
        f"Client: {client}".strip(),
        f"Country: {country}".strip(),
        f"Service: {service}".strip(),
        f"Total: {total}".strip(),
        f"Discount: {discount}".strip(),
        f"Tax: {tax}".strip(),
        f"Balance: {balance}".strip(),
        f"Status: {status}".strip(),
        f"Due: {due}".strip(),
    ]
    lines = [ln for ln in lines if ln.split(":")[-1].strip() != ""]
    return "\n".join(lines)

def load_invoices_1(path: Path, max_rows: int | None = None) -> pd.DataFrame:
    df = pd.read_csv(path)
    missing = [c for c in INVOICE1_REQUIRED if c not in df.columns]
    if missing:
        raise ValueError(f"Missing invoice1 columns: {missing}")
    if max_rows:
        df = df.head(max_rows)
    texts = df.apply(lambda r: invoice1_row_to_text(r.to_dict()), axis=1)
    out = pd.DataFrame({"text": texts.map(basic_clean)})
    out["doc_type"] = "INVOICE"
    out["source"] = "INVOICE_SET_1"
    return out

In [5]:
KAGGLE_REQUIRED = [
    "first_name","last_name","email","product_id","qty","amount",
    "invoice_date","address","city","stock_code","job"
]

def invoice2_row_to_text(row: dict) -> str:
    fn = _safe(row, "first_name")
    ln = _safe(row, "last_name")
    email = _safe(row, "email")
    product_id = _safe(row, "product_id")
    qty = _safe(row, "qty")
    amount = _safe(row, "amount")
    inv_date = _safe(row, "invoice_date")
    addr = _safe(row, "address")
    city = _safe(row, "city")
    stock = _safe(row, "stock_code")
    job = _safe(row, "job")

    client = (fn + " " + ln).strip() or "Customer"
    lines = [
        "INVOICE",
        f"Date: {inv_date}".strip(),
        f"Customer: {client}".strip(),
        f"Company: {job}".strip(),
        f"Address: {addr} {city}".strip(),
        f"Email: {email}".strip(),
        f"Item: {product_id} {stock}".strip(),
        f"Qty: {qty}".strip(),
        f"Unit price: {amount}".strip(),
    ]
    lines = [ln for ln in lines if ln.split(":")[-1].strip() != "" and ln.strip() != "Address:"]
    return "\n".join(lines)

def load_invoices_2(path: Path, max_rows: int | None = None) -> pd.DataFrame:
    df = pd.read_csv(path)
    missing = [c for c in KAGGLE_REQUIRED if c not in df.columns]
    if missing:
        raise ValueError(f"Missing invoice2 columns: {missing}")
    if max_rows:
        df = df.head(max_rows)
    texts = df.apply(lambda r: invoice2_row_to_text(r.to_dict()), axis=1)
    out = pd.DataFrame({"text": texts.map(basic_clean)})
    out["doc_type"] = "INVOICE"
    out["source"] = "INVOICE_SET_2"
    return out

In [6]:
import json

def load_arxiv(path: Path, max_rows: int | None = None) -> pd.DataFrame:
    rows = []
    if path.suffix.lower() == ".jsonl":
        with open(path, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                if max_rows and i >= max_rows:
                    break
                line = line.strip()
                if not line:
                    continue
                obj = json.loads(line)
                title = str(obj.get("title", "")).strip()
                abstract = str(obj.get("abstract", "")).strip()
                txt = (title + "\n\n" + abstract).strip()
                if txt:
                    rows.append(txt)
    else:
        # tolerate JSONL in .json
        with open(path, "r", encoding="utf-8") as f:
            head = f.read(4096)
        if "\n" in head and head.lstrip().startswith("{"):
            with open(path, "r", encoding="utf-8") as f:
                for i, line in enumerate(f):
                    if max_rows and i >= max_rows:
                        break
                    line = line.strip()
                    if not line:
                        continue
                    obj = json.loads(line)
                    title = str(obj.get("title", "")).strip()
                    abstract = str(obj.get("abstract", "")).strip()
                    txt = (title + "\n\n" + abstract).strip()
                    if txt:
                        rows.append(txt)
        else:
            data = json.loads(Path(path).read_text(encoding="utf-8"))
            if isinstance(data, list):
                it = data
            else:
                it = data.get("data", [])
            for obj in it:
                if max_rows and len(rows) >= max_rows:
                    break
                title = str(obj.get("title", "")).strip()
                abstract = str(obj.get("abstract", "")).strip()
                txt = (title + "\n\n" + abstract).strip()
                if txt:
                    rows.append(txt)

    out = pd.DataFrame({"text": pd.Series(rows).map(basic_clean)})
    out["doc_type"] = "SCIENTIFIC_PAPER"
    out["source"] = "ARXIV"
    return out

In [7]:
# Build
parts = []

if ENRON_CSV.exists():
    parts.append(load_enron(ENRON_CSV, max_rows=None))
if INVOICE1_CSV.exists():
    parts.append(load_invoices_1(INVOICE1_CSV, max_rows=None))
if INVOICE2_CSV.exists():
    parts.append(load_invoices_2(INVOICE2_CSV, max_rows=None))
if ARXIV_PATH.exists():
    parts.append(load_arxiv(ARXIV_PATH, max_rows=None))

if not parts:
    raise FileNotFoundError("No datasets found in input/")

df = pd.concat(parts, ignore_index=True)
df["text"] = df["text"].astype(str).map(basic_clean)
df = df[df["text"].str.len() > 0].copy()

df["doc_id"] = df["text"].map(stable_id)
df = df.drop_duplicates(subset=["doc_id"]).reset_index(drop=True)
df = filter_by_length(df, "text", min_chars=80, max_chars=12000)

print(df.shape)
print(df["doc_type"].value_counts())

(3113730, 4)
doc_type
SCIENTIFIC_PAPER    2872037
EMAIL                221693
INVOICE               20000
Name: count, dtype: int64


In [None]:
предательство, измена, инстаграм измена, 

Unnamed: 0,text,doc_type,source,doc_id
1,Traveling to have a business meeting takes the...,EMAIL,ENRON,2631cf6b64a37507
3,"Randy, Can you send me a schedule of the salar...",EMAIL,ENRON,89ce8dde64d2b23b
6,Please cc the following distribution list with...,EMAIL,ENRON,6eb7c06288a2d944
8,1. login: pallen pw: ke9davis I don't think th...,EMAIL,ENRON,fa7758f7c5f01bde
9,---------------------- Forwarded by Phillip K ...,EMAIL,ENRON,f243c5ccecb4af42
...,...,...,...,...
3135380,On the origin of the irreversibility line in t...,SCIENTIFIC_PAPER,ARXIV,99175dfe1f070e81
3135381,Nonlinear Response of HTSC Thin Film Microwave...,SCIENTIFIC_PAPER,ARXIV,920c9e984b480968
3135382,Critical State Flux Penetration and Linear Mic...,SCIENTIFIC_PAPER,ARXIV,099995b18c91d508
3135383,Density of States and NMR Relaxation Rate in A...,SCIENTIFIC_PAPER,ARXIV,c2ce38d3013486fe


In [10]:
invoice1_df = pd.read_csv(INVOICE1_CSV)
invoice2_df = pd.read_csv(INVOICE2_CSV)

inv_target = len(invoice1_df) + len(invoice2_df)
print("inv_target:", inv_target)

if inv_target <= 0:
    raise ValueError("Invoice target is 0. Check invoice datasets loading.")

classes = ["EMAIL", "INVOICE", "SCIENTIFIC_PAPER"]

# sanity
for cls in classes:
    n = int((df["doc_type"] == cls).sum())
    if n == 0:
        raise ValueError(f"No rows for class {cls}")
    if cls != "INVOICE" and n < inv_target:
        print(f"Warning: {cls} has only {n} (< inv_target {inv_target}). Will cap to {n} instead.")

cap = min(
    inv_target,
    int((df["doc_type"] == "EMAIL").sum()),
    int((df["doc_type"] == "SCIENTIFIC_PAPER").sum()),
)

# keep invoices as-is, cap the other two
email_cap = df[df["doc_type"] == "EMAIL"].sample(n=cap, random_state=42)
paper_cap = df[df["doc_type"] == "SCIENTIFIC_PAPER"].sample(n=cap, random_state=42)

inv_df = df[df["doc_type"] == "INVOICE"].copy()

# if invoices are larger than cap, cap them too (for full balance)
inv_cap = inv_df.sample(n=cap, random_state=42) if len(inv_df) > cap else inv_df

df_bal = pd.concat([email_cap, inv_cap, paper_cap], ignore_index=True)
df_bal = df_bal.sample(frac=1.0, random_state=42).reset_index(drop=True)

print("inv_target:", inv_target)
print("cap used:", cap)
print("Balanced counts:", df_bal["doc_type"].value_counts().to_dict())

df = df_bal

inv_target: 20000
inv_target: 20000
cap used: 20000
Balanced counts: {'EMAIL': 20000, 'INVOICE': 20000, 'SCIENTIFIC_PAPER': 20000}


In [11]:
from sklearn.model_selection import train_test_split

train_df, tmp_df = train_test_split(df, test_size=0.30, random_state=42, stratify=df["doc_type"])
val_df, test_df = train_test_split(tmp_df, test_size=0.50, random_state=42, stratify=tmp_df["doc_type"])

PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
train_path = PROCESSED_DIR / "train.csv"
val_path = PROCESSED_DIR / "val.csv"
test_path = PROCESSED_DIR / "test.csv"

train_df.to_csv(train_path, index=False)
val_df.to_csv(val_path, index=False)
test_df.to_csv(test_path, index=False)

print("Saved:", train_path, val_path, test_path)

Saved: C:\Users\viach\Downloads\document-classifier-portfolio-v2\data\processed\train.csv C:\Users\viach\Downloads\document-classifier-portfolio-v2\data\processed\val.csv C:\Users\viach\Downloads\document-classifier-portfolio-v2\data\processed\test.csv
