In [1]:
import pickle
import matplotlib.pyplot as plt


In [2]:
# Load the variables from the file
with open('Results/MicroKPNN_encoder_confounder_free_plots/metrics.pkl', 'rb') as f:
    data = pickle.load(f)

# Extract variables
train_avg_metrics = data['train_avg_metrics']
val_avg_metrics = data['val_avg_metrics']
test_avg_metrics = data['test_avg_metrics']

# Now, train_avg_metrics, val_avg_metrics, and test_avg_metrics are restored


In [5]:
import matplotlib.pyplot as plt
import os

epochs = range(1, 100 + 1)
# Set global font size
plt.rcParams.update({
    'font.size': 14,  # Base font size
    'axes.titlesize': 16,  # Title size
    'axes.labelsize': 14,  # X and Y labels size
    'xtick.labelsize': 12,  # X-tick labels size
    'ytick.labelsize': 12,  # Y-tick labels size
    'legend.fontsize': 12,  # Legend font size
})

# Create directory if not exists
output_dir = 'Results/MicroKPNN_encoder_confounder_free_plots'
os.makedirs(output_dir, exist_ok=True)

# 1. Correlation G Loss History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['gloss_history'], label='Train', linewidth=2)
plt.title("Average Correlation Loss History")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{output_dir}/correlation_g_loss_history.pdf', dpi=300)  # High DPI for papers
plt.close()

# 2. Average Distance Correlation History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['dcor_history'], label='Train', linewidth=2)
plt.plot(epochs, val_avg_metrics['dcor_history'], label='Validation', linewidth=2)
plt.title("Average Distance Correlation History")
plt.xlabel("Epoch")
plt.ylabel("Distance Correlation")
plt.legend()
plt.savefig(f'{output_dir}/average_distance_correlation_history.pdf', dpi=300)
plt.close()

# 3. Average Disease Loss History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['loss_history'], label='Train', linewidth=2)
plt.plot(epochs, val_avg_metrics['loss_history'], label='Validation', linewidth=2)
plt.title("Average Phenotype Loss History")
plt.xlabel("Epoch")
plt.ylabel("Disease Loss")
plt.legend()
plt.savefig(f'{output_dir}/average_disease_loss_history.pdf', dpi=300)
plt.close()

# 4. Average Accuracy History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['accuracy'], label='Train', linewidth=2)
plt.plot(epochs, val_avg_metrics['accuracy'], label='Validation', linewidth=2)
plt.title("Average Accuracy History")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.savefig(f'{output_dir}/average_accuracy_history.pdf', dpi=300)
plt.close()

# 5. Average F1 Score History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['f1_score'], label='Train', linewidth=2)
plt.plot(epochs, val_avg_metrics['f1_score'], label='Validation', linewidth=2)
plt.title("Average F1 Score History")
plt.xlabel("Epoch")
plt.ylabel("F1 Score")
plt.legend()
plt.savefig(f'{output_dir}/average_f1_score_history.pdf', dpi=300)
plt.close()

# 6. Average AUCPR Score History
plt.figure(figsize=(6, 5))
plt.plot(epochs, train_avg_metrics['auc_pr'], label='Train', linewidth=2)
plt.plot(epochs, val_avg_metrics['auc_pr'], label='Validation', linewidth=2)
plt.title("Average AUCPR Score History")
plt.xlabel("Epoch")
plt.ylabel("AUCPR Score")
plt.legend()
plt.savefig(f'{output_dir}/average_aucpr_score_history.pdf', dpi=300)
plt.close()


