In [None]:
import pandas as pd
import torch
import numpy as np
from torch_geometric.data import Data
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler, StandardScaler
import glob, random
import pandas as pd
import torch
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv,GATConv,SAGEConv, global_mean_pool, global_max_pool, global_add_pool,GINEConv,SGConv,MessagePassing
from torch_geometric.utils import dense_to_sparse
import matplotlib.pyplot as plt
from egnn_pytorch import EGNN
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score
from torch_scatter import scatter_mean
import time
from collections import defaultdict

def set_seed(seed):
    random.seed(seed)            
    np.random.seed(seed)          
    torch.manual_seed(seed)       
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False     

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


graph_file_pattern = "dataset1/graph_*.csv"
parameter_file_pattern = "dataset1/parameter_*.csv"



shape_encoder = OneHotEncoder(sparse_output=False)
material_encoder = OneHotEncoder(sparse_output=False)

node_feature_scaler = MinMaxScaler()
graph_attribute_scaler = MinMaxScaler()
target_scaler = MinMaxScaler()

shape_encoder.fit(np.array(['circle', 'square', 'ellipse']).reshape(-1, 1))
material_encoder.fit(np.array(['SiO2', 'ZBLAN', 'GeO2']).reshape(-1, 1))
def create_graph_dataset(graph_file_pattern, parameter_file_pattern, device=device):
    structure_files = [f for f in glob.glob(graph_file_pattern) ]
    parameter_files = [f for f in glob.glob(parameter_file_pattern) ]

    all_node_features, all_graph_attributes, all_targets, all_filenames = [], [], [], []
    
    for struct_file, param_file in zip(structure_files, parameter_files):
        df_structure = pd.read_csv(struct_file)
        node_features = df_structure.iloc[:, [2, 3, 4, 5]].values
        node_shapes = df_structure.iloc[:, 6].astype(str).values.reshape(-1, 1)
        node_shapes_encoded = shape_encoder.transform(node_shapes)
        node_features = np.hstack((node_features, node_shapes_encoded))
        all_node_features.append(node_features)

        
        df_params = pd.read_csv(param_file)
        graph_attributes = df_params.iloc[:, [2, 3]].values
        fiber_material = df_params.iloc[:, 1].astype(str).values.reshape(-1, 1)
        fiber_material_encoded = material_encoder.transform(fiber_material)
        graph_attributes = np.hstack((graph_attributes, fiber_material_encoded))
        all_graph_attributes.append(graph_attributes)

        
        targets = df_params.iloc[: ,6:11].values
        all_targets.append(targets)

        all_filenames.append((struct_file, param_file))
        
    all_node_features = np.vstack(all_node_features)
    node_feature_scaler.fit(all_node_features)

    all_graph_attributes = np.vstack(all_graph_attributes)
    graph_attribute_scaler.fit(all_graph_attributes)

    all_targets = np.vstack(all_targets)
    target_scaler.fit(all_targets)

    all_graphs = []
    for struct_file, param_file, filename in zip(structure_files, parameter_files, all_filenames):
        df_structure = pd.read_csv(struct_file)
        node_features = df_structure.iloc[:, [2,3, 4, 5]].values
        node_shapes = df_structure.iloc[:, 6].astype(str).values.reshape(-1, 1)
        node_shapes_encoded = shape_encoder.transform(node_shapes)
        node_features = np.hstack((node_features, node_shapes_encoded))
        node_features_normalized = node_feature_scaler.transform(node_features)
        node_features_tensor = torch.tensor(node_features_normalized, dtype=torch.float).to(device)
        df_params = pd.read_csv(param_file)
        graph_attributes = df_params.iloc[:, [2, 3]].values
        fiber_material = df_params.iloc[:, 1].astype(str).values.reshape(-1, 1)
        fiber_material_encoded = material_encoder.transform(fiber_material)
        graph_attributes = np.hstack((graph_attributes, fiber_material_encoded))
        graph_attributes_normalized = graph_attribute_scaler.transform(graph_attributes)
        graph_attributes_tensor = torch.tensor(graph_attributes_normalized, dtype=torch.float).to(device)
        targets = df_params.iloc[:, 6:11].values
        targets_normalized = target_scaler.transform(targets)
        targets_tensor = torch.tensor(targets_normalized, dtype=torch.float).to(device)
        coordinates = df_structure.iloc[:, [0, 1]].values  
        num_nodes = coordinates.shape[0] 
        edge_index, edge_weights = create_edges_from_coordinates(coordinates)

        for i in range(graph_attributes_tensor.shape[0]):
            graph = Data(
                x=node_features_tensor,
                edge_index=edge_index,
                edge_weights = edge_weights,
                y=targets_tensor[i],
                graph_attr=graph_attributes_tensor[i],
                filename=filename 
            )
            all_graphs.append(graph)
    return all_graphs


