In [6]:
import gradio as gr
import plotly.graph_objects as go
import requests
from Bio.PDB import PDBParser
import numpy as np
import tempfile
import os
import subprocess
import csv
import re
import shutil
import time
import pandas as pd
from pathlib import Path

# ============== Search and Download Functions ==============

def search_pdb_for_first_hit(protein_name: str):
    """Search RCSB PDB and return the first result found."""
    query = {
        "query": {
            "type": "terminal",
            "service": "text",
            "parameters": {
                "attribute": "struct.title",
                "operator": "contains_phrase",
                "value": protein_name
            }
        },
        "return_type": "entry",
        "request_options": {
            "return_all_hits": False,
            "results_content_type": ["experimental"],
            "sort": [{"sort_by": "score", "direction": "desc"}]
        }
    }
    
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    
    try:
        response = requests.post(url, json=query, timeout=30)
        response.raise_for_status()
        data = response.json()
        
        result_set = data.get('result_set', [])
        if not result_set:
            return None, f"No structures found for '{protein_name}'"
            
        first_pdb_id = result_set[0]['identifier']
        return first_pdb_id, f"Found PDB ID: {first_pdb_id}"
        
    except requests.exceptions.RequestException as e:
        return None, f"Error searching PDB: {e}"


def download_pdb_content(pdb_id: str, output_dir: str = "proteins"):
    """Download PDB file content from RCSB PDB"""
    Path(output_dir).mkdir(exist_ok=True)
    pdb_id = pdb_id.strip().upper()
    url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
    output_file = Path(output_dir) / f"{pdb_id}.pdb"
    
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        with open(output_file, 'w') as f:
            f.write(response.text)
        
        return response.text, f"Downloaded PDB file ({len(response.text)/1024:.2f} KB)", str(output_file)
        
    except requests.exceptions.HTTPError as e:
        if e.response.status_code == 404:
            return None, f"Error: PDB ID '{pdb_id}' not found", None
        else:
            return None, f"HTTP Error: {e}", None
    except Exception as e:
        return None, f"Error downloading file: {e}", None


def download_fasta_content(pdb_id: str, output_dir: str = "proteins"):
    """Download FASTA sequence for a given PDB ID"""
    Path(output_dir).mkdir(exist_ok=True)
    pdb_id = pdb_id.strip().upper()
    url = f"https://www.rcsb.org/fasta/entry/{pdb_id}"
    output_file = Path(output_dir) / f"{pdb_id}.fasta"
    
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        if not response.text.strip().startswith(">"):
            return None, f"No valid FASTA data found for {pdb_id}", None
        
        with open(output_file, 'w') as f:
            f.write(response.text)
        
        return response.text, f"Downloaded FASTA sequence", str(output_file)
        
    except requests.exceptions.RequestException as e:
        return None, f"Error downloading FASTA: {e}", None


# ============== Visualization Functions ==============

def parse_pdb_structure(pdb_data, pdb_id="structure"):
    """Parse PDB data and extract atom coordinates"""
    parser = PDBParser(QUIET=True)
    
    with tempfile.NamedTemporaryFile(mode='w', suffix='.pdb', delete=False) as tmp:
        tmp.write(pdb_data)
        tmp_path = tmp.name
    
    try:
        structure = parser.get_structure(pdb_id, tmp_path)
        atoms_data = []
        
        for model in structure:
            for chain in model:
                chain_id = chain.id
                for residue in chain:
                    res_name = residue.get_resname()
                    res_id = residue.get_id()[1]
                    
                    for atom in residue:
                        coord = atom.get_coord()
                        atoms_data.append({
                            'coord': coord,
                            'element': atom.element,
                            'residue': res_name,
                            'res_id': res_id,
                            'chain': chain_id,
                            'name': atom.name
                        })
        
        os.unlink(tmp_path)
        return atoms_data
    
    except Exception as e:
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)
        raise e


def get_atom_color(atom_type):
    """Return color for different atom types (CPK coloring)"""
    color_map = {
        'C': '#909090',  'N': '#3050F8',  'O': '#FF0D0D',  'S': '#FFFF30',
        'P': '#FF8000',  'H': '#FFFFFF',  'F': '#90E050',  'CL': '#1FF01F',
        'BR': '#A62929', 'I': '#940094',  'FE': '#E06633', 'CA': '#3DFF00',
    }
    return color_map.get(atom_type.upper(), '#FF1493')


def find_bonds(atoms_data, max_distance=2.0):
    """Find bonds between atoms based on distance"""
    bonds = []
    n_atoms = len(atoms_data)
    
    for i in range(n_atoms):
        atom1 = atoms_data[i]
        for j in range(i + 1, min(i + 20, n_atoms)):
            atom2 = atoms_data[j]
            
            if atom1['chain'] != atom2['chain']:
                continue
            if abs(atom1['res_id'] - atom2['res_id']) > 1:
                continue
            
            dist = np.linalg.norm(atom1['coord'] - atom2['coord'])
            
            if dist < max_distance:
                bonds.append((i, j))
    
    return bonds


def create_stick_traces(atoms_data, bonds):
    """Create stick/line traces for bonds"""
    traces = []
    
    for bond in bonds:
        i, j = bond
        atom1 = atoms_data[i]
        atom2 = atoms_data[j]
        
        coord1 = atom1['coord']
        coord2 = atom2['coord']
        
        trace = go.Scatter3d(
            x=[coord1[0], coord2[0]],
            y=[coord1[1], coord2[1]],
            z=[coord1[2], coord2[2]],
            mode='lines',
            line=dict(color='gray', width=3),
            showlegend=False,
            hoverinfo='skip'
        )
        traces.append(trace)
    
    return traces


def create_protein_visualization(atoms_data, style="sphere", show_backbone=True, title="Protein Structure"):
    """Create complete protein visualization"""
    coords = np.array([atom['coord'] for atom in atoms_data])
    atom_types = [atom['element'] for atom in atoms_data]
    colors = [get_atom_color(atom) for atom in atom_types]
    
    fig = go.Figure()
    
    # Add backbone trace
    if show_backbone:
        ca_atoms = [i for i, atom in enumerate(atoms_data) if atom['name'] == 'CA']
        if len(ca_atoms) > 1:
            ca_coords = coords[ca_atoms]
            backbone_trace = go.Scatter3d(
                x=ca_coords[:, 0],
                y=ca_coords[:, 1],
                z=ca_coords[:, 2],
                mode='lines',
                line=dict(color='lightblue', width=6),
                name='Backbone',
                hoverinfo='skip'
            )
            fig.add_trace(backbone_trace)
    
    # For stick visualization, add bond lines
    if style == "stick" or style == "ball-and-stick":
        bonds = find_bonds(atoms_data)
        stick_traces = create_stick_traces(atoms_data, bonds)
        for trace in stick_traces:
            fig.add_trace(trace)
    
    # Add atoms
    marker_size = 8 if style == "sphere" else 3 if style == "stick" else 5
    
    atoms_trace = go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers',
        marker=dict(
            size=marker_size,
            color=colors,
            line=dict(width=0.5, color='white')
        ),
        text=[f"{atom['name']} ({atom['element']}) - {atom['residue']}{atom['res_id']}" 
              for atom in atoms_data],
        hovertemplate='<b>%{text}</b><br>X: %{x:.2f}<br>Y: %{y:.2f}<br>Z: %{z:.2f}<extra></extra>',
        name='Atoms'
    )
    fig.add_trace(atoms_trace)
    
    # Update layout
    fig.update_layout(
        title=f"{title} ({len(atoms_data)} atoms)",
        scene=dict(
            xaxis_title='X (Å)',
            yaxis_title='Y (Å)',
            zaxis_title='Z (Å)',
            bgcolor='white',
            xaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            yaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            zaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            aspectmode='data'
        ),
        showlegend=True,
        width=900,
        height=700,
        margin=dict(l=0, r=0, t=40, b=0)
    )
    
    return fig


