# Dyad Position Predictor
Train a neural network to predict dyad positions on a per-position basis from encoded DNA sequences (0-7).

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from pathlib import Path

from ChromatinFibers import simulate_chromatin_fibers, SequencePlotter  

In [2]:
n_reads = 50000
n_bp = 10000

filename = rf'data/LLM models/dyad_predictor {n_reads}_{n_bp}.pt'
plotter=SequencePlotter()

In [None]:
# Gnerate data if filename does not exist
data_filename = filename.replace('.pt', '.npz')
if Path(data_filename).exists() is False:
    dyad_positions, _, encoded_seq = simulate_chromatin_fibers(n_samples=n_reads, length=n_bp)

    np.savez_compressed(
        data_filename,
        dyad_positions=np.array(dyad_positions, dtype=object),
        encoded_seq=np.array(encoded_seq, dtype=object), 
    )
    print("Data saved:", data_filename)  
else:
    print("Data already exists, skipping simulation.")
    data = np.load(data_filename, allow_pickle=True)
    dyad_positions = data["dyad_positions"]
    encoded_seq = data["encoded_seq"]

  5%|▍         | 2349/50000 [1:18:08<36:19:06,  2.74s/it]

## Step 1: Define the Model Architecture

In [None]:
class DyadPredictor(nn.Module):
    """Per-position dyad predictor using Conv1d and bidirectional context."""
    def __init__(self, vocab_size=8, embedding_dim=16, hidden_dim=64, num_layers=2, dropout=0.3):
        super().__init__()
        self.vocab_size = vocab_size
        
        # Embedding layer (map 0-7 to dense vectors)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Conv blocks for local context
        self.conv1 = nn.Conv1d(embedding_dim, hidden_dim, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        
        # BiLSTM for global context
        self.lstm = nn.LSTM(hidden_dim, hidden_dim // 2, num_layers=num_layers, 
                           bidirectional=True, dropout=dropout if num_layers > 1 else 0, 
                           batch_first=True)
        
        # Output head
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, x):
        """Forward pass.
        Args:
            x: (batch_size, seq_len) - encoded sequence
        Returns:
            logits: (batch_size, seq_len, 1) - per-position dyad logits
        """
        # Embedding: (batch, seq_len) -> (batch, seq_len, embed_dim)
        x = self.embedding(x)
        
        # Conv blocks: (batch, seq_len, embed_dim) -> (batch, embed_dim, seq_len)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        
        # Back to (batch, seq_len, hidden_dim)
        x = x.transpose(1, 2)
        
        # LSTM: (batch, seq_len, hidden_dim) -> (batch, seq_len, hidden_dim)
        x, _ = self.lstm(x)
        
        # Per-position classification: (batch, seq_len, hidden_dim) -> (batch, seq_len, 1)
        logits = self.fc(x)
        
        return logits

## Step 2: Define Custom Dataset

In [None]:
class DyadDataset(Dataset):
    """Dataset for dyad position prediction."""
    def __init__(self, data_list, max_seq_len=None):
        """
        Args:
            data_list: list of tuples (dyad_positions, encoded_sequence)
                - dyad_positions: list of integers (positions with dyads)
                - encoded_sequence: list/array of ints (0-7)
            max_seq_len: optional, pad/truncate sequences to this length
        """
        self.data = data_list
        self.max_seq_len = max_seq_len or max(len(seq) for _, seq in data_list)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        dyad_pos, encoded_seq = self.data[idx]
        seq_len = len(encoded_seq)
        
        # Create binary label: 1 if dyad at position, 0 otherwise
        label = np.zeros(seq_len, dtype=np.float32)
        for pos in dyad_pos:
            if 0 <= pos < seq_len:
                label[pos] = 1.0
        
        # Convert to tensors
        seq_tensor = torch.LongTensor(encoded_seq)
        label_tensor = torch.FloatTensor(label)
        
        # Pad/truncate to max_seq_len
        if seq_len < self.max_seq_len:
            pad_len = self.max_seq_len - seq_len
            seq_tensor = torch.nn.functional.pad(seq_tensor, (0, pad_len), value=0)
            label_tensor = torch.nn.functional.pad(label_tensor, (0, pad_len), value=-1)  # -1 for padding
        elif seq_len > self.max_seq_len:
            seq_tensor = seq_tensor[:self.max_seq_len]
            label_tensor = label_tensor[:self.max_seq_len]
        
        return seq_tensor, label_tensor

