In [None]:
# 🧪 Developer Lab Notebook
# Experiment: Upstream Disruption Analysis for Interpretability

import pandas as pd
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
from core import AnnotatorPipeline

# Load Model
annotator = AnnotatorPipeline()

# Helper Functions

def generate_kmer(seq, k=6, overlap=1):
    """Tokenizes a sequence into k-mers."""
    return " ".join([seq[i:i+k] for i in range(0, len(seq) - k + 1, overlap)])

def disrupt_sequence(seq, disruption_ratio=0.5, window_size=60):
    """Randomize a window region around TIS site upstream."""
    center = len(seq) // 2
    half_window = window_size // 2
    start = max(center - half_window, 0)
    end = min(center + half_window, len(seq))
    
    sequence_list = list(seq)
    region_length = end - start
    num_bases_to_disrupt = int(region_length * disruption_ratio)
    indices = random.sample(range(start, end), num_bases_to_disrupt)
    
    for idx in indices:
        sequence_list[idx] = random.choice(['A', 'T', 'C', 'G'])
    
    return "".join(sequence_list)

def predict_prob(model, tokenizer, seqs):
    """Get model probabilities."""
    model.eval()
    batch = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True)
    batch = {k: v.to(annotator.device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    logits = outputs.logits.cpu().numpy()
    probs = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
    return probs[:, 1]  # Probability of class 1 (TIS)

# Upload File
uploaded_file = './webtool/developer-lab/sample_tis_sequences.csv'  # <-- Replace or use file_uploader in notebook UI

# Load Data
data = pd.read_csv(uploaded_file)  # must have 'sequence' and 'label' columns
print(f"Loaded {len(data)} sequences.")

# Parameters
disruption_ratios = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
window_size = 60  # Number of nucleotides around center to consider

# Prepare Results
results = {r: [] for r in disruption_ratios}

# Run experiment
for ratio in disruption_ratios:
    print(f"Running for disruption ratio {ratio}...")
    disrupted_seqs = []
    
    for seq in data['sequence']:
        if ratio == 0.0:
            disrupted_seqs.append(seq)
        else:
            disrupted_seqs.append(disrupt_sequence(seq, disruption_ratio=ratio, window_size=window_size))
    
    # Tokenize disrupted sequences
    kmers = [generate_kmer(s, k=6, overlap=1) for s in disrupted_seqs]
    
    # Predict
    probs = predict_prob(annotator.model_tis, annotator.tokenizer, kmers)
    results[ratio] = probs

# Plot
mean_probs = [np.mean(results[r]) for r in disruption_ratios]

plt.figure(figsize=(8,6))
plt.plot(disruption_ratios, mean_probs, marker='o')
plt.title('TIS Probability Drop vs. Upstream Disruption')
plt.xlabel('Disruption Ratio')
plt.ylabel('Mean TIS Class Probability')
plt.grid(True)
plt.show()