**Goal:** 
1) Load the trees and get the split time and sample time
2) Truncate/Pad the vector to feed a standard n to FNN


In [8]:
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
from pathlib import Path
import dendropy
from tqdm import tqdm


In [9]:
BASE = Path.cwd().parent
PARAMS = BASE / "data" / "processed"/ "preprocessed_parameters.csv"
TREES_DIR = BASE / "data" / "raw"
MAX_TIPS = 1000

In [10]:
def pad_or_truncate(arr,length):
    arr=arr.astype(np.float32)
    if arr.size>=length:
        arr=arr[:length]
        return arr
    new_array=np.zeros(length,dtype=np.float32)
    new_array[:arr.size]=arr

    return new_array


In [11]:
def tree_to_split_and_sampling_times(tree_path: Path):
    tree = dendropy.Tree.get(path=str(tree_path), schema="newick")
    tree.calc_node_root_distances()

    tip_dists = [leaf.root_distance for leaf in tree.leaf_node_iter()]
    height = float(max(tip_dists)) if tip_dists else 0.0

    sampling_times = np.array(sorted(height - d for d in tip_dists), dtype=np.float32)

    internal_nodes = [n for n in tree.internal_nodes() if n.parent_node is not None]
    split_times = np.array(
        sorted(height - n.root_distance for n in internal_nodes),
        dtype=np.float32
    )

    return split_times, sampling_times, height

In [13]:
def main():
    df = pd.read_csv(PARAMS)
    df = df.rename(columns={"lambda": "lambda1"})

    X_list, y_list = [], []

    for row in tqdm(df.itertuples(index=False), total=len(df)):
        tree_id = int(getattr(row, "tree_id"))

        tree_path = TREES_DIR / f"tree_{tree_id}.nwk"
        if not tree_path.exists():
            raise FileNotFoundError(f"Missing tree file: {tree_path}")

        split_t, samp_t, height = tree_to_split_and_sampling_times(tree_path)

        split_fixed = pad_or_truncate(split_t, MAX_TIPS - 1)
        samp_fixed  = pad_or_truncate(samp_t, MAX_TIPS)

        n_tips = len(samp_t)
        features = np.concatenate([split_fixed, samp_fixed, np.array([n_tips, height], np.float32)])
        X_list.append(features)

        lam1 = float(getattr(row, "lambda1"))
        mu   = float(getattr(row, "mu"))
        psi  = float(getattr(row, "psi"))
        lam2 = float(getattr(row, "lambda2"))

        y_list.append(np.array([lam1, mu, psi, lam2], dtype=np.float32))

    X = np.stack(X_list)
    y = np.stack(y_list)

    out_dir = BASE / "data" / "processed"

    np.save(out_dir / "X.npy", X)
    np.save(out_dir / "y.npy", y)


In [14]:
if __name__ == "__main__":
    main()

100%|██████████| 100/100 [00:03<00:00, 28.97it/s]
