In [3]:
# %%
# Cell 1: Imports and cross-modality training function
import os
import wandb
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score

def train_cross_modality(
    train_modality: str = "no_audio",
    test_modality: str = "only_audio",
    duration_s: int = 10,
    cfg: dict = None
):
    """
    Train a RandomForest on one modality and evaluate on another.

    Args:
        train_modality: feature set for training (e.g., 'no_audio', 'with_audio', 'only_audio')
        test_modality: feature set for testing
        duration_s: window duration in seconds
        cfg: dict of hyperparameters (keys: n_estimators, max_depth, random_state)
    Returns:
        dict: test metrics for local collection
    """
    # Default hyperparameters
    defaults = {"n_estimators": 100, "max_depth": 10, "random_state": 42}
    params = defaults.copy()
    if cfg:
        params.update(cfg)

    # Start a W&B run
    run = wandb.init(
        project="rf-segment-classification",
        job_type=f"train_{train_modality}_test_{test_modality}",
        config=params
    )
    config = run.config

    # Build file paths
    base_dir = "./features/fixed_really/balanced"
    train_file = os.path.join(base_dir, train_modality, f"features_{train_modality}_{duration_s}s_balanced.csv")
    test_file  = os.path.join(base_dir, test_modality,  f"features_{test_modality}_{duration_s}s_balanced.csv")

    # Load and align data
    df_tr = pd.read_csv(train_file)
    df_te = pd.read_csv(test_file)
    drop = ["video_id","segment","participant","start_time","end_time","label",
            "padded_duration_s","original_duration","real_sample_count",
            "expected_sample_count","padding_applied"]
    feat_tr = set(df_tr.columns) - set(drop)
    feat_te = set(df_te.columns) - set(drop)
    common = sorted(feat_tr & feat_te)

    X_tr, y_tr = df_tr[common].fillna(0), df_tr["label"]
    X_te, y_te = df_te[common].fillna(0), df_te["label"]

    # Train and predict
    clf = RandomForestClassifier(
        n_estimators=config.n_estimators,
        max_depth=config.max_depth,
        class_weight="balanced",
        random_state=config.random_state,
        n_jobs=-1
    )
    clf.fit(X_tr, y_tr)
    preds = clf.predict(X_te)

    report = classification_report(y_te, preds, output_dict=True, zero_division=0)
    f1_mac = f1_score(y_te, preds, average="macro")

    # Log to W&B
    metrics = {
        "test/accuracy": report["accuracy"],
        "test/f1_macro": f1_mac,
        **{f"test/{cls}/{m}": report[cls][m]
           for cls in ["0","1"] for m in ["precision","recall","f1-score"]}
    }
    run.log(metrics)

    imps = pd.Series(clf.feature_importances_, index=common).nlargest(10)
    for feat, val in imps.items():
        run.log({f"{feat}": val})

    run.finish()

    return {
        "train_modality": train_modality,
        "test_modality": test_modality,
        "duration_s": duration_s,
        "accuracy": report["accuracy"],
        "f1_macro": f1_mac,
        "precision_1": report["1"]["precision"],
        "recall_1": report["1"]["recall"],
        "f1_1": report["1"]["f1-score"]
    }

# %%
# Cell 2: Run all cross-modality tests
import itertools

def run_all_cross_modality_tests(duration_s=10, cfg=None):
    modalities = ["with_audio", "no_audio", "only_audio"]
    results = []

    for train_mod in modalities:
        for test_mod in modalities:
            print(f"▶ Training on {train_mod}, Testing on {test_mod}")
            res = train_cross_modality(train_modality=train_mod, test_modality=test_mod, duration_s=duration_s, cfg=cfg)
            results.append(res)


    return pd.DataFrame(results)



In [4]:
# %%
# Cell 3: Run & display results
wandb.login()
df_results = run_all_cross_modality_tests(duration_s=10)
print(df_results)


▶ Training on with_audio, Testing on with_audio


0,1
SMA,▁
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
SMA,0.06868
accelX_filtered_deriv_std,0.10279
accelX_filtered_energy,0.04593
accelX_filtered_var,0.05146
accelY_filtered_deriv_std,0.05379
accelZ_filtered_deriv_std,0.09164
accelZ_filtered_var,0.05117
corr_xz,0.09022
corr_yz,0.06824
test/0/f1-score,0.98506


▶ Training on with_audio, Testing on no_audio