## Step 3: Load Your Data
Replace this with your actual data loading logic.

In [None]:
# EXAMPLE: Generate synthetic data
# def generate_synthetic_data(n_samples=500, length = 10_000):
#     """Generate synthetic dyad/sequence pairs for demonstration."""
#     dyad_positions, _, encoded_seq = simulate_chromatin_fibers(n_samples=n_samples, length=10_000)
#     data = [(dyads, seq) for dyads, seq in zip(dyad_positions, encoded_seq)]
#     return data


# # Generate or load your data here
# print("Loading data...")
# all_data = generate_synthetic_data(n_samples=n_reads, length=n_bp)  # Replace with your data loading

all_data = [(dyads, seq) for dyads, seq in zip(dyad_positions, encoded_seq)]

# Split into train/val/test
n_train = int(0.7 * len(all_data))
n_val = int(0.15 * len(all_data))

train_data = all_data[:n_train]
val_data = all_data[n_train:n_train+n_val]
test_data = all_data[n_train+n_val:]

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

## Step 4: Create DataLoaders

In [None]:
# Create datasets and dataloaders
train_dataset = DyadDataset(train_data)
val_dataset = DyadDataset(val_data, max_seq_len=train_dataset.max_seq_len)
test_dataset = DyadDataset(test_data, max_seq_len=train_dataset.max_seq_len)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Max sequence length: {train_dataset.max_seq_len}")
print(f"DataLoaders created: train={len(train_loader)} batches, val={len(val_loader)}, test={len(test_loader)}")

## Step 5: Initialize Model and Training Components

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = DyadPredictor(vocab_size=8, embedding_dim=16, hidden_dim=64, num_layers=2, dropout=0.3)
model = model.to(device)

