In [1]:
!pip install -q biopython
!pip install -q dash jupyter-dash dash-bootstrap-components
from IPython.display import clear_output

import dash
from dash import html, dcc, Input, Output, State, ctx, MATCH, ALL
 

from Bio.PDB import PDBParser, MMCIFParser
import os
import numpy as np
import dash
import dash_bootstrap_components as dbc
import numpy as np
import plotly.graph_objects as go
 

In [2]:



app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
app.config.suppress_callback_exceptions = True
modes = ["Unconditional", "Binder Design", "Motif Scaffolding", "Partial Diffusion"]

app.layout = dbc.Container(
    [
        html.H3("RFdiffusion Generator"),
        dbc.Label("Select Mode"),
        dcc.Dropdown(modes, id="mode", value="Unconditional"),
        html.Hr(),
        # Dynamic fields
        html.Div(id="form-fields"),
        html.Hr(),
        html.H4("Run Parameters"),
        dbc.Row([
            dbc.Col([
                dbc.Label("Iterations"),
                dcc.Slider(
                    id="iterations",
                    min=25,
                    max=200,
                    step=25,
                    marks={i: str(i) for i in range(25, 201, 25)},
                    value=50,
                    tooltip={"placement": "bottom", "always_visible": True}
                ),
            ], width=4),
            dbc.Col([
                dbc.Label("Number of Designs"),
                dcc.Slider(
                    id="num-designs",
                    min=1,
                    max=32,
                    step=None,
                    marks={2**i: str(2**i) for i in range(6)},
                    value=1,
                    tooltip={"placement": "bottom", "always_visible": True}
                ),
            ], width=4),
            dbc.Col([
                dbc.Label("Visualization"),
                dcc.Dropdown(
                    id="visualization",
                    options=[
                        {"label": "None", "value": "none"},
                        {"label": "Image", "value": "image"},
                        {"label": "Interactive", "value": "interactive"}
                    ],
                    value="image",
                    clearable=False
                ),
            ], width=4),
        ]),
        dbc.Row([
            dbc.Col([
                dbc.Label("Use Beta Model"),
                dbc.Checklist(
                    id="use-beta-model",
                    options=[{"label": "Yes", "value": True}],
                    value=[False] ,
                    inline=True,
                    switch=True
                ),
                dbc.Label('if you are seeing lots of helices, switch to the "beta" params for a better SSE balance.' , style={"font": "0.8em"}),
            ], width=4),
        ]),
        html.Hr(),
        html.Div(id="runButtonSpace"),
       
        html.Br(),
        html.Br(),
        html.Div(id="output"),
        dcc.Store(id="contigs-list", data=[]),
        dcc.Store(id="chain-list", data=[]),
        dcc.Store(id="segments-store", data=[{ 'segment-type':'none', "free-length":1, 'fixed-chain':'', 'fixed-range':'' }]),
    ]
)

def segment_row(index,segment, chainList):
    # Get values from the index dictionary
    segment_type = segment.get('segment-type', 'none')
    free_length = segment.get('free-length', '')
    fixed_chain = segment.get('fixed-chain', '')
    fixed_range = segment.get('fixed-range', '')
    
  
    
    # Set style based on segment type
    free_length_style = {} if segment_type == 'free' else {"display": "none"}
    fixed_styles = {} if segment_type == 'fixed' else {"display": "none"}
    
    return html.Div(
        id={"type": "segment-row", "index": index},
        children=[
           
            dbc.Row(
                [
                    dbc.Col(
                        dcc.Dropdown(
                            id={"type": "segment-type", "index": index},
                            options=[
                                {"label": "Select", "value": "none"},
                                {"label": "Free Region", "value": "free"},
                                {"label": "Fixed Motif", "value": "fixed"},
                                {"label": "Chain Break", "value": "break"},
                            ],
                            value=segment_type,
                            clearable=False,
                        ),
                        width=2,
                    ),
                    dbc.Col(
                        dbc.Input(
                            id={"type": "free-length", "index": index},
                            type="text",
                            placeholder="Fixed Length or Random Range (e.g. 100 or 100-250)",
                            value=free_length,
                            style=free_length_style
                        ),
                        width=3,
                    ),
                    dbc.Col(
                        [
                            dcc.Dropdown(
                                id={"type": "fixed-chain", "index": index},
                                options=[
                                    {"label": c, "value": c} for c in [chain["id"] for chain in chainList]
                                ],
                                placeholder="Chain",
                                value=fixed_chain,
                                style=fixed_styles
                            ),
                            dbc.Input(
                                id={"type": "fixed-range", "index": index},
                                type="text",
                                placeholder="Fixed Residues (e.g. 10-25)",
                                value=fixed_range,
                                style=fixed_styles
                            ),
                        ],
                        width=4,
                    ),
                    dbc.Col(
                        dbc.Button(
                            "Delete",
                            id={"type": "delete-segment", "index": index},
                            color="danger",
                            n_clicks=0
                        ),
                        width=1,
                    ),
                ],
                className="mb-2",
            )
        ],
    )


