In [None]:
import numpy as np
import scipy.sparse as sparse
import scipy.sparse.linalg as linalg
import matplotlib.pyplot as plt
from datetime import datetime
import json

class MutationLibrary:
    
    def __init__(self, beam_length=5.0):
        self.beam_length = beam_length
        self.mutation_types = self._initialize_mutation_types()
    
    def _initialize_mutation_types(self):
        return {
                'joint_misaligned': {
                'category': '连接界面',
                'name': '错位连接',
                'description': '刚度波动，先升后降',
                'default_params': {
                    'amplitude_range': [(-0.15, -0.05), (0.05, 0.15)],  
                    'width_range': (0.2, 0.3),
                    'shape': 'bimodal',  
                    'symmetry': 'asymmetric',
                    'position_constraint': 'any',
                    'probability_weight': 0.1,
                },
                'generate_function': self._generate_joint_misaligned
            },

            'joint_normal': {
                'category': '连接界面',
                'name': '正常连接界面',
                'description': '轻微刚度过渡，±10%变化',
                'default_params': {
                    'amplitude_range': (-0.1, 0.1),  
                    'width_range': (0.04, 0.06),     
                    'shape': 'gaussian',
                    'symmetry': 'symmetric',
                    'position_constraint': 'any',  
                    'probability_weight': 0.3,     
                },
                'generate_function': self._generate_joint_normal
            },
            
            'joint_loose': {
                'category': '连接界面',
                'name': '松动连接',
                'description': '刚度显著下降，20-40%损失，左侧影响更大',
                'default_params': {
                    'amplitude_range': (-0.4, -0.2),
                    'width_range': (0.10, 0.25),
                    'shape': 'asymmetric_gaussian',  
                    'symmetry': 'left_biased',  
                    'position_constraint': 'any',
                    'probability_weight': 0.25,
                },
                'generate_function': self._generate_joint_loose
            },
            
            'joint_over_tight': {
                'category': '连接界面',
                'name': '过紧连接',
                'description': '刚度异常上升，20-30%增加',
                'default_params': {
                    'amplitude_range': (0.2, 0.3),
                    'width_range': (0.10, 0.25),
                    'shape': 'sharp_gaussian',  
                    'symmetry': 'symmetric',
                    'position_constraint': 'any',
                    'probability_weight': 0.15,
                },
                'generate_function': self._generate_joint_over_tight
            },
            
            'joint_misaligned': {
                'category': '连接界面',
                'name': '错位连接',
                'description': '刚度波动，先升后降',
                'default_params': {
                    'amplitude_range': (-0.15, 0.15),  
                    'width_range': (0.2, 0.3),
                    'shape': 'bimodal',  
                    'symmetry': 'asymmetric',
                    'position_constraint': 'any',
                    'probability_weight': 0.1,
                },
                'generate_function': self._generate_joint_misaligned
            },
            
            'overload_damage': {
                'category': '外部损伤',
                'name': '过载损伤',
                'description': '屈服后刚度下降，大范围',
                'default_params': {
                    'amplitude_range': (-0.25, -0.15),
                    'width_range': (0.3, 1.0),
                    'shape': 'plateau',  
                    'symmetry': 'symmetric',
                    'position_constraint': 'middle',  
                    'probability_weight': 0.2,
                },
                'generate_function': self._generate_overload_damage
            },
            
            'scratch_damage': {
                'category': '外部损伤',
                'name': '划痕/刮伤',
                'description': '线状表面损伤，浅层影响',
                'default_params': {
                    'amplitude_range': (-0.2, -0.1),
                    'width_range': (0.005, 0.01),  
                    'length_range': (0.1, 0.8),     
                    'shape': 'linear_decay',
                    'symmetry': 'unidirectional',
                    'position_constraint': 'any',
                    'probability_weight': 0.15,
                },
                'generate_function': self._generate_scratch_damage
            },
            
            'impact_dent': {
                'category': '外部损伤',
                'name': '冲击凹陷',
                'description': '局部凹陷，中心下降，边缘波动',
                'default_params': {
                    'amplitude_range': (-0.5, -0.2),
                    'diameter_range': (0.05, 0.15),  
                    'shape': 'impact_wave',  
                    'symmetry': 'radial',
                    'position_constraint': 'any',
                    'probability_weight': 0.1,
                },
                'generate_function': self._generate_impact_dent
            },
            
            'uniform_corrosion': {
                'category': '腐蚀磨损',
                'name': '均匀腐蚀',
                'description': '大面积均匀刚度下降',
                'default_params': {
                    'amplitude_range': (-0.3, -0.1),
                    'width_range': (0.5, 2.0),
                    'shape': 'uniform',
                    'symmetry': 'symmetric',
                    'position_constraint': 'any',
                    'probability_weight': 0.1,
                },
                'generate_function': self._generate_uniform_corrosion
            },
            
            'hairline_crack': {
                'category': '裂缝损伤',
                'name': '发丝裂缝',
                'description': '极窄区域，轻微刚度下降',
                'default_params': {
                    'amplitude_range': (-0.15, -0.05),
                    'width_range': (0.002, 0.005),  
                    'shape': 'sharp_step',  
                    'symmetry': 'symmetric',
                    'position_constraint': 'stress_concentration',  
                    'probability_weight': 0.08,
                },
                'generate_function': self._generate_hairline_crack
            },
            
            'surface_crack': {
                'category': '裂缝损伤',
                'name': '表面裂缝',
                'description': '一侧开放，渐进式下降',
                'default_params': {
                    'amplitude_range': (-0.2, -0.1),
                    'width_range': (0.01, 0.03),
                    'shape': 'exponential_decay',
                    'symmetry': 'one_sided',  
                    'position_constraint': 'surface',
                    'probability_weight': 0.07,
                },
                'generate_function': self._generate_surface_crack
            },
            
            'fatigue_crack_cluster': {
                'category': '裂缝损伤',
                'name': '疲劳裂纹群',
                'description': '多个小裂纹聚集，通常在根部',
                'default_params': {
                    'amplitude_range': (-0.3, -0.2),  
                    'width_range': (0.1, 0.3),        
                    'num_cracks_range': (3, 7),       
                    'individual_amplitude_range': (-0.1, -0.05),  
                    'shape': 'cluster',
                    'symmetry': 'clustered',
                    'position_constraint': 'root_only',  
                    'probability_weight': 0.05,
                },
                'generate_function': self._generate_fatigue_crack_cluster
            }
        }

    def _generate_joint_normal(self, position=None, **kwargs):
        params = self.mutation_types['joint_normal']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.5, self.beam_length - 0.5)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        
        return {
            'type': 'joint_normal',
            'name': '正常连接界面',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'shape': 'gaussian',
            'symmetry': 'symmetric',
            'metadata': {
                'physics': '理想螺栓连接、正常焊接',
                'severity': '轻微',
                'repair_priority': '低'
            }
        }
    
    def _generate_joint_loose(self, position=None, **kwargs):
        params = self.mutation_types['joint_loose']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.5, self.beam_length - 0.5)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        left_bias = np.random.uniform(0.6, 0.8)  
        
        return {
            'type': 'joint_loose',
            'name': '松动连接',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'left_width': width * left_bias,
            'right_width': width * (1 - left_bias),
            'shape': 'asymmetric_gaussian',
            'symmetry': 'left_biased',
            'metadata': {
                'physics': '螺栓松动、焊缝开裂',
                'severity': '中等',
                'repair_priority': '高'
            }
        }
    
    def _generate_joint_over_tight(self, position=None, **kwargs):
        params = self.mutation_types['joint_over_tight']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.5, self.beam_length - 0.5)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        sharpness = np.random.uniform(2.0, 4.0)  
        
        return {
            'type': 'joint_over_tight',
            'name': '过紧连接',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'sharpness': sharpness,  
            'shape': 'sharp_gaussian',
            'symmetry': 'symmetric',
            'metadata': {
                'physics': '螺栓过拧、预应力效应',
                'severity': '警告',
                'repair_priority': '中'
            }
        }
    
    def _generate_joint_misaligned(self, position=None, **kwargs):
        params = self.mutation_types['joint_misaligned']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.5, self.beam_length - 0.5)
        
        width = np.random.uniform(*params['width_range'])
        peak1_amplitude = np.random.uniform(0.05, 0.15)  
        peak2_amplitude = np.random.uniform(-0.15, -0.05)  
        peak_separation = width * 0.3  
        
        return {
            'type': 'joint_misaligned',
            'name': '错位连接',
            'position': position,
            'width': width,
            'amplitude': [peak1_amplitude, peak2_amplitude],
            'peak_separation': peak_separation,
            'shape': 'bimodal',
            'symmetry': 'asymmetric',
            'metadata': {
                'physics': '安装错位、偏心连接',
                'severity': '中等',
                'repair_priority': '高'
            }
        }
    
    def _generate_overload_damage(self, position=None, **kwargs):
        params = self.mutation_types['overload_damage']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.2 * self.beam_length, 0.6 * self.beam_length)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        plateau_ratio = np.random.uniform(0.3, 0.7)  
        
        return {
            'type': 'overload_damage',
            'name': '过载损伤',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'plateau_ratio': plateau_ratio,
            'shape': 'plateau',
            'symmetry': 'symmetric',
            'metadata': {
                'physics': '超载使用、塑性变形',
                'severity': '严重',
                'repair_priority': '紧急'
            }
        }
    
    def _generate_scratch_damage(self, position=None, **kwargs):
        params = self.mutation_types['scratch_damage']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.2, self.beam_length - 0.2)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        length = np.random.uniform(*params['length_range'])
        
        angle = np.random.uniform(-30, 30)  
        
        return {
            'type': 'scratch_damage',
            'name': '划痕/刮伤',
            'position': position,
            'width': width,
            'length': length,
            'amplitude': amplitude,
            'angle': angle,
            'shape': 'linear_decay',
            'symmetry': 'unidirectional',
            'metadata': {
                'physics': '工具刮擦、安装损伤',
                'severity': '轻微',
                'repair_priority': '低'
            }
        }
    
    def _generate_impact_dent(self, position=None, **kwargs):
        params = self.mutation_types['impact_dent']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.3, self.beam_length - 0.3)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        diameter = np.random.uniform(*params['diameter_range'])
        ripple_factor = np.random.uniform(1.5, 2.5)  
        
        return {
            'type': 'impact_dent',
            'name': '冲击凹陷',
            'position': position,
            'diameter': diameter,
            'amplitude': amplitude,
            'ripple_factor': ripple_factor,
            'shape': 'impact_wave',
            'symmetry': 'radial',
            'metadata': {
                'physics': '物体撞击、机械损伤',
                'severity': '中等',
                'repair_priority': '中'
            }
        }
    
    def _generate_uniform_corrosion(self, position=None, **kwargs):
        params = self.mutation_types['uniform_corrosion']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.1 * self.beam_length, 0.9 * self.beam_length)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        uniformity = np.random.uniform(0.8, 0.95)  
        
        return {
            'type': 'uniform_corrosion',
            'name': '均匀腐蚀',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'uniformity': uniformity,
            'shape': 'uniform',
            'symmetry': 'symmetric',
            'metadata': {
                'physics': '电化学腐蚀、氧化',
                'severity': '中等到严重',
                'repair_priority': '高'
            }
        }
    
    def _generate_hairline_crack(self, position=None, **kwargs):
        """生成发丝裂缝"""
        params = self.mutation_types['hairline_crack']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            
            position = np.random.choice([
                np.random.uniform(0.1, 0.3),
                np.random.uniform(0.4, 0.6),
                np.random.uniform(0.7, 0.9)
            ]) * self.beam_length
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        sharpness = np.random.uniform(10.0, 20.0)  
        
        return {
            'type': 'hairline_crack',
            'name': '发丝裂缝',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'sharpness': sharpness,
            'shape': 'sharp_step',
            'symmetry': 'symmetric',
            'metadata': {
                'physics': '疲劳初期、微裂纹',
                'severity': '警告',
                'repair_priority': '中'
            }
        }
    
    def _generate_surface_crack(self, position=None, **kwargs):
        params = self.mutation_types['surface_crack']['default_params'].copy()
        params.update(kwargs)
        
        if position is None:
            position = np.random.uniform(0.2, self.beam_length - 0.2)
        
        amplitude = np.random.uniform(*params['amplitude_range'])
        width = np.random.uniform(*params['width_range'])
        decay_rate = np.random.uniform(2.0, 5.0)  
        
        return {
            'type': 'surface_crack',
            'name': '表面裂缝',
            'position': position,
            'width': width,
            'amplitude': amplitude,
            'decay_rate': decay_rate,
            'shape': 'exponential_decay',
            'symmetry': 'one_sided',
            'metadata': {
                'physics': '腐蚀裂纹、表面疲劳',
                'severity': '中等',
                'repair_priority': '高'
            }
        }
    
    def _generate_fatigue_crack_cluster(self, position=None, **kwargs):
        params = self.mutation_types['fatigue_crack_cluster']['default_params'].copy()
        params.update(kwargs)
        
        position = 0.1 * self.beam_length  
        
        overall_amplitude = np.random.uniform(*params['amplitude_range'])
        cluster_width = np.random.uniform(*params['width_range'])
        num_cracks = np.random.randint(*params['num_cracks_range'])
        
        individual_cracks = []
        for i in range(num_cracks):
            crack_pos = position + np.random.uniform(0, cluster_width)
            crack_amp = np.random.uniform(*params['individual_amplitude_range'])
            crack_width = np.random.uniform(0.005, 0.02)
            individual_cracks.append({
                'position': crack_pos,
                'amplitude': crack_amp,
                'width': crack_width
            })
        
        return {
            'type': 'fatigue_crack_cluster',
            'name': '疲劳裂纹群',
            'position': position,
            'width': cluster_width,
            'amplitude': overall_amplitude,
            'num_cracks': num_cracks,
            'individual_cracks': individual_cracks,
            'shape': 'cluster',
            'symmetry': 'clustered',
            'metadata': {
                'physics': '高周疲劳区域',
                'severity': '严重',
                'repair_priority': '紧急'
            }
        }

    def generate_random_mutations(self, n_mutations_range=(0, 3), 
                                 exclude_categories=None,
                                 include_categories=None):

        n_mutations = np.random.randint(*n_mutations_range)
        
        if n_mutations == 0:
            return []
        
        available_types = {}
        for type_name, type_info in self.mutation_types.items():
            category = type_info['category']
            
            if exclude_categories and category in exclude_categories:
                continue
            if include_categories and category not in include_categories:
                continue
            
            available_types[type_name] = type_info
        
        if not available_types:
            return []
        
        type_names = list(available_types.keys())
        weights = [available_types[name]['default_params']['probability_weight'] 
                  for name in type_names]
        weights = np.array(weights) / np.sum(weights)  
        
        selected_types = np.random.choice(type_names, size=n_mutations, 
                                         p=weights, replace=False)
        
        mutations = []
        for type_name in selected_types:
            type_info = available_types[type_name]
            generate_func = type_info['generate_function']
            
            mutation = generate_func()
            mutations.append(mutation)
        
        return mutations

