In [4]:
import numpy as np
import pandas as pd
from typing import List, Set
from sklearn.model_selection import train_test_split

# Reuse your TARGET_COLS from above
TARGET_COLS = [
    "oc_usda.c729_w.pct",
    "c.tot_usda.a622_w.pct",
    "n.tot_usda.a623_w.pct",
    "ph.h2o_usda.a268_index",
    "ph.cacl2_usda.a481_index",
    "cec_usda.a723_cmolc.kg",
    "ec_usda.a364_ds.m",
    "clay.tot_usda.a334_w.pct",
    "sand.tot_usda.c60_w.pct",
    "silt.tot_usda.c62_w.pct",
    "bd_usda.a4_g.cm3",
    "wr.10kPa_usda.a414_w.pct",
    "wr.33kPa_usda.a415_w.pct",
    "wr.1500kPa_usda.a417_w.pct",
    "awc.33.1500kPa_usda.c80_w.frac",
    "fe.ox_usda.a60_w.pct",
    "al.ox_usda.a59_w.pct",
    "fe.dith_usda.a66_w.pct",
    "al.dith_usda.a65_w.pct",
    "p.ext_usda.a1070_mg.kg",
    "k.ext_usda.a1065_mg.kg",
    "mg.ext_usda.a1066_mg.kg",
    "ca.ext_usda.a1059_mg.kg",
    "na.ext_usda.a1068_mg.kg",
]


def make_union_test_csv(
    csv_path: str,
    target_cols: List[str] = TARGET_COLS,
    test_size: float = 0.2,
    random_state: int = 42,
    min_rows_required: int = 200,
    output_path: str = "test_union.csv",
) -> None:
    """
    Reconstructs the per-target train/test split used in `train_per_target_models`
    (using the same test_size and random_state), then takes the UNION of all
    test-row indices across targets and writes those rows to a single CSV.

    Parameters
    ----------
    csv_path : str
        Path to the cleaned OSSL-style CSV used in training.
    target_cols : list of str
        Targets you trained models for.
    test_size : float
        Same test_size as used in training.
    random_state : int
        Same random_state as used in training (42).
    min_rows_required : int
        Only reconstruct splits for targets that had enough rows in training.
    output_path : str
        Where to save the resulting test CSV.
    """
    df = pd.read_csv(csv_path, low_memory=False)
    print(f"Loaded dataset: {df.shape} (rows, columns)")

    all_test_indices: Set[int] = set()

    for target in target_cols:
        if target not in df.columns:
            print(f"[SKIP] Target {target} not found in dataframe.")
            continue

        y = df[target].to_numpy(dtype=float)
        mask = ~np.isnan(y)
        n_available = int(mask.sum())
        print(f"Target {target}: {n_available} non-missing rows.")

        if n_available < min_rows_required:
            print(f"  [SKIP] Only {n_available} (< {min_rows_required}) rows.")
            continue

        # original row indices where this target is non-missing
        row_idx_all = np.where(mask)[0]

        # indices within this filtered subset
        idx_all = np.arange(row_idx_all.shape[0])

        # split indices with same random_state & test_size as in training
        _, idx_test = train_test_split(
            idx_all, test_size=test_size, random_state=random_state
        )

        # map back to original df row indices
        test_idx = row_idx_all[idx_test]
        all_test_indices.update(test_idx)

    if not all_test_indices:
        raise RuntimeError("No test indices collected (check target_cols/min_rows_required).")

    test_indices_sorted = sorted(all_test_indices)
    df_test = df.iloc[test_indices_sorted].copy()
    print(f"Total unique test rows (union across targets): {len(df_test)}")

    df_test.to_csv(output_path, index=False)
    print(f"Saved union test CSV to: {output_path}")

cleaned_dataset_path = r"C:\Users\SAADB\Desktop\Python_Code\Google_Hackathon\ossl_cleaned.csv"
test_set_path = r"C:\Users\SAADB\Desktop\Python_Code\Google_Hackathon\Model_and_Results\ossl_test_union.csv"
make_union_test_csv(cleaned_dataset_path, output_path=test_set_path)

Loaded dataset: (135651, 2837) (rows, columns)
Target oc_usda.c729_w.pct: 132490 non-missing rows.
Target c.tot_usda.a622_w.pct: 90488 non-missing rows.
Target n.tot_usda.a623_w.pct: 131247 non-missing rows.
Target ph.h2o_usda.a268_index: 100174 non-missing rows.
Target ph.cacl2_usda.a481_index: 93994 non-missing rows.
Target cec_usda.a723_cmolc.kg: 76046 non-missing rows.
Target ec_usda.a364_ds.m: 55822 non-missing rows.
Target clay.tot_usda.a334_w.pct: 79922 non-missing rows.
Target sand.tot_usda.c60_w.pct: 79816 non-missing rows.
Target silt.tot_usda.c62_w.pct: 79869 non-missing rows.
Target bd_usda.a4_g.cm3: 51265 non-missing rows.
Target wr.10kPa_usda.a414_w.pct: 4110 non-missing rows.
Target wr.33kPa_usda.a415_w.pct: 19459 non-missing rows.
Target wr.1500kPa_usda.a417_w.pct: 41345 non-missing rows.
Target awc.33.1500kPa_usda.c80_w.frac: 16175 non-missing rows.
Target fe.ox_usda.a60_w.pct: 28259 non-missing rows.
Target al.ox_usda.a59_w.pct: 28260 non-missing rows.
Target fe.dith_