In [None]:
# Set up datasets for both directions
traindata_age_to_target = np.array([data['Age'].values, data['Target'].values]).transpose()
traindata_target_to_age = np.array([data['Target'].values, data['Age'].values]).transpose()

# Run CANM model
_, score_CANM_age_to_target = fit_model(CANM, traindata_age_to_target, verbose=True)
_, score_CANM_target_to_age = fit_model(CANM, traindata_target_to_age, verbose=True)

# Run TransformerVAE model
_, score_TransformerVAE_age_to_target = fit_model(TransformerVAE, traindata_age_to_target, verbose=True)
_, score_TransformerVAE_target_to_age = fit_model(TransformerVAE, traindata_target_to_age, verbose=True)

# Calculate average losses over the last 10 epochs for comparison
avg_score_CANM_age_to_target = np.mean(score_CANM_age_to_target[-10:])
avg_score_CANM_target_to_age = np.mean(score_CANM_target_to_age[-10:])
avg_score_TransformerVAE_age_to_target = np.mean(score_TransformerVAE_age_to_target[-10:])
avg_score_TransformerVAE_target_to_age = np.mean(score_TransformerVAE_target_to_age[-10:])

# Determine the inferred directions and print values
direction_CANM = "Age -> Target" if avg_score_CANM_age_to_target > avg_score_CANM_target_to_age else "Target -> Age"
direction_TransformerVAE = "Age -> Target" if avg_score_TransformerVAE_age_to_target > avg_score_TransformerVAE_target_to_age else "Target -> Age"

print("\nCANM Inferred Direction:", direction_CANM)
print("  Avg Score Age -> Target:", avg_score_CANM_age_to_target)
print("  Avg Score Target -> Age:", avg_score_CANM_target_to_age)

print("\nTransformerVAE Inferred Direction:", direction_TransformerVAE)
print("  Avg Score Age -> Target:", avg_score_TransformerVAE_age_to_target)
print("  Avg Score Target -> Age:", avg_score_TransformerVAE_target_to_age)

# Plotting the convergence of loss values over epochs for each model and direction
plt.figure(figsize=(14, 8))
plt.plot(score_TransformerVAE_age_to_target, label='TV-CANM Age -> Target')
plt.plot(score_TransformerVAE_target_to_age, label='TV-CANM Target -> Age', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Negative Average Loss')
plt.title('Causal Direction Convergence for TV-CANM (Age to Target 450 - Epochs)')
plt.legend()
plt.show()

# Bar plot comparison of final average scores for each direction
labels = ['CANM Age -> Target', 'CANM Target -> Age', 'TV-CANM Age -> Target', 'TV-CANM Target -> Age']
scores = [avg_score_CANM_age_to_target, avg_score_CANM_target_to_age, avg_score_TransformerVAE_age_to_target, avg_score_TransformerVAE_target_to_age]
colors = ['blue', 'lightblue', 'green', 'lightgreen']

plt.figure(figsize=(10, 6))
plt.bar(labels, scores, color=colors)
plt.xlabel('Direction')
plt.ylabel('Average Score (Negative Loss)')
plt.title('Comparison of Causal Direction Scores for CANM and TV-CANM (Age to Target 450 - Epochs)')
plt.show()