class StiffnessFieldGenerator:
    
    def __init__(self, L=5.0, n_elements=300):

        self.L = L
        self.n_elements = n_elements
        self.x_nodes = np.linspace(0, L, n_elements + 1)
        self.x_centers = (self.x_nodes[:-1] + self.x_nodes[1:]) / 2
        self.dx = self.x_centers[1] - self.x_centers[0]
    
    def generate_kl_field(self, mean_E=200e9, cv=0.1, length_scale=1.0, 
                         n_terms=None, random_seed=None):

        if random_seed is not None:
            np.random.seed(random_seed)
        
        sigma_log = np.sqrt(np.log(1 + cv**2))
        mu_log = np.log(mean_E) - 0.5 * sigma_log**2
        
        x = self.x_centers
        n = len(x)
        
        X1, X2 = np.meshgrid(x, x, indexing='ij')
        distance_matrix = np.abs(X1 - X2)
        
        C = np.exp(-0.5 * (distance_matrix / length_scale)**2)
        
        eigenvalues, eigenvectors = np.linalg.eigh(C)
        
        idx = np.argsort(eigenvalues)[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]
        
        if n_terms is None:
            cumulative_variance = np.cumsum(eigenvalues) / np.sum(eigenvalues)
            n_terms = np.where(cumulative_variance >= 0.95)[0][0] + 1
            n_terms = min(n_terms, 10)  
        
        eigenvalues = eigenvalues[:n_terms]
        eigenvectors = eigenvectors[:, :n_terms]
        
        xi = np.random.randn(n_terms)
        
        Y = mu_log + np.dot(eigenvectors, np.sqrt(eigenvalues) * xi) * sigma_log
        
        E_base = np.exp(Y)
        
        return E_base
    
    def add_mutations(self, E_base, mutations):

        E_with_mutations = E_base.copy()
        mutation_mask = np.zeros_like(E_base, dtype=bool)
        
        for mutation in mutations:
            pos = mutation.get('position', 2.5)
            width = mutation.get('width', 0.1)  
            amplitude = mutation.get('amplitude', -0.3)
            shape = mutation.get('shape', 'gaussian')
            mtype = mutation.get('type', 'joint')
            
            distances = np.abs(self.x_centers - pos)
            
            if width <= self.dx:

                idx = np.argmin(distances)
                influence_range = distances <= (self.dx * 0.51)  
            else:
                influence_range = distances <= (width / 2)
            
            if shape == 'bimodal' and isinstance(amplitude, (list, tuple)) and len(amplitude) == 2:
                peak_separation = mutation.get('peak_separation', width * 0.3)
                
                pos1 = pos - peak_separation/2
                dist1 = np.abs(self.x_centers - pos1)
                influence_range1 = dist1 <= (width / 2)
                weights1 = np.exp(-0.5 * (dist1[influence_range1] / (width/4))**2)
                if np.max(weights1) > 0:
                    weights1 = weights1 / np.max(weights1)
                
                pos2 = pos + peak_separation/2
                dist2 = np.abs(self.x_centers - pos2)
                influence_range2 = dist2 <= (width / 2)
                weights2 = np.exp(-0.5 * (dist2[influence_range2] / (width/4))**2)
                if np.max(weights2) > 0:
                    weights2 = weights2 / np.max(weights2)
                
                E_with_mutations[influence_range1] *= (1.0 + amplitude[0] * weights1)

                E_with_mutations[influence_range2] *= (1.0 + amplitude[1] * weights2)
                

                mutation_mask[influence_range1] = True
                mutation_mask[influence_range2] = True
                
            elif shape == 'cluster' and mtype == 'fatigue_crack_cluster':

                individual_cracks = mutation.get('individual_cracks', [])
                for crack in individual_cracks:
                    crack_pos = crack.get('position', pos)
                    crack_amp = crack.get('amplitude', -0.1)
                    crack_width = crack.get('width', 0.01)
                    

                    crack_dist = np.abs(self.x_centers - crack_pos)
                    crack_influence = crack_dist <= (crack_width / 2)
                    
                    if np.any(crack_influence):

                        sigma = crack_width / 4
                        weights = np.exp(-0.5 * (crack_dist[crack_influence] / sigma)**2)
                        if np.max(weights) > 0:
                            weights = weights / np.max(weights)
                        
                        E_with_mutations[crack_influence] *= (1.0 + crack_amp * weights)
                        mutation_mask[crack_influence] = True
            
            else:

                if isinstance(amplitude, (list, tuple)):

                    amplitude = amplitude[0] if len(amplitude) > 0 else -0.3
                    print(f"警告：突变类型 {mtype} 的振幅为列表，已取第一个值 {amplitude}")
                
                if shape == 'gaussian':
                    sigma = width / 4  
                    weights = np.exp(-0.5 * (distances[influence_range] / sigma)**2)
                elif shape == 'step':
                    weights = np.ones(np.sum(influence_range))
                elif shape == 'triangle':
                    weights = 1.0 - distances[influence_range] / (width / 2)
                elif shape == 'asymmetric_gaussian':
                    left_width = mutation.get('left_width', width * 0.7)
                    right_width = mutation.get('right_width', width * 0.3)
                    
                    weights = np.zeros(np.sum(influence_range))
                    for i, idx in enumerate(np.where(influence_range)[0]):
                        dist = distances[idx]
                        if self.x_centers[idx] < pos:  
                            sigma = left_width / 4
                        else: 
                            sigma = right_width / 4
                        weights[i] = np.exp(-0.5 * (dist / sigma)**2)
                else:
                    weights = np.ones(np.sum(influence_range))
                
                if np.max(weights) > 0:
                    weights = weights / np.max(weights)
                
                if amplitude < 0:  
                    reduction = -amplitude * weights
                    E_with_mutations[influence_range] *= (1.0 - reduction)
                else:  
                    increase = amplitude * weights
                    E_with_mutations[influence_range] *= (1.0 + increase)
                
                # 更新突变掩码
                mutation_mask[influence_range] = True
        
        return E_with_mutations, mutation_mask
    
    def generate_complete_field(self, mean_E=200e9, cv=0.1, length_scale=1.0,
                               mutations=None, random_seed=None):

        E_base = self.generate_kl_field(
            mean_E=mean_E, cv=cv, length_scale=length_scale,
            random_seed=random_seed
        )
        
        if mutations is None:
            mutations = []
        
        E_final, mutation_mask = self.add_mutations(E_base, mutations)
        
        return E_final, E_base, mutation_mask

