In [None]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

def test_model(train_path, test_path):
    train_df = pd.read_csv(train_path, sep="\t")
    test_df = pd.read_csv(test_path, sep="\t")

    le = LabelEncoder()
    le.fit(train_df["classification_x"])

    # detect Features
    feature_cols = [
        col for col in train_df.columns
        if col not in ["Unnamed: 0", "Geneid", "DNASequence", "classification_x", "group"]
    ]

    X_train = train_df[feature_cols]
    y_train = train_df["classification_x"]

    X_val = test_df[feature_cols]
    y_val = test_df["classification_x"]

    print(f"Train: {train_path} | Test: {test_path}")
    print(f"X_train: {X_train.shape}; y_train: {y_train.shape}")
    print(f"X_val: {X_val.shape}; y_val: {y_val.shape}")

    clf = DecisionTreeClassifier(random_state=42)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_val)
    print(classification_report(y_val, y_pred, target_names=le.classes_))
    print("-" * 60)

# List of all splits
splits = [
    ("../data/combined-data-stratified-split/train_data.tsv", "../data/combined-data-stratified-split/test_data.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_0.tsv", "../data/leave-one-group-out-split/splits/test_split_0.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_1.tsv", "../data/leave-one-group-out-split/splits/test_split_1.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_2.tsv", "../data/leave-one-group-out-split/splits/test_split_2.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_3.tsv", "../data/leave-one-group-out-split/splits/test_split_3.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_4.tsv", "../data/leave-one-group-out-split/splits/test_split_4.tsv"),
    ("../data/leave-one-group-out-split/splits/train_split_5.tsv", "../data/leave-one-group-out-split/splits/test_split_5.tsv")
]

for train_path, test_path in splits:
    test_model(train_path, test_path)


Train: ../data/combined-data-stratified-split/train_data.tsv | Test: ../data/combined-data-stratified-split/test_data.tsv
X_train: (784, 85); y_train: (784,)
X_val: (197, 85); y_val: (197,)
              precision    recall  f1-score   support

       early       0.43      0.50      0.46        46
        late       0.69      0.62      0.66        93
      middle       0.55      0.57      0.56        58

    accuracy                           0.58       197
   macro avg       0.56      0.56      0.56       197
weighted avg       0.59      0.58      0.58       197

------------------------------------------------------------
Train: ../data/leave-one-group-out-split/splits/train_split_0.tsv | Test: ../data/leave-one-group-out-split/splits/test_split_0.tsv
X_train: (863, 85); y_train: (863,)
X_val: (54, 85); y_val: (54,)
              precision    recall  f1-score   support

       early       0.05      0.33      0.08         3
        late       0.93      0.41      0.57        32
      m