# ============== RAMPlot Functions ==============

def run_ramplot(input_folder="proteins", output_folder="my_analysis_folder"):
    """Run ramplot command and return plot images"""
    Path(output_folder).mkdir(exist_ok=True)
    
    cmd = [
        "ramplot", "pdb",
        "-i", input_folder,
        "-o", output_folder,
        "-m", "0",
        "-r", "600",
        "-p", "png"
    ]
    
    try:
        result = subprocess.run(cmd, check=True, text=True, capture_output=True)
        
        # Find generated plots
        plots_folder = Path(output_folder) / "Plots"
        if not plots_folder.exists():
            return None, None, None, None, "No plots folder generated"
        
        plot_files = {
            'map_2d': None,
            'map_3d': None,
            'std_2d': None,
            'std_3d': None
        }
        
        for file in plots_folder.glob("*.png"):
            if "MapType2DAll" in file.name:
                plot_files['map_2d'] = str(file)
            elif "MapType3DAll" in file.name:
                plot_files['map_3d'] = str(file)
            elif "StdMapType2DGeneralGly" in file.name:
                plot_files['std_2d'] = str(file)
            elif "StdMapType3DGeneral" in file.name:
                plot_files['std_3d'] = str(file)
        
        return (plot_files['map_2d'], plot_files['map_3d'], 
                plot_files['std_2d'], plot_files['std_3d'], 
                "RAMPlot completed successfully")
        
    except subprocess.CalledProcessError as e:
        return None, None, None, None, f"Error running RAMPlot: {e.stderr}"
    except Exception as e:
        return None, None, None, None, f"Error: {str(e)}"


def extract_favoured_percentage(csv_path):
    """Extract the Favoured percentage from CSV file"""
    try:
        with open(csv_path, 'r') as f:
            content = f.read()
            match = re.search(r'Favoured:\s*,\d+,\((\d+\.?\d*)%\)', content)
            
            if match:
                return float(match.group(1))
            else:
                return None
                
    except Exception as e:
        return None


# ============== SWISS-MODEL Functions ==============