class EnhancedStiffnessFieldGenerator(StiffnessFieldGenerator):
    
    def __init__(self, L=5.0, n_elements=300):
        super().__init__(L, n_elements)
        self.mutation_library = MutationLibrary(L)
    
    def generate_with_random_mutations(self, mean_E=200e9, cv=0.1, 
                                      length_scale=1.0,
                                      mutation_config=None,
                                      random_seed=None,
                                      ensure_fixed_end_joint=True,
                                      min_joint_spacing=1.5):

        if random_seed is not None:
            np.random.seed(random_seed)
        
        E_base = self.generate_kl_field(
            mean_E=mean_E, cv=cv, length_scale=length_scale
        )
        
        if mutation_config is None:
            mutation_config = {}
        
        custom_mutations = mutation_config.get('custom_mutations', [])
        
        if ensure_fixed_end_joint and len(custom_mutations) == 0:

            fixed_end_mutation = self.mutation_library._generate_joint_normal(
                position=0.05  
            )
            custom_mutations.append(fixed_end_mutation)
        
        random_mutations = self.mutation_library.generate_random_mutations(
            n_mutations_range=mutation_config.get('n_mutations_range', (0, 3)),
            exclude_categories=mutation_config.get('exclude_categories'),
            include_categories=mutation_config.get('include_categories')
        )
        
        all_mutations = custom_mutations + random_mutations
        
        joint_mutations = [m for m in all_mutations if m['type'].startswith('joint')]
        if len(joint_mutations) > 1:
            all_mutations = self._enforce_joint_spacing(all_mutations, min_joint_spacing)
        
        if ensure_fixed_end_joint:
            all_mutations = self._ensure_minimum_joints(all_mutations, min_count=3)
        
        old_format_mutations = []
        for mutation in all_mutations:
            old_format = self._convert_to_old_format(mutation)
            old_format_mutations.append(old_format)
        
        E_final, mutation_mask = self.add_mutations(E_base, old_format_mutations)
        
        return E_final, E_base, mutation_mask, all_mutations
    
    def _enforce_joint_spacing(self, mutations, min_spacing=1.5):

        if len(mutations) <= 1:
            return mutations
        

        sorted_mutations = sorted(mutations, key=lambda x: x['position'])
        

        adjusted_mutations = [sorted_mutations[0]]
        
        for i in range(1, len(sorted_mutations)):
            prev_mut = adjusted_mutations[-1]
            curr_mut = sorted_mutations[i]
            

            if curr_mut['type'].startswith('joint') and prev_mut['type'].startswith('joint'):
                curr_pos = curr_mut['position']
                prev_pos = prev_mut['position']
                
                if abs(curr_pos - prev_pos) < min_spacing:

                    new_pos = prev_pos + min_spacing
                    if new_pos < self.L - 0.5:  
                        curr_mut = curr_mut.copy()
                        curr_mut['position'] = new_pos
            
            adjusted_mutations.append(curr_mut)
        
        return adjusted_mutations
    
    def _ensure_minimum_joints(self, mutations, min_count=3):

        joint_mutations = [m for m in mutations if m['type'].startswith('joint')]
        
        if len(joint_mutations) >= min_count:
            return mutations
        

        needed_count = min_count - len(joint_mutations)
        
        existing_positions = [m['position'] for m in joint_mutations]
        new_joint_types = ['joint_normal', 'joint_loose', 'joint_over_tight']
        
        for _ in range(needed_count):
            available_positions = []
            for pos in np.linspace(0.5, self.L - 0.5, 20):
                if all(abs(pos - existing_pos) >= 1.0 for existing_pos in existing_positions):
                    available_positions.append(pos)
            
            if not available_positions:
                break
            
            new_pos = np.random.choice(available_positions)
            new_type = np.random.choice(new_joint_types)
            
            if new_type == 'joint_normal':
                new_mutation = self.mutation_library._generate_joint_normal(position=new_pos)
            elif new_type == 'joint_loose':
                new_mutation = self.mutation_library._generate_joint_loose(position=new_pos)
            else:  
                new_mutation = self.mutation_library._generate_joint_over_tight(position=new_pos)
            
            mutations.append(new_mutation)
            existing_positions.append(new_pos)
        
        return mutations
    
    def _convert_to_old_format(self, mutation):
        old_format = {
            'position': mutation['position'],
            'type': mutation['type'],
            'metadata': mutation.get('metadata', {})
        }
        
        if mutation['type'] == 'joint_misaligned':
            old_format.update({
                'width': mutation.get('width', 0.2),
                'amplitude': mutation['amplitude'], 
                'shape': 'bimodal',
                'peak_separation': mutation.get('peak_separation', 0.06)
            })
        elif mutation['type'] == 'joint_loose':
            old_format.update({
                'width': mutation.get('width', 0.15),
                'amplitude': mutation['amplitude'],
                'shape': 'asymmetric_gaussian',
                'left_width': mutation.get('left_width', 0.105), 
                'right_width': mutation.get('right_width', 0.045)  
            })
        elif mutation['type'] == 'fatigue_crack_cluster':

            old_format.update({
                'width': mutation['width'],
                'amplitude': mutation['amplitude'],
                'shape': 'cluster',
                'num_cracks': mutation['num_cracks'],
                'individual_cracks': mutation['individual_cracks']
            })
        else:

            amplitude = mutation.get('amplitude', -0.3)

            if isinstance(amplitude, (list, tuple)):
                amplitude = amplitude[0] if len(amplitude) > 0 else -0.3
            
            old_format.update({
                'width': mutation.get('width', 0.1),
                'amplitude': amplitude,
                'shape': mutation.get('shape', 'gaussian')
            })
        
        return old_format

        
