In [None]:
# Training Curves
# Training curve visualization (modified to RMSE)
plt.figure(figsize=(18, 6))
plt.rcParams.update({
    'axes.titlesize': 20,
    'axes.labelsize': 22,
    'xtick.labelsize': 24,
    'ytick.labelsize': 24,
    'legend.fontsize': 20,
    'grid.alpha': 0.3,
    'grid.linestyle': '--',
    'grid.linewidth': 0.5
})

ax1 = plt.subplot(1, 2, 1)
# Modification: Convert MSE to RMSE
train_rmse = [np.sqrt(x) for x in train_losses]
test_rmse = [np.sqrt(x) for x in test_losses]  # Modified to test_rmse
ax1.plot(train_rmse, label='Train RMSE', linewidth=2, color='blue')
ax1.plot(test_rmse, label='Test RMSE', linewidth=2, color='red')  # Modified to test_rmse
ax1.set_xlabel('Epoch')
ax1.set_ylabel('RMSE Loss')  # Modified ylabel
ax1.grid(True)
ax1.legend()

ax2 = plt.subplot(1, 2, 2)
ax2.plot(train_r2_scores, label='Train R2', linewidth=2, color='blue')
ax2.plot(test_r2_scores, label='Test R2', linewidth=2, color='red')  # Modified to test_r2
ax2.set_xlabel('Epoch')
ax2.set_ylabel('R2 Score')
ax2.grid(True)
ax2.legend()

plt.tight_layout()
plt.show()

# Prediction Results Scatter Plot (unchanged)
plt.figure(figsize=(8, 8))
colors = ['r', 'g', 'b', 'c', 'm']
markers = ['o', 's', 'D', '^', 'v']
labels = [f'{i}' for i in range(5)]
node_metrics = []

# for node_id in range(5):
#     mask = test_node_ids == node_id
#     t = test_targets[mask]
#     p = test_preds[mask]

#     plt.scatter(t, p, 
#                c=colors[node_id], 
#                marker=markers[node_id],
#                s=90,
#                alpha=0.8,
#                label=labels[node_id])

#     # Modification: Add RMSE calculation
#     mse = np.mean((t - p) ** 2)
#     rmse = np.sqrt(mse)
#     r2 = r2_score(t, p)
#     node_metrics.append({
#         'node_id': node_id,
#         'mse': mse,
#         'rmse': rmse,  # Added RMSE
#         'r2': r2,
#         'count': len(t)
#     })

# min_val = 0
# max_val = 0.5
# plt.plot([min_val, max_val], [min_val, max_val], 'k-', linewidth=1.5, alpha=0.7)
# plt.xlim(-0.02, 0.5)
# plt.ylim(-0.02, 0.5)
# plt.xticks(np.arange(0, 0.51, 0.1))
# plt.yticks(np.arange(0, 0.51, 0.1))
# plt.xlabel('True Values', fontsize=20, weight='normal')
# plt.ylabel('Predictions', fontsize=20, weight='normal')
# plt.grid(True, linestyle='--', alpha=0.3)
# legend = plt.legend(title='NodeNum', fontsize=18)
# plt.setp(legend.get_title(), fontsize=16)
# plt.tight_layout()
# plt.show()


# # Node Performance Evaluation (add RMSE output)
# print("\n=== Node Prediction Performance Evaluation ===")
# print(f"{'Node':<6}{'Samples':<8}{'MSE':<12}{'RMSE':<12}{'R2':<10}")
# print("-" * 50)
# sorted_by_mse = sorted(node_metrics, key=lambda x: x['mse'])
# for metrics in sorted_by_mse:
#     print(f"{metrics['node_id']:<6}{metrics['count']:<8}{metrics['mse']:.6f}{metrics['rmse']:.6f}{metrics['r2']:.4f}")

# print("\nSorted by R2:")
# sorted_by_r2 = sorted(node_metrics, key=lambda x: -x['r2'])
# for metrics in sorted_by_r2:
#     print(f"Node{metrics['node_id']}: R2={metrics['r2']:.4f}, RMSE={metrics['rmse']:.4f}, MSE={metrics['mse']:.6f}")

# Gating Response Visualization (unchanged)
model.eval()
test_values = torch.linspace(0, 0.26, 100).unsqueeze(1).to(device)

