# Compound Triage Dashboard

Interactive dashboard for evaluating compounds using Lipinski's Rule of Five. Works as both a Jupyter notebook and a Voila web app.


In [None]:
import os
import base64

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import pandas as pd
import plotly.express as px

from notebook_utils import get_api_client, demo_mode_banner

# Connect to API (demo fallback if unavailable)
api_url = os.environ.get("API_URL") or os.environ.get("AMPRENTA_API_URL")
client, demo_mode = get_api_client(api_url=api_url)

if demo_mode or client is None:
    demo_mode_banner()
    print("API unavailable — demo mode enabled.")
else:
    print(f"Connected to {getattr(client, 'api_url', api_url)}")


In [None]:
# Load compounds (demo fallback if API unavailable)
from dataclasses import dataclass


@dataclass
class _DemoCompound:
    compound_id: str
    molecular_weight: float | None = None
    logp: float | None = None
    hbd_count: int | None = None
    hba_count: int | None = None
    rotatable_bonds: int | None = None


try:
    if demo_mode or client is None:
        compounds = [
            _DemoCompound("DEMO-001", molecular_weight=320.4, logp=2.1, hbd_count=1, hba_count=4, rotatable_bonds=5),
            _DemoCompound("DEMO-002", molecular_weight=520.0, logp=4.7, hbd_count=3, hba_count=9, rotatable_bonds=11),
            _DemoCompound("DEMO-003", molecular_weight=680.2, logp=6.1, hbd_count=2, hba_count=12, rotatable_bonds=16),
        ]
    else:
        compounds = client.compounds.list()

    print(f"Loaded {len(compounds)} compounds")
except Exception as e:
    compounds = []
    print(f"Error loading compounds: {e!r}")


In [None]:
def score_mw(mw):
    """Score molecular weight: green (<500), yellow (500-600), red (>600)"""
    if mw is None:
        return None, "gray"
    if mw < 500:
        return "✓", "green"
    elif mw <= 600:
        return "⚠", "orange"
    else:
        return "✗", "red"

def score_logp(logp):
    """Score LogP: green (≤3), yellow (3-5), red (>5)."""
    if logp is None:
        return None, "gray"
    if logp <= 3:
        return "✓", "green"
    elif 3 < logp <= 5:
        return "⚠", "orange"
    else:
        return "✗", "red"

def score_hbd(hbd):
    """Score HBD: green (<=5), yellow (6-7), red (>7)"""
    if hbd is None:
        return None, "gray"
    if hbd <= 5:
        return "✓", "green"
    elif hbd <= 7:
        return "⚠", "orange"
    else:
        return "✗", "red"

def score_hba(hba):
    """Score HBA: green (<=10), yellow (11-12), red (>12)"""
    if hba is None:
        return None, "gray"
    if hba <= 10:
        return "✓", "green"
    elif hba <= 12:
        return "⚠", "orange"
    else:
        return "✗", "red"

def score_rotatable_bonds(rot_bonds):
    """Score rotatable bonds: green (<=10), yellow (11-15), red (>15)"""
    if rot_bonds is None:
        return None, "gray"
    if rot_bonds <= 10:
        return "✓", "green"
    elif rot_bonds <= 15:
        return "⚠", "orange"
    else:
        return "✗", "red"

def calculate_lipinski_score(compound):
    """Calculate overall Lipinski score (0-5 rules passed)."""
    score = 0
    if compound.molecular_weight and compound.molecular_weight < 500:
        score += 1
    if compound.logp is not None and 0 <= compound.logp <= 5:
        score += 1
    if compound.hbd_count is not None and compound.hbd_count <= 5:
        score += 1
    if compound.hba_count is not None and compound.hba_count <= 10:
        score += 1
    if compound.rotatable_bonds is not None and compound.rotatable_bonds <= 10:
        score += 1
    return score


In [None]:
# Create output widgets
output_area = widgets.Output()
plot_output = widgets.Output()
export_link = widgets.HTML("")

# Determine slider ranges
mw_values = [c.molecular_weight for c in compounds if getattr(c, "molecular_weight", None) is not None]
logp_values = [c.logp for c in compounds if getattr(c, "logp", None) is not None]

mw_min = float(min(mw_values)) if mw_values else 0.0
mw_max = float(max(mw_values)) if mw_values else 800.0
logp_min = float(min(logp_values)) if logp_values else -2.0
logp_max = float(max(logp_values)) if logp_values else 8.0