class Element(object):
    def __init__(self, node_a, node_b):
        self.node_a = node_a
        self.node_b = node_b
        
    def get_dof(self):
        return (self.node_a.dof_a, self.node_a.dof_b,
                self.node_b.dof_a, self.node_b.dof_b)
    
    def get_boundary(self):
        return (self.node_a.boundary, self.node_a.boundary,
                self.node_b.boundary, self.node_b.boundary)
        
        
class Node(object):
    def __init__(self, coord):
        self.coord = coord
        self.boundary = False
        self.dof_a = 0
        self.dof_b = 0
        
        
def mesh(xa, xb, n_elements): 
    n_nodes = n_elements + 1  
    xcoords, step = np.linspace(xa, xb, num=n_nodes, retstep=True) 
    nodes = [Node(x) for x in xcoords] 
    nodes[0].boundary = True     
    elements = [Element(nodes[i], nodes[i + 1]) for i in range(n_elements)] 
    a = 0.5 * step
    n_dof = 2 * (n_nodes - 1)
    dof_count = 0 
    for n in nodes:
        if not n.boundary:
            n.dof_a = dof_count
            n.dof_b = dof_count + 1
            dof_count += 2

    return elements, a, n_dof
        
    
def shapes(xi, a):
    n = [0.5 * (1 - xi), 0.5 * (1 + xi)]
    dn = [-0.5 / a, 0.5 / a]
    return n, dn