def run_swiss_model(fasta_path, api_token, project_title="Homology_Model"):
    """Run SWISS-MODEL homology modeling"""
    BASE_URL = "https://swissmodel.expasy.org"
    HEADERS = {"Authorization": f"Token {api_token}"}
    
    try:
        # Read FASTA file
        sequences = []
        current_sequence = []
        
        with open(fasta_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if current_sequence:
                        sequences.append("".join(current_sequence))
                        current_sequence = []
                else:
                    if line:
                        current_sequence.append(line)
            
            if current_sequence:
                sequences.append("".join(current_sequence))
        
        if not sequences:
            return None, "No valid sequences found in FASTA file"
        
        FASTA_INPUT = sequences[0] if len(sequences) == 1 else sequences
        
        # Submit job
        payload = {
            "target_sequences": FASTA_INPUT,
            "project_title": project_title
        }
        
        submit_response = requests.post(
            f"{BASE_URL}/automodel/", 
            headers=HEADERS, 
            json=payload
        )
        submit_response.raise_for_status()
        
        project_id = submit_response.json().get("project_id")
        
        # Poll for results
        max_attempts = 60  # 60 minutes max
        for attempt in range(max_attempts):
            status_response = requests.get(
                f"{BASE_URL}/project/{project_id}/models/summary/", 
                headers=HEADERS
            )
            status_response.raise_for_status()
            
            status_data = status_response.json()
            job_status = status_data.get("status")
            
            if job_status == "COMPLETED":
                models = status_data.get("models")
                if not models:
                    return None, "Job completed but no models found"
                
                model_id = models[0].get("model_id")
                
                # Download PDB
                pdb_response = requests.get(
                    f"{BASE_URL}/project/{project_id}/models/{model_id}.pdb",
                    headers=HEADERS
                )
                pdb_response.raise_for_status()
                
                # Save file
                output_filename = f"proteins/{project_id}_{model_id}.pdb"
                with open(output_filename, "w") as f:
                    f.write(pdb_response.text)
                
                return pdb_response.text, f"SWISS-MODEL completed. Saved to {output_filename}"
                
            elif job_status == "FAILED":
                return None, "SWISS-MODEL job failed"
            
            time.sleep(60)  # Wait 1 minute
        
        return None, "SWISS-MODEL job timed out"
        
    except Exception as e:
        return None, f"Error running SWISS-MODEL: {str(e)}"


# ============== CASTp Functions ==============

def run_castp_analysis(final_pdb_path):
    """Run CASTp pocket analysis on the final PDB structure"""
    
    if not final_pdb_path or not os.path.exists(final_pdb_path):
        return None, "No final structure available for CASTp analysis"
    
    status_messages = []
    
    # Clear old CASTp results
    castp_dir = "castp_results"
    if os.path.exists(castp_dir):
        try:
            shutil.rmtree(castp_dir)
            status_messages.append(f"Cleared old CASTp results")
        except Exception as e:
            status_messages.append(f"Warning: Could not clear CASTp directory - {e}")
    
    # Create output directory
    os.makedirs(castp_dir, exist_ok=True)
    status_messages.append("Running CASTp pocket analysis...")
    status_messages.append("This may take several minutes...")
    
    try:
        # Run CASTpFoldpy command
        result = subprocess.run([
            "castpfoldpy",
            "--submit-download",
            "-p", final_pdb_path,
            "-d", castp_dir,
            "--pocket"
        ], check=True, capture_output=True, text=True, timeout=600)
        
        status_messages.append("CASTp analysis completed successfully")
        
        # Find the subdirectory containing results (e.g., j_68fb5c3c3e38f)
        subdirs = [d for d in Path(castp_dir).iterdir() if d.is_dir()]
        
        if not subdirs:
            # Try to find .pocInfo in root castp_results directory
            pocinfo_files = list(Path(castp_dir).glob("*.pocInfo"))
        else:
            # Look in the subdirectory
            result_dir = subdirs[0]
            status_messages.append(f"Found results directory: {result_dir.name}")
            pocinfo_files = list(result_dir.glob("*.pocInfo"))
        
        if not pocinfo_files:
            return None, "\n".join(status_messages) + "\nNo .pocInfo file found in results"
        
        pocinfo_path = pocinfo_files[0]
        status_messages.append(f"Found pocket info file: {pocinfo_path.name}")
        
        # Parse pocInfo file
        df = parse_pocinfo_file(pocinfo_path)
        
        if df is not None and len(df) > 0:
            status_messages.append(f"Successfully parsed {len(df)} pockets")
            return df, "\n".join(status_messages)
        else:
            return None, "\n".join(status_messages) + "\nNo pocket data found or failed to parse"
            
    except subprocess.TimeoutExpired:
        return None, "\n".join(status_messages) + "\nCASTp analysis timed out (>10 minutes)"
    except subprocess.CalledProcessError as e:
        error_msg = f"CASTp command failed: {e.stderr if e.stderr else 'Unknown error'}"
        return None, "\n".join(status_messages) + f"\n{error_msg}"
    except Exception as e:
        return None, "\n".join(status_messages) + f"\nError: {str(e)}"


def parse_pocinfo_file(pocinfo_path):
    """Parse .pocInfo file and return as DataFrame"""
    try:
        data = []
        with open(pocinfo_path, 'r') as f:
            for line in f:
                if line.startswith("POC:"):
                    parts = line.strip().split('\t')
                    if len(parts) >= 11:  # Ensure we have all columns
                        # Skip the first "POC:" marker
                        row = {
                            'Molecule': parts[1],
                            'ID': parts[2],
                            'N_mth': parts[3],
                            'Area_sa': parts[4],
                            'Area_ms': parts[5],
                            'Vol_sa': parts[6],
                            'Vol_ms': parts[7],
                            'Length': parts[8],
                            'cnr': parts[9]
                        }
                        data.append(row)
        
        if data:
            df = pd.DataFrame(data)
            return df
        else:
            return None
            
    except Exception as e:
        print(f"Error parsing pocInfo file: {e}")
        return None


# ============== Main Functions ==============

def search_and_visualize_protein(protein_name, style="sphere", show_backbone=True):
    """Main function that combines search, download, and visualization"""
    
    if not protein_name or not protein_name.strip():
        return None, "Please enter a protein name", None, None, None

    status_messages = []
    status_messages.append(f"Searching for: {protein_name}")
    
    # Clean old folders before processing
    folders_to_clear = ["proteins", "my_analysis_folder", "castp_results"]
    for folder in folders_to_clear:
        if os.path.exists(folder):
            try:
                shutil.rmtree(folder)
                status_messages.append(f"Cleared old data in '{folder}'")
            except Exception as e:
                status_messages.append(f"Warning: Could not clear '{folder}' - {e}")

    # Step 1: Search for PDB ID
    pdb_id, search_msg = search_pdb_for_first_hit(protein_name)
    status_messages.append(search_msg)
    
    if not pdb_id:
        return None, "\n".join(status_messages), None, None, None
    
    # Step 2: Download PDB and FASTA files
    status_messages.append(f"Downloading structure data for {pdb_id}")
    pdb_content, pdb_msg, pdb_path = download_pdb_content(pdb_id)
    status_messages.append(pdb_msg)
    
    if not pdb_content:
        return None, "\n".join(status_messages), None, None, None
    
    fasta_content, fasta_msg, fasta_path = download_fasta_content(pdb_id)
    status_messages.append(fasta_msg)
    
    # Step 3: Parse and visualize
    try:
        status_messages.append("Parsing structure and generating visualization")
        atoms_data = parse_pdb_structure(pdb_content, pdb_id)
        fig = create_protein_visualization(atoms_data, style, show_backbone, "Original Structure")
        
        status_messages.append(f"Visualization complete - {len(atoms_data)} atoms displayed")
        
        return fig, "\n".join(status_messages), pdb_path, fasta_path, pdb_content
        
    except Exception as e:
        status_messages.append(f"Error: {str(e)}")
        return None, "\n".join(status_messages), None, None, None


def generate_ramplot_and_process(pdb_path, fasta_path, original_pdb_content, api_token, style, show_backbone):
    """Generate RAMPlot and optionally run SWISS-MODEL"""
    
    if not pdb_path:
        return None, None, None, None, "Please download a protein structure first", None, "No structure available", None
    
    status_messages = []
    status_messages.append("Running RAMPlot analysis...")
    
    # Run RAMPlot
    map_2d, map_3d, std_2d, std_3d, ramplot_status = run_ramplot()
    status_messages.append(ramplot_status)
    
    # Extract favoured percentage
    csv_path = "my_analysis_folder/Analysis.csv"
    favoured_percent = extract_favoured_percentage(csv_path)
    
    if favoured_percent is None:
        status_messages.append("Could not extract Favoured percentage from Analysis.csv")
        return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "Analysis incomplete", None
    
    status_messages.append(f"Favoured percentage: {favoured_percent}%")
    
    # Check if SWISS-MODEL is needed
    if favoured_percent >= 90.0:
        status_messages.append(f"Favoured percentage >= 90%. Retaining original structure.")
        
        try:
            atoms_data = parse_pdb_structure(original_pdb_content)
            fig = create_protein_visualization(atoms_data, style, show_backbone, "Final Structure (Original)")
            final_status = f"Favoured: {favoured_percent}% - Original structure retained"
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), fig, final_status, pdb_path
        except Exception as e:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, f"Error: {str(e)}", None
    
    else:
        status_messages.append(f"Favoured percentage < 90%. Running SWISS-MODEL...")
        
        if not api_token or "YOUR_API_TOKEN" in api_token:
            status_messages.append("Error: Please provide a valid SWISS-MODEL API token")
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "API token required", None
        
        # Run SWISS-MODEL
        new_pdb_content, swiss_msg = run_swiss_model(fasta_path, api_token)
        status_messages.append(swiss_msg)
        
        if not new_pdb_content:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "SWISS-MODEL failed", None
        
        try:
            # Find the new PDB file
            new_pdb_files = list(Path("proteins").glob("*.pdb"))
            new_pdb_path = str(new_pdb_files[-1]) if new_pdb_files else None
            
            atoms_data = parse_pdb_structure(new_pdb_content)
            fig = create_protein_visualization(atoms_data, style, show_backbone, "Final Structure (SWISS-MODEL)")
            final_status = f"Favoured: {favoured_percent}% - New structure from SWISS-MODEL"
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), fig, final_status, new_pdb_path
        except Exception as e:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, f"Error: {str(e)}", None


# ============== Gradio Interface ==============

