# HTS Plate Viewer

Interactive 96-well plate visualization for High-Throughput Screening campaigns. Works as both a Jupyter notebook and a Voila web app.


In [None]:
import os
import io
import base64

import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from notebook_utils import get_api_client, demo_mode_banner

# Connect to API (demo fallback if unavailable)
api_url = os.environ.get("API_URL") or os.environ.get("AMPRENTA_API_URL")
client, demo_mode = get_api_client(api_url=api_url)

if demo_mode or client is None:
    demo_mode_banner()
    print("API unavailable — demo mode enabled.")
else:
    print(f"Connected to {getattr(client, 'api_url', api_url)}")


def parse_well(well_pos: str) -> tuple:
    """Parse well position string to row/col indices.

    Args:
        well_pos: Well position string (e.g., "A01", "H12")

    Returns:
        Tuple of (row_index, col_index) where A=0, H=7, col 1-12 -> 0-11
    """
    if not well_pos or len(well_pos) < 2:
        return None, None
    row = ord(well_pos[0].upper()) - ord("A")
    col = int(well_pos[1:]) - 1
    return row, col


def calculate_z_prime(pos_controls: list[float], neg_controls: list[float]) -> float | None:
    """Compute Z' factor given positive/negative control values."""
    if not pos_controls or not neg_controls:
        return None
    mu_pos = np.mean(pos_controls)
    mu_neg = np.mean(neg_controls)
    sd_pos = np.std(pos_controls)
    sd_neg = np.std(neg_controls)
    denom = abs(mu_pos - mu_neg)
    if denom == 0:
        return None
    return 1 - (3 * (sd_pos + sd_neg) / denom)


def render_zprime_badge(z_prime: float | None) -> widgets.HTML:
    """Render a color-coded badge for Z' factor."""
    if z_prime is None:
        color = "#6c757d"
        text_color = "#fff"
        label = "Z' factor: N/A"
    elif z_prime >= 0.5:
        color = "#28a745"
        text_color = "#fff"
        label = f"Z' factor: {z_prime:.2f}"
    elif z_prime >= 0:
        color = "#ffc107"
        text_color = "#212529"
        label = f"Z' factor: {z_prime:.2f}"
    else:
        color = "#dc3545"
        text_color = "#fff"
        label = f"Z' factor: {z_prime:.2f}"
    return widgets.HTML(
        value=(
            f'<span style="display:inline-block; padding:6px 12px; margin-right:12px; '
            f'border-radius:999px; background:{color}; color:{text_color}; '
            f'font-weight:700; font-family:system-ui, -apple-system, sans-serif;">'
            f"{label}</span>"
        )
    )


def make_download_link(label: str, data_url: str, filename: str) -> widgets.HTML:
    """Create a styled download link using a base64 data URL."""
    return widgets.HTML(
        value=(
            f'<a download="{filename}" href="{data_url}" target="_blank" '
            f'style="display:inline-block; margin-right:10px; padding:6px 12px; '
            f'background:#0d6efd; color:#fff; border-radius:4px; text-decoration:none; '
            f'font-weight:600; font-family:system-ui, -apple-system, sans-serif;">'
            f"{label}</a>"
        )
    )


In [None]:
# Load campaigns (demo fallback if API unavailable)
from dataclasses import dataclass


@dataclass
class _DemoCampaign:
    campaign_id: str
    campaign_name: str
    description: str | None = None
    assay_type: str | None = None
    target: str | None = None
    total_wells: int | None = 96
    hit_count: int | None = 12


try:
    if demo_mode or client is None:
        campaigns = [
            _DemoCampaign(
                campaign_id="DEMO-HTS-001",
                campaign_name="Demo HTS Campaign",
                description="Sample campaign shown when API is unavailable.",
                assay_type="biochemical",
                target="DEMO",
            )
        ]
    else:
        campaigns = client.screening.list_campaigns()

    print(f"Loaded {len(campaigns)} campaigns")
except Exception as e:
    campaigns = []
    print(f"Error loading campaigns: {e!r}")

# Create dropdown options
if campaigns:
    campaign_options = [(f"{c.campaign_name} ({c.campaign_id})", c.campaign_id) for c in campaigns]
else:
    campaign_options = [("No campaigns found", None)]