def element_matrices(a, G, E, Iz, rho, A, kappa):
    ke = np.zeros((4, 4))
    me = np.zeros((4, 4))
    

    points = [0.577350269189626, -0.577350269189626]
    for p in points:
        n, dn = shapes(p, a)
        Ba = np.array([[0, dn[0], 0, dn[1]]])
        Bb = np.array([[n[0], 0, n[1], 0]])
        Bc = np.array([[0, n[0], 0, n[1]]])

        ke += E * Iz * Ba.T @ Ba * a
        me += a * rho * A * Bb.T @ Bb + a * rho * Iz * Bc.T @ Bc
    

    n, dn = shapes(0.0, a)
    Bd = np.array([[dn[0], n[0], dn[1], n[1]]])
    ke += 2 * a * kappa * G * A * Bd.T @ Bd
    
    return ke, me


def assemble(elements, E_final, G_final, n_dof, Iz, a): 
    rows, cols, kk, mm = [], [], [], []
    rho = 7850
    A = 0.25
    kappa = 5/6
    
    for e_idx, element in enumerate(elements):
        dof = element.get_dof() 
        boundary = element.get_boundary() 
        E = E_final[e_idx]
        G = G_final[e_idx]
        ke, me = element_matrices(a, G, E, Iz, rho, A, kappa)
        for i in range(4):
            for j in range(4):
                if not boundary[i] and not boundary[j]:
                    rows.append(dof[i])
                    cols.append(dof[j])
                    kk.append(ke[i, j])
                    mm.append(me[i, j])

    K = sparse.coo_matrix((kk, (rows, cols)), shape=(n_dof, n_dof)).tocsr()
    M = sparse.coo_matrix((mm, (rows, cols)), shape=(n_dof, n_dof)).tocsr()

    return K, M