def contigBuilder(segments,chainList):
   
    return [html.Div(
        [
            html.Div(id="segments-container", children=
                   [   dbc.Row([
                dbc.Col( dbc.Label("Segment Type"), width=2),
                dbc.Col( dbc.Label("Generate Length"), width=3),
                dbc.Col( dbc.Label("Fixed Residues"), width=4),
                ])]+                     [segment_row(i,seg,chainList) for i,seg in enumerate(  segments)]),
            dbc.Button(
                "Add Segment", id="add-segment", color="primary", className="mt-2",n_clicks=0
            ),
            html.Hr(),
            html.H5("Current Contigs String:"),
            html.Div(
                id="contigs-preview",
                style={"fontFamily": "monospace", "fontSize": "1.2em"},
            ),
        ]
    )]
    
@app.callback(
    Output("segments-container", "children"),
    Input("segments-store", "data"),
    Input("chain-list", "data"),
)
def render_segments(segments,chainList):
    return  [   dbc.Row([
                dbc.Col( dbc.Label("Segment Type"), width=2),
                dbc.Col( dbc.Label("Generate Length"), width=3),
                dbc.Col( dbc.Label("Fixed Residues"), width=4),
                ])]+       [segment_row(i,seg,chainList) for i,seg in enumerate(  segments)]
    
@app.callback(
    Output("segments-store", "data"),
    Output("add-segment", "n_clicks"),
    Output({"type": "delete-segment", "index": ALL}, "n_clicks"),
    Output("contigs-preview", "children"),
    [
        Input("add-segment", "n_clicks"),
        Input({"type": "delete-segment", "index": ALL}, "n_clicks"),
        Input({"type": "segment-type", "index": ALL}, "value"),
        Input({"type": "free-length", "index": ALL}, "value"),
        Input({"type": "fixed-chain", "index": ALL}, "value"),
        Input({"type": "fixed-range", "index": ALL}, "value"),
    ],
    [State("segments-store", "data")],
    prevent_initial_call=True,
)
def update_segments(
    add_clicks,
    delete_clicks,
    seg_types,
    free_lengths,
    fixed_chains,
    fixed_ranges,
    segments,
):
    triggered = ctx.triggered_id

    if not segments:
        segments = []

    # --- Delete ---
    if np.sum(delete_clicks)>0 and isinstance(triggered, dict) and triggered["type"] == "delete-segment":
        segments = [s for s in segments if s["id"] != triggered["index"]]
        for i, seg in enumerate(segments):
            seg["id"] = i
    # --- Add ---
    elif add_clicks>0 and triggered == "add-segment":
        segments.append({
            'id': len(segments),
            "segment-type": "none",
            "free-length": "",
            "fixed-chain": "",
            "fixed-range": "",
        })
    # --- Update fields ---
    else:
        for i, seg in enumerate(segments):
            seg["id"] = i
            seg["segment-type"] = seg_types[i]
            seg["free-length"] = free_lengths[i]
            seg["fixed-chain"] = fixed_chains[i]
            seg["fixed-range"] = fixed_ranges[i]

    # --- Build preview string ---
    preview_list = []
    for seg in segments:
        if seg["segment-type"] == "free" and seg["free-length"]:
            preview_list.append(f"{seg['free-length']}")
        elif seg["segment-type"] == "fixed" and seg["fixed-chain"] and seg["fixed-range"]:
            preview_list.append(f"{seg['fixed-chain']}{seg['fixed-range']}")
        elif seg["segment-type"] == "break":
            preview_list.append("0 ")
    contig_str = "/".join(preview_list)

    return segments, 0, [0] * len(delete_clicks), contig_str