# Loss function: BCEWithLogitsLoss (combines sigmoid + BCE)
pos_weight = torch.tensor([10.0]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
# ReduceLROnPlateau without 'verbose' for compatibility with older PyTorch versions
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Compute a pos_weight from training data (ignore padding where label == -1)

def compute_pos_weight_from_dataset(dataset):
    pos = 0
    neg = 0
    for seq_tensor, label_tensor in dataset:
        arr = np.asarray(label_tensor)
        mask = arr >= 0  # only count real positions, not padding (-1)
        if mask.sum() == 0:
            continue
        pos += int((arr[mask] == 1).sum())
        neg += int((arr[mask] == 0).sum())
    # Avoid division by zero
    if pos == 0:
        print("Warning: no positive examples found in train_dataset; setting pos=1 to avoid division by zero")
        pos = 1
    # pos_weight used by BCEWithLogitsLoss scales the positive class loss: pos_weight = neg/pos
    pw = float(neg) / float(pos)
    return torch.tensor([pw], dtype=torch.float32)

# Compute and clamp pos_weight to a reasonable range to avoid extreme scaling
pos_weight = compute_pos_weight_from_dataset(train_dataset).to(device)
pos_weight = torch.clamp(pos_weight, min=1.0, max=100.0)
print(f"Computed pos_weight: {pos_weight.item():.4f}")

# Define the weighted criterion (reduction='none' so you can mask padding later)
criterion_weighted = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')

# Expose for REPL visibility
criterion_weighted


## Step 6: Training Loop

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    for seq_batch, label_batch in tqdm(train_loader, desc="Training"):
        seq_batch = seq_batch.to(device)  # (batch, seq_len)
        label_batch = label_batch.to(device)  # (batch, seq_len)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(seq_batch)  # (batch, seq_len, 1)
        logits = logits.squeeze(-1)  # (batch, seq_len)
        
        # Compute loss (ignore padding positions with label=-1)
        loss_per_pos = criterion(logits, label_batch)  # (batch, seq_len)
        mask = (label_batch >= 0).float()  # Mask out padding
        loss = (loss_per_pos * mask).sum() / mask.sum().clamp(min=1)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def validate(model, val_loader, criterion, device):
    """Validate model."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for seq_batch, label_batch in tqdm(val_loader, desc="Validating"):
            seq_batch = seq_batch.to(device)
            label_batch = label_batch.to(device)
            
            logits = model(seq_batch)
            logits = logits.squeeze(-1)
            
            loss_per_pos = criterion(logits, label_batch)
            mask = (label_batch >= 0).float()
            loss = (loss_per_pos * mask).sum() / mask.sum().clamp(min=1)
            
            total_loss += loss.item()
            
            # Store for metrics
            probs = torch.sigmoid(logits)
            all_preds.append(probs.cpu().numpy())
            all_labels.append(label_batch.cpu().numpy())
    
    return total_loss / len(val_loader), all_preds, all_labels

# Training loop
epochs = 50
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience = 5
patience_counter = 0

print("Starting training...")
for epoch in range(epochs):
   
    train_loss = train_epoch(model, train_loader, criterion_weighted, optimizer, device)   
    val_loss, _, _ = validate(model, val_loader, criterion_weighted, device)

    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), filename)
        print(f"  Saved best model to {filename}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    scheduler.step(val_loss)

print("Training completed!")

## Step 7: Plot Training History

In [None]:
plt.figure(figsize=(15, 3))
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.ylim(0, max(max(train_losses), max(val_losses)) * 1.1)
plt.xlim(0, 50)
plt.tight_layout()
plotter.add_caption(f'Training history for {n_reads} reads and {n_bp} bp.')

## Step 8: Evaluate on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load(filename))

# Evaluate on test set
model.eval()
# test_loss, test_preds, test_labels = validate(model, test_loader, criterion, device)
test_loss, test_preds, test_labels = validate(model, test_loader, criterion_weighted, device)

print(f"Test Loss: {test_loss:.4f}")

# Compute metrics
all_preds_flat = np.concatenate(test_preds).ravel()
all_labels_flat = np.concatenate(test_labels).ravel()

# Remove padding positions
valid_mask = all_labels_flat >= 0
all_preds_flat = all_preds_flat[valid_mask]
all_labels_flat = all_labels_flat[valid_mask]

# Threshold at 0.3 (lower for imbalanced data with weighted loss)
predictions_binary = (all_preds_flat >= 0.3).astype(int)


# Metrics: try sklearn, fallback to numpy implementations if missing
try:
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    sklearn_available = True
except Exception as e:
    sklearn_available = False
    import warnings
    warnings.warn("scikit-learn not installed; using numpy fallback for basic metrics. Install with: pip install scikit-learn")
    def accuracy_score(y_true, y_pred):
        y_true = np.asarray(y_true)
        y_pred = np.asarray(y_pred)
        return float((y_true == y_pred).mean())
    def precision_score(y_true, y_pred, zero_division=0):
        y_true = np.asarray(y_true)
        y_pred = np.asarray(y_pred)
        tp = int(((y_true == 1) & (y_pred == 1)).sum())
        fp = int(((y_true == 0) & (y_pred == 1)).sum())
        denom = tp + fp
        if denom == 0:
            return float(zero_division)
        return float(tp / denom)
    def recall_score(y_true, y_pred, zero_division=0):
        y_true = np.asarray(y_true)
        y_pred = np.asarray(y_pred)
        tp = int(((y_true == 1) & (y_pred == 1)).sum())
        fn = int(((y_true == 1) & (y_pred == 0)).sum())
        denom = tp + fn
        if denom == 0:
            return float(zero_division)
        return float(tp / denom)
    def f1_score(y_true, y_pred, zero_division=0):
        p = precision_score(y_true, y_pred, zero_division)
        r = recall_score(y_true, y_pred, zero_division)
        if (p + r) == 0:
            return float(zero_division)
        return 2 * (p * r) / (p + r)
    def roc_auc_score(y_true, y_score):
        y_true = np.asarray(y_true)
        y_score = np.asarray(y_score)
        # require both classes present
        if len(np.unique(y_true)) < 2:
            return float('nan')
        # Sort scores descending
        desc = np.argsort(-y_score)
        y_true_sorted = y_true[desc]
        # cumulative true/false positives
        tp = np.cumsum(y_true_sorted == 1)
        fp = np.cumsum(y_true_sorted == 0)
        tp_total = tp[-1]
        fp_total = fp[-1]
        if tp_total == 0 or fp_total == 0:
            return float('nan')
        tpr = np.concatenate([[0.0], tp / tp_total])
        fpr = np.concatenate([[0.0], fp / fp_total])
        return float(np.trapz(tpr, fpr))

# Calculate metrics
acc = accuracy_score(all_labels_flat, predictions_binary)
prec = precision_score(all_labels_flat, predictions_binary, zero_division=0)
rec = recall_score(all_labels_flat, predictions_binary, zero_division=0)
f1 = f1_score(all_labels_flat, predictions_binary, zero_division=0)
auc = roc_auc_score(all_labels_flat, all_preds_flat)

print(f"\nTest Metrics:")
print(f"  Accuracy:  {acc:.4f}")
print(f"  Precision: {prec:.4f}")
print(f"  Recall:    {rec:.4f}")
print(f"  F1-Score:  {f1:.4f}")
print(f"  ROC-AUC:   {auc:.4f}")

## Step 9: Save Model for Later Use

In [None]:
# Save model and configuration
model_config = {
    '_vocab_size': 8,
    '_embedding_dim': 16,
    '_hidden_dim': 64,
    '_num_layers': 2,
    '_dropout': 0.3,
    'max_seq_len': train_dataset.max_seq_len,
    'n_reads': n_reads,
    'n_bp': n_bp
}

# Save config
with open(Path(filename).with_suffix('.json'), 'w') as f:
    json.dump(model_config, f, indent=2)

# Model is already saved as 'best_dyad_predictor.pt'
print("Model and config saved!")
print(f"Config: {model_config}")

## Step 10: Load Saved Model

In [None]:
def predict_dyads(model, encoded_sequence, threshold=0.2, device='cpu'):
    """
    Predict dyad positions for a single sequence.
    
    Args:
        model: trained DyadPredictor
        encoded_sequence: list/array of integers (0-7)
        threshold: probability threshold for positive class (default 0.5)
        device: torch device
    
    Returns:
        dyad_positions: list of predicted dyad positions
        probabilities: array of per-position probabilities
    """

    model.eval()
    with torch.no_grad():
        seq_tensor = torch.LongTensor(encoded_sequence).unsqueeze(0).to(device)
        logits = model(seq_tensor)
        probs = torch.sigmoid(logits).squeeze().cpu().numpy()
    
    dyad_positions = np.where(probs >= threshold)[0].tolist()
    return dyad_positions, probs


# Load config
with open(Path(filename).with_suffix('.json'), 'r') as f:
    config = json.load(f)


# Create model
loaded_model = DyadPredictor(**{k[1:]: v for k, v in config.items() if k[0] == '_'})
loaded_model.load_state_dict(torch.load(filename))
loaded_model = loaded_model.to(device)
loaded_model.eval()

# Use for predictions
new_seq = [1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]

dyad_positions, _, encoded_seq = simulate_chromatin_fibers(n_samples=1, length=n_bp)

new_data = [(dyad_positions[0], encoded_seq[0])]
new_seq = new_data[0][1]

dyads, probs = predict_dyads(loaded_model, new_seq, device=device)

index = np.arange(len(new_seq))
nucs = np.zeros_like(index)
nucs[new_data[0][0]] = 1.0

for i in new_data[0][0]:
    nucs[i-65:i+65] = -1   # highlight nucleosome region

n_plots = 10
for i in range(n_plots):
    plt.figure(figsize=(12, 1))
    plt.xlabel("i (bp)")

    plt.vlines(new_data[0][0], ymin=-1, ymax=2, color='black',linestyles= 'dotted', alpha=1)
    plt.fill_between(index, probs, label='Predicted Dyad Probability', color='orange', alpha=0.5)
    plt.fill_between(index, nucs, color='blue', alpha=0.5, label='True Dyad Positions')


    methylations = np.zeros_like(index) -1
    methylations[index[new_seq>4]] = 1
    plt.plot(index, methylations, 'o', label='methylations', color='green', alpha=0.5, fillstyle='full', markersize=3)

    plt.xlim(i*len(index)//n_plots, (i+1)*len(index)//n_plots)
    plt.ylabel("Probability")
    plt.ylim(-0.1, 1.1)
    
    # plt.tight_layout
    plt.show()  