with gr.Blocks(title="Protein Structure Viewer", theme=gr.themes.Default()) as demo:
    gr.Markdown("# Protein Structure Viewer")
    gr.Markdown("Search and visualize protein structures from RCSB Protein Data Bank")
    
    # Hidden state variables
    pdb_path_state = gr.State(None)
    fasta_path_state = gr.State(None)
    pdb_content_state = gr.State(None)
    final_pdb_path_state = gr.State(None)
    
    # Section 1: Protein Search and Visualization
    gr.Markdown("## Step 1: Search and Download Protein")
    
    with gr.Row():
        with gr.Column(scale=1):
            protein_input = gr.Textbox(
                label="Protein Name",
                placeholder="Enter protein name",
                lines=1
            )
            
            style_dropdown = gr.Dropdown(
                choices=["sphere", "stick", "ball-and-stick"],
                value="sphere",
                label="Visualization Style"
            )
            
            backbone_checkbox = gr.Checkbox(
                label="Show Backbone",
                value=True
            )
            
            search_btn = gr.Button("Process", variant="primary", size="lg")
            
            status_output = gr.Textbox(
                label="Status",
                lines=8,
                interactive=False
            )
    
        with gr.Column(scale=2):
            plot_output = gr.Plot(label="Original 3D Structure")
    
    # Section 2: RAMPlot Analysis
    gr.Markdown("## Step 2: RAMPlot Analysis")
    
    with gr.Row():
        with gr.Column(scale=1):
            api_token_input = gr.Textbox(
                label="SWISS-MODEL API Token (required if Favoured < 90%)",
                placeholder="Enter your API token",
                type="password",
                lines=1
            )
            
            ramplot_btn = gr.Button("Generate RAMPlot & Process", variant="primary", size="lg")
            
            ramplot_status = gr.Textbox(
                label="Analysis Status",
                lines=10,
                interactive=False
            )
        
        with gr.Column(scale=2):
            with gr.Row():
                plot_map_2d = gr.Image(label="MapType 2D All")
                plot_map_3d = gr.Image(label="MapType 3D All")
            with gr.Row():
                plot_std_2d = gr.Image(label="StdMapType 2D General Gly")
                plot_std_3d = gr.Image(label="StdMapType 3D General")
    
    # Section 3: Final Structure
    gr.Markdown("## Step 3: Final Structure")
    
    with gr.Row():
        with gr.Column(scale=1):
            final_status_output = gr.Textbox(
                label="Final Status",
                lines=3,
                interactive=False
            )
        
        with gr.Column(scale=2):
            final_plot_output = gr.Plot(label="Final 3D Structure")
    
    # Section 4: CASTp Pocket Analysis
    gr.Markdown("## Step 4: CASTp Pocket Analysis")
    
    with gr.Row():
        with gr.Column(scale=1):
            castp_btn = gr.Button("Run CASTp Analysis", variant="primary", size="lg")
            
            castp_status = gr.Textbox(
                label="CASTp Status",
                lines=8,
                interactive=False
            )
        
        with gr.Column(scale=2):
            castp_output = gr.Dataframe(
                label="Pocket Information",
                headers=["Molecule", "ID", "N_mth", "Area_sa", "Area_ms", 
                         "Vol_sa", "Vol_ms", "Length", "cnr"],
                interactive=False
            )
    
    # Connect buttons
    search_btn.click(
        fn=search_and_visualize_protein,
        inputs=[protein_input, style_dropdown, backbone_checkbox],
        outputs=[plot_output, status_output, pdb_path_state, fasta_path_state, pdb_content_state]
    )
    
    ramplot_btn.click(
        fn=generate_ramplot_and_process,
        inputs=[pdb_path_state, fasta_path_state, pdb_content_state, api_token_input, 
                style_dropdown, backbone_checkbox],
        outputs=[plot_map_2d, plot_map_3d, plot_std_2d, plot_std_3d, 
                 ramplot_status, final_plot_output, final_status_output, final_pdb_path_state]
    )
    
    castp_btn.click(
        fn=run_castp_analysis,
        inputs=[final_pdb_path_state],
        outputs=[castp_output, castp_status]
    )

if __name__ == "__main__":
    demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7868
* Running on public URL: https://c18f5962d9cef28d7a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


In [11]:
import gradio as gr
import plotly.graph_objects as go
import requests
from Bio.PDB import PDBParser, PDBIO, Select
from Bio.PDB.Polypeptide import is_aa
import numpy as np
import tempfile
import os
import subprocess
import csv
import re
import shutil
import time
import pandas as pd
from pathlib import Path
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

# ============== Search and Download Functions ==============

def search_pdb_for_first_hit(protein_name: str):
    """Search RCSB PDB and return the first result found."""
    query = {
        "query": {
            "type": "terminal",
            "service": "text",
            "parameters": {
                "attribute": "struct.title",
                "operator": "contains_phrase",
                "value": protein_name
            }
        },
        "return_type": "entry",
        "request_options": {
            "return_all_hits": False,
            "results_content_type": ["experimental"],
            "sort": [{"sort_by": "score", "direction": "desc"}]
        }
    }
    
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    
    try:
        response = requests.post(url, json=query, timeout=30)
        response.raise_for_status()
        data = response.json()
        
        result_set = data.get('result_set', [])
        if not result_set:
            return None, f"No structures found for '{protein_name}'"
            
        first_pdb_id = result_set[0]['identifier']
        return first_pdb_id, f"Found PDB ID: {first_pdb_id}"
        
    except requests.exceptions.RequestException as e:
        return None, f"Error searching PDB: {e}"


def download_pdb_content(pdb_id: str, output_dir: str = "proteins"):
    """Download PDB file content from RCSB PDB"""
    Path(output_dir).mkdir(exist_ok=True)
    pdb_id = pdb_id.strip().upper()
    url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
    output_file = Path(output_dir) / f"{pdb_id}.pdb"
    
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        with open(output_file, 'w') as f:
            f.write(response.text)
        
        return response.text, f"Downloaded PDB file ({len(response.text)/1024:.2f} KB)", str(output_file)
        
    except requests.exceptions.HTTPError as e:
        if e.response.status_code == 404:
            return None, f"Error: PDB ID '{pdb_id}' not found", None
        else:
            return None, f"HTTP Error: {e}", None
    except Exception as e:
        return None, f"Error downloading file: {e}", None


def download_fasta_content(pdb_id: str, output_dir: str = "proteins"):
    """Download FASTA sequence for a given PDB ID"""
    Path(output_dir).mkdir(exist_ok=True)
    pdb_id = pdb_id.strip().upper()
    url = f"https://www.rcsb.org/fasta/entry/{pdb_id}"
    output_file = Path(output_dir) / f"{pdb_id}.fasta"
    
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        
        if not response.text.strip().startswith(">"):
            return None, f"No valid FASTA data found for {pdb_id}", None
        
        with open(output_file, 'w') as f:
            f.write(response.text)
        
        return response.text, f"Downloaded FASTA sequence", str(output_file)
        
    except requests.exceptions.RequestException as e:
        return None, f"Error downloading FASTA: {e}", None