def create_no_edge_graph(num_nodes):
    edge_index = torch.empty(2, 0, dtype=torch.long)  
    edge_weights = torch.empty(0, dtype=torch.float)  
    return edge_index, edge_weights

def create_fully_connected_graph(num_nodes):
    # 对于全连接图，所有节点都相互连接
    edge_index = []
    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:  
                edge_index.append([i, j])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() 
    edge_weights = torch.ones(edge_index.shape[1])  
    return edge_index, edge_weights
    
def create_circular_graph(num_nodes):
    edge_index = []
    for i in range(num_nodes):
        edge_index.append([i, (i + 1) % num_nodes])  
        edge_index.append([(i + 1) % num_nodes, i])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() 
    edge_weights = torch.ones(edge_index.shape[1]) 
    return edge_index, edge_weights

def create_random_graph(num_nodes, edge_prob=0.2):
    edge_index = []
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if random.random() < edge_prob:  
                edge_index.append([i, j])
                edge_index.append([j, i])  
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  
    edge_weights = torch.ones(edge_index.shape[1]) 
    return edge_index, edge_weights

def create_epsilon_neighborhood_graph(coordinates, epsilon = 15):
    if not isinstance(coordinates, torch.Tensor):
        coordinates = torch.tensor(coordinates, dtype=torch.float32)

    num_nodes = coordinates.size(0)
    edge_index = []
    edge_weights = []

    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            dist = torch.norm(coordinates[i] - coordinates[j])
            if dist < epsilon:  
                edge_index.append([i, j])
                edge_index.append([j, i])  
                edge_weights.append(dist.item())  

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  
    edge_weights = torch.tensor(edge_weights, dtype=torch.float32)

    return edge_index, edge_weights
    
def create_edges_from_coordinates(coordinates, k=6):
    neigh = NearestNeighbors(n_neighbors=k)
    neigh.fit(coordinates)
    dist_matrix, indices = neigh.kneighbors(coordinates)

    edge_index = []
    edge_weights = []
    
    for i in range(len(coordinates)):
        for j, dist in zip(indices[i], dist_matrix[i]):
            if i != j:  
                edge_index.append([i, j])
                edge_weights.append(dist)
                
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  
    edge_weights = torch.tensor(edge_weights, dtype=torch.float)

    return edge_index, edge_weights

# 数据集划分
def split_dataset(graphs):
    train_data, test_data = train_test_split(graphs, test_size=0.2, random_state=42)
    val_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)
    return train_data, val_data, test_data


all_graph_data = create_graph_dataset(graph_file_pattern, parameter_file_pattern)

train_graphs, val_graphs, test_graphs = split_dataset(all_graph_data)
print(f"Train graphs: {len(train_graphs)}, Val graphs: {len(val_graphs)}, Test graphs: {len(test_graphs)}")

class GCNModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim, num_graph_attributes):
        super(GCNModel, self).__init__()
        self.num_graph_attributes = num_graph_attributes
        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.fc1 = nn.Linear(hidden_dim + num_graph_attributes, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x, edge_index, batch, graph_attr):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = self.norm3(x)
        x = F.relu(x)
        
        x = global_mean_pool(x, batch)  
        graph_attr = graph_attr.view(-1, self.num_graph_attributes)  
        x = torch.cat((x, graph_attr), dim=1) 

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x



class GATModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim, num_graph_attributes):
        super(GATModel, self).__init__()
        self.attn1 = GATConv(num_node_features, hidden_dim)
        self.attn2 = GATConv(hidden_dim, hidden_dim)
        self.attn3 = GATConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.fc1 = nn.Linear(hidden_dim + num_graph_attributes, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch, graph_attr):    
        x = self.attn1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.attn2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)
        x = self.attn3(x, edge_index)
        x = self.norm3(x)
        x = F.relu(x)
        x = global_mean_pool(x, batch) 

        graph_attr = graph_attr.view(-1, num_graph_attributes) 

        x = torch.cat((x, graph_attr), dim=1)  
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x



class GINModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim, num_graph_attributes):
        super(GINModel, self).__init__()
        self.conv1 = GINConv(nn.Sequential(nn.Linear(num_node_features, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        self.conv2 = GINConv(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        self.conv3 = GINConv(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

        self.fc1 = nn.Linear(hidden_dim + num_graph_attributes, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch, graph_attr):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)  
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = self.norm2(x)  
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = self.norm3(x)  
        x = F.relu(x)
        x = global_mean_pool(x, batch)  

        graph_attr = graph_attr.view(-1, num_graph_attributes)  
        x = torch.cat((x, graph_attr), dim=1)  
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
class GraphSAGEModel(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim, num_graph_attributes):
        super(GraphSAGEModel, self).__init__()
        self.conv1 = SAGEConv(num_node_features, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim + num_graph_attributes, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch, graph_attr):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)  
    
        graph_attr = graph_attr.view(-1, num_graph_attributes)
        
        # 拼接 x 和 graph_attr
        x = torch.cat((x, graph_attr), dim=1) 
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs):
    train_loss_list = []
    val_loss_list = []
    best_val_loss = 10000  
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)            
            target = batch.y.view(-1, 5)
            loss = criterion(out, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        train_loss_list.append(avg_train_loss)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                out = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)
                target = batch.y.view(-1, 5)
                loss = criterion(out, target)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        val_loss_list.append(avg_val_loss)

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Epoch {epoch+1}: Validation loss improved to {avg_val_loss:.4f}, saving model...")
    
    model.load_state_dict(torch.load('best_model.pth'))

    
    torch.save(model.state_dict(), final_model_path)
    print(f"Training complete. Final model saved to {final_model_path}")
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, num_epochs + 1), train_loss_list, label='Train Loss')
    plt.plot(range(1, num_epochs + 1), val_loss_list, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train and Validation Loss')
    plt.legend()
    plt.savefig(f"Train and Validation Loss.png")
    plt.show()

    
    model.load_state_dict(torch.load('best_model.pth'))
    torch.save(model.state_dict(), final_model_path)
    
    train_loss_list_scaled = [loss * 1e3 for loss in train_loss_list]
    test_loss_list_scaled = [loss * 1e3 for loss in test_loss_list]
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, num_epochs + 1), train_loss_list_scaled, label='Train')  
    plt.plot(range(1, num_epochs + 1), test_loss_list_scaled, label='Validation') 
    plt.xlabel('Epoch', fontsize=20)
    plt.ylabel('Loss (× 10$^{-3})$', fontsize=20)
    plt.title('Train and Validation Loss', fontsize=20)
    plt.tick_params(axis='both', labelsize=20, direction='in')
    plt.legend(fontsize=20)
    
    plt.savefig(f"Train and Validation Loss.png")
    plt.show()
    
    loss_data = {
        'Epoch': list(range(1, num_epochs + 1)),
        'Train Loss': train_loss_list,
        'Test Loss': test_loss_list
    }
    
    df = pd.DataFrame(loss_data)
    df.to_excel('train_and_test_loss.xlsx', index=False, engine='openpyxl')
    print("Loss data saved to 'train_and_test_loss.xlsx'.")
    
def validate(model, val_loader, criterion):
    model.eval()
    valid_loss = 0
    total_mae = 0
    total_rmse = 0
    total_r2 = 0 
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)
            target = batch.y.view(-1, 5)
            if target.shape[0] != out.shape[0]:
                print(f"Warning: target size {target.shape} does not match output size {out.shape}.")
                continue
            # 计算 SmoothL1Loss
            loss = criterion(out, target)
            valid_loss += loss.item()
            
            # 计算 MAE 和 RMSE
            mae = torch.mean(torch.abs(out - target))
            rmse = torch.sqrt(torch.mean((out - target) ** 2))
            total_mae += mae.item()
            total_rmse += rmse.item()
            # 计算 R²
            r2 = r2_score(target.cpu().numpy(), out.cpu().numpy())
            total_r2 += r2
    avg_valid_loss = valid_loss / len(val_loader)
    avg_mae = total_mae / len(val_loader)
    avg_rmse = total_rmse / len(val_loader)
    avg_r2 = total_r2 / len(val_loader)  # 计算平均 R²
    print(f"Validation Loss: {avg_valid_loss}, MAE: {avg_mae}, RMSE: {avg_rmse}, R²: {avg_r2}")

def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    total_mae = 0
    total_rmse = 0
    total_r2 = 0 
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)

            target = batch.y.view(-1, 5)
            if target.shape[0] != out.shape[0]:
                print(f"Warning: target size {target.shape} does not match output size {out.shape}.")
                continue
            loss = criterion(out, target)
            test_loss += loss.item()
            
            # 计算 MAE 和 RMSE
            mae = torch.mean(torch.abs(out - target))
            rmse = torch.sqrt(torch.mean((out - target) ** 2))
            total_mae += mae.item()
            total_rmse += rmse.item()

            r2 = r2_score(target.cpu().numpy(), out.cpu().numpy())
            total_r2 += r2
    
    avg_test_loss = test_loss / len(test_loader)
    avg_mae = total_mae / len(test_loader)
    avg_rmse = total_rmse / len(test_loader)
    avg_r2 = total_r2 / len(test_loader)  
    
    print(f"Test Loss: {avg_test_loss}, MAE: {avg_mae}, RMSE: {avg_rmse}, R²: {avg_r2}")