0,1
SMA,▁
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
SMA,0.06868
accelX_filtered_deriv_std,0.10279
accelX_filtered_energy,0.04593
accelX_filtered_var,0.05146
accelY_filtered_deriv_std,0.05379
accelZ_filtered_deriv_std,0.09164
accelZ_filtered_var,0.05117
corr_xz,0.09022
corr_yz,0.06824
test/0/f1-score,0.94534


▶ Training on with_audio, Testing on only_audio


0,1
SMA,▁
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
SMA,0.06868
accelX_filtered_deriv_std,0.10279
accelX_filtered_energy,0.04593
accelX_filtered_var,0.05146
accelY_filtered_deriv_std,0.05379
accelZ_filtered_deriv_std,0.09164
accelZ_filtered_var,0.05117
corr_xz,0.09022
corr_yz,0.06824
test/0/f1-score,0.6691


▶ Training on no_audio, Testing on with_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
corr_xy,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.11423
accelX_filtered_energy,0.07252
accelX_filtered_var,0.07695
accelY_filtered_deriv_std,0.06991
accelZ_filtered_deriv_std,0.07518
accelZ_filtered_energy,0.04668
corr_xy,0.04928
corr_xz,0.06677
corr_yz,0.09609
test/0/f1-score,0.93581


▶ Training on no_audio, Testing on no_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
corr_xy,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.11423
accelX_filtered_energy,0.07252
accelX_filtered_var,0.07695
accelY_filtered_deriv_std,0.06991
accelZ_filtered_deriv_std,0.07518
accelZ_filtered_energy,0.04668
corr_xy,0.04928
corr_xz,0.06677
corr_yz,0.09609
test/0/f1-score,0.96305


▶ Training on no_audio, Testing on only_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
corr_xy,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.11423
accelX_filtered_energy,0.07252
accelX_filtered_var,0.07695
accelY_filtered_deriv_std,0.06991
accelZ_filtered_deriv_std,0.07518
accelZ_filtered_energy,0.04668
corr_xy,0.04928
corr_xz,0.06677
corr_yz,0.09609
test/0/f1-score,0.66203


▶ Training on only_audio, Testing on with_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.09344
accelX_filtered_energy,0.04672
accelX_filtered_var,0.04937
accelY_filtered_deriv_std,0.0668
accelZ_filtered_deriv_std,0.07305
accelZ_filtered_energy,0.0548
accelZ_filtered_var,0.06005
corr_xz,0.05375
corr_yz,0.07016
test/0/f1-score,0.81166


▶ Training on only_audio, Testing on no_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.09344
accelX_filtered_energy,0.04672
accelX_filtered_var,0.04937
accelY_filtered_deriv_std,0.0668
accelZ_filtered_deriv_std,0.07305
accelZ_filtered_energy,0.0548
accelZ_filtered_var,0.06005
corr_xz,0.05375
corr_yz,0.07016
test/0/f1-score,0.80251


▶ Training on only_audio, Testing on only_audio


0,1
accelX_filtered_deriv_std,▁
accelX_filtered_energy,▁
accelX_filtered_var,▁
accelY_filtered_deriv_std,▁
accelZ_filtered_deriv_std,▁
accelZ_filtered_energy,▁
accelZ_filtered_var,▁
corr_xz,▁
corr_yz,▁
test/0/f1-score,▁

0,1
accelX_filtered_deriv_std,0.09344
accelX_filtered_energy,0.04672
accelX_filtered_var,0.04937
accelY_filtered_deriv_std,0.0668
accelZ_filtered_deriv_std,0.07305
accelZ_filtered_energy,0.0548
accelZ_filtered_var,0.06005
corr_xz,0.05375
corr_yz,0.07016
test/0/f1-score,0.70388


  train_modality test_modality  duration_s  accuracy  f1_macro  precision_1  \
0     with_audio    with_audio          10  0.973498  0.934024     0.790496   
1     with_audio      no_audio          10  0.899048  0.642835     0.491015   
2     with_audio    only_audio          10  0.518558  0.392897     0.706004   
3       no_audio    with_audio          10  0.885342  0.699891     0.435645   
4       no_audio      no_audio          10  0.935759  0.858741     0.610633   
5       no_audio    only_audio          10  0.520429  0.418313     0.626087   
6     only_audio    with_audio          10  0.703411  0.557102     0.197781   
7     only_audio      no_audio          10  0.686191  0.519539     0.156313   
8     only_audio    only_audio          10  0.711234  0.711056     0.701244   

   recall_1      f1_1  
0  1.000000  0.882991  
1  0.260411  0.340329  
2  0.063602  0.116691  
3  0.496252  0.463978  
4  0.986837  0.754437  
5  0.101441  0.174593  
6  0.643296  0.302544  
7  0.486226  0.23