# ============== Visualization Functions ==============

def parse_pdb_structure(pdb_data, pdb_id="structure"):
    """Parse PDB data and extract atom coordinates"""
    parser = PDBParser(QUIET=True)
    
    with tempfile.NamedTemporaryFile(mode='w', suffix='.pdb', delete=False) as tmp:
        tmp.write(pdb_data)
        tmp_path = tmp.name
    
    try:
        structure = parser.get_structure(pdb_id, tmp_path)
        atoms_data = []
        
        for model in structure:
            for chain in model:
                chain_id = chain.id
                for residue in chain:
                    res_name = residue.get_resname()
                    res_id = residue.get_id()[1]
                    
                    for atom in residue:
                        coord = atom.get_coord()
                        atoms_data.append({
                            'coord': coord,
                            'element': atom.element,
                            'residue': res_name,
                            'res_id': res_id,
                            'chain': chain_id,
                            'name': atom.name
                        })
        
        os.unlink(tmp_path)
        return atoms_data
    
    except Exception as e:
        if os.path.exists(tmp_path):
            os.unlink(tmp_path)
        raise e


def get_atom_color(atom_type):
    """Return color for different atom types (CPK coloring)"""
    color_map = {
        'C': '#909090',  'N': '#3050F8',  'O': '#FF0D0D',  'S': '#FFFF30',
        'P': '#FF8000',  'H': '#FFFFFF',  'F': '#90E050',  'CL': '#1FF01F',
        'BR': '#A62929', 'I': '#940094',  'FE': '#E06633', 'CA': '#3DFF00',
    }
    return color_map.get(atom_type.upper(), '#FF1493')


def find_bonds(atoms_data, max_distance=2.0):
    """Find bonds between atoms based on distance"""
    bonds = []
    n_atoms = len(atoms_data)
    
    for i in range(n_atoms):
        atom1 = atoms_data[i]
        for j in range(i + 1, min(i + 20, n_atoms)):
            atom2 = atoms_data[j]
            
            if atom1['chain'] != atom2['chain']:
                continue
            if abs(atom1['res_id'] - atom2['res_id']) > 1:
                continue
            
            dist = np.linalg.norm(atom1['coord'] - atom2['coord'])
            
            if dist < max_distance:
                bonds.append((i, j))
    
    return bonds


def create_stick_traces(atoms_data, bonds):
    """Create stick/line traces for bonds"""
    traces = []
    
    for bond in bonds:
        i, j = bond
        atom1 = atoms_data[i]
        atom2 = atoms_data[j]
        
        coord1 = atom1['coord']
        coord2 = atom2['coord']
        
        trace = go.Scatter3d(
            x=[coord1[0], coord2[0]],
            y=[coord1[1], coord2[1]],
            z=[coord1[2], coord2[2]],
            mode='lines',
            line=dict(color='gray', width=3),
            showlegend=False,
            hoverinfo='skip'
        )
        traces.append(trace)
    
    return traces


def create_protein_visualization(atoms_data, style="sphere", show_backbone=True, title="Protein Structure"):
    """Create complete protein visualization"""
    coords = np.array([atom['coord'] for atom in atoms_data])
    atom_types = [atom['element'] for atom in atoms_data]
    colors = [get_atom_color(atom) for atom in atom_types]
    
    fig = go.Figure()
    
    # Add backbone trace
    if show_backbone:
        ca_atoms = [i for i, atom in enumerate(atoms_data) if atom['name'] == 'CA']
        if len(ca_atoms) > 1:
            ca_coords = coords[ca_atoms]
            backbone_trace = go.Scatter3d(
                x=ca_coords[:, 0],
                y=ca_coords[:, 1],
                z=ca_coords[:, 2],
                mode='lines',
                line=dict(color='lightblue', width=6),
                name='Backbone',
                hoverinfo='skip'
            )
            fig.add_trace(backbone_trace)
    
    # For stick visualization, add bond lines
    if style == "stick" or style == "ball-and-stick":
        bonds = find_bonds(atoms_data)
        stick_traces = create_stick_traces(atoms_data, bonds)
        for trace in stick_traces:
            fig.add_trace(trace)
    
    # Add atoms
    marker_size = 8 if style == "sphere" else 3 if style == "stick" else 5
    
    atoms_trace = go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers',
        marker=dict(
            size=marker_size,
            color=colors,
            line=dict(width=0.5, color='white')
        ),
        text=[f"{atom['name']} ({atom['element']}) - {atom['residue']}{atom['res_id']}" 
              for atom in atoms_data],
        hovertemplate='<b>%{text}</b><br>X: %{x:.2f}<br>Y: %{y:.2f}<br>Z: %{z:.2f}<extra></extra>',
        name='Atoms'
    )
    fig.add_trace(atoms_trace)
    
    # Update layout
    fig.update_layout(
        title=f"{title} ({len(atoms_data)} atoms)",
        scene=dict(
            xaxis_title='X (Å)',
            yaxis_title='Y (Å)',
            zaxis_title='Z (Å)',
            bgcolor='white',
            xaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            yaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            zaxis=dict(showbackground=True, backgroundcolor='rgb(230, 230, 230)'),
            aspectmode='data'
        ),
        showlegend=True,
        width=900,
        height=700,
        margin=dict(l=0, r=0, t=40, b=0)
    )
    
    return fig


# ============== RAMPlot Functions ==============

def run_ramplot(input_folder="proteins", output_folder="my_analysis_folder"):
    """Run ramplot command and return plot images"""
    Path(output_folder).mkdir(exist_ok=True)
    
    cmd = [
        "ramplot", "pdb",
        "-i", input_folder,
        "-o", output_folder,
        "-m", "0",
        "-r", "600",
        "-p", "png"
    ]
    
    try:
        result = subprocess.run(cmd, check=True, text=True, capture_output=True)
        
        # Find generated plots
        plots_folder = Path(output_folder) / "Plots"
        if not plots_folder.exists():
            return None, None, None, None, "No plots folder generated"
        
        plot_files = {
            'map_2d': None,
            'map_3d': None,
            'std_2d': None,
            'std_3d': None
        }
        
        for file in plots_folder.glob("*.png"):
            if "MapType2DAll" in file.name:
                plot_files['map_2d'] = str(file)
            elif "MapType3DAll" in file.name:
                plot_files['map_3d'] = str(file)
            elif "StdMapType2DGeneralGly" in file.name:
                plot_files['std_2d'] = str(file)
            elif "StdMapType3DGeneral" in file.name:
                plot_files['std_3d'] = str(file)
        
        return (plot_files['map_2d'], plot_files['map_3d'], 
                plot_files['std_2d'], plot_files['std_3d'], 
                "RAMPlot completed successfully")
        
    except subprocess.CalledProcessError as e:
        return None, None, None, None, f"Error running RAMPlot: {e.stderr}"
    except Exception as e:
        return None, None, None, None, f"Error: {str(e)}"