def modal_analysis_with_mode_shapes(n_elements=300, n_modes=10, mutation_config=None, 
                                   random_seed=None, ensure_fixed_end_joint=True,
                                   return_mode_shapes=True):

    L = 5.0
    E_mean = 200e9
    v = 0.3
    rho = 7850
    A = 0.25
    I = 1/192
    kappa = 5/6
    G_mean = E_mean / (2 * (1 + v))
    h = 0.5
    wid = 0.5
    Iz = wid * h**3 / 12
    
    elements, a, n_dof = mesh(0, L, n_elements)
    
    generator = EnhancedStiffnessFieldGenerator(L=L, n_elements=n_elements)
    
    if mutation_config is None:
        mutation_config = {
            'n_mutations_range': (0, 5),  
            'exclude_categories': None,
            'include_categories': None
        }
    
    E_final, E_base, mutation_mask, mutations = generator.generate_with_random_mutations(
        mean_E=200e9, cv=0.1, length_scale=1.0,
        mutation_config=mutation_config,
        random_seed=random_seed,
        ensure_fixed_end_joint=ensure_fixed_end_joint,
        min_joint_spacing=1.5
    )
    
    G_final = E_final / (2 * (1 + v))
    K, M = assemble(elements, E_final, G_final, n_dof, Iz, a)
    
    w2, eigenvectors = linalg.eigsh(K, k=n_modes, M=M, sigma=0, which='LM')
    freqs = np.sqrt(w2) / (2 * np.pi)
    
    if return_mode_shapes:

        n_nodes = n_elements + 1
        
        mode_shapes = np.zeros((n_modes, n_nodes))
        
        for mode_idx in range(n_modes):

            eigenvector = eigenvectors[:, mode_idx]
            

            for node_idx in range(n_nodes):

                if node_idx == 0:  
                    mode_shapes[mode_idx, node_idx] = 0.0
                else:
                    dof_idx = 2 * (node_idx - 1)  
                    mode_shapes[mode_idx, node_idx] = eigenvector[dof_idx]
        
        for mode_idx in range(n_modes):
            max_val = np.max(np.abs(mode_shapes[mode_idx, :]))
            if max_val > 0:
                mode_shapes[mode_idx, :] /= max_val
        
        return freqs, mode_shapes, E_final, mutation_mask, mutations
    else:
        return freqs, None, E_final, mutation_mask, mutations