num_graph_attributes = 5
input_dim = train_graphs[0].num_node_features + num_graph_attributes  

train_loader = DataLoader(train_graphs, batch_size=32, shuffle=False)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

best_model_path = "best_model.pth"
final_model_path = 'final_model.pth'

model = GCNModel(
    num_node_features=train_graphs[0].num_node_features, 
    hidden_dim=64, 
    output_dim=train_graphs[0].y.size(0),  
    num_graph_attributes=train_graphs[0].graph_attr.size(0),
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.SmoothL1Loss()
start_train_time = time.time()
train_and_evaluate(model, train_loader, test_loader, optimizer, criterion, num_epochs=50)
end_train_time = time.time()
train_time = end_train_time - start_train_time
print(f"训练时间: {train_time:.2f} 秒")
validate(model, val_loader, criterion)
test(model, test_loader, criterion)



In [None]:
def visualize_test_results(model, test_loader, target_scaler):
    model.load_state_dict(torch.load(best_model_path))
    model.to(device)
    model.eval()
    
    all_targets = []
    all_predictions = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)
            all_targets.append(batch.y.cpu().numpy())
            all_predictions.append(out.cpu().numpy())
    
    all_targets = np.concatenate(all_targets, axis=0)
    all_predictions = np.concatenate(all_predictions, axis=0)

    num_features = target_scaler.scale_.shape[0]
    
    mse = mean_squared_error(all_targets, all_predictions)
    mae = mean_absolute_error(all_targets, all_predictions)
    
    print(f'Mean Squared Error (MSE): {mse:.4f}')
    print(f'Mean Absolute Error (MAE): {mae:.4f}')
    
    if all_targets.ndim == 1 or all_targets.shape[1] != num_features:
        all_targets = target_scaler.inverse_transform(all_targets.reshape(-1, num_features))
        all_predictions = target_scaler.inverse_transform(all_predictions.reshape(-1, num_features))
    else:
        all_targets = target_scaler.inverse_transform(all_targets)
        all_predictions = target_scaler.inverse_transform(all_predictions)
  
    param_names = [
        "Effective Refractive Index (a.u.)",
        "Effective Mode Area (um$^2$)",
        "Nonlinear Coefficient (1/W/km)",
        "Dispersion Coefficient (ps/nm/km)",
        "Group velocity dispersion (ps$^2$/km)",
        "Loss",
        r"$\beta_3$",
        r"$\beta_4$"
    ]  
    
    df = pd.DataFrame({
        'Predictions (1)': all_targets[:, 0],
        'Ground Truths (1)': all_predictions[:, 0],
        'Predictions (2)': all_targets[:, 1],
        'Ground Truths (2)': all_predictions[:, 1],
        'Predictions (3)': all_targets[:, 2],
        'Ground Truths (3)': all_predictions[:, 2],
        'Predictions (4)': all_targets[:, 3],
        'Ground Truths (4)': all_predictions[:, 3],
        'Predictions (5)': all_targets[:, 4],
        'Ground Truths (5)': all_predictions[:, 4],
    })
    
    df.to_excel('predictions_and_ground_truths1.xlsx', index=False)
    
    plt.rcParams['font.family'] = 'Times New Roman'
    for i in range(all_targets.shape[1]):
        plt.figure(figsize=(8, 6))
        plt.scatter(all_targets[:, i], all_predictions[:, i], alpha=0.5)
        plt.plot([all_targets[:, i].min(), all_targets[:, i].max()],
                 [all_targets[:, i].min(), all_targets[:, i].max()],
                 'r--')
        plt.title(f'{param_names[i]}', fontsize=20)
        plt.xlabel('FEM', fontsize=20)
        plt.ylabel('GNN', fontsize=20)
        plt.xlim(all_targets[:, i].min(), all_targets[:, i].max())
        plt.ylim(all_targets[:, i].min(), all_targets[:, i].max())
        
        plt.tick_params(axis='both', labelsize=20, direction='in')
        
        plt.grid(False)
        
        plt.savefig(f"{[i]}.png")
        plt.show()

start_predict_time = time.time()
visualize_test_results(model, test_loader, target_scaler)
end_predict_time = time.time()
predict_time = end_predict_time - start_predict_time

