In [None]:
!pip install --force-reinstall transformers datasets evaluate scikit-learn accelerate --no-build-isolation

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting accelerate
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting filelock (from transformers)
  Downloading filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting numpy>=1.17 (from transformers)
  Downloading numpy-2.3.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [9

In [None]:
!pip uninstall -y torch torchvision torchaudio
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

Found existing installation: torch 2.9.1
Uninstalling torch-2.9.1:
  Successfully uninstalled torch-2.9.1
Found existing installation: torchvision 0.24.0+cu126
Uninstalling torchvision-0.24.0+cu126:
  Successfully uninstalled torchvision-0.24.0+cu126
Found existing installation: torchaudio 2.9.0+cu126
Uninstalling torchaudio-2.9.0+cu126:
  Successfully uninstalled torchaudio-2.9.0+cu126
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.4.0
  Downloading https://download.pytorch.org/whl/cu121/torch-2.4.0%2Bcu121-cp312-cp312-linux_x86_64.whl (799.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m799.0/799.0 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.0
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.0%2Bcu121-cp312-cp312-linux_x86_64.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m114.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting t

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16")
backbone = AutoModel.from_pretrained("kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16").to(device)

class BinaryDiseaseClassifier(nn.Module):
    def ___init___(self, backbone, hidden_size=512):
        super().___init___()
        self.backbone = backbone
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.LayerNorm(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 4, 2)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids)
        hidden = outputs.last_hidden_state if hasattr(outputs, 'last_hidden_state') else outputs[0]
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden.size()).float()
        pooled = torch.sum(hidden * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
        return self.classifier(self.dropout(pooled))

model = BinaryDiseaseClassifier(backbone).to(device)
model.load_state_dict(torch.load('caduceus_binary_final.pth', map_location=device))
model.eval()

def predict(sequence):
    encoding = tokenizer(sequence[:512], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = (input_ids != 0).long()

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        prob = torch.softmax(logits, dim=1)[0, 1].item()
        pred = 1 if prob > 0.5 else 0
    return "POSITIVE" if pred == 1 else "NEGATIVE", f"{prob*100:.1f}%"

sequence = input("Enter DNA sequence: ")
result, confidence = predict(sequence)
print(f"\nResult: {result} (Confidence: {confidence})")

KeyboardInterrupt: Interrupted by user

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(
    "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
)

backbone = AutoModel.from_pretrained(
    "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
).to(device)


# ======================================================
# 🔥 FIXED CLASSIFIER (MATCHES CHECKPOINT SHAPES EXACTLY)
# ======================================================
class BinaryDiseaseClassifier(nn.Module):
    def ___init___(self, backbone):
        super().___init___()
        self.backbone = backbone
        self.dropout = nn.Dropout(0.3)

        self.classifier = nn.Sequential(
            nn.Linear(512, 256),     
            nn.LayerNorm(256),        
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(256, 128),      
            nn.LayerNorm(128),        
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(128, 2)         
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids)
        hidden = outputs.last_hidden_state if hasattr(outputs, 'last_hidden_state') else outputs[0]

        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden.size()).float()
        pooled = torch.sum(hidden * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)

        return self.classifier(self.dropout(pooled))


model = BinaryDiseaseClassifier(backbone).to(device)
model.load_state_dict(torch.load("caduceus_binary_final.pth", map_location=device))
model.eval()



def predict(sequence):
    encoding = tokenizer(
        sequence[:512],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    input_ids = encoding["input_ids"].to(device)
    attention_mask = (input_ids != 0).long()

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        prob = torch.softmax(logits, dim=1)[0, 1].item()

    pred = "POSITIVE" if prob > 0.5 else "NEGATIVE"
    return pred, f"{prob * 10000:.1f}%"


sequence = input("Enter DNA sequence: ")
result, confidence = predict(sequence)
print(f"\nResult: {result} (Confidence: {confidence})")


The repository kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 .
 You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
The repository kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 .
 You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do

  model.load_state_dict(torch.load("caduceus_binary_final.pth", map_location=device))


Enter DNA sequence: GCTTCACGTGTACCATGTTCCCGGCGGCCTCCTCGAAGGGCCTGTGCGGCTGCCGGCCCAGCTCCCGCAGGCTGCACAGCTTGGGCAGCCAGGTCCACGAG

Result: NEGATIVE (Confidence: 49.1%)


In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16")
backbone = AutoModel.from_pretrained("kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16").to(device)

checkpoint = torch.load('dna_disease_classifier_final.pth', map_location=device)
num_labels = checkpoint['model_config']['num_labels']
actual_hidden_size = checkpoint['model_config']['actual_hidden_size']
backbone_call_method = checkpoint['model_config']['backbone_call_method']

class DiseaseClassifier(nn.Module):
    def ___init___(self, backbone, num_labels, actual_hidden_size, backbone_call_method, dropout_rate=0.3):
        super().___init___()
        self.backbone = backbone
        self.backbone_call_method = backbone_call_method
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Sequential(
            nn.Linear(actual_hidden_size, actual_hidden_size // 2),
            nn.LayerNorm(actual_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(actual_hidden_size // 2, actual_hidden_size // 4),
            nn.LayerNorm(actual_hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(actual_hidden_size // 4, num_labels)
        )

    def forward(self, input_ids, attention_mask=None):
        if self.backbone_call_method == "keyword_args":
            outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        elif self.backbone_call_method == "input_ids_only":
            outputs = self.backbone(input_ids)
        else:
            outputs = self.backbone(input_ids, attention_mask)

        if hasattr(outputs, 'last_hidden_state'):
            hidden = outputs.last_hidden_state
        elif isinstance(outputs, tuple):
            hidden = outputs[0]
        else:
            hidden = outputs

        if attention_mask is not None and self.backbone_call_method != "input_ids_only":
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden.size()).float()
            pooled = torch.sum(hidden * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
        else:
            pooled = hidden.mean(dim=1)

        return self.classifier(self.dropout(pooled))

model = DiseaseClassifier(backbone, num_labels, actual_hidden_size, backbone_call_method).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

def predict(sequence, threshold=0.5):
    encoding = tokenizer(sequence[:512], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
    input_ids = encoding['input_ids'].to(device)
    attention_mask = (input_ids != 0).long()

    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits)[0].cpu().numpy()

    return (probs >= threshold).astype(int)

sequence = input("Enter DNA sequence: ")
predictions = predict(sequence)
print(f"\nPredictions: {predictions}")
print(f"Detected diseases at positions: {[i for i, p in enumerate(predictions) if p == 1]}")

The repository kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 .
 You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y
The repository kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16 .
 You can inspect the repository content at https://hf.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do

  checkpoint = torch.load('dna_disease_classifier_final.pth', map_location=device)


Enter DNA sequence: TAGCATGGAAACAGTTAAACTGAAGCTTTCTTCTCCTTATAGGTTGCCATCTTTTCTTGATCTCTGCAATAGCTTTCCCTGGATTCAGACCCTTGAAAAAA

Predictions: [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Detected diseases at positions: [0]


In [None]:
!pip install --upgrade --force-reinstall "numpy<2.0" numba shap==0.45.0


Collecting numpy<2.0
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numba
  Downloading numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting shap==0.45.0
  Downloading shap-0.45.0-cp312-cp312-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)
Collecting scipy (from shap==0.45.0)
  Using cached scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (62 kB)
Collecting scikit-learn (from shap==0.45.0)
  Using cached scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting pandas (from shap==0.45.0)
  Using cache

In [None]:
import shap
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
from tqdm import tqdm
import seaborn as sns

# Create output directory
os.makedirs("shap_outputs_multilabel", exist_ok=True)

class MultiLabelGenomicShapExplainer:
    """Compute SHAP values for multi-label genomic disease prediction"""

    def ___init___(self, model, tokenizer, device, disease_labels, max_length=512):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.disease_labels = disease_labels
        self.num_labels = len(disease_labels)
        self.max_length = max_length
        self.nucleotides = ['A', 'T', 'C', 'G']

    def predict_proba(self, sequence):
        """Get model prediction probabilities for all labels"""
        encoding = self.tokenizer(
            sequence[:self.max_length],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = (input_ids != 0).long()

        with torch.no_grad():
            logits = self.model(input_ids, attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy()[0]

        return probs

    def compute_shap_values_multilabel(self, sequence, label_indices=None):
        """
        Compute SHAP values for multiple disease labels

        Args:
            sequence: DNA sequence string
            label_indices: List of label indices to compute SHAP for

        Returns:
            shap_values_dict: Dictionary mapping label_idx -> shap_values array
            analyzed_seq: The sequence that was analyzed
            baseline_probs: Baseline probabilities for all labels
        """
        seq = sequence[:self.max_length]
        seq_length = min(len(seq), self.max_length)

        # Get baseline prediction for all labels
        baseline_probs = self.predict_proba(seq)

        if label_indices is None:
            label_indices = [i for i in range(self.num_labels) if baseline_probs[i] > 0.0]

        print(f"\n Computing SHAP values for {len(label_indices)} disease labels...")
        print(f"   Sequence length: {seq_length} nucleotides")

        # Dictionary to store SHAP values for each label
        shap_values_dict = {}

        # Compute SHAP values for each label
        for label_idx in label_indices:
            label_name = self.disease_labels[label_idx]
            baseline_prob = baseline_probs[label_idx]

            print(f"\n Processing: {label_name} (prob: {baseline_prob:.4f})")

            # Initialize SHAP values for this label
            shap_values = np.zeros(seq_length)

            for i in tqdm(range(seq_length), desc=f"  Computing SHAP", leave=False):
                original_nt = seq[i]

                if original_nt not in self.nucleotides:
                    continue

                mutation_effects = []

                for mutant_nt in self.nucleotides:
                    if mutant_nt == original_nt:
                        continue

                    mutated_seq = seq[:i] + mutant_nt + seq[i+1:]
                    mutated_probs = self.predict_proba(mutated_seq)
                    mutated_prob = mutated_probs[label_idx]

                    # Effect: how much does mutation change this label's probability
                    effect = baseline_prob - mutated_prob
                    mutation_effects.append(effect)

                if mutation_effects:
                    shap_values[i] = np.mean(mutation_effects)

            shap_values_dict[label_idx] = shap_values
            print(f"   Completed: Mean |SHAP| = {np.mean(np.abs(shap_values)):.6f}")

        return shap_values_dict, seq, baseline_probs


def create_multilabel_visualizations(explainer, shap_values_dict, analyzed_seq,
                                     baseline_probs, prob_threshold=0.5):
    """
    Create comprehensive visualizations for multi-label SHAP analysis
    Optimized for ~100bp sequences with 1-4 diseases
    """

    seq_length = len(analyzed_seq)
    feature_names = [f"Pos{i}_{nt}" for i, nt in enumerate(analyzed_seq)]

    # Get diseases sorted by probability
    sorted_indices = np.argsort(baseline_probs)[::-1]
    analyzed_diseases = [(idx, explainer.disease_labels[idx], baseline_probs[idx])
                        for idx in sorted_indices
                        if baseline_probs[idx] > prob_threshold and idx in shap_values_dict]

    num_diseases = len(analyzed_diseases)

    print("\n" + "="*80)
    print(f" Creating visualizations for {num_diseases} disease(s)")
    print("="*80)

    if num_diseases == 0:
        print("⚠️  No diseases to visualize!")
        return

    # =====================================================================
    # PLOT 1: Multi-Label Heatmap
    # =====================================================================
    print("\n PLOT 1: Multi-Label SHAP Heatmap")

    shap_matrix = np.array([shap_values_dict[idx] for idx, _, _ in analyzed_diseases])
    disease_names = [f"{name[:40]}" for _, name, _ in analyzed_diseases]

    # Optimized for 100bp sequences
    fig_width = 20
    fig_height = max(4, num_diseases * 1.2)

    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    vmax = np.max(np.abs(shap_matrix))
    im = ax.imshow(shap_matrix, cmap='RdBu_r', aspect='auto', vmin=-vmax, vmax=vmax)

    # Set ticks for 100bp sequence
    tick_step = 5
    ax.set_xticks(np.arange(0, seq_length, tick_step))
    ax.set_xticklabels([f"{i}" for i in range(0, seq_length, tick_step)], fontsize=8)
    ax.set_yticks(np.arange(num_diseases))
    ax.set_yticklabels(disease_names, fontsize=10)

    # Colorbar
    cbar = plt.colorbar(im, ax=ax, pad=0.02)
    cbar.set_label('SHAP Value', fontsize=11, fontweight='bold')

    ax.set_xlabel('Nucleotide Position', fontsize=11, fontweight='bold')
    ax.set_ylabel('Disease', fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig('shap_outputs_multilabel/1_multilabel_heatmap.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    print(" Saved: 1_multilabel_heatmap.png")

    
    print("\n PLOT 2: Aggregated SHAP Values")

    aggregated_shap = np.zeros(seq_length)
    for idx, _, _ in analyzed_diseases:
        aggregated_shap += shap_values_dict[idx]

    fig, ax = plt.subplots(figsize=(20, 5))
    colors = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in aggregated_shap]
    bars = ax.bar(range(seq_length), aggregated_shap, color=colors, alpha=0.8,
                   edgecolor='black', linewidth=0.5)

    ax.axhline(y=0, color='black', linestyle='-', linewidth=1.5)
    ax.set_xlabel('Nucleotide Position', fontsize=11, fontweight='bold')
    ax.set_ylabel('Aggregated SHAP Value', fontsize=11, fontweight='bold')
    ax.grid(axis='y', alpha=0.3, linestyle='--')

    # Ticks every 5 positions for 100bp
    ax.set_xticks(range(0, seq_length, 5))
    ax.set_xticklabels(range(0, seq_length, 5), fontsize=8)

    plt.tight_layout()
    plt.savefig('shap_outputs_multilabel/2_aggregated_shap.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    print(" Saved: 2_aggregated_shap.png")


    print(f"\n📊 PLOT 3: Individual Waterfall Plots ({num_diseases} disease(s))")

    for rank, (idx, name, prob) in enumerate(analyzed_diseases, 1):
        shap_values = shap_values_dict[idx]

        shap_explanation = shap.Explanation(
            values=shap_values,
            base_values=prob,
            data=np.array(list(analyzed_seq)),
            feature_names=np.array(feature_names)
        )

        fig = plt.figure(figsize=(10, 8))
        shap.plots.waterfall(shap_explanation, max_display=20, show=False)

        # Remove default title and add custom
        ax = plt.gca()
        ax.set_title(f"{name} (p={prob:.4f})", fontsize=11, fontweight='bold', pad=10)

        plt.tight_layout()

        safe_name = name[:30].replace(" ", "_").replace("/", "-")
        plt.savefig(f'shap_outputs_multilabel/3_{rank}_waterfall_{safe_name}.png',
                    dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()
        print(f"   Saved: 3_{rank}_waterfall_{safe_name}.png")

   
    if num_diseases > 1:
        print("\n PLOT 4: Disease Comparison")

        avg_abs_shap = [np.mean(np.abs(shap_values_dict[idx])) for idx, _, _ in analyzed_diseases]
        disease_names_short = [name[:35] for _, name, _ in analyzed_diseases]

        fig, ax = plt.subplots(figsize=(10, max(4, num_diseases * 0.8)))
        colors_bar = plt.cm.viridis(np.linspace(0.3, 0.9, num_diseases))
        bars = ax.barh(disease_names_short, avg_abs_shap, color=colors_bar,
                        edgecolor='black', linewidth=0.8)

        ax.set_xlabel('Average |SHAP Value|', fontsize=11, fontweight='bold')
        ax.set_ylabel('Disease', fontsize=11, fontweight='bold')
        ax.grid(axis='x', alpha=0.3, linestyle='--')

        # Add value labels
        for i, (bar, val) in enumerate(zip(bars, avg_abs_shap)):
            ax.text(val, i, f' {val:.6f}', va='center', fontsize=9)

        plt.tight_layout()
        plt.savefig('shap_outputs_multilabel/4_disease_comparison.png', dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()
        print(" Saved: 4_disease_comparison.png")
    else:
        print("\n⏭  PLOT 4: Skipped (only 1 disease)")


    print("\n PLOT 5: Genomic Multi-Label Overlay")

    nt_colors = {'A': '#00CC00', 'T': '#FF0000', 'G': '#FFB300', 'C': '#0000FF', 'N': '#808080'}

    fig_height = max(6, 2.5 + num_diseases * 1.3)

    fig, axes = plt.subplots(num_diseases + 1, 1, figsize=(20, fig_height),
                             gridspec_kw={'height_ratios': [0.8] + [1]*num_diseases})

    # Ensure axes is always iterable
    if num_diseases == 1:
        axes = [axes[0], axes[1]]

    # Top: DNA Sequence
    ax_seq = axes[0]
    base_fontsize = 9  # Fixed for 100bp

    for i, nt in enumerate(analyzed_seq):
        color = nt_colors.get(nt, '#808080')
        importance = np.abs(aggregated_shap[i])
        fontsize = base_fontsize + min(importance * 25, 5)
        alpha = 0.6 + min(importance * 2, 0.4)

        ax_seq.text(i, 0, nt, fontsize=fontsize, ha='center', va='center',
                   color=color, fontweight='bold', alpha=alpha)

    ax_seq.set_xlim(-1, seq_length)
    ax_seq.set_ylim(-0.5, 0.5)
    ax_seq.axis('off')

    # Individual disease SHAP plots
    for ax_idx, (label_idx, name, prob) in enumerate(analyzed_diseases):
        ax = axes[ax_idx + 1]
        shap_vals = shap_values_dict[label_idx]
        colors_plot = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in shap_vals]

        ax.bar(range(seq_length), shap_vals, color=colors_plot, alpha=0.8,
               edgecolor='black', linewidth=0.3)
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
        ax.set_ylabel('SHAP', fontsize=9, fontweight='bold')
        ax.text(0.01, 0.95, f'{name[:45]} (p={prob:.4f})',
                transform=ax.transAxes, fontsize=9, va='top', fontweight='bold')
        ax.grid(axis='y', alpha=0.2, linestyle='--')
        ax.set_xlim(-1, seq_length)

        if ax_idx < num_diseases - 1:
            ax.set_xticks([])
        else:
            ax.set_xticks(range(0, seq_length, 5))
            ax.set_xticklabels(range(0, seq_length, 5), fontsize=8)
            ax.set_xlabel('Nucleotide Position', fontsize=10, fontweight='bold')

    plt.tight_layout()
    plt.savefig('shap_outputs_multilabel/5_genomic_multilabel.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    print(" Saved: 5_genomic_multilabel.png")

  
    print("\n PLOT 6: Contribution Matrix")

    contribution_matrix = np.abs(shap_matrix)

    fig, ax = plt.subplots(figsize=(20, max(4, num_diseases * 1.2)))

    im = ax.imshow(contribution_matrix, cmap='YlOrRd', aspect='auto')

    # Ticks
    ax.set_xticks(np.arange(0, seq_length, 5))
    ax.set_xticklabels([f"{i}" for i in range(0, seq_length, 5)], fontsize=8)
    ax.set_yticks(np.arange(num_diseases))
    ax.set_yticklabels(disease_names, fontsize=10)

    # Colorbar
    cbar = plt.colorbar(im, ax=ax, pad=0.02)
    cbar.set_label('|SHAP Value|', fontsize=11, fontweight='bold')

    ax.set_xlabel('Nucleotide Position', fontsize=11, fontweight='bold')
    ax.set_ylabel('Disease', fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig('shap_outputs_multilabel/6_contribution_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    print(" Saved: 6_contribution_matrix.png")

   
    print(f"\n PLOT 7: Individual Force Plots ({num_diseases} disease(s))")

    for rank, (idx, name, prob) in enumerate(analyzed_diseases, 1):
        shap_values = shap_values_dict[idx]

        shap_explanation = shap.Explanation(
            values=shap_values,
            base_values=prob,
            data=np.array(list(analyzed_seq)),
            feature_names=np.array(feature_names)
        )

        fig = plt.figure(figsize=(20, 3))
        shap.plots.force(shap_explanation, matplotlib=True, show=False)

        # Add custom title
        ax = plt.gca()
        ax.set_title(f"{name} (base={prob:.4f})", fontsize=10, fontweight='bold', pad=8)

        plt.tight_layout()

        safe_name = name[:30].replace(" ", "_").replace("/", "-")
        plt.savefig(f'shap_outputs_multilabel/7_{rank}_force_{safe_name}.png',
                    dpi=300, bbox_inches='tight')
        plt.show()
        plt.close()
        print(f"   Saved: 7_{rank}_force_{safe_name}.png")

    print("\n" + "="*80)
    print(" All visualizations completed!")
    print("="*80)



print("\n" + "="*80)
print("MULTI-LABEL SHAP ANALYSIS FOR GENOMIC DISEASE PREDICTION")
print("="*80)

# Initialize explainer
explainer = MultiLabelGenomicShapExplainer(
    model=model,
    tokenizer=tokenizer,
    device=device,
    disease_labels=disease_labels,
    max_length=512
)

# Disease selection
PROBABILITY_THRESHOLD = 0.5

baseline_probs_initial = explainer.predict_proba(sequence)
sorted_indices = np.argsort(baseline_probs_initial)[::-1]

high_prob_indices = [idx for idx in range(len(baseline_probs_initial))
                     if baseline_probs_initial[idx] > PROBABILITY_THRESHOLD]

print(f"\n Disease Selection (Threshold = {PROBABILITY_THRESHOLD}):")
print("="*80)

if len(high_prob_indices) > 0:
    print(f" Found {len(high_prob_indices)} disease(s) with probability > {PROBABILITY_THRESHOLD}:")
    for rank, idx in enumerate(sorted(high_prob_indices, key=lambda x: baseline_probs_initial[x], reverse=True), 1):
        print(f"  {rank}. {disease_labels[idx][:65]}: {baseline_probs_initial[idx]:.6f}")
else:
    print(f"  No diseases found with probability > {PROBABILITY_THRESHOLD}")
    print(f" Falling back to top 3 diseases:")
    high_prob_indices = sorted_indices[:3].tolist()
    for rank, idx in enumerate(high_prob_indices, 1):
        print(f"  {rank}. {disease_labels[idx][:65]}: {baseline_probs_initial[idx]:.6f}")

print("="*80)

# Compute SHAP
shap_values_dict, analyzed_seq, baseline_probs = explainer.compute_shap_values_multilabel(
    sequence=sequence,
    label_indices=high_prob_indices
)

# Create visualizations
create_multilabel_visualizations(
    explainer=explainer,
    shap_values_dict=shap_values_dict,
    analyzed_seq=analyzed_seq,
    baseline_probs=baseline_probs,
    prob_threshold=PROBABILITY_THRESHOLD
)

print("\n" + "="*80)
print(" MULTI-LABEL SHAP ANALYSIS SUMMARY")
print("="*80)

seq_length = len(analyzed_seq)
sorted_indices_final = np.argsort(baseline_probs)[::-1]

print(f"""
Sequence Statistics:
  • Length: {seq_length} bp
  • A: {analyzed_seq.count('A')}  T: {analyzed_seq.count('T')}  G: {analyzed_seq.count('G')}  C: {analyzed_seq.count('C')}

Disease Predictions:
  • Total Labels: {len(disease_labels)}
  • Labels Analyzed (SHAP computed): {len(shap_values_dict)}
  • Labels with prob > 0.5: {sum(baseline_probs > 0.5)}
  • Labels with prob > 0.1: {sum(baseline_probs > 0.1)}
""")

print("Top 15 Disease Predictions:")
print(f"{'Rank':<5} {'Disease':<58} {'Probability':>12} {'Avg |SHAP|':>12}")
print("-" * 90)

for rank, idx in enumerate(sorted_indices_final[:15], 1):
    name = disease_labels[idx]
    prob = baseline_probs[idx]
    avg_shap = np.mean(np.abs(shap_values_dict[idx])) if idx in shap_values_dict else 0.0
    marker = "🎯" if idx in shap_values_dict else "  "
    print(f"{marker} {rank:<3} {name:<58.58} {prob:>12.6f} {avg_shap:>12.6f}")

# Per-disease SHAP statistics
print("\n" + "="*80)
print("Per-Disease SHAP Statistics:")
print("="*80)

for idx in sorted_indices_final:
    if idx not in shap_values_dict:
        continue

    name = disease_labels[idx]
    prob = baseline_probs[idx]
    shap_vals = shap_values_dict[idx]

    top_pos_idx = np.argmax(shap_vals)
    top_neg_idx = np.argmin(shap_vals)

    print(f"\n  {name}")
    print(f"   Probability: {prob:.6f}")
    print(f"   Mean SHAP: {np.mean(shap_vals):.6f}")
    print(f"   Std SHAP: {np.std(shap_vals):.6f}")
    print(f"   Max SHAP: {np.max(shap_vals):.6f} at Pos{top_pos_idx}({analyzed_seq[top_pos_idx]})")
    print(f"   Min SHAP: {np.min(shap_vals):.6f} at Pos{top_neg_idx}({analyzed_seq[top_neg_idx]})")

    # Top 5 important positions
    top_5_idx = np.argsort(np.abs(shap_vals))[-5:][::-1]
    print(f"   Top 5 positions: ", end="")
    print(", ".join([f"Pos{i}({analyzed_seq[i]})={shap_vals[i]:+.6f}" for i in top_5_idx]))

# Regional analysis
print("\n" + "="*80)
print(" Regional Analysis (10bp windows):")
print("="*80)

aggregated_shap_all = np.zeros(seq_length)
for idx in shap_values_dict.keys():
    aggregated_shap_all += shap_values_dict[idx]

window_size = 10
if seq_length >= window_size:
    windowed_importance = np.convolve(
        np.abs(aggregated_shap_all),
        np.ones(window_size) / window_size,
        mode='valid'
    )
    top_region_idx = np.argmax(windowed_importance)
    top_region_end = top_region_idx + window_size

    print(f"  • Most important region: Position {top_region_idx}-{top_region_end}")
    print(f"    Sequence: {analyzed_seq[top_region_idx:top_region_end]}")
    print(f"    Avg. |SHAP|: {windowed_importance[top_region_idx]:.6f}")
else:
    print(f"  • Sequence too short for windowed analysis")

print("\n" + "="*80)
print(" Analysis completed!")
print(" All plots saved to 'shap_outputs_multilabel/'")
print("="*80)

print("""
 GENERATED PLOTS:

1. Multi-Label Heatmap - Disease × Position SHAP matrix
2. Aggregated SHAP - Sum of SHAP across all diseases
3. Individual Waterfall Plots - One per disease (1-4 plots)
4. Disease Comparison - Only if multiple diseases
5. Genomic Multi-Label Overlay - Sequence + disease tracks
6. Contribution Matrix - Absolute SHAP heatmap
7. Individual Force Plots - One per disease (1-4 plots)

Total: 5-9 plots depending on number of diseases analyzed
""")

ImportError: Numba needs NumPy 2.0 or less. Got NumPy 2.3.

In [None]:
!pip install flask flask-cors pyngrok transformers shap tqdm seaborn matplotlib torch


Collecting flask-cors
  Downloading flask_cors-6.0.1-py3-none-any.whl.metadata (5.3 kB)
Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Downloading flask_cors-6.0.1-py3-none-any.whl (13 kB)
Downloading pyngrok-7.5.0-py3-none-any.whl (24 kB)
Installing collected packages: pyngrok, flask-cors
Successfully installed flask-cors-6.0.1 pyngrok-7.5.0


In [None]:

# UNIFIED BACKEND: SINGLE + MULTI PRED + FULL SHAP (Colab + ngrok Compatible)


import os, json, torch, torch.nn as nn, numpy as np
from transformers import AutoModel, AutoTokenizer
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from pyngrok import ngrok
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from tqdm import tqdm

# CONFIG

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
MODEL_ID = "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
MAX_TOKEN_LEN = 512

#  Use absolute paths (important for Colab)
SINGLE_SHAP_DIR = "/content/shap_outputs"
MULTI_SHAP_DIR = "/content/shap_outputs_multilabel"

os.makedirs(SINGLE_SHAP_DIR, exist_ok=True)
os.makedirs(MULTI_SHAP_DIR, exist_ok=True)

# Set ngrok token
ngrok.set_auth_token("34v9U5BH0Pf7ijgQMr6EEeoMuMd_5ZSdwivLjk8BBhQh7GVv1")

# LOAD TOKENIZER + BACKBONE

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)


backbone_single = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device)
backbone_multi  = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device)

# MODEL DEFINITIONS

class BinaryDiseaseClassifier(nn.Module):
    def __init__(self, backbone, hidden_size=512):
        super().__init__()
        self.backbone = backbone
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.LayerNorm(hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 4, 2)
        )


    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids)
        hidden = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden.size()).float()
        pooled = torch.sum(hidden * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
        return self.classifier(self.dropout(pooled))

class DiseaseClassifier(nn.Module):
    def __init__(self, backbone, num_labels, actual_hidden_size):
        super().__init__()
        self.backbone = backbone
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Sequential(
            nn.Linear(actual_hidden_size, actual_hidden_size // 2),
            nn.LayerNorm(actual_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(actual_hidden_size // 2, actual_hidden_size // 4),
            nn.LayerNorm(actual_hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(actual_hidden_size // 4, num_labels)
        )


    def forward(self, input_ids, attention_mask=None):
        outputs = self.backbone(input_ids)
        hidden = outputs.last_hidden_state if hasattr(outputs, "last_hidden_state") else outputs[0]
        pooled = hidden.mean(dim=1)
        return self.classifier(self.dropout(pooled))


# LOAD MODELS + LABELS

print(" Loading weights...")

single_model = BinaryDiseaseClassifier(backbone_single).to(device)
try:
    single_model.load_state_dict(torch.load("caduceus_binary_final.pth", map_location=device, weights_only=True))
except TypeError:
    single_model.load_state_dict(torch.load("caduceus_binary_final.pth", map_location=device))
single_model.eval()

checkpoint = torch.load("dna_disease_classifier_final.pth", map_location=device)
num_labels = checkpoint["model_config"]["num_labels"]
actual_hidden_size = checkpoint["model_config"]["actual_hidden_size"]
label_mappings = checkpoint["label_mappings"]

def load_and_sync_labels(label_mappings, num_labels):
    """
    Extracts clean disease labels from the checkpoint label mappings and
    automatically synchronizes their count with the model's output layer.
    Ensures no IndexError due to label mismatch.
    """
    def extract_clean_labels(label_mappings):
        labels = None
        if "id2label" in label_mappings and isinstance(label_mappings["id2label"], dict):
            labels = [v for k, v in sorted(label_mappings["id2label"].items())]
        elif "label2id" in label_mappings and isinstance(label_mappings["label2id"], dict):
            inv = {v: k for k, v in label_mappings["label2id"].items()}
            labels = [inv[i] for i in sorted(inv.keys())]
        elif "top_labels" in label_mappings:
            labels = label_mappings["top_labels"]
        elif isinstance(label_mappings, (list, tuple)):
            labels = [str(x) for x in label_mappings]
        else:
            print(" Could not interpret label_mappings structure; using default placeholders.")
            labels = []

        # Clean up formatting
        cleaned = []
        for lbl in labels:
            lbl = str(lbl).strip().replace("_", " ").replace("-", " ")
            if len(lbl) > 2:
                cleaned.append(lbl)
        return cleaned

    # Extract and synchronize
    disease_labels = extract_clean_labels(label_mappings)
    original_len = len(disease_labels)

    #  Synchronize with model output
    if original_len != num_labels:
        print(f" Label mismatch detected: model expects {num_labels}, but found {original_len}.")
        if original_len > num_labels:
            # Too many labels → trim extras
            trimmed_labels = disease_labels[num_labels:]
            disease_labels = disease_labels[:num_labels]
            print(f"🔧 Trimmed {len(trimmed_labels)} extra labels: {trimmed_labels}")
        else:
            # Too few labels → pad with placeholders
            missing = num_labels - original_len
            added_labels = [f"Unknown_{i}" for i in range(missing)]
            disease_labels.extend(added_labels)
            print(f"🔧 Added {missing} placeholder labels: {added_labels}")

    print(f" Final label count synchronized: {len(disease_labels)}")
    return disease_labels


disease_labels = load_and_sync_labels(label_mappings, num_labels)


multi_model = DiseaseClassifier(backbone_multi, num_labels, actual_hidden_size).to(device)
multi_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
multi_model.eval()

print(f" Single model loaded")
print(f" Multi model loaded with {len(disease_labels)} labels")


# PREDICT HELPERS

def predict_single(sequence: str):
    encoding = tokenizer(sequence[:MAX_TOKEN_LEN], truncation=True, padding='max_length', max_length=MAX_TOKEN_LEN, return_tensors='pt')
    input_ids = encoding["input_ids"].to(device)
    attention_mask = (input_ids != 0).long()
    with torch.no_grad():
        logits = single_model(input_ids, attention_mask)
        prob = torch.softmax(logits, dim=1)[0, 1].item()
    return {"result": "POSITIVE" if prob > 0.5 else "NEGATIVE", "confidence": f"{prob * 100:.2f}%"}

def predict_multi(sequence: str, threshold=0.7, top_k=15):
    encoding = tokenizer(sequence[:MAX_TOKEN_LEN], truncation=True, padding="max_length", max_length=MAX_TOKEN_LEN, return_tensors="pt")
    input_ids = encoding["input_ids"].to(device)
    attention_mask = (input_ids != 0).long()
    with torch.no_grad():
        logits = multi_model(input_ids, attention_mask)
        probs = torch.sigmoid(logits)[0].cpu().numpy()
    sorted_indices = np.argsort(probs)[::-1]
    sorted_results = [(disease_labels[i], float(probs[i])) for i in sorted_indices[:top_k]]
    detected = [disease_labels[i] for i, p in enumerate(probs) if p >= threshold]
    return {"top_results": sorted_results, "detected": detected}

# SHAP HELPERS

def run_shap_single(sequence: str):
    os.makedirs(SINGLE_SHAP_DIR, exist_ok=True)
    print("\n" + "="*70)
    print("COMPUTING SHAP VALUES (Single-label)")
    print("="*70)

    class GenomicShapExplainer:
        def __init__(self, model, tokenizer, device, max_length=512):
            self.model = model
            self.tokenizer = tokenizer
            self.device = device
            self.max_length = max_length
            self.nucleotides = ['A', 'T', 'C', 'G']

        def predict_proba(self, sequence):
            encoding = self.tokenizer(
                sequence[:self.max_length],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].to(self.device)
            attention_mask = (input_ids != 0).long()
            with torch.no_grad():
                logits = self.model(input_ids, attention_mask)
                probs = torch.softmax(logits, dim=1)
            return probs.cpu().numpy()[0]

        def compute_shap_values(self, sequence):
            seq = sequence[:self.max_length]
            seq_length = min(len(seq), self.max_length)
            baseline_probs = self.predict_proba(seq)
            baseline_disease_prob = baseline_probs[1]
            shap_values = np.zeros(seq_length)
            print(f" Computing SHAP values for {seq_length} nucleotide positions...")
            for i in tqdm(range(seq_length)):
                original_nt = seq[i]
                if original_nt not in self.nucleotides:
                    continue
                mutation_effects = []
                for mutant_nt in self.nucleotides:
                    if mutant_nt == original_nt:
                        continue
                    mutated_seq = seq[:i] + mutant_nt + seq[i+1:]
                    mutated_probs = self.predict_proba(mutated_seq)
                    mutated_disease_prob = mutated_probs[1]
                    effect = baseline_disease_prob - mutated_disease_prob
                    mutation_effects.append(effect)
                if mutation_effects:
                    shap_values[i] = np.mean(mutation_effects)
            return shap_values, seq, baseline_disease_prob

    explainer = GenomicShapExplainer(single_model, tokenizer, device, max_length=MAX_TOKEN_LEN)
    shap_values, analyzed_seq, baseline_prob = explainer.compute_shap_values(sequence)

    feature_names = [f"Pos{i}_{nt}" for i, nt in enumerate(analyzed_seq)]
    seq_length = len(analyzed_seq)

    print("\n Creating SHAP Explanation object...")
    shap_explanation = shap.Explanation(
        values=shap_values,
        base_values=baseline_prob,
        data=np.array(list(analyzed_seq)),
        feature_names=np.array(feature_names)
    )
    shap_explanation_reshaped = shap.Explanation(
        values=shap_values.reshape(1, -1),
        base_values=np.array([baseline_prob]),
        data=np.array([list(analyzed_seq)]),
        feature_names=np.array(feature_names)
    )

    paths = []


    print("\n" + "="*70); print(" PLOT 2: WATERFALL"); print("="*70)
    max_display_waterfall = min(20, seq_length)
    plt.figure(figsize=(10, max(8, max_display_waterfall * 0.4)))
    shap.plots.waterfall(shap_explanation, max_display=max_display_waterfall, show=False)
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/1_waterfall_plot.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 3: FORCE
    print("\n" + "="*70); print(" PLOT 3: FORCE"); print("="*70)
    fig_width = max(12, seq_length * 0.15)
    plt.figure(figsize=(fig_width, 3))
    shap.plots.force(shap_explanation, matplotlib=True, show=False)
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/2_force_plot.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 4: DECISION
    print("\n" + "="*70); print(" PLOT 4: DECISION"); print("="*70)
    fig_height = max(8, seq_length * 0.08)
    plt.figure(figsize=(10, fig_height))
    shap.decision_plot(
        base_value=baseline_prob,
        shap_values=shap_values,
        features=np.array(list(analyzed_seq)),
        feature_names=feature_names,
        highlight=0,
        show=False
    )
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/3_decision_plot.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 5: BAR
    print("\n" + "="*70); print(" PLOT 5: BAR"); print("="*70)
    max_display_bar = min(30, seq_length)
    plt.figure(figsize=(10, max(6, max_display_bar * 0.3)))
    shap.plots.bar(shap_explanation_reshaped, max_display=max_display_bar, show=False)
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/4_bar_plot.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 6: BEESWARM
    print("\n" + "="*70); print(" PLOT 6: BEESWARM"); print("="*70)
    max_display_beeswarm = min(30, seq_length)
    plt.figure(figsize=(10, max(6, max_display_beeswarm * 0.3)))
    shap.plots.beeswarm(shap_explanation_reshaped, max_display=max_display_beeswarm, show=False)
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/5_beeswarm_plot.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 7: HEATMAP
    print("\n" + "="*70); print("PLOT 7: HEATMAP"); print("="*70)
    shap_matrix = shap_values.reshape(1, -1)
    plt.figure(figsize=(max(15, seq_length * 0.2), 3))
    sns.heatmap(
        shap_matrix, cmap='RdBu_r', center=0,
        xticklabels=[f"{i}:{nt}" for i, nt in enumerate(analyzed_seq)],
        yticklabels=['SHAP'], cbar_kws={'label': 'SHAP Value (Disease Risk)'},
        linewidths=0.5, linecolor='gray'
    )
    plt.xticks(rotation=90, fontsize=8); plt.yticks(fontsize=10)
    plt.xlabel('Position:Nucleotide'); plt.title('SHAP Heatmap: Position-wise Contribution')
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/6_heatmap.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # PLOT 8: GENOMIC VISUALIZATION
    print("\n" + "="*70); print(" PLOT 8: GENOMIC VISUALIZATION"); print("="*70)
    nt_colors = {'A': '#00CC00', 'T': '#FF0000', 'G': '#FFB300', 'C': '#0000FF', 'N': '#808080'}
    fig_width_genomic = max(15, seq_length * 0.2)
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width_genomic, 8),
                                   gridspec_kw={'height_ratios': [1, 2]})
    base_fontsize = max(6, min(12, 800 // seq_length))
    for i, nt in enumerate(analyzed_seq):
        color = nt_colors.get(nt, '#808080')
        importance = np.abs(shap_values[i])
        fontsize = base_fontsize + min(importance * 50, base_fontsize * 0.5)
        alpha = 0.5 + min(importance * 3, 0.5)
        ax1.text(i, 0, nt, fontsize=fontsize, ha='center', va='center',
                 color=color, fontweight='bold', alpha=alpha)
    ax1.set_xlim(-1, seq_length); ax1.set_ylim(-0.5, 0.5); ax1.axis('off')
    ax1.set_title('DNA Sequence (size ∝ importance)', fontsize=14, fontweight='bold', pad=10)
    colors = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in shap_values]
    ax2.bar(range(seq_length), shap_values, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=1.5)
    ax2.set_xlabel('Nucleotide Position', fontsize=12, fontweight='bold')
    ax2.set_ylabel('SHAP Value\n(Disease Risk)', fontsize=12, fontweight='bold')
    ax2.set_xlim(-1, seq_length); ax2.grid(axis='y', alpha=0.3, linestyle='--')
    tick_interval = max(5, seq_length // 20)
    ax2.set_xticks(range(0, seq_length, tick_interval))
    ax2.set_xticklabels(range(0, seq_length, tick_interval))
    ax2.set_title('Red = Increases Risk | Teal = Decreases Risk', fontsize=12, style='italic', pad=10)
    plt.tight_layout()
    out = f"{SINGLE_SHAP_DIR}/7_genomic_visualization.png"
    plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

    # SUMMARY (print + JSON)
    print("\n" + "="*70)
    print(" SHAP ANALYSIS SUMMARY (Single-label)")
    print("="*70)
    threshold = 0.001
    n_positive = int(np.sum(shap_values > threshold))
    n_negative = int(np.sum(shap_values < -threshold))
    n_neutral = int(seq_length - n_positive - n_negative)
    top_pos_idx = int(np.argmax(shap_values))
    top_neg_idx = int(np.argmin(shap_values))
    summary_text = f"""
Model Prediction:
  • Disease Probability: {baseline_prob:.2%}
  • Classification: {'Hereditary cancer predisposing syndrome' if baseline_prob > 0.5 else 'HEALTHY No Disease'}
  • Base Value: {baseline_prob:.6f}

Sequence Statistics:
  • Total Length: {seq_length} bp
  • A: {analyzed_seq.count('A')}  T: {analyzed_seq.count('T')}  G: {analyzed_seq.count('G')}  C: {analyzed_seq.count('C')}

SHAP Statistics:
  • Max SHAP: {np.max(shap_values):.6f} at Pos {top_pos_idx} ({analyzed_seq[top_pos_idx]})
  • Min SHAP: {np.min(shap_values):.6f} at Pos {top_neg_idx} ({analyzed_seq[top_neg_idx]})

Feature Contributions (|v|>{threshold}):
  • Risk-Increasing: {n_positive} ({n_positive/seq_length*100:.1f}%)
  • Risk-Decreasing: {n_negative} ({n_negative/seq_length*100:.1f}%)
  • Neutral: {n_neutral} ({n_neutral/seq_length*100:.1f}%)
"""
    print(summary_text)
    return {
        "plots": paths,
        "baseline_probability": float(baseline_prob),
        "summary": summary_text
    }


# FULL SHAP (MULTI)

def run_shap_multi(sequence: str, prob_threshold=0.5):
    os.makedirs(MULTI_SHAP_DIR, exist_ok=True)
    print("\n" + "="*80)
    print("MULTI-LABEL SHAP ANALYSIS FOR GENOMIC DISEASE PREDICTION")
    print("="*80)

    class MultiLabelGenomicShapExplainer:
        def __init__(self, model, tokenizer, device, disease_labels, max_length=512):
            self.model = model
            self.tokenizer = tokenizer
            self.device = device
            self.disease_labels = disease_labels
            self.num_labels = len(disease_labels)
            self.max_length = max_length
            self.nucleotides = ['A', 'T', 'C', 'G']

        def predict_proba(self, sequence):
            encoding = self.tokenizer(
                sequence[:self.max_length],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].to(self.device)
            attention_mask = (input_ids != 0).long()
            with torch.no_grad():
                logits = multi_model(input_ids, attention_mask)
                probs = torch.sigmoid(logits).cpu().numpy()[0]
            return probs

        def compute_shap_values_multilabel(self, sequence, label_indices=None):
            seq = sequence[:self.max_length]
            seq_length = min(len(seq), self.max_length)
            baseline_probs = self.predict_proba(seq)
            if label_indices is None:
                label_indices = [i for i in range(self.num_labels) if baseline_probs[i] > 0.0]
            print(f"\n Computing SHAP values for {len(label_indices)} disease labels...")
            print(f"   Sequence length: {seq_length} nucleotides")
            shap_values_dict = {}
            for label_idx in label_indices:
                label_name = self.disease_labels[label_idx]
                baseline_prob = baseline_probs[label_idx]
                print(f"\n Processing: {label_name} (prob: {baseline_prob:.4f})")
                shap_values = np.zeros(seq_length)
                for i in tqdm(range(seq_length), desc=f"  Computing SHAP", leave=False):
                    original_nt = seq[i]
                    if original_nt not in self.nucleotides:
                        continue
                    mutation_effects = []
                    for mutant_nt in self.nucleotides:
                        if mutant_nt == original_nt:
                            continue
                        mutated_seq = seq[:i] + mutant_nt + seq[i+1:]
                        mutated_probs = self.predict_proba(mutated_seq)
                        mutated_prob = mutated_probs[label_idx]
                        effect = baseline_prob - mutated_prob
                        mutation_effects.append(effect)
                    if mutation_effects:
                        shap_values[i] = np.mean(mutation_effects)
                shap_values_dict[label_idx] = shap_values
                print(f"   Completed: Mean |SHAP| = {np.mean(np.abs(shap_values)):.6f}")
            return shap_values_dict, seq, baseline_probs

    def create_multilabel_visualizations(explainer, shap_values_dict, analyzed_seq,
                                         baseline_probs, prob_threshold=0.5):
        seq_length = len(analyzed_seq)
        feature_names = [f"Pos{i}_{nt}" for i, nt in enumerate(analyzed_seq)]
        sorted_indices = np.argsort(baseline_probs)[::-1]
        analyzed_diseases = [(idx, explainer.disease_labels[idx], baseline_probs[idx])
                             for idx in sorted_indices
                             if baseline_probs[idx] > prob_threshold and idx in shap_values_dict]
        num_diseases = len(analyzed_diseases)
        print("\n" + "="*80)
        print(f" Creating visualizations for {num_diseases} disease(s)")
        print("="*80)
        paths = []

        if num_diseases == 0:
            print("  No diseases to visualize!")
            return paths, analyzed_diseases

        # PLOT 1: Multi-Label Heatmap
        print("\n PLOT 1: Multi-Label SHAP Heatmap")
        shap_matrix = np.array([shap_values_dict[idx] for idx, _, _ in analyzed_diseases])
        disease_names = [f"{name[:40]}" for _, name, _ in analyzed_diseases]
        fig, ax = plt.subplots(figsize=(20, max(4, num_diseases * 1.2)))
        vmax = np.max(np.abs(shap_matrix))
        im = ax.imshow(shap_matrix, cmap='RdBu_r', aspect='auto', vmin=-vmax, vmax=vmax)
        ax.set_xticks(np.arange(0, seq_length, 5))
        ax.set_xticklabels([f"{i}" for i in range(0, seq_length, 5)], fontsize=8)
        ax.set_yticks(np.arange(num_diseases))
        ax.set_yticklabels(disease_names, fontsize=10)
        cbar = plt.colorbar(im, ax=ax, pad=0.02)
        cbar.set_label('SHAP Value', fontsize=11, fontweight='bold')
        ax.set_xlabel('Nucleotide Position'); ax.set_ylabel('Disease')
        plt.tight_layout()
        out = f"{MULTI_SHAP_DIR}/1_multilabel_heatmap.png"
        plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        # PLOT 2: Aggregated SHAP
        print("\n PLOT 2: Aggregated SHAP Values")
        aggregated_shap = np.zeros(seq_length)
        for idx, _, _ in analyzed_diseases:
            aggregated_shap += shap_values_dict[idx]
        fig, ax = plt.subplots(figsize=(20, 5))
        colors = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in aggregated_shap]
        ax.bar(range(seq_length), aggregated_shap, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1.5)
        ax.set_xlabel('Nucleotide Position'); ax.set_ylabel('Aggregated SHAP Value')
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        ax.set_xticks(range(0, seq_length, 5)); ax.set_xticklabels(range(0, seq_length, 5), fontsize=8)
        plt.tight_layout()
        out = f"{MULTI_SHAP_DIR}/2_aggregated_shap.png"
        plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        # PLOT 3: Individual Waterfalls
        print(f"\n PLOT 3: Individual Waterfall Plots ({num_diseases} disease(s))")
        for rank, (idx, name, prob) in enumerate(analyzed_diseases, 1):
            shap_values = shap_values_dict[idx]
            shap_explanation = shap.Explanation(
                values=shap_values,
                base_values=prob,
                data=np.array(list(analyzed_seq)),
                feature_names=np.array(feature_names)
            )
            plt.figure(figsize=(10, 8))
            shap.plots.waterfall(shap_explanation, max_display=20, show=False)
            ax = plt.gca()
            ax.set_title(f"{name} (p={prob:.4f})", fontsize=11, fontweight='bold', pad=10)
            plt.tight_layout()
            safe_name = name[:30].replace(" ", "_").replace("/", "-")
            out = f'{MULTI_SHAP_DIR}/3_{rank}_waterfall_{safe_name}.png'
            plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        # PLOT 4: Disease Comparison
        if num_diseases > 1:
            print("\n PLOT 4: Disease Comparison")
            avg_abs_shap = [np.mean(np.abs(shap_values_dict[idx])) for idx, _, _ in analyzed_diseases]
            disease_names_short = [name[:35] for _, name, _ in analyzed_diseases]
            fig, ax = plt.subplots(figsize=(10, max(4, num_diseases * 0.8)))
            colors_bar = plt.cm.viridis(np.linspace(0.3, 0.9, num_diseases))
            bars = ax.barh(disease_names_short, avg_abs_shap, color=colors_bar, edgecolor='black', linewidth=0.8)
            ax.set_xlabel('Average |SHAP Value|'); ax.set_ylabel('Disease'); ax.grid(axis='x', alpha=0.3, linestyle='--')
            for i, (bar, val) in enumerate(zip(bars, avg_abs_shap)):
                ax.text(val, i, f' {val:.6f}', va='center', fontsize=9)
            plt.tight_layout()
            out = f"{MULTI_SHAP_DIR}/4_disease_comparison.png"
            plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)
        else:
            print("\n  PLOT 4: Skipped (only 1 disease)")

        # PLOT 5: Genomic Multi-Label Overlay
        print("\n PLOT 5: Genomic Multi-Label Overlay")
        nt_colors = {'A': '#00CC00', 'T': '#FF0000', 'G': '#FFB300', 'C': '#0000FF', 'N': '#808080'}
        fig_height = max(6, 2.5 + num_diseases * 1.3)
        fig, axes = plt.subplots(num_diseases + 1, 1, figsize=(20, fig_height),
                                 gridspec_kw={'height_ratios': [0.8] + [1]*num_diseases})
        if num_diseases == 1:
            axes = [axes[0], axes[1]]
        aggregated_shap = np.zeros(seq_length)
        for idx, _, _ in analyzed_diseases:
            aggregated_shap += shap_values_dict[idx]
        ax_seq = axes[0]; base_fontsize = 9
        for i, nt in enumerate(analyzed_seq):
            color = nt_colors.get(nt, '#808080')
            importance = np.abs(aggregated_shap[i])
            fontsize = base_fontsize + min(importance * 25, 5)
            alpha = 0.6 + min(importance * 2, 0.4)
            ax_seq.text(i, 0, nt, fontsize=fontsize, ha='center', va='center',
                        color=color, fontweight='bold', alpha=alpha)
        ax_seq.set_xlim(-1, seq_length); ax_seq.set_ylim(-0.5, 0.5); ax_seq.axis('off')
        for ax_idx, (label_idx, name, prob) in enumerate(analyzed_diseases):
            ax = axes[ax_idx + 1]
            shap_vals = shap_values_dict[label_idx]
            colors_plot = ['#FF6B6B' if v > 0 else '#4ECDC4' for v in shap_vals]
            ax.bar(range(seq_length), shap_vals, color=colors_plot, alpha=0.8, edgecolor='black', linewidth=0.3)
            ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
            ax.set_ylabel('SHAP', fontsize=9, fontweight='bold')
            ax.text(0.01, 0.95, f'{name[:45]} (p={prob:.4f})', transform=ax.transAxes, fontsize=9, va='top', fontweight='bold')
            ax.grid(axis='y', alpha=0.2, linestyle='--'); ax.set_xlim(-1, seq_length)
            if ax_idx < num_diseases - 1:
                ax.set_xticks([])
            else:
                ax.set_xticks(range(0, seq_length, 5))
                ax.set_xticklabels(range(0, seq_length, 5), fontsize=8)
                ax.set_xlabel('Nucleotide Position', fontsize=10, fontweight='bold')
        plt.tight_layout()
        out = f"{MULTI_SHAP_DIR}/5_genomic_multilabel.png"
        plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        # PLOT 6: Contribution Matrix
        print("\n PLOT 6: Contribution Matrix")
        shap_matrix = np.array([shap_values_dict[idx] for idx, _, _ in analyzed_diseases])
        contribution_matrix = np.abs(shap_matrix)
        fig, ax = plt.subplots(figsize=(20, max(4, num_diseases * 1.2)))
        im = ax.imshow(contribution_matrix, cmap='YlOrRd', aspect='auto')
        ax.set_xticks(np.arange(0, seq_length, 5))
        ax.set_xticklabels([f"{i}" for i in range(0, seq_length, 5)], fontsize=8)
        ax.set_yticks(np.arange(num_diseases))
        ax.set_yticklabels(disease_names, fontsize=10)
        cbar = plt.colorbar(im, ax=ax, pad=0.02)
        cbar.set_label('|SHAP Value|', fontsize=11, fontweight='bold')
        ax.set_xlabel('Nucleotide Position'); ax.set_ylabel('Disease')
        plt.tight_layout()
        out = f"{MULTI_SHAP_DIR}/6_contribution_matrix.png"
        plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        # PLOT 7: Individual Force Plots
        print(f"\n PLOT 7: Individual Force Plots ({num_diseases} disease(s))")
        for rank, (idx, name, prob) in enumerate(analyzed_diseases, 1):
            shap_values = shap_values_dict[idx]
            shap_explanation = shap.Explanation(
                values=shap_values,
                base_values=prob,
                data=np.array(list(analyzed_seq)),
                feature_names=np.array(feature_names)
            )
            plt.figure(figsize=(20, 3))
            shap.plots.force(shap_explanation, matplotlib=True, show=False)
            ax = plt.gca()
            ax.set_title(f"{name} (base={prob:.4f})", fontsize=10, fontweight='bold', pad=8)
            plt.tight_layout()
            safe_name = name[:30].replace(" ", "_").replace("/", "-")
            out = f'{MULTI_SHAP_DIR}/7_{rank}_force_{safe_name}.png'
            plt.savefig(out, dpi=300, bbox_inches='tight'); plt.close(); paths.append(out)

        return paths, analyzed_diseases

    # disease selection
    explainer = MultiLabelGenomicShapExplainer(
        model=multi_model, tokenizer=tokenizer, device=device,
        disease_labels=disease_labels, max_length=MAX_TOKEN_LEN
    )
    baseline_probs_initial = explainer.predict_proba(sequence)
    sorted_indices = np.argsort(baseline_probs_initial)[::-1]
    high_prob_indices = [idx for idx in range(len(baseline_probs_initial))
                         if baseline_probs_initial[idx] > prob_threshold]
    print(f"\n Disease Selection (Threshold = {prob_threshold}):")
    if len(high_prob_indices) > 0:
        print(f" Found {len(high_prob_indices)} disease(s) with probability > {prob_threshold}:")
        for rank, idx in enumerate(sorted(high_prob_indices, key=lambda x: baseline_probs_initial[x], reverse=True), 1):
            print(f"  {rank}. {disease_labels[idx][:65]}: {baseline_probs_initial[idx]:.6f}")
    else:
        print(f"  No diseases > {prob_threshold}. Falling back to top 3.")
        high_prob_indices = sorted_indices[:3].tolist()
        for rank, idx in enumerate(high_prob_indices, 1):
            print(f"  {rank}. {disease_labels[idx][:65]}: {baseline_probs_initial[idx]:.6f}")

    shap_values_dict, analyzed_seq, baseline_probs = explainer.compute_shap_values_multilabel(
        sequence=sequence,
        label_indices=high_prob_indices
    )
    plot_paths, analyzed_diseases = create_multilabel_visualizations(
        explainer=explainer,
        shap_values_dict=shap_values_dict,
        analyzed_seq=analyzed_seq,
        baseline_probs=baseline_probs,
        prob_threshold=prob_threshold
    )

    # SUMMARY LOG
    print("\n" + "="*80)
    print(" MULTI-LABEL SHAP ANALYSIS SUMMARY")
    print("="*80)
    seq_length = len(analyzed_seq)
    sorted_indices_final = np.argsort(baseline_probs)[::-1]
    print(f"""
Sequence Statistics:
  • Length: {seq_length} bp
  • A: {analyzed_seq.count('A')}  T: {analyzed_seq.count('T')}  G: {analyzed_seq.count('G')}  C: {analyzed_seq.count('C')}

Disease Predictions:
  • Total Labels: {len(disease_labels)}
  • Labels Analyzed (SHAP computed): {len(shap_values_dict)}
  • Labels with prob > 0.5: {int(sum(baseline_probs > 0.5))}
  • Labels with prob > 0.1: {int(sum(baseline_probs > 0.1))}
""")
    print("Top 15 Disease Predictions:")
    print(f"{'Rank':<5} {'Disease':<58} {'Probability':>12} {'Avg |SHAP|':>12}")
    print("-" * 90)
    for rank, idx in enumerate(sorted_indices_final[:15], 1):
        name = disease_labels[idx]
        prob = baseline_probs[idx]
        avg_shap = float(np.mean(np.abs(shap_values_dict[idx]))) if idx in shap_values_dict else 0.0
        marker = "*" if idx in shap_values_dict else "  "
        print(f"{marker} {rank:<3} {name:<58.58} {prob:>12.6f} {avg_shap:>12.6f}")

    summary_lines = []
    summary_lines.append(" MULTI-LABEL SHAP ANALYSIS SUMMARY")
    summary_lines.append("=" * 80)
    summary_lines.append(f"\nSequence Statistics:")
    summary_lines.append(f"  • Length: {len(analyzed_seq)} bp")
    summary_lines.append(f"  • A: {analyzed_seq.count('A')}  T: {analyzed_seq.count('T')}  G: {analyzed_seq.count('G')}  C: {analyzed_seq.count('C')}\n")
    summary_lines.append("Disease Predictions:")
    summary_lines.append(f"  • Total Labels: {len(disease_labels)}")
    summary_lines.append(f"  • Labels Analyzed (SHAP computed): {len(shap_values_dict)}")
    summary_lines.append(f"  • Labels with prob > 0.5: {int(sum(baseline_probs > 0.5))}")
    summary_lines.append(f"  • Labels with prob > 0.1: {int(sum(baseline_probs > 0.1))}\n")
    summary_lines.append("Top 15 Disease Predictions:")
    summary_lines.append(f"{'Rank':<5} {'Disease':<58} {'Probability':>12} {'Avg |SHAP|':>12}")
    summary_lines.append("-" * 90)
    for rank, idx in enumerate(sorted_indices_final[:15], 1):
        name = disease_labels[idx]
        prob = baseline_probs[idx]
        avg_shap = float(np.mean(np.abs(shap_values_dict[idx]))) if idx in shap_values_dict else 0.0
        marker = "*" if idx in shap_values_dict else "  "
        summary_lines.append(f"{marker} {rank:<3} {name:<58.58} {prob:>12.6f} {avg_shap:>12.6f}")
    summary_text = "\n".join(summary_lines)


    return {
        "plots": plot_paths,
        "summary": summary_text,
        "analyzed_diseases": [
            {"index": int(idx), "name": name, "prob": float(prob)}
            for idx, name, prob in analyzed_diseases
        ],
        "baseline_probs": [float(x) for x in baseline_probs]
    }

# FLASK APP

app = Flask(__name__)
CORS(app)

# Serve SHAP image files from Colab absolute paths
@app.route("/shap_outputs/<path:filename>")
def serve_single_shap(filename):
    full_path = os.path.join("/content/shap_outputs", filename)
    if not os.path.exists(full_path):
        print(f" File not found (single): {full_path}")
    return send_from_directory("/content/shap_outputs", filename)

@app.route("/shap_outputs_multilabel/<path:filename>")
def serve_multi_shap(filename):
    full_path = os.path.join("/content/shap_outputs_multilabel", filename)
    if not os.path.exists(full_path):
        print(f" File not found (multi): {full_path}")
    return send_from_directory("/content/shap_outputs_multilabel", filename)

@app.route("/health")
def health():
    return jsonify({"status": "running", "device": device})

@app.route("/predict", methods=["POST"])
def predict_api():
    data = request.get_json(force=True) or {}
    sequence = data.get("sequence", "").strip().upper()
    mode = data.get("mode", "single").lower()
    if not sequence:
        return jsonify({"error": "DNA sequence required"}), 400
    if mode == "single":
        res = predict_single(sequence)
    else:
        res = predict_multi(sequence)
    return jsonify(res)

#  ngrok connection must happen BEFORE defining shap_api
public_url = ngrok.connect(5000).public_url
print("🔗 Public API endpoint:", public_url)

@app.route("/shap", methods=["POST"])
def shap_api():
    data = request.get_json(force=True) or {}
    sequence = data.get("sequence", "").strip().upper()
    mode = data.get("mode", "single").lower()
    if not sequence:
        return jsonify({"error": "DNA sequence required"}), 400

    if mode == "single":
        payload = run_shap_single(sequence)
        folder_route = "shap_outputs"
    else:
        payload = run_shap_multi(sequence, prob_threshold=float(data.get("prob_threshold", 0.5)))
        folder_route = "shap_outputs_multilabel"

    #  Use ngrok URL (not localhost)
    base_url = public_url.rstrip("/")
    plots = []
    for p in payload.get("plots", []):
        if os.path.exists(p):
            plots.append(f"{base_url}/{folder_route}/{os.path.basename(p)}")
        else:
            print(f" Plot not found on disk: {p}")

    print(f" Returning {len(plots)} plot URLs to frontend")

    return jsonify({
        "message": f"SHAP ({mode}) completed",
        "summary": payload.get("summary", ""),
        "plots": plots
    })


# START SERVER

app.run(port=5000)


Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenization_caduceus.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16:
- tokenization_caduceus.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


special_tokens_map.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

configuration_caduceus.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16:
- configuration_caduceus.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_caduceus.py: 0.00B [00:00, ?B/s]

  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


modeling_rcps.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16:
- modeling_rcps.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16:
- modeling_caduceus.py
- modeling_rcps.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/30.9M [00:00<?, ?B/s]

 Loading weights...


  checkpoint = torch.load("dna_disease_classifier_final.pth", map_location=device)


 Label mismatch detected: model expects 100, but found 95.
🔧 Added 5 placeholder labels: ['Unknown_0', 'Unknown_1', 'Unknown_2', 'Unknown_3', 'Unknown_4']
 Final label count synchronized: 100
 Single model loaded
 Multi model loaded with 100 labels
🔗 Public API endpoint: https://gema-thirstless-insincerely.ngrok-free.dev
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:25:53] "OPTIONS /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:26:01] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:26:05] "OPTIONS /shap HTTP/1.1" 200 -



COMPUTING SHAP VALUES (Single-label)
 Computing SHAP values for 101 nucleotide positions...


100%|██████████| 101/101 [00:20<00:00,  4.88it/s]



 Creating SHAP Explanation object...

 PLOT 2: WATERFALL

 PLOT 3: FORCE

 PLOT 4: DECISION

 PLOT 5: BAR

 PLOT 6: BEESWARM

PLOT 7: HEATMAP

 PLOT 8: GENOMIC VISUALIZATION


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:26:37] "POST /shap HTTP/1.1" 200 -



 SHAP ANALYSIS SUMMARY (Single-label)

Model Prediction:
  • Disease Probability: 84.81%
  • Classification: Hereditary cancer predisposing syndrome
  • Base Value: 0.848091

Sequence Statistics:
  • Total Length: 101 bp
  • A: 27  T: 35  G: 16  C: 23

SHAP Statistics:
  • Max SHAP: 0.840123 at Pos 61 (C)
  • Min SHAP: -0.017432 at Pos 5 (T)

Feature Contributions (|v|>0.001):
  • Risk-Increasing: 66 (65.3%)
  • Risk-Decreasing: 31 (30.7%)
  • Neutral: 4 (4.0%)

 Returning 7 plot URLs to frontend


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:28:57] "[33mGET /content/shap_outputs/1_waterfall_plot.png HTTP/1.1[0m" 404 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:28:58] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:29:10] "[33mGET /content/1_waterfall_plot.png HTTP/1.1[0m" 404 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:29:43] "GET /shap_outputs/1_waterfall_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:30:16] "OPTIONS /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:30:17] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:30:47] "OPTIONS /shap HTTP/1.1" 200 -



COMPUTING SHAP VALUES (Single-label)
 Computing SHAP values for 101 nucleotide positions...


100%|██████████| 101/101 [00:16<00:00,  6.05it/s]



 Creating SHAP Explanation object...

 PLOT 2: WATERFALL

 PLOT 3: FORCE

 PLOT 4: DECISION

 PLOT 5: BAR

 PLOT 6: BEESWARM

PLOT 7: HEATMAP

 PLOT 8: GENOMIC VISUALIZATION


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:15] "POST /shap HTTP/1.1" 200 -



 SHAP ANALYSIS SUMMARY (Single-label)

Model Prediction:
  • Disease Probability: 84.81%
  • Classification: Hereditary cancer predisposing syndrome
  • Base Value: 0.848091

Sequence Statistics:
  • Total Length: 101 bp
  • A: 27  T: 35  G: 16  C: 23

SHAP Statistics:
  • Max SHAP: 0.840123 at Pos 61 (C)
  • Min SHAP: -0.017432 at Pos 5 (T)

Feature Contributions (|v|>0.001):
  • Risk-Increasing: 66 (65.3%)
  • Risk-Decreasing: 31 (30.7%)
  • Neutral: 4 (4.0%)

 Returning 7 plot URLs to frontend


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:16] "GET /shap_outputs/1_waterfall_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:16] "GET /shap_outputs/2_force_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:16] "GET /shap_outputs/3_decision_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:18] "GET /shap_outputs/5_beeswarm_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:18] "GET /shap_outputs/4_bar_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:19] "GET /shap_outputs/6_heatmap.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:19] "GET /shap_outputs/7_genomic_visualization.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:50] "OPTIONS /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:50] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:31:53] "OPTIONS /shap HTTP/1.1" 200 -



