In [None]:
# 1. Data Loading and Preprocessing
import os
import torch
import numpy as np
from torch_geometric.data import Data
from tqdm import tqdm

# Set data path
data_path = r'G:\000_New_data\Local\Local_Region\Phase49Aspect1_Local_Region\LocalSize21'

# Load all files
all_files = [f for f in os.listdir(data_path) if f.endswith('.txt')]

# Group by sample
sample_groups = {}
for f in all_files:
    parts = f.split('_')
    if len(parts) >= 3:
        group_key = '_'.join(parts[:2])
        node_part = parts[-1].split('.')[0]
        
        try:
            node_id = int(node_part) - 1
            if group_key not in sample_groups:
                sample_groups[group_key] = {
                    'micro': {},
                    'mfrac': {},
                    'damage': {}
                }
            
            if 'MicroInfo' in f:
                sample_groups[group_key]['micro'][node_id] = f
            elif 'Mfraction' in f:
                sample_groups[group_key]['mfrac'][node_id] = f
            elif 'DamageStrain' in f:
                sample_groups[group_key]['damage'][node_id] = f
        except ValueError:
            continue


# Modified sample processing loop in preprocessing code
graphs = []
file_info = []

for group_key, file_dicts in sample_groups.items():
    # Check if 5 nodes are present
    if len(file_dicts['micro']) == 5 and len(file_dicts['mfrac']) == 5 and len(file_dicts['damage']) == 5:
        # Sort by node ID (using dictionary items())
        sorted_micro_nodes = sorted(file_dicts['micro'].items(), key=lambda x: x[0])
        sorted_mfrac_nodes = sorted(file_dicts['mfrac'].items(), key=lambda x: x[0])
        sorted_damage_nodes = sorted(file_dicts['damage'].items(), key=lambda x: x[0])
        
        # Process file information
        file_info.append({
            'micro': [os.path.join(data_path, f[1]) for f in sorted_micro_nodes],
            'mfrac': [os.path.join(data_path, f[1]) for f in sorted_mfrac_nodes],
            'damage': [os.path.join(data_path, f[1]) for f in sorted_damage_nodes]
        })
        
        # 1. Collect Mfraction values
        mfraction_values = []
        for node_id, filename in sorted_mfrac_nodes:
            try:
                with open(os.path.join(data_path, filename), 'r') as f:
                    first_line = f.readline().strip()
                    if first_line:
                        try:
                            mfraction_values.append(float(first_line.split()[0]))
                        except:
                            mfraction_values.append(0.0)
                    else:
                        mfraction_values.append(0.0)
            except:
                mfraction_values.append(0.0)
        
        # 2. Process microstructure features
        node_features = []
        for node_id, filename in sorted_micro_nodes:
            with open(os.path.join(data_path, filename), 'r') as f:
                features = []
                for line in f:
                    value = line.strip().split(',')[0]
                    try:
                        features.append(float(value))
                    except:
                        features.append(0.0)
                
                if len(features) == 441:
                    features = np.array(features).reshape(21, 21)
                    features_with_mfraction = np.append(features.flatten(), mfraction_values[node_id])
                    node_features.append(features_with_mfraction)
                else:
                    default_features = np.zeros((21, 21)).flatten()
                    features_with_mfraction = np.append(default_features, mfraction_values[node_id])
                    node_features.append(features_with_mfraction)
        
        # 3. Construct graph structure
        edge_index = torch.tensor([
            [0, 1, 2, 3],
            [1, 2, 3, 4]
        ], dtype=torch.long)
        
        # 4. Read damage strain values
        y_values = []
        for node_id, filename in sorted_damage_nodes:
            try:
                with open(os.path.join(data_path, filename), 'r') as f:
                    line = f.readline().strip()
                    if line:
                        try:
                            y_values.append(float(line.split(',')[0]))
                        except:
                            y_values.append(0.0)
                    else:
                        y_values.append(0.0)
            except:
                y_values.append(0.0)
        
        x = torch.tensor(node_features, dtype=torch.float)
        y = torch.tensor(y_values, dtype=torch.float).unsqueeze(1)
        
        data = Data(x=x, edge_index=edge_index, y=y)
        graphs.append(data)

# Check results
print(f"Successfully loaded {len(graphs)} graph data samples")
print("First sample information:")
print(graphs[0])
print("Node feature dimensions:", graphs[0].x.shape)
print("Edge index:", graphs[0].edge_index)
print("Labels:", graphs[0].y)