In [None]:
# Set up datasets for both directions
traindata_age_to_height = np.array([data['Age'].values, data['Height'].values]).transpose()
traindata_height_to_age = np.array([data['Height'].values, data['Age'].values]).transpose()

# Run CANM model
_, score_CANM_age_to_height = fit_model(CANM, traindata_age_to_height, verbose=True)
_, score_CANM_height_to_age = fit_model(CANM, traindata_height_to_age, verbose=True)

# Run TransformerVAE model
_, score_TransformerVAE_age_to_height = fit_model(TransformerVAE, traindata_age_to_height, verbose=True)
_, score_TransformerVAE_height_to_age = fit_model(TransformerVAE, traindata_height_to_age, verbose=True)

# Calculate average losses over the last 10 epochs for comparison
avg_score_CANM_age_to_height = np.mean(score_CANM_age_to_height[-10:])
avg_score_CANM_height_to_age = np.mean(score_CANM_height_to_age[-10:])
avg_score_TransformerVAE_age_to_height = np.mean(score_TransformerVAE_age_to_height[-10:])
avg_score_TransformerVAE_height_to_age = np.mean(score_TransformerVAE_height_to_age[-10:])

# Determine the inferred directions and print values
direction_CANM = "Age -> Height" if avg_score_CANM_age_to_height > avg_score_CANM_height_to_age else "Height -> Age"
direction_TransformerVAE = "Age -> Height" if avg_score_TransformerVAE_age_to_height > avg_score_TransformerVAE_height_to_age else "Height -> Age"

print("\nCANM Inferred Direction:", direction_CANM)
print("  Avg Score Age -> Height:", avg_score_CANM_age_to_height)
print("  Avg Score Height -> Age:", avg_score_CANM_height_to_age)

print("\nTransformerVAE Inferred Direction:", direction_TransformerVAE)
print("  Avg Score Age -> Height:", avg_score_TransformerVAE_age_to_height)
print("  Avg Score Height -> Age:", avg_score_TransformerVAE_height_to_age)

# Plotting the convergence of loss values over epochs for each model and direction
plt.figure(figsize=(14, 8))
plt.plot(score_TransformerVAE_age_to_height, label='TV-CANM Age -> Height')
plt.plot(score_TransformerVAE_height_to_age, label='TV-CANM Height -> Age', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Negative Average Loss')
plt.title('Causal Direction Convergence for TV-CANM (Age to Height 450 - Epochs)')
plt.legend()
plt.show()

# Bar plot comparison of final average scores for each direction
labels = ['CANM Age -> Height', 'CANM Height -> Age', 'TV-CANM Age -> Height', 'TV-CANM Height -> Age']
scores = [avg_score_CANM_age_to_height, avg_score_CANM_height_to_age, avg_score_TransformerVAE_age_to_height, avg_score_TransformerVAE_height_to_age]
colors = ['blue', 'lightblue', 'green', 'lightgreen']

plt.figure(figsize=(10, 6))
plt.bar(labels, scores, color=colors)
plt.xlabel('Direction')
plt.ylabel('Average Score (Negative Loss)')
plt.title('Comparison of Causal Direction Scores for CANM and TV-CANM (Age to Height 450 - Epochs)')
plt.show()