def process_mode_shapes_for_nn(mode_shapes, n_fixed_points=50):

    n_modes, n_nodes = mode_shapes.shape
    

    n_modes_used = min(5, n_modes)
    
    L = 5.0
    x_fixed = np.linspace(0, L, n_fixed_points)
    x_original = np.linspace(0, L, n_nodes)
    

    mode_shapes_processed = np.zeros((n_modes_used, n_fixed_points))
    
    for mode_idx in range(n_modes_used):

        mode_interp = np.interp(x_fixed, x_original, mode_shapes[mode_idx, :])
        mode_shapes_processed[mode_idx, :] = mode_interp
    
    return mode_shapes_processed

def add_measurement_noise(freqs, mode_shapes, freq_noise_std=0.01, mode_noise_std=0.02):
    
    freq_noise = np.random.randn(len(freqs)) * freq_noise_std
    noisy_freqs = freqs * (1 + freq_noise)
    
    if mode_shapes is not None:
        n_modes, n_points = mode_shapes.shape
        
        noisy_mode_shapes = mode_shapes.copy()
        
        for mode_idx in range(n_modes):

            base_noise = np.random.randn(n_points) * mode_noise_std
            
            for i in range(1, n_points):
                base_noise[i] = 0.7 * base_noise[i-1] + 0.3 * base_noise[i]
            
            noisy_mode_shapes[mode_idx, :] += base_noise
            
        for mode_idx in range(n_modes):
            max_val = np.max(np.abs(noisy_mode_shapes[mode_idx, :]))
            if max_val > 0:
                noisy_mode_shapes[mode_idx, :] /= max_val
    else:
        noisy_mode_shapes = None
    
    return noisy_freqs, noisy_mode_shapes

def create_damage_labels(E_final, mutation_mask, n_fixed_points=50):

    L = 5.0
    n_elements = len(E_final)
    x_centers = np.linspace(0, L, n_elements)
    x_fixed = np.linspace(0, L, n_fixed_points)
    
    stiffness_field_interp = np.interp(x_fixed, x_centers, E_final)
    
    mutation_mask_interp = np.interp(x_fixed, x_centers, mutation_mask.astype(float))
    damage_labels = (mutation_mask_interp > 0.5).astype(int)
    
    stiffness_threshold = np.percentile(E_final, 30)  
    stiffness_labels = (stiffness_field_interp < stiffness_threshold).astype(int)
    
    combined_labels = np.logical_or(damage_labels, stiffness_labels).astype(int)
    
    return combined_labels, stiffness_field_interp

def create_multi_class_labels(mutations, n_fixed_points=50):
    
    L = 5.0
    x_fixed = np.linspace(0, L, n_fixed_points)
    multi_class_labels = np.zeros(n_fixed_points, dtype=int)  # 0: 健康
    
    damage_category_map = {
        'joint': 1,          # 连接处问题
        'external': 2,       # 外部损伤
        'corrosion': 3,      # 腐蚀磨损
        'crack': 4           # 裂缝
    }
    
    for mutation in mutations:
        pos = mutation['position']
        category = None
        
        if 'joint' in mutation['type']:
            category = 1
        elif any(keyword in mutation['type'] for keyword in ['overload', 'scratch', 'impact']):
            category = 2
        elif 'corrosion' in mutation['type']:
            category = 3
        elif 'crack' in mutation['type']:
            category = 4
        
        if category is not None:
            width = mutation.get('width', 0.1)
            left_bound = max(0, pos - width/2)
            right_bound = min(L, pos + width/2)
            
            mask = (x_fixed >= left_bound) & (x_fixed <= right_bound)
            multi_class_labels[mask] = category
    
    return multi_class_labels

def generate_training_sample(sample_id, n_elements=300, n_fixed_points=50,
                            add_noise=True, multi_class=False, random_seed=None):
    
    if random_seed is not None:
        np.random.seed(random_seed)
    
    freqs, mode_shapes, E_final, mutation_mask, mutations = modal_analysis_with_mode_shapes(
        n_elements=n_elements,
        n_modes=5,  
        random_seed=random_seed,
        ensure_fixed_end_joint=True
    )
    
    mode_shapes_processed = process_mode_shapes_for_nn(mode_shapes, n_fixed_points)
    

    if add_noise:
        noisy_freqs, noisy_mode_shapes = add_measurement_noise(
            freqs[:10], mode_shapes_processed,
            freq_noise_std=0.01, 
            mode_noise_std=0.02   
        )
    else:
        noisy_freqs = freqs[:10]
        noisy_mode_shapes = mode_shapes_processed
    
    binary_labels, stiffness_interp = create_damage_labels(E_final, mutation_mask, n_fixed_points)
    
    if multi_class:
        damage_labels = create_multi_class_labels(mutations, n_fixed_points)
    else:
        damage_labels = binary_labels
    

    sample = {
        'sample_id': sample_id,
        'input': {
            'frequencies': noisy_freqs.astype(np.float32),  
            'mode_shapes': noisy_mode_shapes.astype(np.float32),  
        },
        'output': {
            'stiffness_field': stiffness_interp.astype(np.float32),  
            'damage_labels': damage_labels.astype(np.int32),  
            'binary_labels': binary_labels.astype(np.int32),  
        },
        'metadata': {
            'n_elements': n_elements,
            'n_fixed_points': n_fixed_points,
            'n_mutations': len(mutations),
            'mutation_types': [mut['type'] for mut in mutations],
            'mutation_positions': [mut['position'] for mut in mutations],
            'random_seed': random_seed,
            'generation_time': datetime.now().isoformat()
        }
    }
    
    return sample

