In [1]:
"""
Predictive Chemical Degradation Network Generator - User Version
"""

import pickle
import pandas as pd
import numpy as np
import networkx as nx
from collections import Counter, defaultdict
import json
import base64
from io import BytesIO
import warnings
import sys
warnings.filterwarnings('ignore')

# RDKit imports
try:
    from rdkit import Chem
    from rdkit.Chem import Draw, AllChem, Descriptors
    from rdkit.Chem import DataStructs
    from rdkit.Chem.Fingerprints import FingerprintMols
    from rdkit import RDLogger
    RDLogger.DisableLog('rdApp.*')
    RDKIT_AVAILABLE = True
except ImportError:
    RDKIT_AVAILABLE = False
    print("❌ Error: This feature requires RDKit library")
    print("Installation command: pip install rdkit-pypi")
    exit(1)

class OptimizedPredictiveDegradationNetwork:
    def __init__(self, from_pickle='degradation_model.pkl'):
        """
        Initialize predictive network analyzer from pickle file
        
        Parameters:
        -----------
        from_pickle : str
            Path to pickle file with pre-processed data
        """
        self.load_from_pickle(from_pickle)
        
    def load_from_pickle(self, filename):
        """
        Load processed data from pickle file
        """
        print(f"Loading from pickle file: {filename}")
        
        try:
            with open(filename, 'rb') as f:
                data = pickle.load(f)
            
            # Restore all attributes
            self.alpha = data['alpha']
            self.beta = data['beta']
            self.gamma = data['gamma']
            self.reaction_systems = data['reaction_systems']
            self.smiles_cache = data.get('smiles_cache', {})
            self.forward_index = defaultdict(list, data['forward_index'])
            self.backward_index = defaultdict(list, data['backward_index'])
            self.chem_groups_all = defaultdict(set, data['chem_groups_all'])
            self.chem_groups = data['chem_groups']
            self.all_templates = data['all_templates']
            self.reaction_type_map = data['reaction_type_map']
            self.template_class_freq = data['template_class_freq']
            self.reaction_type_freq = data['reaction_type_freq']
            
            # Restore molecule database from SMILES
            self.molecule_db = {}
            for smiles, mol_smiles in data['molecule_db'].items():
                mol = Chem.MolFromSmiles(mol_smiles)
                if mol:
                    self.molecule_db[smiles] = mol
            
            # Restore fingerprints
            self.fingerprints = {}
            for smiles, fp_string in data['fingerprints_data'].items():
                # Convert bit string back to fingerprint
                fp = DataStructs.CreateFromBitString(fp_string)
                self.fingerprints[smiles] = fp
            
            print(f"✅ Loaded processed data from {filename}")
            print(f"   - Molecules: {len(self.molecule_db)}")
            print(f"   - Templates: {len(self.all_templates)}")
            print(f"   - Forward reactions: {sum(len(v) for v in self.forward_index.values())}")
            print(f"✓ Confidence model parameters: α={self.alpha}, β={self.beta}, γ={self.gamma}")
            
        except FileNotFoundError:
            print(f"❌ Error: Model file '{filename}' not found")
            print("Please ensure the degradation_model.pkl file is in the current directory")
            raise
        except Exception as e:
            print(f"❌ Error loading model: {str(e)}")
            raise
    
    def standardize_smiles(self, smiles):
        """
        Standardize SMILES string to ensure unique representation for same molecular structure
        Uses cache to avoid redundant calculations
        """
        if pd.isna(smiles) or smiles == '':
            return smiles
            
        # Check cache
        if smiles in self.smiles_cache:
            return self.smiles_cache[smiles]
        
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                self.smiles_cache[smiles] = smiles
                return smiles
            
            # Use canonical SMILES to ensure uniqueness
            canonical_smiles = Chem.MolToSmiles(mol, canonical=True)
            self.smiles_cache[smiles] = canonical_smiles
            return canonical_smiles
        except:
            self.smiles_cache[smiles] = smiles
            return smiles
    
    def calculate_similarity(self, smiles1, smiles2):
        """Calculate Tanimoto similarity between two molecules (using standardized SMILES)"""
        # Standardize both SMILES
        std_smiles1 = self.standardize_smiles(smiles1)
        std_smiles2 = self.standardize_smiles(smiles2)
        
        if std_smiles1 not in self.fingerprints or std_smiles2 not in self.fingerprints:
            return 0.0
        
        fp1 = self.fingerprints[std_smiles1]
        fp2 = self.fingerprints[std_smiles2]
        
        return DataStructs.TanimotoSimilarity(fp1, fp2)
    
    def check_system_consistency(self, reaction_type1, reaction_type2):
        """Check reaction system consistency"""
        if pd.isna(reaction_type1) or pd.isna(reaction_type2):
            return 0.5
        return 1.0 if reaction_type1 == reaction_type2 else 0.0
    
    def calculate_confidence_optimized(self, similarity, template_freq_log, consistency):
        """Optimized confidence calculation (implementing special rules)"""
        # Rule 1: Perfect match
        if similarity == 1.0 and consistency == 1.0:
            return 1.0
        
        # Rule 2: Same molecule but different system
        elif similarity == 1.0 and consistency < 1.0:
            return self.alpha + self.beta + self.gamma * consistency
        
        # Rule 3: General case
        else:
            return self.alpha * similarity + self.beta * template_freq_log + self.gamma * consistency
    
    def find_similar_molecules(self, query_smiles, threshold=0.3, top_n=20):
        """Find molecules similar to query molecule (using standardized SMILES)"""
        # Standardize query SMILES
        query_smiles_std = self.standardize_smiles(query_smiles)
        
        query_mol = Chem.MolFromSmiles(query_smiles_std)
        if not query_mol:
            return []
        
        query_fp = AllChem.GetMorganFingerprintAsBitVect(query_mol, 2, nBits=2048)
        
        similarities = []
        for smiles, fp in self.fingerprints.items():
            if smiles != query_smiles_std:  # Compare using standardized SMILES
                similarity = DataStructs.TanimotoSimilarity(query_fp, fp)
                if similarity >= threshold:
                    similarities.append((smiles, similarity))
        
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_n]
    
    def apply_reaction_template(self, template, reactant_smiles):
        """Apply reaction template to reactant to generate products (with product standardization)"""
        try:
            if not template or '>>' not in template:
                return []
            
            rxn = AllChem.ReactionFromSmarts(template)
            if not rxn:
                return []
            
            # Standardize reactant SMILES
            reactant_smiles_std = self.standardize_smiles(reactant_smiles)
            reactant = Chem.MolFromSmiles(reactant_smiles_std)
            if not reactant:
                return []
            
            products_tuples = rxn.RunReactants((reactant,))
            
            products = []
            seen_canonical = set()  # For deduplication
            
            for product_tuple in products_tuples[:5]:
                for product in product_tuple:
                    try:
                        # Use canonical SMILES to ensure uniqueness
                        product_smiles = Chem.MolToSmiles(product, canonical=True)
                        
                        # Check if we've seen this standardized SMILES
                        if product_smiles not in seen_canonical:
                            seen_canonical.add(product_smiles)
                            products.append(product_smiles)
                    except:
                        continue
            
            return products
        except Exception as e:
            return []
    
    def predict_degradation_products(self, query_smiles, query_reaction_type=None):
        """
        Predict degradation products for query molecule (using optimized confidence model and standardized SMILES)
        """
        # Standardize query SMILES
        query_smiles = self.standardize_smiles(query_smiles)
        
        predictions = []
        similarity_threshold = 0.3  # Fixed threshold
        
        # Find similar molecules
        similar_molecules = self.find_similar_molecules(
            query_smiles, 
            threshold=similarity_threshold,
            top_n=20
        )
        
        if not similar_molecules:
            return predictions
        
        # Collect all candidate reactions
        candidate_reactions = []
        
        for similar_smiles, similarity in similar_molecules:
            for reaction in self.forward_index.get(similar_smiles, []):
                if reaction['template'] and reaction['template'] != '':
                    # Get log-transformed template frequency
                    template_frequency_log = self.template_class_freq.get(
                        reaction['template_class'], 
                        0.001
                    )
                    
                    # Calculate system consistency
                    if query_reaction_type:
                        system_consistency = self.check_system_consistency(
                            query_reaction_type, 
                            reaction['reaction_type']
                        )
                    else:
                        system_consistency = self.reaction_type_freq.get(
                            reaction['reaction_type'], 
                            0.5
                        )
                    
                    # Use optimized confidence calculation
                    confidence = self.calculate_confidence_optimized(
                        similarity,
                        template_frequency_log,
                        system_consistency
                    )
                    
                    candidate_reactions.append({
                        'template': reaction['template'],
                        'template_class': reaction['template_class'],
                        'reaction_type': reaction['reaction_type'],
                        'source_smiles': similar_smiles,
                        'source_product': reaction['product'],
                        'confidence': confidence
                    })
        
        # Sort by confidence
        candidate_reactions.sort(key=lambda x: x['confidence'], reverse=True)
        
        # Apply templates to generate predicted products
        seen_products = set()
        for candidate in candidate_reactions[:50]:
            products = self.apply_reaction_template(candidate['template'], query_smiles)
            
            for product_smiles in products:
                # Products already standardized in apply_reaction_template
                if product_smiles not in seen_products:
                    seen_products.add(product_smiles)
                    predictions.append({
                        'product': product_smiles,
                        'template': candidate['template'],
                        'template_class': candidate['template_class'],
                        'reaction_type': candidate['reaction_type'],
                        'confidence': candidate['confidence'],
                        'source': candidate['source_smiles']
                    })
        
        # Sort by confidence
        predictions.sort(key=lambda x: x['confidence'], reverse=True)
        
        return predictions
    
    def build_predictive_network(self, center_smiles, max_depth=2, confidence_threshold=0.5,
                                query_reaction_type=None):
        """Build hybrid network containing real and predicted pathways (supporting system-specific filtering, using standardized SMILES)"""
        G = nx.MultiDiGraph()
        edge_info = defaultdict(list)
        
        # Standardize center node SMILES
        center_smiles = self.standardize_smiles(center_smiles)
        
        # Add center node
        is_known = center_smiles in self.chem_groups
        G.add_node(center_smiles,
                  smiles=center_smiles,
                  chem_group=self.chem_groups.get(center_smiles, 'predicted'),
                  is_center=True,
                  is_known=is_known,
                  node_type='real' if is_known else 'predicted',
                  depth_from_center=0)
        
        # Build network
        visited = set()
        queue = [(center_smiles, 0, 'real' if is_known else 'predicted', query_reaction_type)]
        
        while queue:
            current_smiles, depth, parent_type, current_reaction_type = queue.pop(0)
            
            if current_smiles in visited or depth >= max_depth:
                continue
            
            visited.add(current_smiles)
            
            # Process real degradation pathways
            if current_smiles in self.forward_index:
                for reaction in self.forward_index[current_smiles]:
                    # Product already standardized in build_graph_index
                    product = reaction['product']
                    
                    # Determine if this should be a real or predicted pathway
                    if query_reaction_type and reaction['reaction_type'] != query_reaction_type:
                        # Different system real pathways treated as predicted, confidence 0.9
                        edge_type = 'predicted'
                        confidence = 0.9
                        node_type = 'predicted'
                    else:
                        # Same system or unspecified system treated as real pathway
                        edge_type = 'real'
                        confidence = 1.0
                        node_type = 'real'
                    
                    # Only add if confidence threshold is met
                    if confidence >= confidence_threshold:
                        if product not in G:
                            G.add_node(product,
                                      smiles=product,
                                      chem_group=self.chem_groups.get(product, 'unknown'),
                                      is_center=False,
                                      is_known=True,
                                      node_type=node_type,
                                      depth_from_center=depth + 1)
                        else:
                            # Update attributes if node exists
                            existing_depth = G.nodes[product].get('depth_from_center', float('inf'))
                            if depth + 1 < existing_depth:
                                G.nodes[product]['depth_from_center'] = depth + 1
                        
                        G.add_edge(current_smiles, product,
                                  template_class=reaction['template_class'],
                                  template=reaction['template'],
                                  edge_type=edge_type,
                                  confidence=confidence,
                                  depth=depth + 1)
                        
                        edge_info[(current_smiles, product)].append({
                            **reaction,
                            'edge_type': edge_type,
                            'confidence': confidence
                        })
                        
                        # Continue searching
                        if edge_type == 'real' or confidence > 0.7:
                            queue.append((product, depth + 1, node_type, 
                                        reaction['reaction_type'] if edge_type == 'real' else current_reaction_type))
            
            # Add predicted degradation pathways
            predictions = self.predict_degradation_products(
                current_smiles,
                query_reaction_type=current_reaction_type
            )
            
            # Get existing products (need standardized comparison)
            existing_products = set()
            for u, v in G.edges():
                if u == current_smiles:
                    existing_products.add(v)
            
            # Filter and add predicted products
            predictions = [p for p in predictions 
                          if p['product'] not in existing_products 
                          and p['confidence'] >= confidence_threshold]
            
            # Limit number of predicted products
            for pred in predictions[:10]:
                product = pred['product']  # Already standardized
                
                product_mol = Chem.MolFromSmiles(product)
                if not product_mol:
                    continue
                
                if product not in G:
                    G.add_node(product,
                              smiles=product,
                              chem_group='predicted',
                              is_center=False,
                              is_known=False,
                              node_type='predicted',
                              depth_from_center=depth + 1)
                else:
                    # Update depth if needed
                    existing_depth = G.nodes[product].get('depth_from_center', float('inf'))
                    if depth + 1 < existing_depth:
                        G.nodes[product]['depth_from_center'] = depth + 1
                
                G.add_edge(current_smiles, product,
                          template_class=pred['template_class'],
                          template=pred['template'],
                          edge_type='predicted',
                          confidence=pred['confidence'],
                          depth=depth + 1)
                
                edge_info[(current_smiles, product)].append({
                    'product': product,
                    'template_class': pred['template_class'],
                    'template': pred['template'],
                    'edge_type': 'predicted',
                    'confidence': pred['confidence'],
                    'source': pred['source']
                })
                
                # Continue depth search for high-confidence predicted products
                if depth + 1 < max_depth and pred['confidence'] > 0.7:
                    queue.append((product, depth + 1, 'predicted', pred['reaction_type']))
        
        return G, edge_info
    
    def get_molecule_image(self, smiles, size=(200, 200)):
        """Generate base64 encoding of molecular structure image (using standardized SMILES)"""
        try:
            # Standardize SMILES
            smiles_std = self.standardize_smiles(smiles)
            mol = Chem.MolFromSmiles(smiles_std)
            if mol is None:
                return None
            
            img = Draw.MolToImage(mol, size=size)
            
            buffer = BytesIO()
            img.save(buffer, format='PNG')
            buffer.seek(0)
            img_base64 = base64.b64encode(buffer.read()).decode('utf-8')
            
            return f"data:image/png;base64,{img_base64}"
        except:
            return None
    
    def generate_predictive_network_html(self, center_smiles, max_depth, confidence_threshold, 
                                        reaction_type, G, edge_info):
        """Generate HTML visualization of predictive network"""
        
        # Prepare node data
        nodes = []
        node_id_map = {}
        
        degrees = dict(G.degree())
        in_degrees = dict(G.in_degree())
        out_degrees = dict(G.out_degree())
        
        node_colors = {
            'real': '#4CAF50',
            'predicted': '#2196F3',
            'center': '#FF5722'
        }
        
        print("\nGenerating molecular structure images...")
        
        for i, node in enumerate(G.nodes()):
            node_data = G.nodes[node]
            node_id = f"node_{i}"
            node_id_map[node] = node_id
            
            mol_image = self.get_molecule_image(node)
            
            if node_data.get('is_center'):
                color = node_colors['center']
                size = 35
                border_width = 5
            else:
                color = node_colors[node_data.get('node_type', 'predicted')]
                size = min(15 + degrees[node] * 2, 30)
                border_width = 3 if node_data.get('node_type') == 'real' else 2
            
            node_type_text = "Real molecule" if node_data.get('node_type') == 'real' else "Predicted molecule"
            if node_data.get('is_center'):
                node_type_text = "Center node"
            
            title_text = (f"【{node_type_text}】\n"
                         f"SMILES: {node}\n"
                         f"━━━━━━━━━━━━━━━━\n"
                         f"Distance from center: {node_data.get('depth_from_center', 0)}\n"
                         f"In-degree: {in_degrees[node]}\n"
                         f"Out-degree: {out_degrees[node]}")
            
            nodes.append({
                'id': node_id,
                'label': '',
                'title': title_text,
                'color': color,
                'size': size,
                'borderWidth': border_width,
                'borderWidthSelected': border_width + 2,
                'smiles': node,
                'node_type': node_data.get('node_type', 'predicted'),
                'is_center': node_data.get('is_center', False),
                'mol_image': mol_image
            })
        
        # Prepare edge data
        edges = []
        edge_id = 0
        
        for (u, v), reactions in edge_info.items():
            if u not in node_id_map or v not in node_id_map:
                continue
            
            edge_types = [r.get('edge_type', 'predicted') for r in reactions]
            is_predicted = 'predicted' in edge_types
            
            if is_predicted:
                confidence = reactions[0].get('confidence', 0.5)
                
                # Edge color based on confidence
                if confidence == 0.9:
                    # Real pathway from other systems
                    color = {'color': '#2196F3', 'opacity': 0.8}
                    width = 2.5
                else:
                    opacity = 0.3 + confidence * 0.6
                    color = {'color': '#2196F3', 'opacity': opacity}
                    width = 1 + confidence * 3
                
                dashes = True
                
                # Edge description (don't show reaction type)
                edge_title = f"【Predicted degradation pathway】\n"
                edge_title += f"Confidence: {confidence:.2%}\n"
                edge_title += f"Template Class: {reactions[0].get('template_class', 'Unknown')}"
            else:
                confidence = 1.0
                color = {'color': '#4CAF50', 'opacity': 0.8}
                dashes = False
                width = 2
                
                edge_title = f"【Real degradation pathway】\n"
                edge_title += f"Confidence: 100%\n"
                edge_title += f"Template Class: {reactions[0].get('template_class', 'Unknown')}"
            
            edges.append({
                'id': f"edge_{edge_id}",
                'from': node_id_map[u],
                'to': node_id_map[v],
                'title': edge_title,
                'width': width,
                'color': color,
                'dashes': dashes,
                'arrows': {'to': {'enabled': True, 'scaleFactor': 0.8}},
                'smooth': {'type': 'continuous'},
                'edge_type': 'predicted' if is_predicted else 'real',
                'confidence': confidence
            })
            edge_id += 1
        
        # Generate HTML
        html = f'''<!DOCTYPE html>
<html>
<head>
    <meta charset="utf-8">
    <title>Predictive Degradation Network - {center_smiles[:30]}</title>
    <script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
    <style>
        body {{
            margin: 0;
            padding: 0;
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        }}
        
        #header {{
            background: rgba(255, 255, 255, 0.95);
            padding: 15px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
        }}
        
        #network {{
            width: 100%;
            height: calc(100vh - 250px);
            background: white;
            border-radius: 10px;
            margin: 10px;
            box-shadow: 0 4px 20px rgba(0,0,0,0.1);
        }}
        
        #controls {{
            padding: 15px;
            background: rgba(255, 255, 255, 0.95);
            display: flex;
            align-items: center;
            gap: 20px;
            margin: 10px;
            border-radius: 10px;
        }}
        
        .btn {{
            padding: 10px 20px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            border: none;
            border-radius: 5px;
            cursor: pointer;
            font-weight: bold;
            transition: transform 0.2s;
        }}
        
        .btn:hover {{
            transform: translateY(-2px);
        }}
        
        #info {{
            padding: 15px;
            background: rgba(255, 255, 255, 0.95);
            margin: 10px;
            border-radius: 10px;
            display: flex;
            justify-content: space-between;
            align-items: center;
        }}
        
        .legend {{
            position: absolute;
            top: 80px;
            right: 20px;
            background: rgba(255, 255, 255, 0.95);
            padding: 15px;
            border-radius: 10px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
        }}
        
        .legend-item {{
            display: flex;
            align-items: center;
            margin: 10px 0;
        }}
        
        .legend-color {{
            width: 30px;
            height: 30px;
            margin-right: 10px;
            border-radius: 50%;
            border: 2px solid #ddd;
        }}
        
        .legend-line {{
            width: 40px;
            height: 3px;
            margin-right: 10px;
        }}
        
        #molecule-popup {{
            display: none;
            position: absolute;
            background: white;
            border: 3px solid #667eea;
            border-radius: 10px;
            padding: 15px;
            z-index: 1000;
            box-shadow: 0 4px 20px rgba(0,0,0,0.2);
        }}
        
        .stats {{
            display: flex;
            gap: 20px;
        }}
        
        .stat-item {{
            padding: 5px 10px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            border-radius: 5px;
            font-size: 14px;
        }}
        
        .param-info {{
            background: #f0f0f0;
            padding: 8px 15px;
            border-radius: 5px;
            font-size: 13px;
            color: #333;
        }}
    </style>
</head>
<body>
    <div id="header">
        <h2 style="margin: 0; color: #333;">🧪 Predictive Chemical Degradation Network Analysis (Standardized Version)</h2>
    </div>
    
    <div id="controls">
        <button class="btn" onclick="resetView()">🔄 Reset View</button>
        <button class="btn" onclick="toggleLabels()">🏷️ Toggle Labels</button>
        <button class="btn" onclick="togglePhysics()">⚙️ Physics Engine</button>
    </div>
    
    <div id="network"></div>
    
    <div id="info">
        <div class="stats">
            <span class="stat-item">Nodes: {G.number_of_nodes()}</span>
            <span class="stat-item">Edges: {G.number_of_edges()}</span>
            <span class="stat-item">Depth: {max_depth}</span>
            <span class="stat-item">Confidence Threshold: {confidence_threshold:.1%}</span>
        </div>
        <div class="param-info">
            Reaction System: {reaction_type if reaction_type else 'All systems'}
        </div>
        <div>
            <span style="color: #4CAF50;">● Real pathways</span> | 
            <span style="color: #2196F3;">● Predicted pathways</span>
        </div>
    </div>
    
    <div class="legend">
        <h4 style="margin: 0 0 15px 0;">Legend</h4>
        
        <div class="legend-item">
            <div class="legend-color" style="background: #FF5722;"></div>
            <span>Center Node</span>
        </div>
        
        <div class="legend-item">
            <div class="legend-color" style="background: #4CAF50;"></div>
            <span>Real Molecule</span>
        </div>
        
        <div class="legend-item">
            <div class="legend-color" style="background: #2196F3;"></div>
            <span>Predicted Molecule</span>
        </div>
        
        <hr style="margin: 10px 0;">
        
        <div class="legend-item">
            <div class="legend-line" style="background: #4CAF50; height: 3px;"></div>
            <span>Current System</span>
        </div>
        
        
        <div class="legend-item">
            <div class="legend-line" style="background: #2196F3; border-top: 3px dashed #2196F3; height: 0;"></div>
            <span>Predicted Pathway</span>
        </div>
        
        <hr style="margin: 10px 0;">
        
        </div>
    </div>
    
    <div id="molecule-popup">
        <div style="text-align: center;">
            <img id="mol-image" src="" style="max-width: 200px; max-height: 200px;">
            <div id="mol-smiles" style="margin-top: 10px; font-size: 12px; word-break: break-all;"></div>
        </div>
    </div>
    
    <script>
        var nodes = new vis.DataSet({json.dumps(nodes)});
        var edges = new vis.DataSet({json.dumps(edges)});
        
        var container = document.getElementById('network');
        var data = {{
            nodes: nodes,
            edges: edges
        }};
        
        var options = {{
            nodes: {{
                shape: 'dot',
                font: {{
                    size: 12,
                    color: '#333',
                    strokeWidth: 3,
                    strokeColor: '#fff'
                }}
            }},
            edges: {{
                smooth: {{
                    type: 'continuous',
                    forceDirection: 'none'
                }}
            }},
            physics: {{
                enabled: true,
                barnesHut: {{
                    gravitationalConstant: -8000,
                    centralGravity: 0.3,
                    springLength: 200,
                    damping: 0.09
                }}
            }},
            interaction: {{
                hover: true,
                tooltipDelay: 100,
                navigationButtons: true
            }}
        }};
        
        var network = new vis.Network(container, data, options);
        var showLabels = false;
        var physicsEnabled = true;
        
        // Molecule structure popup
        network.on("click", function(params) {{
            document.getElementById('molecule-popup').style.display = 'none';
            
            if (params.nodes.length > 0) {{
                var nodeId = params.nodes[0];
                var node = nodes.get(nodeId);
                
                if (node.mol_image) {{
                    var popup = document.getElementById('molecule-popup');
                    var img = document.getElementById('mol-image');
                    var smiles = document.getElementById('mol-smiles');
                    
                    img.src = node.mol_image;
                    smiles.textContent = node.smiles;
                    
                    popup.style.display = 'block';
                    popup.style.left = params.event.center.x + 'px';
                    popup.style.top = params.event.center.y + 'px';
                }}
            }}
        }});
        
        function resetView() {{
            network.fit();
        }}
        
        function toggleLabels() {{
            showLabels = !showLabels;
            nodes.forEach(function(node) {{
                nodes.update({{
                    id: node.id,
                    label: showLabels ? (node.smiles.length > 20 ? node.smiles.substring(0, 20) + '...' : node.smiles) : ''
                }});
            }});
        }}
        
        function togglePhysics() {{
            physicsEnabled = !physicsEnabled;
            network.setOptions({{ physics: {{ enabled: physicsEnabled }} }});
        }}
        
        network.on("stabilizationIterationsDone", function() {{
            network.fit();
        }});
    </script>
</body>
</html>'''
        
        return html
    
    def query_predictive_network(self):
        """Interactive query for predictive network"""
        print("\n" + "="*60)
        print("🔬 Predictive Chemical Degradation Network Generator (Standardized Version)")
        print("="*60)
        print("Feature: Predict degradation pathways for unknown molecules based on optimized confidence model")
        print(f"Model parameters: α={self.alpha}, β={self.beta}, γ={self.gamma}")
        print("✨ SMILES standardization enabled, identical structures will be automatically merged")
        
        while True:
            print("\n" + "-"*40)
            center_smiles = input("Enter molecule SMILES (type 'quit' to exit): ").strip()
            
            if center_smiles.lower() == 'quit':
                print("\n👋 Goodbye!")
                break
            
            if not center_smiles:
                print("⚠️ Please enter a valid SMILES")
                continue
            
            # Validate and standardize SMILES
            mol = Chem.MolFromSmiles(center_smiles)
            if not mol:
                print("❌ Invalid SMILES, please check your input")
                continue
            
            # Show standardized SMILES
            standardized_smiles = self.standardize_smiles(center_smiles)
            if standardized_smiles != center_smiles:
                print(f"📝 Original SMILES: {center_smiles}")
                print(f"✓ Standardized SMILES: {standardized_smiles}")
                center_smiles = standardized_smiles
            
            # Check if molecule is in database
            is_known = center_smiles in self.chem_groups
            if is_known:
                print(f"✅ This molecule exists in the database")
                print(f"   Chemical group: {self.chem_groups.get(center_smiles, 'unknown')}")
            else:
                print(f"ℹ️ This molecule is not in the database, will use prediction mode")
            
            # Select reaction system
            print("\nSelect reaction system (enter number, press Enter to select all systems):")
            print("  1. cAOP (Catalytic Advanced Oxidation)")
            print("  2. H2O2-based (Hydrogen Peroxide-based)")
            print("  3. PS-based (Persulfate-based)")
            print("  4. eAOP (Electrochemical Advanced Oxidation)")
            print("  5. light-based")
            print("  6. other")
            print("  7. Ozone-based")
            print("  8. pAOP (Physical Advanced Oxidation)")
            
            system_choice = input("\nEnter choice (1-8, or press Enter to skip): ").strip()
            
            reaction_type = None
            if system_choice:
                try:
                    choice_num = int(system_choice)
                    if 1 <= choice_num <= 8:
                        reaction_type = self.reaction_systems[choice_num]
                        print(f"   Selected system: {reaction_type}")
                    else:
                        print("⚠️ Invalid choice, will search all systems")
                except ValueError:
                    print("⚠️ Invalid input, will search all systems")
            else:
                print("   No system specified, will search all systems")
            
            # Set confidence threshold
            threshold_str = input("\nEnter confidence threshold (0.3-1.0, default 0.5): ").strip()
            try:
                confidence_threshold = float(threshold_str) if threshold_str else 0.5
                if confidence_threshold < 0.3 or confidence_threshold > 1.0:
                    print("⚠️ Threshold must be between 0.3-1.0, using default 0.5")
                    confidence_threshold = 0.5
            except ValueError:
                print("⚠️ Invalid threshold, using default 0.5")
                confidence_threshold = 0.5
            
            # Set search depth
            depth_str = input("Enter search depth (default 2, range 1-5): ").strip()
            try:
                max_depth = int(depth_str) if depth_str else 2
                if max_depth < 1 or max_depth > 5:
                    print("⚠️ Depth must be between 1-5, using default 2")
                    max_depth = 2
            except ValueError:
                print("⚠️ Invalid depth value, using default 2")
                max_depth = 2
            
            print(f"\n⚙️ Parameter settings:")
            print(f"   - Reaction system: {reaction_type if reaction_type else 'All systems'}")
            print(f"   - Confidence threshold: {confidence_threshold:.1%}")
            print(f"   - Search depth: {max_depth}")
            print(f"   - Similarity threshold: 30% (fixed)")
            
            # Build network
            print("\nBuilding predictive network...")
            G, edge_info = self.build_predictive_network(
                center_smiles, 
                max_depth=max_depth,
                confidence_threshold=confidence_threshold,
                query_reaction_type=reaction_type
            )
            
            if G.number_of_nodes() <= 1:
                print(f"\n❌ No degradation pathways found meeting confidence threshold")
                print("   Suggestion: Lower confidence threshold or select another system")
                continue
            
            # Statistics
            real_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') == 'real']
            pred_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') == 'predicted']
            real_edges = sum(1 for _, _, d in G.edges(data=True) if d.get('edge_type') == 'real')
            pred_edges = sum(1 for _, _, d in G.edges(data=True) if d.get('edge_type') == 'predicted')
            
            # Count edges with different confidence levels
            conf_90_edges = sum(1 for _, _, d in G.edges(data=True) if d.get('confidence') == 0.9)
            
            print(f"\n📊 Network statistics:")
            print(f"   Nodes: {G.number_of_nodes()} (Real: {len(real_nodes)}, Predicted: {len(pred_nodes)})")
            print(f"   Edges: {G.number_of_edges()} (Current system: {real_edges}, Other systems: {conf_90_edges}, Predicted: {pred_edges - conf_90_edges})")
            
            # Calculate average confidence of predicted edges
            pred_confidences = [d.get('confidence', 0) for _, _, d in G.edges(data=True) 
                              if d.get('edge_type') == 'predicted']
            if pred_confidences:
                avg_confidence = sum(pred_confidences) / len(pred_confidences)
                print(f"   Predicted average confidence: {avg_confidence:.2%}")
                print(f"   Predicted confidence range: {min(pred_confidences):.2%} - {max(pred_confidences):.2%}")
            
            # Generate HTML
            print("\nGenerating visualization...")
            html_content = self.generate_predictive_network_html(
                center_smiles, max_depth, confidence_threshold, reaction_type, G, edge_info
            )
            
            # Save file
            safe_filename = center_smiles.replace('/', '_').replace('\\', '_')[:30]
            system_suffix = reaction_type.replace('-', '_') if reaction_type else 'all'
            filename = f"prediction_network_{safe_filename}_{system_suffix}_conf{int(confidence_threshold*100)}_depth{max_depth}.html"
            
            with open(filename, 'w', encoding='utf-8') as f:
                f.write(html_content)
            
            print(f"\n✅ Visualization generated!")
            print(f"📁 Filename: {filename}")
            print(f"\nLegend:")
            print("   🟢 Green solid line = Real degradation pathway (100% confidence)")
            print("   🔵 Blue dashed line = Predicted degradation pathway")
            print("   📊 Edge thickness and opacity reflect confidence level")