def extract_favoured_percentage(csv_path):
    """Extract the Favoured percentage from CSV file"""
    try:
        with open(csv_path, 'r') as f:
            content = f.read()
            match = re.search(r'Favoured:\s*,\d+,\((\d+\.?\d*)%\)', content)
            
            if match:
                return float(match.group(1))
            else:
                return None
                
    except Exception as e:
        return None


# ============== SWISS-MODEL Functions ==============

def run_swiss_model(fasta_path, api_token, project_title="Homology_Model"):
    """Run SWISS-MODEL homology modeling"""
    BASE_URL = "https://swissmodel.expasy.org"
    HEADERS = {"Authorization": f"Token {api_token}"}
    
    try:
        # Read FASTA file
        sequences = []
        current_sequence = []
        
        with open(fasta_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if current_sequence:
                        sequences.append("".join(current_sequence))
                        current_sequence = []
                else:
                    if line:
                        current_sequence.append(line)
            
            if current_sequence:
                sequences.append("".join(current_sequence))
        
        if not sequences:
            return None, "No valid sequences found in FASTA file"
        
        FASTA_INPUT = sequences[0] if len(sequences) == 1 else sequences
        
        # Submit job
        payload = {
            "target_sequences": FASTA_INPUT,
            "project_title": project_title
        }
        
        submit_response = requests.post(
            f"{BASE_URL}/automodel/", 
            headers=HEADERS, 
            json=payload
        )
        submit_response.raise_for_status()
        
        project_id = submit_response.json().get("project_id")
        
        # Poll for results
        max_attempts = 60  # 60 minutes max
        for attempt in range(max_attempts):
            status_response = requests.get(
                f"{BASE_URL}/project/{project_id}/models/summary/", 
                headers=HEADERS
            )
            status_response.raise_for_status()
            
            status_data = status_response.json()
            job_status = status_data.get("status")
            
            if job_status == "COMPLETED":
                models = status_data.get("models")
                if not models:
                    return None, "Job completed but no models found"
                
                model_id = models[0].get("model_id")
                
                # Download PDB
                pdb_response = requests.get(
                    f"{BASE_URL}/project/{project_id}/models/{model_id}.pdb",
                    headers=HEADERS
                )
                pdb_response.raise_for_status()
                
                # Save file
                output_filename = f"proteins/{project_id}_{model_id}.pdb"
                with open(output_filename, "w") as f:
                    f.write(pdb_response.text)
                
                return pdb_response.text, f"SWISS-MODEL completed. Saved to {output_filename}"
                
            elif job_status == "FAILED":
                return None, "SWISS-MODEL job failed"
            
            time.sleep(60)  # Wait 1 minute
        
        return None, "SWISS-MODEL job timed out"
        
    except Exception as e:
        return None, f"Error running SWISS-MODEL: {str(e)}"


# ============== PrankWeb Functions ==============

def run_prankweb_analysis(pdb_path):
    """Run PrankWeb pocket prediction using Selenium"""
    
    if not pdb_path or not os.path.exists(pdb_path):
        return None, "No PDB file available for PrankWeb analysis"
    
    status_messages = []
    
    # Clear old PrankWeb results
    prankweb_dir = "prankweb_results"
    if os.path.exists(prankweb_dir):
        try:
            shutil.rmtree(prankweb_dir)
            status_messages.append(f"Cleared old PrankWeb results")
        except Exception as e:
            status_messages.append(f"Warning: Could not clear PrankWeb directory - {e}")
    
    os.makedirs(prankweb_dir, exist_ok=True)
    absolute_path = os.path.abspath(pdb_path)
    
    status_messages.append("Starting PrankWeb analysis...")
    status_messages.append("Opening browser (headless mode)...")
    
    # Setup Chrome driver with download preferences
    chrome_options = webdriver.ChromeOptions()
    chrome_options.add_argument('--headless=new')
    chrome_options.add_argument('--no-sandbox')
    chrome_options.add_argument('--disable-dev-shm-usage')
    chrome_options.add_argument('--window-size=1920,1080')
    prefs = {
        "download.default_directory": os.path.abspath(prankweb_dir),
        "download.prompt_for_download": False,
    }
    chrome_options.add_experimental_option("prefs", prefs)
    
    driver = None
    try:
        driver = webdriver.Chrome(options=chrome_options)
        
        status_messages.append("Opening PrankWeb...")
        driver.get("https://prankweb.cz/")
        time.sleep(3)
        
        # Click "Custom structure" radio button
        status_messages.append("Selecting custom structure option...")
        wait = WebDriverWait(driver, 30)
        custom_structure = wait.until(EC.presence_of_element_located(
            (By.XPATH, "//*[contains(text(), 'Custom structure')]")))
        driver.execute_script("arguments[0].click();", custom_structure)
        time.sleep(1)
        
        # Upload file
        status_messages.append(f"Uploading file: {os.path.basename(pdb_path)}")
        file_input = driver.find_element(By.CSS_SELECTOR, "input[type='file']")
        file_input.send_keys(absolute_path)
        time.sleep(2)
        
        # Submit
        status_messages.append("Submitting analysis...")
        submit_btn = wait.until(EC.presence_of_element_located(
            (By.CSS_SELECTOR, "button[type='submit']")))
        driver.execute_script("arguments[0].click();", submit_btn)
        
        # Wait for results
        status_messages.append("Waiting for results (this may take several minutes)...")
        wait_long = WebDriverWait(driver, 600)  # 10 minutes
        info_tab = wait_long.until(EC.presence_of_element_located(
            (By.XPATH, "//*[contains(text(), 'Info')]")))
        
        status_messages.append("Results ready! Downloading data...")
        driver.execute_script("arguments[0].click();", info_tab)
        time.sleep(2)
        
        # Download
        download_btn = wait_long.until(EC.presence_of_element_located(
            (By.XPATH, "//*[contains(text(), 'Download prediction data')]")))
        driver.execute_script("arguments[0].click();", download_btn)
        
        status_messages.append("Download started, waiting for completion...")
        time.sleep(10)  # Wait for download
        
        # Find and extract the ZIP file
        zip_files = list(Path(prankweb_dir).glob("*.zip"))
        if not zip_files:
            return None, "\n".join(status_messages) + "\nNo ZIP file downloaded"
        
        zip_path = zip_files[0]
        status_messages.append(f"Extracting {zip_path.name}...")
        
        import zipfile
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(prankweb_dir)
        
        # Find predictions CSV
        csv_files = list(Path(prankweb_dir).rglob("*predictions.csv"))
        if not csv_files:
            return None, "\n".join(status_messages) + "\nNo predictions CSV found"
        
        csv_path = csv_files[0]
        status_messages.append(f"Found predictions: {csv_path.name}")
        
        # Parse CSV
        df = pd.read_csv(csv_path, skipinitialspace=True)
        status_messages.append(f"Successfully parsed {len(df)} pockets")
        
        return df, "\n".join(status_messages)
        
    except Exception as e:
        import traceback
        error_msg = f"\nError: {str(e)}\n{traceback.format_exc()}"
        return None, "\n".join(status_messages) + error_msg
        
    finally:
        if driver:
            driver.quit()


# ============== Protein Processing Functions ==============

class ProteinSelector(Select):
    """Selector to keep only protein atoms"""
    
    def __init__(self, keep_residues=None):
        self.keep_residues = keep_residues if keep_residues else []
    
    def accept_residue(self, residue):
        if is_aa(residue, standard=True):
            return True
        if residue.get_resname().strip() in self.keep_residues:
            return True
        return False


def prepare_protein_for_docking(pdb_path):
    """Prepare protein for molecular docking"""
    
    if not pdb_path or not os.path.exists(pdb_path):
        return None, None, "No PDB file available for processing"
    
    status_messages = []
    output_dir = Path("prepared_protein_meeko")
    
    # Clear old results
    if output_dir.exists():
        try:
            shutil.rmtree(output_dir)
            status_messages.append("Cleared old protein preparation results")
        except Exception as e:
            status_messages.append(f"Warning: Could not clear directory - {e}")
    
    output_dir.mkdir(exist_ok=True)
    
    try:
        # Step 1: Clean structure
        status_messages.append("Step 1: Removing water molecules...")
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", pdb_path)
        
        cleaned_pdb = output_dir / "cleaned_protein.pdb"
        io = PDBIO()
        io.set_structure(structure)
        selector = ProteinSelector()
        io.save(str(cleaned_pdb), selector)
        status_messages.append(f"✓ Cleaned protein saved")
        
        # Step 2: Add hydrogens with OpenBabel
        status_messages.append("Step 2: Adding hydrogen atoms...")
        
        try:
            from openbabel import openbabel as ob
            
            hydrogenated_pdb = output_dir / "protein_with_H.pdb"
            
            obConversion = ob.OBConversion()
            obConversion.SetInAndOutFormats("pdb", "pdb")
            
            mol = ob.OBMol()
            obConversion.ReadFile(mol, str(cleaned_pdb))
            mol.AddHydrogens(False, True, 7.4)
            obConversion.WriteFile(mol, str(hydrogenated_pdb))
            
            status_messages.append(f"✓ Hydrogens added")
            
            # Step 3: Convert to PDBQT
            status_messages.append("Step 3: Converting to PDBQT format...")
            
            final_pdbqt = output_dir / "protein_prepared.pdbqt"
            
            obConversion.SetInAndOutFormats("pdb", "pdbqt")
            mol2 = ob.OBMol()
            obConversion.ReadFile(mol2, str(hydrogenated_pdb))
            mol2.AddHydrogens()
            
            charge_model = ob.OBChargeModel.FindType("gasteiger")
            if charge_model:
                charge_model.ComputeCharges(mol2)
                status_messages.append("✓ Gasteiger charges assigned")
            
            obConversion.WriteFile(mol2, str(final_pdbqt))
            status_messages.append(f"✓ PDBQT file created")
            
            # Read and visualize hydrogenated structure
            with open(hydrogenated_pdb, 'r') as f:
                h_content = f.read()
            
            atoms_data = parse_pdb_structure(h_content)
            fig = create_protein_visualization(
                atoms_data, 
                style="ball-and-stick", 
                show_backbone=True, 
                title="Protein with Hydrogens"
            )
            
            status_messages.append("✓ Protein preparation complete!")
            status_messages.append(f"Final PDBQT: {final_pdbqt}")
            
            return fig, str(final_pdbqt), "\n".join(status_messages)
            
        except ImportError:
            status_messages.append("ERROR: OpenBabel not installed")
            status_messages.append("Install with: pip install openbabel-wheel")
            return None, None, "\n".join(status_messages)
            
    except Exception as e:
        import traceback
        error_msg = f"\nError: {str(e)}\n{traceback.format_exc()}"
        return None, None, "\n".join(status_messages) + error_msg


# ============== Main Functions ==============

def search_and_visualize_protein(protein_name, style="sphere", show_backbone=True):
    """Main function that combines search, download, and visualization"""
    
    if not protein_name or not protein_name.strip():
        return None, "Please enter a protein name", None, None, None

    status_messages = []
    status_messages.append(f"Searching for: {protein_name}")
    
    # Clean old folders
    folders_to_clear = ["proteins", "my_analysis_folder", "prankweb_results", "prepared_protein_meeko"]
    for folder in folders_to_clear:
        if os.path.exists(folder):
            try:
                shutil.rmtree(folder)
                status_messages.append(f"Cleared old data in '{folder}'")
            except Exception as e:
                status_messages.append(f"Warning: Could not clear '{folder}' - {e}")

    # Search and download
    pdb_id, search_msg = search_pdb_for_first_hit(protein_name)
    status_messages.append(search_msg)
    
    if not pdb_id:
        return None, "\n".join(status_messages), None, None, None
    
    status_messages.append(f"Downloading structure data for {pdb_id}")
    pdb_content, pdb_msg, pdb_path = download_pdb_content(pdb_id)
    status_messages.append(pdb_msg)
    
    if not pdb_content:
        return None, "\n".join(status_messages), None, None, None
    
    fasta_content, fasta_msg, fasta_path = download_fasta_content(pdb_id)
    status_messages.append(fasta_msg)
    
    # Visualize
    try:
        status_messages.append("Parsing structure and generating visualization")
        atoms_data = parse_pdb_structure(pdb_content, pdb_id)
        fig = create_protein_visualization(atoms_data, style, show_backbone, "Original Structure")
        
        status_messages.append(f"Visualization complete - {len(atoms_data)} atoms displayed")
        
        return fig, "\n".join(status_messages), pdb_path, fasta_path, pdb_content
        
    except Exception as e:
        status_messages.append(f"Error: {str(e)}")
        return None, "\n".join(status_messages), None, None, None


def generate_ramplot_and_process(pdb_path, fasta_path, original_pdb_content, api_token, style, show_backbone):
    """Generate RAMPlot and optionally run SWISS-MODEL"""
    
    if not pdb_path:
        return None, None, None, None, "Please download a protein structure first", None, "No structure available", None
    
    status_messages = []
    status_messages.append("Running RAMPlot analysis...")
    
    # Run RAMPlot
    map_2d, map_3d, std_2d, std_3d, ramplot_status = run_ramplot()
    status_messages.append(ramplot_status)
    
    # Extract favoured percentage
    csv_path = "my_analysis_folder/Analysis.csv"
    favoured_percent = extract_favoured_percentage(csv_path)
    
    if favoured_percent is None:
        status_messages.append("Could not extract Favoured percentage from Analysis.csv")
        return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "Analysis incomplete", None
    
    status_messages.append(f"Favoured percentage: {favoured_percent}%")
    
    # Check if SWISS-MODEL is needed
    if favoured_percent >= 90.0:
        status_messages.append(f"Favoured percentage >= 90%. Retaining original structure.")
        
        try:
            atoms_data = parse_pdb_structure(original_pdb_content)
            fig = create_protein_visualization(atoms_data, style, show_backbone, "Final Structure (Original)")
            final_status = f"Favoured: {favoured_percent}% - Original structure retained"
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), fig, final_status, pdb_path
        except Exception as e:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, f"Error: {str(e)}", None
    
    else:
        status_messages.append(f"Favoured percentage < 90%. Running SWISS-MODEL...")
        
        if not api_token or "YOUR_API_TOKEN" in api_token:
            status_messages.append("Error: Please provide a valid SWISS-MODEL API token")
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "API token required", None
        
        # Run SWISS-MODEL
        new_pdb_content, swiss_msg = run_swiss_model(fasta_path, api_token)
        status_messages.append(swiss_msg)
        
        if not new_pdb_content:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, "SWISS-MODEL failed", None
        
        try:
            # Find the new PDB file
            new_pdb_files = list(Path("proteins").glob("*.pdb"))
            new_pdb_path = str(new_pdb_files[-1]) if new_pdb_files else None
            
            atoms_data = parse_pdb_structure(new_pdb_content)
            fig = create_protein_visualization(atoms_data, style, show_backbone, "Final Structure (SWISS-MODEL)")
            final_status = f"Favoured: {favoured_percent}% - New structure from SWISS-MODEL"
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), fig, final_status, new_pdb_path
        except Exception as e:
            return map_2d, map_3d, std_2d, std_3d, "\n".join(status_messages), None, f"Error: {str(e)}", None