with torch.no_grad():
    raw_weights = model.fusion.gate(test_values)
    grouped_weights = raw_weights.view(100, 32, 25).mean(dim=1)
    expanded_weights = grouped_weights.unsqueeze(1).repeat(1, 32, 1).view(100, 800)
    mean_raw_weights = raw_weights.mean(dim=1)
    mean_grouped_weights = grouped_weights.mean(dim=1)
    heatmap_data = grouped_weights.mean(dim=0).view(5, 5).cpu()
    mean_spatial_weights = heatmap_data.mean()

plt.figure(figsize=(12, 12))
plt.subplot(2, 1, 1)

# Plot raw weights
for i in range(800):
    plt.plot(test_values.cpu(), raw_weights[:, i].cpu(), 
            color='blue', alpha=0.15, linewidth=0.5, 
            label='Raw Weights' if i == 0 else "")

# Plot grouped weights
for i in range(0, 800, 32):
    plt.plot(test_values.cpu(), expanded_weights[:, i].cpu(),
            color='red', linewidth=1.5,
            label='Grouped Weights' if i == 0 else "")

# Plot mean values
plt.plot(test_values.cpu(), mean_raw_weights.cpu(), 
        color='black', linewidth=3, linestyle='--', 
        label='Mean Raw Weights')
plt.plot(test_values.cpu(), mean_grouped_weights.cpu(),
        color='green', linewidth=3, linestyle='-.', 
        label='Mean Grouped Weights')

plt.xlabel("Global Martensite Fraction Value", fontsize=22)
plt.ylabel("Gate Weight Value", fontsize=22)
plt.grid(True, alpha=0.3)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), 
          fontsize=18, framealpha=0.5)
plt.tick_params(axis='both', which='major', labelsize=22)

# Heatmap subplot
plt.subplot(2, 1, 2)
im = plt.imshow(heatmap_data, cmap='viridis', aspect=1)
plt.title(f"Spatial Weight Distribution (Mean={mean_spatial_weights:.3f})", 
          fontsize=22, pad=20)
cbar = plt.colorbar(im, label='Weight Value', shrink=0.8)
cbar.ax.tick_params(labelsize=18)
cbar.set_label('Weight Value', fontsize=16)
plt.xticks(np.arange(5), labels=np.arange(5), fontsize=18)
plt.yticks(np.arange(5), labels=np.arange(5), fontsize=18)

plt.tight_layout()
plt.show()

# Prediction Results Scatter Plot (including training/test sets)
plt.figure(figsize=(8, 8))

# Define dataset styles
dataset_styles = {
    'train': {'color': 'indigo', 'marker': 'o', 'label': 'TrainSet', 'alpha': 1, 's': 120},
    'test': {'color': 'deeppink', 'marker': 'D', 'label': 'TestSet', 'alpha': 0.8, 's': 140}
}

# Collect prediction results for training and test sets
all_datasets = {
    'train': {'loader': train_loader, 'true': [], 'pred': []},
    'test': {'loader': test_loader, 'true': [], 'pred': []}
}

# Get all prediction results
model.eval()
with torch.no_grad():
    for dataset_name in all_datasets:
        loader = all_datasets[dataset_name]['loader']
        for data in loader:
            data = data.to(device)
            out = model(data)
            all_datasets[dataset_name]['true'].append(data.y.cpu().numpy())
            all_datasets[dataset_name]['pred'].append(out.cpu().numpy())

# Plot scatter plot (distinguished by dataset)
for dataset_name in ['train', 'test']:
    true_values = np.concatenate(all_datasets[dataset_name]['true'])
    pred_values = np.concatenate(all_datasets[dataset_name]['pred'])
    style = dataset_styles[dataset_name]
    plt.scatter(
        true_values, pred_values,
        c=style['color'],
        marker=style['marker'],
        s=style['s'],
        alpha=style['alpha'],
        label=style['label']
    )

# Plot ideal reference line
min_val = -0.04
max_val = 0.5
plt.plot([min_val, max_val], [min_val, max_val], 'k', linewidth=1.5, alpha=0.7)

# Chart decoration
plt.xlim(-0.04, 0.5)
plt.ylim(-0.04, 0.5)
plt.xlabel('True Values', fontsize=22)
plt.ylabel('Predictions', fontsize=22)
plt.grid(True, linestyle='--', alpha=0.3)

# Set legend text size
plt.legend(fontsize=18)

plt.tight_layout()
plt.show()