@app.callback(
    Output("form-fields", "children"),
    Output("runButtonSpace", "children"),
    Input("mode", "value"),
    State("segments-store", "data"),
    State("chain-list", "data")
)
def display_fields(mode, segments, chainList):
   
    base_fields = []
    if mode in ["Binder Design", "Motif Scaffolding", "Partial Diffusion"]:
        base_fields.extend(
            [
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Label("PDB File"),
                                dcc.Input(
                                    id="pdb",
                                    type="text",
                                    placeholder="path/to/file.pdb",
                                    value=r"C:\Users\bashc\Desktop\working\working2_Combin4.pdb",
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                html.Button(
                                    "Analyze PDB", id="analyze-pdb", n_clicks=0
                                ),
                            ]
                        ),
                    ]
                ),
                html.Hr(),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                html.Div(id="pdb-analysis-output"),
                            ]
                        ),
                    ]
                ),
            ]
        )

    if mode in ["Binder Design", "Motif Scaffolding"]:
        base_fields.extend(contigBuilder(segments,chainList))
    elif mode !="Partial Diffusion":
        base_fields.append(
            dbc.Row(
                [
                    dbc.Col(
                        [dbc.Label("Sequence Length (Residues or Length Range)"), dbc.Input(id="contigs", placeholder="100 or 50-200", type="text")]
                    ),
                ]
            ),
        )

    if mode == "Binder Design":
        base_fields.append(
            dbc.Row(
                [dbc.Col([dbc.Label("Hotspot"), dbc.Input(id="hotspot", type="text")])]
            )
        )
        
    if mode =="Partial Diffusion":
        base_fields.append(
            dbc.Row(
            [dbc.Col([
                dbc.Label("partial_T"),
                dcc.Dropdown(
                id="partial_T",
                options=[
                    {"label": "Auto", "value": "auto"},
                    {"label": "10", "value": "10"},
                    {"label": "20", "value": "20"},
                    {"label": "40", "value": "40"},
                    {"label": "60", "value": "60"},
                    {"label": "80", "value": "80"},
                ],
                value="auto",
                clearable=False
                ),
                dbc.Label("Specify number of noising steps", style={"fontSize": "0.5em", "color": "gray"}),
            ])]
            )
        )
        
    if mode == "Unconditional":
        base_fields.extend(
            [
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Label("Symmetry"),
                                dcc.Dropdown(
                                    ["none", "auto", "cyclic", "dihedral"],
                                    id="symmetry",
                                    value="none",
                                ),
                                dbc.Label("symmetry='auto' enables automatic symmetry dectection with AnAnaS",style={"fontSize": "0.5em", "color": "gray"}),
                            ]
                        ),
                        dbc.Col(
                            [
                                dbc.Label("Order"),
                                dcc.Dropdown(
                                    [str(i) for i in range(1, 13)],
                                    id="order",
                                    value="1",
                                ),
                            ]
                        ),
                    ]
                ),
                
            ]
        )
        
    if mode == "Unconditional":
        runButton = dbc.Button("Run RFdiffusion", id="run-button-unconditional", color="primary"),
    elif mode == "Binder Design":
        runButton = dbc.Button("Run RFdiffusion", id="run-button-binder-design", color="primary"),
    elif mode == "Motif Scaffolding":
        runButton = dbc.Button("Run RFdiffusion", id="run-button-motif-scaffolding", color="primary"),
    elif mode == "Partial Diffusion":
        runButton = dbc.Button("Run RFdiffusion", id="run-button-partial-diffusion", color="primary"),

    return base_fields,runButton