# ============== Gradio Interface ==============

with gr.Blocks(title="Protein Structure Viewer", theme=gr.themes.Default()) as demo:
    gr.Markdown("# Protein Structure Viewer & Analysis Pipeline")
    gr.Markdown("Complete protein analysis workflow: Search → RAMPlot → PrankWeb → Protein Preparation")
    
    # Hidden state variables
    pdb_path_state = gr.State(None)
    fasta_path_state = gr.State(None)
    pdb_content_state = gr.State(None)
    final_pdb_path_state = gr.State(None)
    
    # Section 1: Protein Search and Visualization
    gr.Markdown("## Step 1: Search and Download Protein")
    
    with gr.Row():
        with gr.Column(scale=1):
            protein_input = gr.Textbox(
                label="Protein Name",
                placeholder="Enter protein name (e.g., hemoglobin, insulin)",
                lines=1
            )
            
            style_dropdown = gr.Dropdown(
                choices=["sphere", "stick", "ball-and-stick"],
                value="sphere",
                label="Visualization Style"
            )
            
            backbone_checkbox = gr.Checkbox(
                label="Show Backbone",
                value=True
            )
            
            search_btn = gr.Button("🔍 Search & Download", variant="primary", size="lg")
            
            status_output = gr.Textbox(
                label="Status",
                lines=8,
                interactive=False
            )
    
        with gr.Column(scale=2):
            plot_output = gr.Plot(label="Original 3D Structure")
    
    # Section 2: RAMPlot Analysis
    gr.Markdown("## Step 2: RAMPlot Quality Analysis")
    
    with gr.Row():
        with gr.Column(scale=1):
            api_token_input = gr.Textbox(
                label="SWISS-MODEL API Token (required if Favoured < 90%)",
                placeholder="Enter your API token",
                type="password",
                lines=1
            )
            
            ramplot_btn = gr.Button("📊 Generate RAMPlot & Process", variant="primary", size="lg")
            
            ramplot_status = gr.Textbox(
                label="Analysis Status",
                lines=10,
                interactive=False
            )
        
        with gr.Column(scale=2):
            with gr.Row():
                plot_map_2d = gr.Image(label="MapType 2D All")
                plot_map_3d = gr.Image(label="MapType 3D All")
            with gr.Row():
                plot_std_2d = gr.Image(label="StdMapType 2D General Gly")
                plot_std_3d = gr.Image(label="StdMapType 3D General")
    
    # Section 3: Final Structure
    gr.Markdown("## Step 3: Final Structure")
    
    with gr.Row():
        with gr.Column(scale=1):
            final_status_output = gr.Textbox(
                label="Final Status",
                lines=3,
                interactive=False
            )
        
        with gr.Column(scale=2):
            final_plot_output = gr.Plot(label="Final 3D Structure")
    
    # Section 4: PrankWeb Pocket Analysis
    gr.Markdown("## Step 4: PrankWeb Pocket Prediction")
    
    with gr.Row():
        with gr.Column(scale=1):
            prankweb_btn = gr.Button("🔬 Run PrankWeb Analysis", variant="primary", size="lg")
            
            prankweb_status = gr.Textbox(
                label="PrankWeb Status",
                lines=8,
                interactive=False
            )
        
        with gr.Column(scale=2):
            prankweb_output = gr.Dataframe(
                label="Pocket Predictions",
                headers=["name", "rank", "score", "probability", "sas_points", 
                         "surf_atoms", "center_x", "center_y", "center_z", 
                         "residue_ids", "surf_atom_ids"],
                interactive=False
            )
    
    # Section 5: Protein Preparation for Docking
    gr.Markdown("## Step 5: Protein Preparation for Molecular Docking")
    
    with gr.Row():
        with gr.Column(scale=1):
            prep_btn = gr.Button("⚗️ Prepare Protein for Docking", variant="primary", size="lg")
            
            prep_status = gr.Textbox(
                label="Preparation Status",
                lines=10,
                interactive=False
            )
            
            pdbqt_path_output = gr.Textbox(
                label="PDBQT File Path",
                lines=1,
                interactive=False
            )
        
        with gr.Column(scale=2):
            prep_plot_output = gr.Plot(label="Protein with Hydrogens")
    
    # Connect buttons
    search_btn.click(
        fn=search_and_visualize_protein,
        inputs=[protein_input, style_dropdown, backbone_checkbox],
        outputs=[plot_output, status_output, pdb_path_state, fasta_path_state, pdb_content_state]
    )
    
    ramplot_btn.click(
        fn=generate_ramplot_and_process,
        inputs=[pdb_path_state, fasta_path_state, pdb_content_state, api_token_input, 
                style_dropdown, backbone_checkbox],
        outputs=[plot_map_2d, plot_map_3d, plot_std_2d, plot_std_3d, 
                 ramplot_status, final_plot_output, final_status_output, final_pdb_path_state]
    )
    
    prankweb_btn.click(
        fn=run_prankweb_analysis,
        inputs=[final_pdb_path_state],
        outputs=[prankweb_output, prankweb_status]
    )
    
    prep_btn.click(
        fn=prepare_protein_for_docking,
        inputs=[final_pdb_path_state],
        outputs=[prep_plot_output, pdbqt_path_output, prep_status]
    )

if __name__ == "__main__":
    demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://7eafd585d3c1f45d6f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders (title is prepared_protein_meeko/cleaned_protein.pdb)