COMPUTING SHAP VALUES (Single-label)
 Computing SHAP values for 103 nucleotide positions...


100%|██████████| 103/103 [00:17<00:00,  6.03it/s]



 Creating SHAP Explanation object...

 PLOT 2: WATERFALL

 PLOT 3: FORCE

 PLOT 4: DECISION

 PLOT 5: BAR

 PLOT 6: BEESWARM

PLOT 7: HEATMAP

 PLOT 8: GENOMIC VISUALIZATION


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:22] "POST /shap HTTP/1.1" 200 -



 SHAP ANALYSIS SUMMARY (Single-label)

Model Prediction:
  • Disease Probability: 93.91%
  • Classification: Hereditary cancer predisposing syndrome
  • Base Value: 0.939086

Sequence Statistics:
  • Total Length: 103 bp
  • A: 16  T: 13  G: 49  C: 25

SHAP Statistics:
  • Max SHAP: 0.014580 at Pos 69 (A)
  • Min SHAP: -0.003457 at Pos 45 (T)

Feature Contributions (|v|>0.001):
  • Risk-Increasing: 54 (52.4%)
  • Risk-Decreasing: 14 (13.6%)
  • Neutral: 35 (34.0%)

 Returning 7 plot URLs to frontend


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/4_bar_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/5_beeswarm_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/2_force_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/6_heatmap.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/3_decision_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/7_genomic_visualization.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:32:32] "GET /shap_outputs/1_waterfall_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:53:26] "OPTIONS /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 04:53:26] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:11:55] "OPTIONS /predict HTTP/1.1" 200 -
INFO:werkze


