In [None]:
# Main execution
if __name__ == "__main__":
    # Configuration parameters
    DATA_FILE = "your_data.csv"  # Change this to your data file
    TRAIN_SPLIT = 0.8           # Proportion of data for training
    NUM_EPOCHS = 450            # Number of epochs for training
    AVERAGING_WINDOW = 10       # Number of last epochs to average for final score

    # Load and preprocess data
    data = pd.read_csv(DATA_FILE)

    # Standardize Age and Target variables
    data['Age'] = (data['Age'] - data['Age'].mean()) / data['Age'].std()
    data['Target'] = (data['Target'] - data['Target'].mean()) / data['Target'].std()

    # Train and test splits
    split_idx = int(TRAIN_SPLIT * len(data))

    # Prepare datasets for both directions
    traindata_age_to_target = np.array([
        data['Age'].values[:split_idx],
        data['Target'].values[:split_idx]
    ]).transpose()

    testdata_age_to_target = np.array([
        data['Age'].values[split_idx:],
        data['Target'].values[split_idx:]
    ]).transpose()

    traindata_target_to_age = np.array([
        data['Target'].values[:split_idx],
        data['Age'].values[:split_idx]
    ]).transpose()

    testdata_target_to_age = np.array([
        data['Target'].values[split_idx:],
        data['Age'].values[split_idx:]
    ]).transpose()

    # Run models for both directions
    results = {
        'CANM': {
            'age_to_target': fit_model(CANM, traindata_age_to_target, testdata_age_to_target,
                                     epochs=NUM_EPOCHS, verbose=True),
            'target_to_age': fit_model(CANM, traindata_target_to_age, testdata_target_to_age,
                                     epochs=NUM_EPOCHS, verbose=True)
        },
        'TransformerVAE': {
            'age_to_target': fit_model(TransformerVAE, traindata_age_to_target, testdata_age_to_target,
                                     epochs=NUM_EPOCHS, verbose=True),
            'target_to_age': fit_model(TransformerVAE, traindata_target_to_age, testdata_target_to_age,
                                     epochs=NUM_EPOCHS, verbose=True)
        }
    }

    # Calculate average scores from last N epochs
    avg_scores = {
        'CANM': {
            'age_to_target': np.mean(results['CANM']['age_to_target']['train_score'][-AVERAGING_WINDOW:]),
            'target_to_age': np.mean(results['CANM']['target_to_age']['train_score'][-AVERAGING_WINDOW:])
        },
        'TransformerVAE': {
            'age_to_target': np.mean(results['TransformerVAE']['age_to_target']['train_score'][-AVERAGING_WINDOW:]),
            'target_to_age': np.mean(results['TransformerVAE']['target_to_age']['train_score'][-AVERAGING_WINDOW:])
        }
    }

    # Determine inferred causal directions
    for model in ['CANM', 'TransformerVAE']:
        direction = "Age -> Target" if avg_scores[model]['age_to_target'] > avg_scores[model]['target_to_age'] else "Target -> Age"
        print(f"\n{model} Inferred Direction:", direction)
        print(f"  Avg Score Age -> Target:", avg_scores[model]['age_to_target'])
        print(f"  Avg Score Target -> Age:", avg_scores[model]['target_to_age'])

    # Plot TransformerVAE convergence
    plt.figure(figsize=(12, 6))
    plt.plot(results['TransformerVAE']['age_to_target']['train_score'],
             label='TV-CANM Age -> Target')
    plt.plot(results['TransformerVAE']['target_to_age']['train_score'],
             label='TV-CANM Target -> Age', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Negative Average Loss')
    plt.title(f'TV-CANM Convergence ({NUM_EPOCHS} Epochs)')
    plt.legend()
    plt.show()

    # Bar plot to compare the models
    labels = ['CANM Age -> Target',
             'CANM Target -> Age',
             'TV-CANM Age -> Target',
             'TV-CANM Target -> Age']

    scores = [avg_scores['CANM']['age_to_target'],
             avg_scores['CANM']['target_to_age'],
             avg_scores['TransformerVAE']['age_to_target'],
             avg_scores['TransformerVAE']['target_to_age']]

    plt.figure(figsize=(10, 6))
    plt.bar(labels, scores, color=['blue', 'lightblue', 'green', 'lightgreen'])
    plt.xlabel('Direction')
    plt.ylabel('Average Score (Negative Loss)')
    plt.title(f'Comparison of Causal Direction Scores for CANM and TV-CANM ({NUM_EPOCHS} Epochs)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()