@app.callback(
    [Output("pdb-analysis-output", "children"), 
     Output("chain-list", "data")],
    [Input("analyze-pdb", "n_clicks")],
    [State("pdb", "value")],
    prevent_initial_call=True,
)
def analyze_pdb(n_clicks, pdb_path):
    
    if not n_clicks or not pdb_path:
        return "", []
    try:
        if pdb_path.lower().endswith(".cif"):
            parser = MMCIFParser()
        else:
            parser = PDBParser()

        structure = parser.get_structure("structure", pdb_path)

        # Collect statistics
        num_models = len(structure)
        chains = list(structure.get_chains())
        num_chains = len(chains)
        chain_ids = [chain.id for chain in chains]

        total_residues = 0
        chain_residues = {}

        for chain in chains:
            residues = list(chain.get_residues())
            chain_residues[chain.id] = len(residues)
            total_residues += len(residues)

        # Create chain details listing for each chain and number of residues
        chain_details = []
        for chain_id, res_count in chain_residues.items():
            chain_details.append({"id": chain_id, "residues": res_count})

        analysis_result = [
            html.H5("PDB Analysis"),
            html.P(
                f"Models: {num_models} Chains: {num_chains} ({', '.join(chain_ids)}) Total residues: {total_residues}"
            ),
        ]
        # Generate a 3D visualization of the protein structure with rainbow coloring
        chain_vis = []
        for chain_id, res_count in chain_residues.items():
            # Get all CA atoms for this chain
            ca_atoms = [atom for residue in structure[0][chain_id] for atom in residue if atom.name == 'CA']
            
            if ca_atoms:
                # Extract coordinates
                x_coords = [atom.coord[0] for atom in ca_atoms]
                y_coords = [atom.coord[1] for atom in ca_atoms]
                z_coords = [atom.coord[2] for atom in ca_atoms]
                
                # Create rainbow color scale for the trace
                colors = [f'rgb({int(255*i/len(ca_atoms))}, {int(255*(1-i/len(ca_atoms)))}, {int(255*abs(0.5-i/len(ca_atoms))*2)})' 
                          for i in range(len(ca_atoms))]
                
                # Add line trace
                chain_vis.append(
                    go.Scatter3d(
                        x=x_coords,
                        y=y_coords,
                        z=z_coords,
                        mode='lines+markers',
                        name=f'Chain {chain_id}',
                        line=dict(color=colors, width=5),
                        marker=dict(size=3, color=colors),
                    )
                )

        # Create the figure with the traces
        fig = go.Figure(data=chain_vis)
        fig.update_layout(
            scene=dict(
                xaxis=dict(showbackground=False),
                yaxis=dict(showbackground=False),
                zaxis=dict(showbackground=False),
            ),
            margin=dict(l=0, r=0, b=0, t=30),
            height=500,
        )

        # Add the visualization to the analysis results
        if chain_vis:
            analysis_result.extend([
                html.H5("Structure Visualization"),
                dcc.Graph(figure=fig),
                html.Hr(),
                html.H5("Chain Details"),
                html.Ul([
                    html.Li(f"Chain {chain['id']}: {chain['residues']} residues") 
                    for chain in chain_details
                ])
            ])
        return analysis_result, chain_details
    except Exception as e:
        return (
            html.Div(
                [html.P("Error analyzing PDB file:"), html.Pre(str(e))],
                style={"color": "red"},
            ),
            [],
        )

#create the run button callbacks for each mode
#motif scaffolding
@app.callback(
    Input("run-button-motif-scaffolding", "n_clicks"),
    State("contigs-preview", "children"), 
    State("pdb", "value"),
    State("iterations", "value"),
    State("num-designs", "value"),
    State("visualization", "value"),
    State("use-beta-model", "value"),
)
def run_motif_scaffolding(n_clicks, contigs, pdb, iterations, num_designs, visualization, use_beta_model):
    if not n_clicks:
        return ""

    lines = [
        f"Mode = Motif Scaffolding",
        f"Contigs = {contigs}",
        f"PDB = {pdb}",
        f"Iterations = {iterations}",
        f"Number of Designs = {num_designs}",
        f"Visualization = {visualization}",
        f"Use Beta Model = {use_beta_model}",
    ]
    print(f"Motif Scaffolding run with parameters: {lines}")
    
#unconditional
@app.callback(
    Input("run-button-unconditional", "n_clicks"),
    State("contigs", "value"),
    State("symmetry", "value"),
    State("order", "value"),
    State("iterations", "value"),
    State("num-designs", "value"),
    State("visualization", "value"),
    State("use-beta-model", "value"),
)
def run_unconditional(n_clicks, contigs, symmetry, order, iterations, num_designs, visualization, use_beta_model):
    if not n_clicks:
        return ""

    lines = [
        f"Mode = Unconditional",
        f"Contigs = {contigs}",
        f"Symmetry = {symmetry}",
        f"Order = {order}",
        f"Iterations = {iterations}",
        f"Number of Designs = {num_designs}",
        f"Visualization = {visualization}",
        f"Use Beta Model = {use_beta_model}",
    ]
    print(f"Unconditional run with parameters: {lines}")

#partial diffusion
@app.callback(
    Input("run-button-partial-diffusion", "n_clicks"),
    State("pdb", "value"),
    State("partial_T", "value"),
    State("iterations", "value"),
    State("num-designs", "value"),
    State("visualization", "value"),
    State("use-beta-model", "value"),
)
def run_partial_diffusion(n_clicks, pdb, partial_T, iterations, num_designs, visualization, use_beta_model):
    if not n_clicks:
        return ""

    lines = [
        f"Mode = Partial Diffusion",
        f"PDB = {pdb}",
        f"Partial T = {partial_T}",
        f"Iterations = {iterations}",
        f"Number of Designs = {num_designs}",
        f"Visualization = {visualization}",
        f"Use Beta Model = {use_beta_model}",
    ]
    print(f"Partial Diffusion run with parameters: {lines}")
 


app.run(debug=True)