In [None]:
def test_and_visualize(model, graph_file_pattern, parameter_file_pattern, target_scaler, device=device):
    structure_files = [f for f in glob.glob(graph_file_pattern) if "SiO2" in f]
    parameter_files = [f for f in glob.glob(parameter_file_pattern) if "SiO2" in f ]
    selected_index = random.randint(0, len(structure_files) - 1)
    struct_file = structure_files[selected_index]
    param_file = parameter_files[selected_index]
    print(f"Structure file: {struct_file}")
    print(f"Parameter file: {param_file}")

    df_structure = pd.read_csv(struct_file)
    df_params = pd.read_csv(param_file)

    node_features = df_structure.iloc[:, [0, 1, 4, 5]].values
    node_shapes = df_structure.iloc[:, 6].astype(str).values.reshape(-1, 1)
    node_shapes_encoded = shape_encoder.transform(node_shapes)
    node_features = np.hstack((node_features, node_shapes_encoded))
    node_features_normalized = node_feature_scaler.transform(node_features)
    node_features_tensor = torch.tensor(node_features_normalized, dtype=torch.float).to(device)

    graph_attributes = df_params.iloc[:, [2, 3]].values
    fiber_material = df_params.iloc[:, 1].astype(str).values.reshape(-1, 1)
    fiber_material_encoded = material_encoder.transform(fiber_material)
    graph_attributes = np.hstack((graph_attributes, fiber_material_encoded))
    graph_attributes_normalized = graph_attribute_scaler.transform(graph_attributes)
    graph_attributes_tensor = torch.tensor(graph_attributes_normalized, dtype=torch.float).to(device)

    targets = df_params.iloc[:, 6:11].values
    targets_normalized = target_scaler.transform(targets)
    targets_tensor = torch.tensor(targets_normalized, dtype=torch.float).to(device)

    wavelengths = df_params.iloc[:, 3].values
    model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
    model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    model.eval()
    predictions = []
    ground_truths = []
    data_list = []

    for i in range(len(wavelengths)):
        data = Data(
            x=node_features_tensor,
            edge_index=torch.empty((2, 0), dtype=torch.long).to(device),
            graph_attr=graph_attributes_tensor[i].unsqueeze(0),
            y=targets_tensor[i].unsqueeze(0)
        )
        data_list.append(data)

    batch = Batch.from_data_list(data_list)

    with torch.no_grad():
        preds = model(batch.x, batch.edge_index, batch.batch, batch.graph_attr)
        preds = preds.cpu().numpy()
        preds = target_scaler.inverse_transform(preds)
        targets = batch.y.cpu().numpy()
        targets = target_scaler.inverse_transform(targets)
        preds[:, 1] -=0
        predictions = preds
        ground_truths = targets

    df = pd.DataFrame({
        'Wavelength': wavelengths,
        'Predictions (1)': predictions[:, 0],
        'Ground Truths (1)': ground_truths[:, 0],
        'Predictions (2)': predictions[:, 1],
        'Ground Truths (2)': ground_truths[:, 1],
        'Predictions (3)': predictions[:, 2],
        'Ground Truths (3)': ground_truths[:, 2],
        'Predictions (4)': predictions[:, 3],
        'Ground Truths (4)': ground_truths[:, 3],
        'Predictions (5)': predictions[:, 4],
        'Ground Truths (5)': ground_truths[:, 4],
    })

    df.to_excel('predictions_and_ground_truths.xlsx', index=False)
    
    param_names = ["Effective Refractive Index", "Effective Mode Area", "Nonlinear Coefficient", "Dispersion Coefficient","Group velocity dispersion","loss","beta3","beta4","Effective Refractive Index"]
    param_units = [
        "a.u.",
        "um^2",
        "1/W/km",
        "ps/nm/km",
        "ps^2/km",
        "dB/m",
        "s^3/m",
        "s^4/m",
        "a.u."
    ]
    for i in range(targets.shape[1]):
        plt.figure(figsize=(7, 5))
        plt.plot(wavelengths, predictions[:, i], '--', label=f"GNN")
        plt.plot(wavelengths, ground_truths[:, i], label=f"FEM")
        plt.xlabel("Wavelength (nm)")
        plt.ylabel(f"{param_names[i]} ({param_units[i]})")
        plt.title(f"{param_names[i]}")
        if i == 0:
            plt.legend(loc='best', fontsize=12)
        plt.grid(False)
        plt.tick_params(axis='both', direction='in')
        plt.savefig(f"{param_names[i]}.png")
        plt.show()

test_and_visualize(
    model=model,
    graph_file_pattern=graph_file_pattern,
    parameter_file_pattern=parameter_file_pattern,
    target_scaler=target_scaler,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