def generate_dataset(n_samples, output_dir='dataset', batch_size=1000,
                    n_elements=300, n_fixed_points=50, 
                    add_noise=True, multi_class=False):

    import os
    import pickle
    from datetime import datetime
    
    os.makedirs(output_dir, exist_ok=True)
    
    dataset_stats = {
        'n_samples_total': n_samples,
        'n_elements': n_elements,
        'n_fixed_points': n_fixed_points,
        'add_noise': add_noise,
        'multi_class': multi_class,
        'generation_start_time': datetime.now().isoformat(),
        'samples_per_batch': batch_size
    }
    
    n_batches = (n_samples + batch_size - 1) // batch_size
    
    for batch_idx in range(n_batches):
        batch_start = batch_idx * batch_size
        batch_end = min((batch_idx + 1) * batch_size, n_samples)
        batch_size_actual = batch_end - batch_start
        
        print(f"生成批次 {batch_idx+1}/{n_batches}: 样本 {batch_start+1}-{batch_end}")
        
        batch_samples = []
        
        for i in range(batch_size_actual):
            sample_id = batch_start + i

            random_seed = 42 + sample_id * 100  
            
            try:
                sample = generate_training_sample(
                    sample_id=sample_id,
                    n_elements=n_elements,
                    n_fixed_points=n_fixed_points,
                    add_noise=add_noise,
                    multi_class=multi_class,
                    random_seed=random_seed
                )
                batch_samples.append(sample)
                
                if (i + 1) % 100 == 0:
                    print(f"  已生成 {i+1}/{batch_size_actual} 个样本")
                    
            except Exception as e:
                print(f"  样本 {sample_id} 生成失败: {e}")
                continue
        

        batch_filename = os.path.join(output_dir, f'batch_{batch_idx:04d}.pkl')
        with open(batch_filename, 'wb') as f:
            pickle.dump(batch_samples, f)
        
        print(f"  批次 {batch_idx+1} 已保存到 {batch_filename}")
    
    dataset_stats['generation_end_time'] = datetime.now().isoformat()
    dataset_stats['n_batches'] = n_batches
    
    stats_filename = os.path.join(output_dir, 'dataset_stats.json')
    import json
    with open(stats_filename, 'w') as f:
        json.dump(dataset_stats, f, indent=2)
    
    print(f"\n数据集生成完成！")
    print(f"总样本数: {n_samples}")
    print(f"输出目录: {output_dir}")
    print(f"统计信息: {stats_filename}")
    
    return dataset_stats
    
if __name__ == "__main__":
    
    print("=" * 60)
    print("梁损伤识别训练数据生成系统")
    print("=" * 60)
    
    generate_full_dataset = False  # 设置为True以生成完整数据集，避免重复生成
    
    if generate_full_dataset:
        print("开始生成完整数据集...")
        
        print("\n--- 生成训练集 (80,000样本) ---")
        train_stats = generate_dataset(
            n_samples=90000,
            output_dir='dataset/train',
            batch_size=1000,  # 每批1000个
            n_elements=300,   # 完整单元数
            n_fixed_points=50,
            add_noise=True,
            multi_class=True   # 多类别标签
        )
        
        print("\n--- 生成验证集 (15,000样本) ---")
        val_stats = generate_dataset(
            n_samples=15000,
            output_dir='dataset/val',
            batch_size=500,
            n_elements=300,
            n_fixed_points=50,
            add_noise=True,
            multi_class=True
        )
        
        print("\n--- 生成测试集 (5000样本) ---")
        test_stats = generate_dataset(
            n_samples=500,
            output_dir='dataset/test',
            batch_size=500,
            n_elements=300,
            n_fixed_points=50,
            add_noise=True,
            multi_class=True
        )
        
        total_stats = {
            'total_samples': 100000,
            'train_samples': 80000,
            'val_samples': 15000,
            'test_samples': 5000,
            'generation_completed': datetime.now().isoformat()
        }
        
        with open('dataset/dataset_summary.json', 'w') as f:
            json.dump(total_stats, f, indent=2)
        
        print("\n" + "=" * 60)
        print("数据集生成完成!")
        print(f"总样本数: 100,000")
        print(f"训练集: 80,000")
        print(f"验证集: 15,000")
        print(f"测试集: 5,000")
        print("=" * 60)
    
    else:
        print("\n完整数据集生成已跳过。")
        print("要生成完整数据集，请将 'generate_full_dataset' 设置为 True")
    