In [None]:
import config
import numpy as np
import Test_without_Training

def get_data(K1, K2, subj1, subj2):
    trainer1 = Test_without_Training.TremorModelTrainer(config, subject=subj1)
    X1_train, y1_train, X1_test, y1_test = trainer1.return_K_data(K=K1)

    trainer2 = Test_without_Training.TremorModelTrainer(config, subject=subj2)
    X2_train, y2_train, X2_test, y2_test = trainer2.return_K_data(K=K2)

    # Combine and shuffle training data from both subjects
    X_train = np.concatenate([X1_train, X2_train], axis=0)
    y_train = np.concatenate([y1_train, y2_train], axis=0)
    idx = np.random.permutation(len(X_train))
    X_train, y_train = X_train[idx], y_train[idx]

    return {
        "trainer": trainer2,  # just pick one to call train_multiple_dataset
        "X1_train": X1_train, "y1_train": y1_train,
        "X2_train": X2_train, "y2_train": y2_train,
        "X_train": X_train,   "y_train": y_train,
        "X1_test": X1_test,   "y1_test": y1_test,
        "X2_test": X2_test,   "y2_test": y2_test
    }

def get_results(K1, K2, subj1, subj2):
    data = get_data(K1, K2, subj1, subj2)
    trainer = data["trainer"]

    # Train with mixed, only subj2, and only subj1
    acc1, model_mix = trainer.train_multiple_dataset(data["X_train"], data["y_train"], data["X1_test"], data["y1_test"])
    acc2, model_s2  = trainer.train_multiple_dataset(data["X2_train"], data["y2_train"], data["X1_test"], data["y1_test"])
    acc3, model_s1  = trainer.train_multiple_dataset(data["X1_train"], data["y1_train"], data["X1_test"], data["y1_test"])

    # Evaluate all models on subj2
    _, acc_mix_s2 = model_mix.evaluate(data["X2_test"], data["y2_test"], verbose=0)
    _, acc_s2_s2  = model_s2.evaluate(data["X2_test"], data["y2_test"], verbose=0)
    _, acc_s1_s2  = model_s1.evaluate(data["X2_test"], data["y2_test"], verbose=0)

    return [acc1, acc2, acc3], [acc_mix_s2, acc_s2_s2, acc_s1_s2]

def print_results(acc_H, acc_X):
    print(f"Trained with Subjects 1 and 2 / Tested on 1 (Inter-session):     {acc_H[0]:.2f}%")
    print(f"Trained with Subject 1        / Tested on 1 (Inter-session):     {acc_H[2]:.2f}%")
    print(f"Trained with Subject 2        / Tested on 1 (Inter-subject):     {acc_H[1]:.2f}%\n")

    print(f"Trained with Subjects 1 and 2 / Tested on 2 (Inter-session):     {acc_X[0]*100:.2f}%")
    print(f"Trained with Subject 2        / Tested on 2 (Inter-session):     {acc_X[1]*100:.2f}%")
    print(f"Trained with Subject 1        / Tested on 2 (Inter-subject):     {acc_X[2]*100:.2f}%")
