# Signature Heatmap Analysis

This notebook demonstrates signature heatmap visualization using clustered expression data.

In [None]:
import os

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from notebook_utils import get_api_client, demo_mode_banner, display_error

# Connect to API (demo fallback if unavailable)
api_url = os.environ.get("API_URL") or os.environ.get("AMPRENTA_API_URL")
client, demo_mode = get_api_client(api_url=api_url)

if demo_mode or client is None:
    demo_mode_banner()
    print("API unavailable â€” demo mode enabled.")
else:
    print(f"Connected to {getattr(client, 'api_url', api_url)}")

## Load Signatures

In [None]:
# Load signatures from API (demo fallback if unavailable)
try:
    if demo_mode or client is None:
        signatures = []
    else:
        signatures = client.signatures.list(limit=10)

    print(f"Loaded {len(signatures)} signatures")

    # Display signature names
    for sig in (signatures or [])[:5]:
        print(f"  - {getattr(sig, 'name', '(unnamed)')}")
except Exception as e:
    demo_mode = True
    demo_mode_banner()
    signatures = []
    display_error("Failed to load signatures", details=repr(e))

## Create Expression Matrix

In [None]:
# Create a matrix from real signature components (best-effort)
# Note: If sample-level expression is not available via API, we visualize component weights/directions.

# Load one signature with components
sig = None
if not (demo_mode or client is None) and signatures:
    try:
        sig = client.signatures.get(signatures[0].id)
    except Exception as e:
        display_error("Failed to fetch signature details", details=repr(e))

rows = []
if sig is not None:
    for c in getattr(sig, "components", []) or []:
        fname = getattr(c, "feature_name", None)
        if not fname:
            continue
        weight = getattr(c, "weight", 1.0)
        direction = getattr(c, "direction", None)
        d_str = str(direction).lower() if direction is not None else ""
        if "up" in d_str:
            d_score = 1
        elif "down" in d_str:
            d_score = -1
        else:
            d_score = 0

        rows.append({"feature": fname, "weight": float(weight) if weight is not None else 0.0, "direction": d_score})

# Demo fallback
if not rows:
    rows = [
        {"feature": "BRCA1", "weight": 1.0, "direction": 1},
        {"feature": "TP53", "weight": 0.8, "direction": -1},
        {"feature": "EGFR", "weight": 0.6, "direction": 1},
        {"feature": "MYC", "weight": 0.7, "direction": 1},
        {"feature": "PIK3CA", "weight": 0.5, "direction": -1},
    ]

# DataFrame (features x "channels")
df_expression = pd.DataFrame(rows).set_index("feature")

print(f"Matrix shape: {df_expression.shape}")
print(f"Features: {df_expression.shape[0]}")
print(f"Channels: {df_expression.shape[1]}")
print("\nFirst few rows:")
display(df_expression.head())

## Normalize/Scale Data

In [None]:
from sklearn.preprocessing import StandardScaler

# Option 1: Z-score normalization (standardize across samples)
scaler = StandardScaler()
df_scaled = pd.DataFrame(
    scaler.fit_transform(df_expression.T).T,
    index=df_expression.index,
    columns=df_expression.columns
)

print("Data normalized (Z-score)")
print(f"Mean: {df_scaled.values.mean():.4f}")
print(f"Std: {df_scaled.values.std():.4f}")

## Generate Clustered Heatmap

In [None]:
# Create clustered heatmap using seaborn
plt.figure(figsize=(14, 10))

# Clustermap with row and column clustering
clustered = sns.clustermap(
    df_scaled,
    method='ward',
    metric='euclidean',
    cmap='RdBu_r',
    center=0,
    vmin=-2,
    vmax=2,
    figsize=(14, 10),
    cbar_kws={"label": "Normalized Expression"},
    row_cluster=True,
    col_cluster=True,
    xticklabels=True,
    yticklabels=True
)

plt.suptitle('Signature Expression Heatmap (Clustered)', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## Save Figure

In [None]:
# Save the heatmap figure
output_path = "signature_heatmap.png"
# Uncomment to save:
# clustered.savefig(output_path, dpi=300, bbox_inches='tight')
# print(f"Figure saved to {output_path}")