In [None]:
# Create widgets
campaign_dropdown = widgets.Dropdown(
    options=campaign_options if campaigns else [("No campaigns", None)],
    description='Campaign:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

output_area = widgets.Output()

def update_plate_viewer(campaign_id):
    """Update plate visualization when campaign is selected."""
    with output_area:
        clear_output(wait=True)
        
        if not campaign_id:
            print("Please select a campaign.")
            return
        
        try:
            # Fetch campaign details
            if demo_mode or client is None:
                campaign = next((c for c in campaigns if getattr(c, "campaign_id", None) == campaign_id), None)

                # Minimal demo hit objects
                @dataclass
                class _DemoHit:
                    well_position: str
                    compound_id: str | None = None
                    normalized_value: float | None = None
                    raw_value: float | None = None
                    z_score: float | None = None
                    hit_category: str | None = None
                    control_type: str | None = None

                # Create a simple demo plate with a few hits
                hits = []
                rng = np.random.default_rng(42)
                for r in range(8):
                    for c in range(12):
                        well = f"{chr(ord('A') + r)}{c+1:02d}"
                        if rng.random() < 0.12:
                            hits.append(
                                _DemoHit(
                                    well_position=well,
                                    compound_id=f"DEMO-CMP-{r}{c:02d}",
                                    normalized_value=float(rng.normal(0.0, 1.0)),
                                    raw_value=float(rng.uniform(0, 1)),
                                    z_score=float(rng.normal(2.5, 0.6)),
                                    hit_category="hit",
                                )
                            )
            else:
                campaign = client.screening.get_campaign(campaign_id)
                hits = client.screening.get_campaign_hits(campaign_id)
            
            # Display campaign info
            print(f"## {campaign.campaign_name}")
            print(f"**Campaign ID:** {campaign.campaign_id}")
            if campaign.description:
                print(f"**Description:** {campaign.description}")
            if campaign.assay_type:
                print(f"**Assay Type:** {campaign.assay_type}")
            if campaign.target:
                print(f"**Target:** {campaign.target}")
            
            # Create 96-well plate matrix (8 rows x 12 cols)
            plate_matrix = np.full((8, 12), np.nan)
            hit_matrix = np.zeros((8, 12), dtype=bool)
            well_labels = [['' for _ in range(12)] for _ in range(8)]
            
            # Process hits to populate plate
            hit_dict = {h.well_position: h for h in hits if h.well_position}
            
            # Note: This uses only hits. For full plate, need endpoint for all results
            for hit in hits:
                if not hit.well_position:
                    continue
                row, col = parse_well(hit.well_position)
                if row is not None and col is not None and 0 <= row < 8 and 0 <= col < 12:
                    # Use normalized_value if available, fallback to raw_value
                    value = hit.normalized_value if hit.normalized_value is not None else hit.raw_value
                    if value is not None:
                        plate_matrix[row, col] = value
                    hit_matrix[row, col] = True
                    well_labels[row][col] = hit.well_position
            
            # Prepare exports and Z' factor
            pos_controls = [
                h.normalized_value
                for h in hits
                if getattr(h, "control_type", None) == "positive" and h.normalized_value is not None
            ]
            neg_controls = [
                h.normalized_value
                for h in hits
                if getattr(h, "control_type", None) == "negative" and h.normalized_value is not None
            ]
            z_prime = calculate_z_prime(pos_controls, neg_controls) if pos_controls and neg_controls else None

            all_hit_rows = []
            for hit in hits:
                all_hit_rows.append({
                    'Well': hit.well_position or 'N/A',
                    'Compound ID': hit.compound_id or 'N/A',
                    'Normalized Value': hit.normalized_value if hit.normalized_value is not None else 'N/A',
                    'Raw Value': hit.raw_value if hit.raw_value is not None else 'N/A',
                    'Z-Score': hit.z_score if hit.z_score is not None else 'N/A',
                    'Category': hit.hit_category or 'N/A',
                    'Control Type': getattr(hit, "control_type", None) or 'N/A',
                })
            df_all_hits = pd.DataFrame(all_hit_rows)
            
            # Create heatmap
            fig, ax = plt.subplots(figsize=(14, 8))
            
            # Create custom colormap
            cmap = "rocket_r"
            
            # Plot heatmap
            sns.heatmap(
                plate_matrix,
                annot=False,
                fmt='.2f',
                cmap=cmap,
                cbar_kws={'label': 'Value'},
                ax=ax,
                vmin=np.nanmin(plate_matrix) if not np.isnan(plate_matrix).all() else 0,
                vmax=np.nanmax(plate_matrix) if not np.isnan(plate_matrix).all() else 1,
                linewidths=0.5,
                linecolor='gray'
            )
            
            # Overlay hit borders
            for row in range(8):
                for col in range(12):
                    if hit_matrix[row, col]:
                        # Draw border around hit wells
                        rect = plt.Rectangle((col, row), 1, 1, 
                                            fill=False, edgecolor='red', linewidth=3)
                        ax.add_patch(rect)
            
            # Set row and column labels
            ax.set_xticks(np.arange(12) + 0.5)
            ax.set_xticklabels([str(i+1) for i in range(12)])
            ax.set_yticks(np.arange(8) + 0.5)
            ax.set_yticklabels([chr(ord('A') + i) for i in range(8)])
            ax.set_xlabel('Column')
            ax.set_ylabel('Row')
            ax.set_title(f'HTS Plate: {campaign.campaign_name}')
            
            # Export controls (PNG + CSV) and Z' badge
            png_link = None
            csv_link = None
            try:
                buf = io.BytesIO()
                fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
                buf.seek(0)
                png_b64 = base64.b64encode(buf.read()).decode("ascii")
                png_data_url = f"data:image/png;base64,{png_b64}"
                png_link = make_download_link(
                    "⬇️ Export Heatmap (PNG)",
                    png_data_url,
                    f"hts_plate_{campaign.campaign_id}.png",
                )
            except Exception:
                png_link = None

            if not df_all_hits.empty:
                csv_b64 = base64.b64encode(df_all_hits.to_csv(index=False).encode("utf-8")).decode("ascii")
                csv_data_url = f"data:text/csv;base64,{csv_b64}"
                csv_link = make_download_link(
                    "⬇️ Export Hits (CSV)",
                    csv_data_url,
                    f"hts_hits_{campaign.campaign_id}.csv",
                )

            controls = [render_zprime_badge(z_prime)]
            if png_link:
                controls.append(png_link)
            if csv_link:
                controls.append(csv_link)
            display(widgets.HBox(controls))

            plt.tight_layout()
            plt.show()
            plt.close(fig)
            
            # Summary statistics
            total_wells = campaign.total_wells or len(hits)
            hit_count = campaign.hit_count or len(hits)
            hit_rate = (hit_count / total_wells * 100) if total_wells > 0 else 0
            
            print(f"\n## Summary")
            print(f"**Total Wells:** {total_wells}")
            print(f"**Hit Count:** {hit_count}")
            print(f"**Hit Rate:** {hit_rate:.2f}%")
            
            # Top N hits table
            if hits:
                print(f"\n## Top 10 Hits")
                top_hits = sorted(hits, 
                                key=lambda h: h.normalized_value if h.normalized_value is not None else h.raw_value or 0,
                                reverse=True)[:10]
                
                hit_data = []
                for hit in top_hits:
                    hit_data.append({
                        'Well': hit.well_position or 'N/A',
                        'Compound ID': hit.compound_id or 'N/A',
                        'Normalized Value': hit.normalized_value if hit.normalized_value is not None else 'N/A',
                        'Raw Value': hit.raw_value if hit.raw_value is not None else 'N/A',
                        'Z-Score': hit.z_score if hit.z_score is not None else 'N/A',
                        'Category': hit.hit_category or 'N/A',
                        'Control Type': getattr(hit, "control_type", None) or 'N/A',
                    })
                
                df_hits = pd.DataFrame(hit_data)
                display(df_hits)
            else:
                print("\n*No hits found for this campaign.*")
            
        except Exception as e:
            print(f"Error loading campaign: {e}")
            import traceback
            traceback.print_exc()

# Link dropdown to update function
campaign_dropdown.observe(
    lambda change: update_plate_viewer(change['new']),
    names='value'
)

# Display widgets
display(campaign_dropdown)
display(output_area)

# Initial update if campaign is selected
if campaign_dropdown.value:
    update_plate_viewer(campaign_dropdown.value)


## Usage

1. Select a campaign from the dropdown above
2. The plate heatmap will display with:
   - Color intensity based on normalized_value (or raw_value)
   - Red borders around hit wells (hit_flag=True)
   - Summary statistics and top hits table
3. To run as a Voila app: `voila hts_plate_viewer.ipynb`

**Note:** Currently displays only hit wells. For full plate visualization, an endpoint to fetch all results (not just hits) is needed.