mw_range_slider = widgets.FloatRangeSlider(
    value=(mw_min, mw_max),
    min=mw_min,
    max=mw_max,
    step=10.0,
    description='MW range:',
    layout=widgets.Layout(width='400px'),
    continuous_update=False,
)

logp_range_slider = widgets.FloatRangeSlider(
    value=(logp_min, logp_max),
    min=logp_min,
    max=logp_max,
    step=0.1,
    description='LogP range:',
    layout=widgets.Layout(width='400px'),
    continuous_update=False,
)

def _filter_compounds():
    mw_low, mw_high = mw_range_slider.value
    logp_low, logp_high = logp_range_slider.value
    filtered = []
    for comp in compounds:
        mw = getattr(comp, "molecular_weight", None)
        logp = getattr(comp, "logp", None)
        if mw is None or logp is None:
            continue
        if mw_low <= mw <= mw_high and logp_low <= logp <= logp_high:
            filtered.append(comp)
    return filtered


def render_triage_table(*_):
    """Render compound triage table with traffic-light scoring and filters."""
    filtered_compounds = _filter_compounds()

    # Update plot
    with plot_output:
        clear_output(wait=True)
        if filtered_compounds:
            df_plot = pd.DataFrame(
                [
                    {
                        "Compound ID": comp.compound_id,
                        "MW": comp.molecular_weight,
                        "LogP": comp.logp,
                        "Lipinski Score": calculate_lipinski_score(comp),
                    }
                    for comp in filtered_compounds
                    if comp.molecular_weight is not None and comp.logp is not None
                ]
            )
            if not df_plot.empty:
                fig = px.scatter(
                    df_plot,
                    x="MW",
                    y="LogP",
                    color="Lipinski Score",
                    hover_data=["Compound ID"],
                    title="Pareto: MW vs LogP",
                )
                fig.update_layout(height=400, margin=dict(l=20, r=20, t=50, b=20))
                display(fig)
            else:
                print("No numeric data available for plot.")
        else:
            print("No compounds match current filters.")

    with output_area:
        clear_output(wait=True)

        if not compounds:
            print("No compounds found.")
            return

        if not filtered_compounds:
            print("No compounds match current filters.")

        # Build triage data
        triage_data = []
        for comp in filtered_compounds:
            mw_symbol, mw_color = score_mw(comp.molecular_weight)
            logp_symbol, logp_color = score_logp(comp.logp)
            hbd_symbol, hbd_color = score_hbd(comp.hbd_count)
            hba_symbol, hba_color = score_hba(comp.hba_count)
            rot_symbol, rot_color = score_rotatable_bonds(comp.rotatable_bonds)
            lipinski_score = calculate_lipinski_score(comp)

            triage_data.append({
                'Compound ID': comp.compound_id or 'N/A',
                'MW': comp.molecular_weight if comp.molecular_weight is not None else 'N/A',
                'MW Score': mw_symbol or 'N/A',
                'LogP': comp.logp if comp.logp is not None else 'N/A',
                'LogP Score': logp_symbol or 'N/A',
                'HBD': comp.hbd_count if comp.hbd_count is not None else 'N/A',
                'HBD Score': hbd_symbol or 'N/A',
                'HBA': comp.hba_count if comp.hba_count is not None else 'N/A',
                'HBA Score': hba_symbol or 'N/A',
                'Rot Bonds': comp.rotatable_bonds if comp.rotatable_bonds is not None else 'N/A',
                'Rot Score': rot_symbol or 'N/A',
                'Lipinski Score': lipinski_score
            })

        # Create DataFrame
        df = pd.DataFrame(triage_data)

        if df.empty:
            print("No data to display.")
        else:
            # Sort by Lipinski score (descending)
            df = df.sort_values('Lipinski Score', ascending=False)

            # Create HTML table with colored cells
            html_rows = []
            html_rows.append('<table border="1" style="border-collapse: collapse; width: 100%;">')
            html_rows.append('<thead><tr>')
            for col in df.columns:
                html_rows.append(f'<th style="padding: 8px; background-color: #f0f0f0;">{col}</th>')
            html_rows.append('</tr></thead><tbody>')

            for idx, row in df.iterrows():
                html_rows.append('<tr>')
                html_rows.append(f'<td style="padding: 8px;">{row["Compound ID"]}</td>')

                # MW column
                mw_val = row['MW']
                mw_sym = row['MW Score']
                mw_color = score_mw(mw_val)[1] if mw_val != 'N/A' else 'gray'
                html_rows.append(f'<td style="padding: 8px;">{mw_val}</td>')
                html_rows.append(f'<td style="padding: 8px; background-color: {mw_color}; text-align: center;">{mw_sym}</td>')

                # LogP column
                logp_val = row['LogP']
                logp_sym = row['LogP Score']
                logp_color = score_logp(logp_val)[1] if logp_val != 'N/A' else 'gray'
                html_rows.append(f'<td style="padding: 8px;">{logp_val}</td>')
                html_rows.append(f'<td style="padding: 8px; background-color: {logp_color}; text-align: center;">{logp_sym}</td>')

                # HBD column
                hbd_val = row['HBD']
                hbd_sym = row['HBD Score']
                hbd_color = score_hbd(hbd_val)[1] if hbd_val != 'N/A' else 'gray'
                html_rows.append(f'<td style="padding: 8px;">{hbd_val}</td>')
                html_rows.append(f'<td style="padding: 8px; background-color: {hbd_color}; text-align: center;">{hbd_sym}</td>')

                # HBA column
                hba_val = row['HBA']
                hba_sym = row['HBA Score']
                hba_color = score_hba(hba_val)[1] if hba_val != 'N/A' else 'gray'
                html_rows.append(f'<td style="padding: 8px;">{hba_val}</td>')
                html_rows.append(f'<td style="padding: 8px; background-color: {hba_color}; text-align: center;">{hba_sym}</td>')

                # Rotatable bonds column
                rot_val = row['Rot Bonds']
                rot_sym = row['Rot Score']
                rot_color = score_rotatable_bonds(rot_val)[1] if rot_val != 'N/A' else 'gray'
                html_rows.append(f'<td style="padding: 8px;">{rot_val}</td>')
                html_rows.append(f'<td style="padding: 8px; background-color: {rot_color}; text-align: center;">{rot_sym}</td>')

                # Lipinski score column
                score_val = row['Lipinski Score']
                score_color = 'green' if score_val >= 4 else 'orange' if score_val >= 3 else 'red'
                html_rows.append(f'<td style="padding: 8px; background-color: {score_color}; text-align: center; font-weight: bold;">{score_val}/5</td>')

                html_rows.append('</tr>')

            html_rows.append('</tbody></table>')

            html_table = ''.join(html_rows)
            display(HTML(html_table))

            # Summary statistics
            total = len(df)
            high_score = len(df[df['Lipinski Score'] >= 4])
            medium_score = len(df[(df['Lipinski Score'] >= 3) & (df['Lipinski Score'] < 4)])
            low_score = len(df[df['Lipinski Score'] < 3])

            print("\n## Summary")
            print(f"**Total Compounds:** {total}")
            print(f"**High Score (≥4/5):** {high_score} ({high_score/total*100:.1f}%)")
            print(f"**Medium Score (3/5):** {medium_score} ({medium_score/total*100:.1f}%)")
            print(f"**Low Score (<3/5):** {low_score} ({low_score/total*100:.1f}%)")

        # Update export link
        if not filtered_compounds:
            export_link.value = ""
        else:
            csv_data = pd.DataFrame([
                {
                    "compound_id": comp.compound_id,
                    "molecular_weight": comp.molecular_weight,
                    "logp": comp.logp,
                    "hbd_count": comp.hbd_count,
                    "hba_count": comp.hba_count,
                    "rotatable_bonds": comp.rotatable_bonds,
                    "lipinski_score": calculate_lipinski_score(comp),
                }
                for comp in filtered_compounds
            ]).to_csv(index=False)
            csv_b64 = base64.b64encode(csv_data.encode('utf-8')).decode('ascii')
            export_link.value = (
                f'<a download="filtered_compounds.csv" '
                f'href="data:text/csv;base64,{csv_b64}" '
                f'style="display:inline-block; margin-top:6px; padding:6px 12px; '
                f'background:#0d6efd; color:#fff; border-radius:4px; text-decoration:none; '
                f'font-weight:600; font-family:system-ui, -apple-system, sans-serif;">'
                f'⬇️ Export Filtered CSV</a>'
            )

# Wire filters
def _on_filter_change(change):
    if change.get('name') == 'value':
        render_triage_table()

mw_range_slider.observe(_on_filter_change, names='value')
logp_range_slider.observe(_on_filter_change, names='value')

# Display UI
display(widgets.HBox([mw_range_slider, logp_range_slider]))
display(export_link)
display(plot_output)
display(output_area)
render_triage_table()
