<a href="https://colab.research.google.com/github/MatiasNazareth1993-coder/Virtual-cell/blob/main/Simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# # Pseudocode (conceptual) — safe to run in silico
# initialize_simulator()
# initialize_sequence_population()
# initialize_policy_agent()

# for iteration in range(N_iterations):
#     # 1) Propose telomere variants
#     candidates = propose_sequences(sequence_population)

#     # 2) For each candidate run organ-on-chip simulation with policy agent
#     results = []
#     for seq in candidates:
#         simulator.load_telomere_model(seq)
#         # policy_agent may be pre-trained or co-trained
#         sim_outcome = run_simulation(simulator, policy_agent, steps=T)
#         results.append((seq, sim_outcome.metrics))

#     # 3) Evaluate and update sequence optimizer
#     sequence_population = update_sequence_population(results)

#     # 4) Train/finetune policy agent on collected trajectories
#     policy_agent.update(from_trajectories(results))

#     # 5) Log metrics; enforce safety checks
#     if detect_unacceptable_risk(results):
#         apply_safety_mitigation()

In [7]:
alpha = 0.5
beta = 0.2
gamma = 0.3
delta = 0.1

# Placeholder values for demonstration; please replace with actual simulation outputs/metrics
organ_function_t = 0.8
organ_function_tminus1 = 0.75
oncogenic_risk_proxy = 0.1
change_in_senescent_fraction = 0.05
telomerase_activation_cost = 0.02

reward = (alpha * (organ_function_t - organ_function_tminus1)
         - beta * oncogenic_risk_proxy
         - gamma * change_in_senescent_fraction
         - delta * telomerase_activation_cost)

In [None]:
{
  "StabilityScore":                0–1,
  "TelomeraseAffinityScore":       0–1,
  "FragilityRiskScore":            0–1 (higher = worse),
  "SecondaryStructureScore":       0–1,
  "ImmuneTriggerScore":            0–1 (higher = worse)
}


In [11]:
# Input: telomere sequence (e.g., “TTAGGGTTAGGC…”)
# → one-hot encode (A,T,G,C)


# fused = concat(sequence_embedding, cell_state_embedding)


# {
#   "sequence": "...",
#   "environment": { debris_load, phag_activity, cell_stress },
#   "true_simulated_outcomes": { stability, affinity, fragility, etc. }
# }


# if StabilityScore > 0.7 and TelomeraseAffinityScore > 0.6:
#       allow_low_activation()
# elif FragilityRiskScore > 0.5:
#       block_activation()
# else:
#       cautious_mode()

In [13]:
import torch.nn as nn

# ------------------------------------------------------
# 1) TEL-ML MODEL (SEQUENCE + CONTEXT PREDICTOR)
# ------------------------------------------------------

