# User Story 14 / 15
@LuiseJedlitschka

**Cross-Validation Strategy: Leave-One-Group-Out**

To evaluate the generalization capability of our models and to avoid overfitting, we employed the leave-one-group-out cross-validation strategy as implemented in scikit-learn. In this approach, each dataset —corresponding to a specific phage— is used once as the test set (singleton), while the remaining datasets collectively form the training set. This ensures that, in each split, the model is validated on data from a phage that was not seen during training, providing a robust assessment of performance across different biological backgrounds.

Note:
- The data was not explicitly stratified according to the classification classes ("early", "middle", "late") during the splitting process as that is not part of the leave-one-group-out strategy. As a result, the distribution of these classes may vary between the training and test sets in each split.
- due to its very unusal class distribution the sprenger data set is completely being left out for now

An overview of the class distribution in the training and test sets for each split is provided in
leave_one_group_out_split/overview.tsv.

All corresponding training and test files for each split are saved in
data/leave-one-group-out-stratified-split/.

In [None]:
import os
import glob
import pandas as pd
from sklearn.model_selection import LeaveOneGroupOut

# Directory with the TSV files
directory = "../data/feature_tables"
# Output directory
output_dir = "../data/leave-one-group-out-split"
os.makedirs(output_dir, exist_ok=True)

# List of all .tsv files in the directory
tsv_files = glob.glob(os.path.join(directory, "*.tsv"))

# Combine all TSV files into a Dataframe, for each dataset assign the group-name

df = pd.concat(
    [pd.read_csv(f, sep="\t").assign(group=os.path.basename(f)) for f in tsv_files],
    ignore_index=True,
)

# save combined table containing features and group index of each gene
output_path = os.path.join(output_dir, "combined.tsv")
df.to_csv(output_path, sep="\t", index=False)

# Prepare subfolder for splits
splits_dir = os.path.join(output_dir, "splits")
os.makedirs(splits_dir, exist_ok=True)

# Perform split of one group each as test -> 6 splits
logo = LeaveOneGroupOut()
results = []

for i, (train_idx, test_idx) in enumerate(logo.split(df, groups=df["group"])):
    train_df = df.iloc[train_idx]
    test_df = df.iloc[test_idx]

    # save train and test data for this split
    train_path = os.path.join(splits_dir, f"train_split_{i}.tsv")
    test_path = os.path.join(splits_dir, f"test_split_{i}.tsv")
    train_df.to_csv(train_path, sep="\t", index=False)
    test_df.to_csv(test_path, sep="\t", index=False)

    test_model(train_df, test_df)

    # overall class distribution
    all_classes = sorted(df["classification_x"].unique())

    # class distribution in train set
    train_counts = train_df["classification_x"].value_counts(normalize=True)
    # class distribution in test set
    test_counts = test_df["classification_x"].value_counts(normalize=True)

    # Check for overlapping genes
    overlapping_genes = set(train_df["Geneid"]).intersection(set(test_df["Geneid"]))
    if overlapping_genes:
        print(f"Split {i}: {len(overlapping_genes)} overlapping genes found!")
    else:
        print(f"Split {i}: No overlapping genes.")

    for cls in all_classes:
        results.append(
            {
                "split": i,
                "group_left_out": df.iloc[test_idx]["group"].iloc[
                    0
                ],  # number of the test group
                "class": cls,
                "train_ratio": train_counts.get(cls, 0),
                "test_ratio": test_counts.get(cls, 0),
                "train_count": train_df["classification_x"].value_counts().get(cls, 0),
                "test_count": test_df["classification_x"].value_counts().get(cls, 0),
            }
        )

# convert results to DataFrame
split_summary = pd.DataFrame(results)

# Save overview of each split
overview_path = os.path.join(output_dir, "logo_class_distributions.tsv")
split_summary.to_csv(overview_path, sep="\t", index=False)


X_train: (1164, 85); y_train: (1164,)
X_train: (54, 85); y_train: (54,)
              precision    recall  f1-score   support

       early       0.25      0.33      0.29         3
        late       0.74      0.78      0.76        32
      middle       0.62      0.53      0.57        19

    accuracy                           0.67        54
   macro avg       0.54      0.55      0.54        54
weighted avg       0.67      0.67      0.67        54

Split 0: No overlapping genes.
X_train: (981, 85); y_train: (981,)
X_train: (237, 85); y_train: (237,)
              precision    recall  f1-score   support

       early       0.38      0.26      0.31        78
        late       0.29      0.41      0.34        59
      middle       0.39      0.40      0.39       100

    accuracy                           0.35       237
   macro avg       0.36      0.35      0.35       237
weighted avg       0.36      0.35      0.35       237

Split 1: No overlapping genes.
X_train: (930, 85); y_train: (93

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


X_train: (1164, 85); y_train: (1164,)
X_train: (54, 85); y_train: (54,)
              precision    recall  f1-score   support

       early       0.30      0.67      0.41         9
        late       0.58      0.65      0.61        23
      middle       0.50      0.18      0.27        22

    accuracy                           0.46        54
   macro avg       0.46      0.50      0.43        54
weighted avg       0.50      0.46      0.44        54

Split 6: No overlapping genes.


| Metric        | Meaning                                                                                                                                                |
| ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
| **precision** | Out of all predicted instances of a class, how many were actually correct?<br>**Formula**: TP / (TP + FP)                                              |
| **recall**    | Out of all actual instances of a class, how many did the model correctly detect?<br>**Formula**: TP / (TP + FN)                                        |
| **f1-score**  | Harmonic mean of precision and recall.<br>Good single metric for imbalanced classes.<br>**Formula**: 2 \* (precision \* recall) / (precision + recall) |
| **support**   | Number of true examples of each class in the dataset.   |
| **accuracy**     | Overall: (correct predictions) / (total samples)                                                                                 |
| **macro avg**    | Unweighted average across all classes.<br>Each class contributes equally (good for comparing classes).                           |
| **weighted avg** | Weighted average, where each class's contribution is proportional to its support.<br>More realistic when classes are imbalanced. |

In [44]:
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report


def test_model(train_df, test_df):
    le = LabelEncoder()
    le.fit_transform(train_df["classification_x"])

    feature_cols = (
        df.iloc[train_idx]
        .drop(
            columns=["Unnamed: 0", "Geneid", "DNASequence", "classification_x", "group"]
        )
        .columns
    )

    X_train = train_df[feature_cols]
    y_train = train_df["classification_x"]

    X_val = test_df[feature_cols]
    y_val = test_df["classification_x"]

    print(f"X_train: {X_train.shape}; y_train: {y_train.shape}")
    print(f"X_train: {X_val.shape}; y_train: {y_val.shape}")

    clf = DecisionTreeClassifier(random_state=42)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_val)
    print(classification_report(y_val, y_pred, target_names=le.classes_))