COMPUTING SHAP VALUES (Single-label)
 Computing SHAP values for 103 nucleotide positions...


100%|██████████| 103/103 [00:17<00:00,  5.93it/s]



 Creating SHAP Explanation object...

 PLOT 2: WATERFALL

 PLOT 3: FORCE

 PLOT 4: DECISION

 PLOT 5: BAR

 PLOT 6: BEESWARM

PLOT 7: HEATMAP

 PLOT 8: GENOMIC VISUALIZATION


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:27] "POST /shap HTTP/1.1" 200 -



 SHAP ANALYSIS SUMMARY (Single-label)

Model Prediction:
  • Disease Probability: 93.91%
  • Classification: Hereditary cancer predisposing syndrome
  • Base Value: 0.939086

Sequence Statistics:
  • Total Length: 103 bp
  • A: 16  T: 13  G: 49  C: 25

SHAP Statistics:
  • Max SHAP: 0.014580 at Pos 69 (A)
  • Min SHAP: -0.003457 at Pos 45 (T)

Feature Contributions (|v|>0.001):
  • Risk-Increasing: 54 (52.4%)
  • Risk-Decreasing: 14 (13.6%)
  • Neutral: 35 (34.0%)

 Returning 7 plot URLs to frontend


INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:27] "GET /shap_outputs/1_waterfall_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:28] "GET /shap_outputs/3_decision_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:28] "GET /shap_outputs/2_force_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:30] "GET /shap_outputs/4_bar_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:30] "GET /shap_outputs/5_beeswarm_plot.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:30] "GET /shap_outputs/6_heatmap.png HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [06/Dec/2025 05:12:31] "GET /shap_outputs/7_genomic_visualization.png HTTP/1.1" 200 -