class TelomerePredictor(nn.Module):
    def __init__(self):
        super().__init__()

        # CNN for local motif detection
        self.cnn = nn.Sequential(
            nn.Conv1d(4, 32, kernel_size=6, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=6, padding=2),
            nn.ReLU()
        )

        # Transformer encoder for long-range interactions
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=64, nhead=4, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # MLP for contextual/environment features
        self.context_mlp = nn.Sequential(
            nn.Linear(10, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )

        # Fusion + prediction head
        self.pred_head = nn.Sequential(
            nn.Linear(64 + 32, 64),
            nn.ReLU(),
            nn.Linear(64, 5),   # 5 prediction outputs
            nn.Sigmoid()        # normalize to 0–1
        )

    def forward(self, seq_onehot, context_vec):
        # seq_onehot shape: [B, L, 4]
        x = seq_onehot.permute(0, 2, 1)  # → [B, 4, L]
        x = self.cnn(x)                  # → [B, 64, L]
        x = x.permute(0, 2, 1)           # → [B, L, 64]
        x = self.transformer(x)          # → [B, L, 64]
        seq_embed = x.mean(dim=1)        # pooled representation

        context_embed = self.context_mlp(context_vec)

        fused = torch.cat([seq_embed, context_embed], dim=1)
        return self.pred_head(fused)

In [14]:
# =========================================================
# MULTI-TASK + MULTI-HEAD ATTENTION VARIANT (SAFE, CONCEPTUAL)
# =========================================================

class MultiTaskTelomereModel(nn.Module):
    def __init__(self, num_context_features=12,
                 d_model=96, num_heads=6, depth=3):
        super().__init__()

        # -----------------------
        # 1) CNN for sequence motifs
        # -----------------------
        self.cnn = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, d_model, kernel_size=5, padding=2),
            nn.ReLU()
        )

        # -----------------------
        # 2) Transformer for multi-head attention
        # -----------------------
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=4*d_model,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth
        )

        # -----------------------
        # 3) Context MLP (Organ-on-chip + cell-state)
        # -----------------------
        self.context_mlp = nn.Sequential(
            nn.Linear(num_context_features, 64),
            nn.ReLU(),
            nn.Linear(64, d_model)
        )

        # -----------------------
        # 4) Shared fusion layer
        # -----------------------
        self.shared_fusion = nn.Sequential(
            nn.Linear(2*d_model, d_model),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # -----------------------
        # 5) Task-Specific Heads
        # -----------------------
        self.stability_head  = self._make_task_head()
        self.affinity_head   = self._make_task_head()
        self.fragility_head  = self._make_task_head()
        self.structure_head  = self._make_task_head()
        self.immune_head     = self._make_task_head()

    # small task-specific prediction module
    def _make_task_head(self):
        return nn.Sequential(
            nn.Linear(96, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, seq_onehot, context_vec):
        # ----- CNN -----
        x = seq_onehot.permute(0, 2, 1)  # [B, 4, L]
        x = self.cnn(x)                  # [B, d_model, L]
        x = x.permute(0, 2, 1)           # [B, L, d_model]

        # ----- Transformer -----
        x = self.transformer(x)
        seq_embed = x.mean(dim=1)

        # ----- Context -----
        cont_embed = self.context_mlp(context_vec)

        # ----- Shared fusion -----
        fused = torch.cat([seq_embed, cont_embed], dim=1)
        fused = self.shared_fusion(fused)

        # ----- Multi-task outputs -----
        return {
            "StabilityScore":        self.stability_head(fused),
            "TelomeraseAffinity":    self.affinity_head(fused),
            "FragilityRisk":         self.fragility_head(fused),
            "SecondaryStructure":    self.structure_head(fused),
            "ImmuneTrigger":         self.immune_head(fused)
        }


In [18]:
# loss = (w1*MSE(Stability)
#       + w2*MSE(Affinity)
#       + w3*MSE(Fragility)
#       + w4*MSE(Structure)
#       + w5*MSE(Immune))
# # SharedRepresentation → Multi-Head Cross-Attention → Task Heads

In [21]:
# Creating a full training pipeline + synthetic dataset generator for the
# multi-task, multi-head telomere predictor (simulation-only, safe).
#
# This notebook:
# 1) Generates a synthetic dataset of telomere-like sequences + context features
# 2) Builds a multi-task model (CNN + Transformer + task heads)
# 3) Trains the model on synthetic targets (stability, affinity, fragility, structure, immune)
# 4) Evaluates and saves model + dataset
# 5) Displays sample data and training curves
#
# Outputs saved to /mnt/data:
# - /mnt/data/synthetic_telomere_dataset.csv
# - /mnt/data/telomere_multitask_model.pth
# - /mnt/data/training_history.npy
#
# NOTE: This code is purely computational and does NOT produce any wet-lab or
# actionable biological instructions. It's safe to run and intended for in-silico simulation.

# Execute the training pipeline
import os
import random
import math
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt

# For interactive dataframe display in the notebook UI
try:
    from caas_jupyter_tools import display_dataframe_to_user
except Exception:
    display_dataframe_to_user = None

# reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTDIR = Path("/mnt/data")
OUTDIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Synthetic dataset generator
# -------------------------
BASE_REPEAT = "TTAGGG"  # reference repeat (human-like) used only as inspiration
NUCLEOTIDES = ['A', 'T', 'G', 'C']

def mutate_repeat_unit(base, mutation_rate=0.1):
    """Mutate a base repeat unit by randomly substituting bases with probability mutation_rate."""
    out = []
    for ch in base:
        if random.random() < mutation_rate:
            out.append(random.choice(NUCLEOTIDES))
        else:
            out.append(ch)
    return ''.join(out)

def build_sequence(repeat_unit, repeats):
    return repeat_unit * repeats

def sequence_entropy(seq):
    counts = {n: seq.count(n) for n in NUCLEOTIDES}
    probs = np.array(list(counts.values())) / len(seq)
    probs = probs[probs > 0]
    return -np.sum(probs * np.log2(probs))

def g_rich_runs_score(seq):
    # proxy for G-quadruplex-like propensity: count occurrences of "GGG" and longer runs
    score = 0
    run = 0
    for ch in seq:
        if ch == 'G':
            run += 1
        else:
            if run >= 3:
                score += run
            run = 0
    if run >= 3:
        score += run
    return score / max(1, len(seq)/6)  # normalize by length/6

def palindromic_score(seq, k=6):
    # count small palindromic windows as proxy for hairpins
    score = 0
    for i in range(len(seq)-k+1):
        window = seq[i:i+k]
        # simple reverse complement check
        rc = window[::-1].translate(str.maketrans("ATGC","TACG"))
        if rc == window:
            score += 1
    return score / max(1, len(seq)/k)

def motif_purity_score(seq, repeat_unit):
    # fraction of sequence that matches the repeated repeat_unit perfectly when tiled
    L = len(repeat_unit)
    perfect = 0
    for i in range(0, len(seq), L):
        block = seq[i:i+L]
        if block == repeat_unit:
            perfect += 1
    return perfect * L / len(seq)

def synth_targets_from_sequence(seq, repeat_unit, context_vec):
    """
    Heuristic synthetic target generator. Produces five values in [0,1]:
    StabilityScore, TelomeraseAffinityScore, FragilityRiskScore,
    SecondaryStructureScore, ImmuneTriggerScore
    """
    length = len(seq)
    repeats = length / max(1, len(repeat_unit))
    purity = motif_purity_score(seq, repeat_unit)  # 0..1
    entropy = sequence_entropy(seq) / 2.0  # normalize roughly
    gscore = g_rich_runs_score(seq)  # normalized proxy
    palin = palindromic_score(seq)
    variability = 1.0 - purity

    # Context modifiers (safe, synthetic)
    debris, phag_activity, inflammation, replication_stress = context_vec[:4]

    # Stability: higher with purity, length; lowered by gscore/palin/replication stress
    stability_raw = 0.5 * purity + 0.4 * (math.tanh((repeats-8)/8)+1)/2 - 0.3 * gscore - 0.2 * palin - 0.2 * replication_stress
    stability = float(1/(1+math.exp(-stability_raw)))  # sigmoid-ish to 0..1

    # Telomerase affinity: higher for periodicity & moderate G-richness; penalize high entropy
    affinity_raw = 0.6 * purity + 0.2 * (math.tanh(gscore)+1)/2 - 0.3 * entropy + 0.2 * (1-replication_stress)
    affinity = float(1/(1+math.exp(-affinity_raw)))

    # Fragility risk: increased by variability, palindromes, and high replication stress
    frag_raw = 0.5 * variability + 0.4 * palin + 0.3 * replication_stress + 0.2 * entropy
    frag = float(1/(1+math.exp(- (frag_raw - 0.5) )))

    # Secondary structure score: proxy from gscore + palin
    sec_raw = 0.6 * gscore + 0.4 * palin
    sec = float(1/(1+math.exp(- (sec_raw - 0.3) )))

    # Immune trigger: higher entropy and exposed non-native motifs, increased by inflammation
    imm_raw = 0.5 * entropy + 0.3 * variability + 0.4 * inflammation
    imm = float(1/(1+math.exp(- (imm_raw - 0.2) )))

    # Add small noise
    stability = min(max(stability + np.random.normal(0, 0.02), 0.0), 1.0)
    affinity = min(max(affinity + np.random.normal(0, 0.03), 0.0), 1.0)
    frag = min(max(frag + np.random.normal(0, 0.03), 0.0), 1.0)
    sec = min(max(sec + np.random.normal(0, 0.02), 0.0), 1.0)
    imm = min(max(imm + np.random.normal(0, 0.03), 0.0), 1.0)

    return stability, affinity, frag, sec, imm

# Generate dataset parameters
N_SAMPLES = 4000
MIN_REPEATS = 5
MAX_REPEATS = 60
MAX_SEQ_LEN = MAX_REPEATS * len(BASE_REPEAT)  # for padding

records = []
sequences = []

for i in range(N_SAMPLES):
    # choose a repeat unit by mutating base with variable mutation rate
    mut_rate = random.choice([0.0, 0.05, 0.1, 0.2])
    ru = mutate_repeat_unit(BASE_REPEAT, mutation_rate=mut_rate)
    repeats = random.randint(MIN_REPEATS, MAX_REPEATS)
    seq = build_sequence(ru, repeats)

    # randomly insert occasional random short motifs to increase diversity
    if random.random() < 0.15:
        pos = random.randint(0, max(0, len(seq)-10))
        insert = ''.join(random.choices(NUCLEOTIDES, k=6))
        seq = seq[:pos] + insert + seq[pos:]

    # context features (synthetic)
    debris = random.random()  # 0-1
    phag_activity = random.random()
    inflammation = random.random()
    replication_stress = random.random()
    # additional context features (stem cell fraction, local nutrients, oxygen proxy, time)
    stem_frac = random.random()
    nutrients = random.random()
    oxygen = random.random()
    time_of_day = random.random()

    context_vec = [debris, phag_activity, inflammation, replication_stress,
                   stem_frac, nutrients, oxygen, time_of_day, random.random(), random.random(), 0.0, 0.0][:12]

    stability, affinity, frag, sec, imm = synth_targets_from_sequence(seq, ru, context_vec)

    records.append({
        "sequence": seq,
        "repeat_unit": ru,
        "repeats": repeats,
        "length": len(seq),
        "purity": motif_purity_score(seq, ru),
        "entropy": sequence_entropy(seq),
        "gscore": g_rich_runs_score(seq),
        "palin": palindromic_score(seq),
        "debris": debris,
        "phag_activity": phag_activity,
        "inflammation": inflammation,
        "replication_stress": replication_stress,
        "stem_frac": stem_frac,
        "nutrients": nutrients,
        "oxygen": oxygen,
        "time_of_day": time_of_day,
        "StabilityScore": stability,
        "TelomeraseAffinityScore": affinity,
        "FragilityRiskScore": frag,
        "SecondaryStructureScore": sec,
        "ImmuneTriggerScore": imm
    })
    sequences.append(seq)

df = pd.DataFrame.from_records(records)
# Save CSV for external inspection
csv_path = OUTDIR / "synthetic_telomere_dataset.csv"
df.to_csv(csv_path, index=False)

# -------------------------
# Dataset and dataloader
# -------------------------
# One-hot encode sequences and pad to MAX_SEQ_LEN
def one_hot_encode_seq(seq, max_len=MAX_SEQ_LEN):
    mapping = {'A':0,'T':1,'G':2,'C':3}
    arr = np.zeros((max_len, 4), dtype=np.float32)
    for i, ch in enumerate(seq[:max_len]):
        if ch in mapping:
            arr[i, mapping[ch]] = 1.0
    return arr

# Build tensors
seq_tensors = np.stack([one_hot_encode_seq(s) for s in df['sequence'].values])
context_features = df[["debris","phag_activity","inflammation","replication_stress",
                       "stem_frac","nutrients","oxygen","time_of_day","purity","entropy",
                       "gscore","palin"]].values.astype(np.float32)
targets = df[["StabilityScore","TelomeraseAffinityScore","FragilityRiskScore",
              "SecondaryStructureScore","ImmuneTriggerScore"]].values.astype(np.float32)

# Simple Dataset
class TelomereDataset(Dataset):
    def __init__(self, seq_tensor, context, targets):
        self.seq = torch.tensor(seq_tensor)  # [N, L, 4]
        self.context = torch.tensor(context)
        self.targets = torch.tensor(targets)

    def __len__(self):
        return self.seq.shape[0]

    def __getitem__(self, idx):
        return self.seq[idx], self.context[idx], self.targets[idx]

dataset = TelomereDataset(seq_tensors, context_features, targets)

# Train/test split
n_train = int(len(dataset)*0.8)
n_val = len(dataset) - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(SEED))

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# -------------------------
# Model definition (multi-task + multi-head)
# -------------------------
class MultiTaskTelomereModel(nn.Module):
    def __init__(self, seq_len=MAX_SEQ_LEN, d_model=96, num_heads=6, depth=2, num_context=12):
        super().__init__()
        self.d_model = d_model
        # CNN
        self.cnn = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, d_model, kernel_size=5, padding=2),
            nn.ReLU()
        )
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=4*d_model, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        # Context MLP
        self.context_mlp = nn.Sequential(
            nn.Linear(num_context, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model),
            nn.ReLU()
        )
        # Shared fusion
        self.shared_fusion = nn.Sequential(
            nn.Linear(2*d_model, d_model),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        # Task heads
        def task_head():
            return nn.Sequential(
                nn.Linear(d_model, 64),
                nn.ReLU(),
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
        self.stability_head = task_head()
        self.affinity_head = task_head()
        self.fragility_head = task_head()
        self.structure_head = task_head()
        self.immune_head = task_head()

    def forward(self, seq_onehot, context_vec):
        # seq_onehot: [B, L, 4]
        x = seq_onehot.permute(0,2,1)  # [B, 4, L]
        x = self.cnn(x)               # [B, d_model, L]
        x = x.permute(0,2,1)          # [B, L, d_model]
        x = self.transformer(x)       # [B, L, d_model]
        seq_embed = x.mean(dim=1)     # [B, d_model]

        cont = self.context_mlp(context_vec)  # [B, d_model]
        fused = torch.cat([seq_embed, cont], dim=1)  # [B, 2*d_model]
        fused = self.shared_fusion(fused)            # [B, d_model]

        out1 = self.stability_head(fused).squeeze(-1)
        out2 = self.affinity_head(fused).squeeze(-1)
        out3 = self.fragility_head(fused).squeeze(-1)
        out4 = self.structure_head(fused).squeeze(-1)
        out5 = self.immune_head(fused).squeeze(-1)
        out = torch.stack([out1, out2, out3, out4, out5], dim=1)
        return out

# -------------------------
# Training loop
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiTaskTelomereModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Weighted MSE on the 5 tasks (we may prioritize fragility/immune slightly)
task_weights = torch.tensor([1.0, 1.0, 1.2, 0.8, 1.2], device=device)

def compute_loss(preds, targets):
    mse = (preds - targets).pow(2)
    weighted = mse * task_weights
    return weighted.mean()

EPOCHS = 12
train_history = {"train_loss": [], "val_loss": []}

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    for seq_batch, ctx_batch, tgt_batch in train_loader:
        seq_batch = seq_batch.to(device)
        ctx_batch = ctx_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        preds = model(seq_batch, ctx_batch)
        loss = compute_loss(preds, tgt_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * seq_batch.size(0)
    train_loss = running_loss / len(train_loader.dataset)
    model.eval()
    val_running = 0.0
    with torch.no_grad():
        for seq_batch, ctx_batch, tgt_batch in val_loader:
            seq_batch = seq_batch.to(device)
            ctx_batch = ctx_batch.to(device)
            tgt_batch = tgt_batch.to(device)
            preds = model(seq_batch, ctx_batch)
            loss = compute_loss(preds, tgt_batch)
            val_running += loss.item() * seq_batch.size(0)
    val_loss = val_running / len(val_loader.dataset)
    train_history["train_loss"].append(train_loss)
    train_history["val_loss"].append(val_loss)
    print(f"Epoch {epoch:02d}  Train Loss: {train_loss:.5f}  Val Loss: {val_loss:.5f}")

# Save model and training history
model_path = OUTDIR / "telomere_multitask_model.pth"
torch.save(model.state_dict(), model_path)
np.save(OUTDIR / "training_history.npy", np.array([train_history["train_loss"], train_history["val_loss"]]))

# Quick evaluation: compute MAE per task on validation set
model.eval()
mae = np.zeros(5)
count = 0
with torch.no_grad():
    for seq_batch, ctx_batch, tgt_batch in val_loader:
        seq_batch = seq_batch.to(device)
        ctx_batch = ctx_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        preds = model(seq_batch, ctx_batch).cpu().numpy()
        mae += np.abs(preds - tgt_batch.numpy()).sum(axis=0)
        count += preds.shape[0]
mae = mae / count

metrics = {
    "MAE_Stability": float(mae[0]),
    "MAE_Affinity": float(mae[1]),
    "MAE_Fragility": float(mae[2]),
    "MAE_SecondaryStructure": float(mae[3]),
    "MAE_Immune": float(mae[4])
}
print("Validation MAE per task:", metrics)

# Save a small sample of the dataset for display
sample_df = df.sample(n=12, random_state=SEED).reset_index(drop=True)
sample_path = OUTDIR / "synthetic_dataset_sample.csv"
sample_df.to_csv(sample_path, index=False)

# Display sample dataframe if UI tool available
if display_dataframe_to_user is not None:
    display_dataframe_to_user("Synthetic Telomere Dataset Sample", sample_df)

# Plot training curves
plt.figure(figsize=(6,4))
plt.plot(train_history["train_loss"], label="train_loss")
plt.plot(train_history["val_loss"], label="val_loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss Curve")
plt.tight_layout()
plot_path = OUTDIR / "training_loss.png"
plt.savefig(plot_path)
plt.close()

# Save full dataset
# (already saved as synthetic_telomere_dataset.csv)
print("\nSaved files:")
print(" - Dataset CSV:", str(csv_path))
print(" - Model:", str(model_path))
print(" - Training history:", str(OUTDIR / 'training_history.npy'))
print(" - Training loss plot:", str(plot_path))

# Provide download links (UI will show files under /mnt/data)
print("\nDownload links (use in the chat UI):")
print(f"[Download synthetic dataset] (sandbox:{csv_path})")
print(f"[Download trained model] (sandbox:{model_path})")
print(f"[Download training loss plot] (sandbox:{plot_path})")

# End of pipeline code.

Epoch 01  Train Loss: 0.00341  Val Loss: 0.00123
Epoch 02  Train Loss: 0.00127  Val Loss: 0.00110
Epoch 03  Train Loss: 0.00105  Val Loss: 0.00093
Epoch 04  Train Loss: 0.00096  Val Loss: 0.00089
Epoch 05  Train Loss: 0.00091  Val Loss: 0.00090
Epoch 06  Train Loss: 0.00090  Val Loss: 0.00087
Epoch 07  Train Loss: 0.00089  Val Loss: 0.00086
Epoch 08  Train Loss: 0.00088  Val Loss: 0.00084
Epoch 09  Train Loss: 0.00087  Val Loss: 0.00088
Epoch 10  Train Loss: 0.00086  Val Loss: 0.00092
Epoch 11  Train Loss: 0.00085  Val Loss: 0.00086
Epoch 12  Train Loss: 0.00085  Val Loss: 0.00087
Validation MAE per task: {'MAE_Stability': 0.01740162990987301, 'MAE_Affinity': 0.02459146432578564, 'MAE_Fragility': 0.02554728798568249, 'MAE_SecondaryStructure': 0.01979451548308134, 'MAE_Immune': 0.02517802432179451}

Saved files:
 - Dataset CSV: /mnt/data/synthetic_telomere_dataset.csv
 - Model: /mnt/data/telomere_multitask_model.pth
 - Training history: /mnt/data/training_history.npy
 - Training loss pl

In [22]:
# Creating a full training pipeline + synthetic dataset generator for the
# multi-task, multi-head telomere predictor (simulation-only, safe).
#
# This notebook:
# 1) Generates a synthetic dataset of telomere-like sequences + context features
# 2) Builds a multi-task model (CNN + Transformer + task heads)
# 3) Trains the model on synthetic targets (stability, affinity, fragility, structure, immune)
# 4) Evaluates and saves model + dataset
# 5) Displays sample data and training curves
#
# Outputs saved to /mnt/data:
# - /mnt/data/synthetic_telomere_dataset.csv
# - /mnt/data/telomere_multitask_model.pth
# - /mnt/data/training_history.npy
#
# NOTE: This code is purely computational and does NOT produce any wet-lab or
# actionable biological instructions. It's safe to run and intended for in-silico simulation.

# Execute the training pipeline
import os
import random
import math
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import matplotlib.pyplot as plt

# For interactive dataframe display in the notebook UI
try:
    from caas_jupyter_tools import display_dataframe_to_user
except Exception:
    display_dataframe_to_user = None

# reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUTDIR = Path("/mnt/data")
OUTDIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Synthetic dataset generator
# -------------------------
BASE_REPEAT = "TTAGGG"  # reference repeat (human-like) used only as inspiration
NUCLEOTIDES = ['A', 'T', 'G', 'C']

def mutate_repeat_unit(base, mutation_rate=0.1):
    """Mutate a base repeat unit by randomly substituting bases with probability mutation_rate."""
    out = []
    for ch in base:
        if random.random() < mutation_rate:
            out.append(random.choice(NUCLEOTIDES))
        else:
            out.append(ch)
    return ''.join(out)

def build_sequence(repeat_unit, repeats):
    return repeat_unit * repeats

def sequence_entropy(seq):
    counts = {n: seq.count(n) for n in NUCLEOTIDES}
    probs = np.array(list(counts.values())) / len(seq)
    probs = probs[probs > 0]
    return -np.sum(probs * np.log2(probs))

def g_rich_runs_score(seq):
    # proxy for G-quadruplex-like propensity: count occurrences of "GGG" and longer runs
    score = 0
    run = 0
    for ch in seq:
        if ch == 'G':
            run += 1
        else:
            if run >= 3:
                score += run
            run = 0
    if run >= 3:
        score += run
    return score / max(1, len(seq)/6)  # normalize by length/6

def palindromic_score(seq, k=6):
    # count small palindromic windows as proxy for hairpins
    score = 0
    for i in range(len(seq)-k+1):
        window = seq[i:i+k]
        # simple reverse complement check
        rc = window[::-1].translate(str.maketrans("ATGC","TACG"))
        if rc == window:
            score += 1
    return score / max(1, len(seq)/k)

def motif_purity_score(seq, repeat_unit):
    # fraction of sequence that matches the repeated repeat_unit perfectly when tiled
    L = len(repeat_unit)
    perfect = 0
    for i in range(0, len(seq), L):
        block = seq[i:i+L]
        if block == repeat_unit:
            perfect += 1
    return perfect * L / len(seq)

def synth_targets_from_sequence(seq, repeat_unit, context_vec):
    """
    Heuristic synthetic target generator. Produces five values in [0,1]:
    StabilityScore, TelomeraseAffinityScore, FragilityRiskScore,
    SecondaryStructureScore, ImmuneTriggerScore
    """
    length = len(seq)
    repeats = length / max(1, len(repeat_unit))
    purity = motif_purity_score(seq, repeat_unit)  # 0..1
    entropy = sequence_entropy(seq) / 2.0  # normalize roughly
    gscore = g_rich_runs_score(seq)  # normalized proxy
    palin = palindromic_score(seq)
    variability = 1.0 - purity

    # Context modifiers (safe, synthetic)
    debris, phag_activity, inflammation, replication_stress = context_vec[:4]

    # Stability: higher with purity, length; lowered by gscore/palin/replication stress
    stability_raw = 0.5 * purity + 0.4 * (math.tanh((repeats-8)/8)+1)/2 - 0.3 * gscore - 0.2 * palin - 0.2 * replication_stress
    stability = float(1/(1+math.exp(-stability_raw)))  # sigmoid-ish to 0..1

    # Telomerase affinity: higher for periodicity & moderate G-richness; penalize high entropy
    affinity_raw = 0.6 * purity + 0.2 * (math.tanh(gscore)+1)/2 - 0.3 * entropy + 0.2 * (1-replication_stress)
    affinity = float(1/(1+math.exp(-affinity_raw)))

    # Fragility risk: increased by variability, palindromes, and high replication stress
    frag_raw = 0.5 * variability + 0.4 * palin + 0.3 * replication_stress + 0.2 * entropy
    frag = float(1/(1+math.exp(- (frag_raw - 0.5) )))

    # Secondary structure score: proxy from gscore + palin
    sec_raw = 0.6 * gscore + 0.4 * palin
    sec = float(1/(1+math.exp(- (sec_raw - 0.3) )))

    # Immune trigger: higher entropy and exposed non-native motifs, increased by inflammation
    imm_raw = 0.5 * entropy + 0.3 * variability + 0.4 * inflammation
    imm = float(1/(1+math.exp(- (imm_raw - 0.2) )))

    # Add small noise
    stability = min(max(stability + np.random.normal(0, 0.02), 0.0), 1.0)
    affinity = min(max(affinity + np.random.normal(0, 0.03), 0.0), 1.0)
    frag = min(max(frag + np.random.normal(0, 0.03), 0.0), 1.0)
    sec = min(max(sec + np.random.normal(0, 0.02), 0.0), 1.0)
    imm = min(max(imm + np.random.normal(0, 0.03), 0.0), 1.0)

    return stability, affinity, frag, sec, imm

# Generate dataset parameters
N_SAMPLES = 4000
MIN_REPEATS = 5
MAX_REPEATS = 60
MAX_SEQ_LEN = MAX_REPEATS * len(BASE_REPEAT)  # for padding

records = []
sequences = []

for i in range(N_SAMPLES):
    # choose a repeat unit by mutating base with variable mutation rate
    mut_rate = random.choice([0.0, 0.05, 0.1, 0.2])
    ru = mutate_repeat_unit(BASE_REPEAT, mutation_rate=mut_rate)
    repeats = random.randint(MIN_REPEATS, MAX_REPEATS)
    seq = build_sequence(ru, repeats)

    # randomly insert occasional random short motifs to increase diversity
    if random.random() < 0.15:
        pos = random.randint(0, max(0, len(seq)-10))
        insert = ''.join(random.choices(NUCLEOTIDES, k=6))
        seq = seq[:pos] + insert + seq[pos:]

    # context features (synthetic)
    debris = random.random()  # 0-1
    phag_activity = random.random()
    inflammation = random.random()
    replication_stress = random.random()
    # additional context features (stem cell fraction, local nutrients, oxygen proxy, time)
    stem_frac = random.random()
    nutrients = random.random()
    oxygen = random.random()
    time_of_day = random.random()

    context_vec = [debris, phag_activity, inflammation, replication_stress,
                   stem_frac, nutrients, oxygen, time_of_day, random.random(), random.random(), 0.0, 0.0][:12]

    stability, affinity, frag, sec, imm = synth_targets_from_sequence(seq, ru, context_vec)

    records.append({
        "sequence": seq,
        "repeat_unit": ru,
        "repeats": repeats,
        "length": len(seq),
        "purity": motif_purity_score(seq, ru),
        "entropy": sequence_entropy(seq),
        "gscore": g_rich_runs_score(seq),
        "palin": palindromic_score(seq),
        "debris": debris,
        "phag_activity": phag_activity,
        "inflammation": inflammation,
        "replication_stress": replication_stress,
        "stem_frac": stem_frac,
        "nutrients": nutrients,
        "oxygen": oxygen,
        "time_of_day": time_of_day,
        "StabilityScore": stability,
        "TelomeraseAffinityScore": affinity,
        "FragilityRiskScore": frag,
        "SecondaryStructureScore": sec,
        "ImmuneTriggerScore": imm
    })
    sequences.append(seq)

df = pd.DataFrame.from_records(records)
# Save CSV for external inspection
csv_path = OUTDIR / "synthetic_telomere_dataset.csv"
df.to_csv(csv_path, index=False)

# -------------------------
# Dataset and dataloader
# -------------------------
# One-hot encode sequences and pad to MAX_SEQ_LEN
def one_hot_encode_seq(seq, max_len=MAX_SEQ_LEN):
    mapping = {'A':0,'T':1,'G':2,'C':3}
    arr = np.zeros((max_len, 4), dtype=np.float32)
    for i, ch in enumerate(seq[:max_len]):
        if ch in mapping:
            arr[i, mapping[ch]] = 1.0
    return arr

# Build tensors
seq_tensors = np.stack([one_hot_encode_seq(s) for s in df['sequence'].values])
context_features = df[["debris","phag_activity","inflammation","replication_stress",
                       "stem_frac","nutrients","oxygen","time_of_day","purity","entropy",
                       "gscore","palin"]].values.astype(np.float32)
targets = df[["StabilityScore","TelomeraseAffinityScore","FragilityRiskScore",
              "SecondaryStructureScore","ImmuneTriggerScore"]].values.astype(np.float32)

# Simple Dataset
class TelomereDataset(Dataset):
    def __init__(self, seq_tensor, context, targets):
        self.seq = torch.tensor(seq_tensor)  # [N, L, 4]
        self.context = torch.tensor(context)
        self.targets = torch.tensor(targets)

    def __len__(self):
        return self.seq.shape[0]

    def __getitem__(self, idx):
        return self.seq[idx], self.context[idx], self.targets[idx]

dataset = TelomereDataset(seq_tensors, context_features, targets)

# Train/test split
n_train = int(len(dataset)*0.8)
n_val = len(dataset) - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(SEED))

BATCH_SIZE = 32
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# -------------------------
# Model definition (multi-task + multi-head)
# -------------------------
class MultiTaskTelomereModel(nn.Module):
    def __init__(self, seq_len=MAX_SEQ_LEN, d_model=96, num_heads=6, depth=2, num_context=12):
        super().__init__()
        self.d_model = d_model
        # CNN
        self.cnn = nn.Sequential(
            nn.Conv1d(4, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, d_model, kernel_size=5, padding=2),
            nn.ReLU()
        )
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=4*d_model, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        # Context MLP
        self.context_mlp = nn.Sequential(
            nn.Linear(num_context, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model),
            nn.ReLU()
        )
        # Shared fusion
        self.shared_fusion = nn.Sequential(
            nn.Linear(2*d_model, d_model),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        # Task heads
        def task_head():
            return nn.Sequential(
                nn.Linear(d_model, 64),
                nn.ReLU(),
                nn.Linear(64, 1),
                nn.Sigmoid()
            )
        self.stability_head = task_head()
        self.affinity_head = task_head()
        self.fragility_head = task_head()
        self.structure_head = task_head()
        self.immune_head = task_head()

    def forward(self, seq_onehot, context_vec):
        # seq_onehot: [B, L, 4]
        x = seq_onehot.permute(0,2,1)  # [B, 4, L]
        x = self.cnn(x)               # [B, d_model, L]
        x = x.permute(0,2,1)          # [B, L, d_model]
        x = self.transformer(x)       # [B, L, d_model]
        seq_embed = x.mean(dim=1)     # [B, d_model]

        cont = self.context_mlp(context_vec)  # [B, d_model]
        fused = torch.cat([seq_embed, cont], dim=1)  # [B, 2*d_model]
        fused = self.shared_fusion(fused)            # [B, d_model]

        out1 = self.stability_head(fused).squeeze(-1)
        out2 = self.affinity_head(fused).squeeze(-1)
        out3 = self.fragility_head(fused).squeeze(-1)
        out4 = self.structure_head(fused).squeeze(-1)
        out5 = self.immune_head(fused).squeeze(-1)
        out = torch.stack([out1, out2, out3, out4, out5], dim=1)
        return out

# -------------------------
# Training loop
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MultiTaskTelomereModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Weighted MSE on the 5 tasks (we may prioritize fragility/immune slightly)
task_weights = torch.tensor([1.0, 1.0, 1.2, 0.8, 1.2], device=device)

def compute_loss(preds, targets):
    mse = (preds - targets).pow(2)
    weighted = mse * task_weights
    return weighted.mean()

EPOCHS = 12
train_history = {"train_loss": [], "val_loss": []}

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    for seq_batch, ctx_batch, tgt_batch in train_loader:
        seq_batch = seq_batch.to(device)
        ctx_batch = ctx_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        preds = model(seq_batch, ctx_batch)
        loss = compute_loss(preds, tgt_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * seq_batch.size(0)
    train_loss = running_loss / len(train_loader.dataset)
    model.eval()
    val_running = 0.0
    with torch.no_grad():
        for seq_batch, ctx_batch, tgt_batch in val_loader:
            seq_batch = seq_batch.to(device)
            ctx_batch = ctx_batch.to(device)
            tgt_batch = tgt_batch.to(device)
            preds = model(seq_batch, ctx_batch)
            loss = compute_loss(preds, tgt_batch)
            val_running += loss.item() * seq_batch.size(0)
    val_loss = val_running / len(val_loader.dataset)
    train_history["train_loss"].append(train_loss)
    train_history["val_loss"].append(val_loss)
    print(f"Epoch {epoch:02d}  Train Loss: {train_loss:.5f}  Val Loss: {val_loss:.5f}")

# Save model and training history
model_path = OUTDIR / "telomere_multitask_model.pth"
torch.save(model.state_dict(), model_path)
np.save(OUTDIR / "training_history.npy", np.array([train_history["train_loss"], train_history["val_loss"]]))

# Quick evaluation: compute MAE per task on validation set
model.eval()
mae = np.zeros(5)
count = 0
with torch.no_grad():
    for seq_batch, ctx_batch, tgt_batch in val_loader:
        seq_batch = seq_batch.to(device)
        ctx_batch = ctx_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        preds = model(seq_batch, ctx_batch).cpu().numpy()
        mae += np.abs(preds - tgt_batch.numpy()).sum(axis=0)
        count += preds.shape[0]
mae = mae / count

metrics = {
    "MAE_Stability": float(mae[0]),
    "MAE_Affinity": float(mae[1]),
    "MAE_Fragility": float(mae[2]),
    "MAE_SecondaryStructure": float(mae[3]),
    "MAE_Immune": float(mae[4]),
}
print("Validation MAE per task:", metrics)

# Save a small sample of the dataset for display
sample_df = df.sample(n=12, random_state=SEED).reset_index(drop=True)
sample_path = OUTDIR / "synthetic_dataset_sample.csv"
sample_df.to_csv(sample_path, index=False)

# Display sample dataframe if UI tool available
if display_dataframe_to_user is not None:
    display_dataframe_to_user("Synthetic Telomere Dataset Sample", sample_df)

# Plot training curves
plt.figure(figsize=(6,4))
plt.plot(train_history["train_loss"], label="train_loss")
plt.plot(train_history["val_loss"], label="val_loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Loss Curve")
plt.tight_layout()
plot_path = OUTDIR / "training_loss.png"
plt.savefig(plot_path)
plt.close()

# Save full dataset
# (already saved as synthetic_telomere_dataset.csv)
print("\nSaved files:")
print(" - Dataset CSV:", str(csv_path))
print(" - Model:", str(model_path))
print(" - Training history:", str(OUTDIR / 'training_history.npy'))
print(" - Training loss plot:", str(plot_path))

# Provide download links (UI will show files under /mnt/data)
print("\nDownload links (use in the chat UI):")
print(f"[Download synthetic dataset] (sandbox:{csv_path})")
print(f"[Download trained model] (sandbox:{model_path})")
print(f"[Download training loss plot] (sandbox:{plot_path})")

# End of pipeline code.


Epoch 01  Train Loss: 0.00341  Val Loss: 0.00123
Epoch 02  Train Loss: 0.00127  Val Loss: 0.00110
Epoch 03  Train Loss: 0.00105  Val Loss: 0.00093
Epoch 04  Train Loss: 0.00096  Val Loss: 0.00089
Epoch 05  Train Loss: 0.00091  Val Loss: 0.00090
Epoch 06  Train Loss: 0.00090  Val Loss: 0.00087
Epoch 07  Train Loss: 0.00089  Val Loss: 0.00086
Epoch 08  Train Loss: 0.00088  Val Loss: 0.00084
Epoch 09  Train Loss: 0.00087  Val Loss: 0.00088
Epoch 10  Train Loss: 0.00086  Val Loss: 0.00092
Epoch 11  Train Loss: 0.00085  Val Loss: 0.00086
Epoch 12  Train Loss: 0.00085  Val Loss: 0.00087
Validation MAE per task: {'MAE_Stability': 0.01740162990987301, 'MAE_Affinity': 0.02459146432578564, 'MAE_Fragility': 0.02554728798568249, 'MAE_SecondaryStructure': 0.01979451548308134, 'MAE_Immune': 0.02517802432179451}

Saved files:
 - Dataset CSV: /mnt/data/synthetic_telomere_dataset.csv
 - Model: /mnt/data/telomere_multitask_model.pth
 - Training history: /mnt/data/training_history.npy
 - Training loss pl