# Function for Jupyter usage
def run_predictor(pickle_file='degradation_model.pkl'):
    """
    Run the predictor in Jupyter environment
    
    Usage:
    >>> run_predictor('degradation_model.pkl')
    """
    if not RDKIT_AVAILABLE:
        print("❌ RDKit is required. Install with: pip install rdkit-pypi")
        return
    
    try:
        network = OptimizedPredictiveDegradationNetwork(from_pickle=pickle_file)
        network.query_predictive_network()
    except Exception as e:
        print(f"❌ Error: {str(e)}")

# Main program
if __name__ == "__main__":
    print("="*60)
    print("🚀 Starting Predictive Chemical Degradation Network Generator")
    print("   (Running from pre-processed data)")
    print("="*60)
    
    # Check RDKit
    if not RDKIT_AVAILABLE:
        print("❌ Error: This feature requires RDKit library")
        print("Installation command: pip install rdkit-pypi")
        exit(1)
    
    # Check if running in Jupyter
    try:
        get_ipython()  # This will exist in Jupyter
        # In Jupyter, use default file
        pickle_file = 'degradation_model.pkl'
    except NameError:
        # Not in Jupyter, check command line args
        if len(sys.argv) > 1 and not sys.argv[1].startswith('--'):
            pickle_file = sys.argv[1]
        else:
            pickle_file = 'degradation_model.pkl'
    
    try:
        # Load from pickle
        network = OptimizedPredictiveDegradationNetwork(from_pickle=pickle_file)
        
        # Enter query mode
        network.query_predictive_network()
    except Exception as e:
        print(f"❌ Error: {str(e)}")

🚀 Starting Predictive Chemical Degradation Network Generator
   (Running from pre-processed data)
Loading from pickle file: degradation_model.pkl
✅ Loaded processed data from degradation_model.pkl
   - Molecules: 3059
   - Templates: 6600
   - Forward reactions: 6600
✓ Confidence model parameters: α=0.8, β=0.1, γ=0.1

🔬 Predictive Chemical Degradation Network Generator (Standardized Version)
Feature: Predict degradation pathways for unknown molecules based on optimized confidence model
Model parameters: α=0.8, β=0.1, γ=0.1
✨ SMILES standardization enabled, identical structures will be automatically merged

----------------------------------------
📝 Original SMILES: C1=CC(=CC=C1O)S(=O)(=O)C2=CC=C(C=C2)O
✓ Standardized SMILES: O=S(=O)(c1ccc(O)cc1)c1ccc(O)cc1
✅ This molecule exists in the database
   Chemical group: phenol-based

Select reaction system (enter number, press Enter to select all systems):
  1. cAOP (Catalytic Advanced Oxidation)
  2. H2O2-based (Hydrogen Peroxide-based)
  3.