# Using DyadPredictorLLM Class

This notebook demonstrates how to use the `DyadPredictorLLM` class for training and applying dyad prediction models.

The class provides a clean interface for:
- Creating new models (simple or dilated)
- Training on HDF5 data files
- Loading pre-trained models from JSON configs
- Applying models to predict dyad positions
- Visualizing model architecture

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

# Add parent directory to path to import ChromatinFibers
sys.path.insert(0, '..')

from DyadPredictorLLM import DyadPredictorLLM
from ChromatinFibers import read_simulation_results
from Plotter import SequencePlotter

plotter = SequencePlotter()

## Option 1: Create and Train a New Model

You can create either a 'simple' or 'dilated' model architecture.

In [None]:
# Initialize a new simple model
predictor = DyadPredictorLLM()
predictor.init_model(
    model_type='simple',  # or 'dilated'
    embedding_dim=16,
    hidden_dim=64,
    num_layers=2,
    dropout=0.3
)

In [None]:
# Train the model
data_filename = r"data/LLM models/test.h5"

predictor.train(
    data_filename=data_filename,
    model_filename=r"data/LLM models/my_model.pt",  # optional, auto-generated if not provided
    epochs=50,
    batch_size=32,
    learning_rate=1e-3,
    patience=5,
    max_batches_per_epoch=100,  # Set to None for full dataset
    max_eval_batches=50,  # Set to None for full validation
)

In [None]:
# Plot training history
predictor.plot_training_history()

## Option 2: Load an Existing Model

Load a pre-trained model from a JSON configuration file.

In [None]:
# Load model from JSON config
predictor = DyadPredictorLLM()
predictor.load_from_json(r"data/LLM models/test_15000.json")

# The corresponding .pt file with weights is automatically loaded

## Visualize Model Architecture

See how data flows through the model layers.

In [None]:
predictor.visualize_model(sequence_length=100)

## Apply Model to Predict Dyad Positions

Use the trained model to predict dyad positions for a specific sequence.

In [None]:
# Apply model to a specific sample
result = predictor.apply(
    data_filename=r"data/LLM models/test.h5",
    index=0,  # Sample index in the HDF5 file
    threshold=0.3,  # Probability threshold for calling peaks
    return_dict=True  # Return full results dictionary
)

print(f"Predicted dyad positions: {result['predicted_dyads']}")
print(f"True dyad positions: {result['true_dyads']}")
print(f"Number of predictions: {len(result['predicted_dyads'])}")
print(f"Number of true dyads: {len(result['true_dyads'])}")

In [None]:
# Plot predictions vs ground truth
fig, axes = plt.subplots(3, 1, figsize=(15, 8))

# Plot 1: Predicted probabilities
axes[0].plot(result['dyad_probabilities'], linewidth=1)
axes[0].axhline(0.3, color='red', linestyle='--', label='Threshold')
axes[0].set_ylabel('Dyad Probability')
axes[0].set_title('Model Predictions')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Predicted dyad positions
pred_binary = np.zeros_like(result['dyad_probabilities'])
pred_binary[result['predicted_dyads']] = 1
axes[1].plot(pred_binary, linewidth=1, color='blue')
axes[1].set_ylabel('Predicted Dyad')
axes[1].set_ylim(-0.1, 1.1)
axes[1].set_title('Predicted Dyad Positions')
axes[1].grid(True, alpha=0.3)

# Plot 3: True dyad positions
true_binary = np.zeros_like(result['dyad_probabilities'])
true_binary[result['true_dyads']] = 1
axes[2].plot(true_binary, linewidth=1, color='green')
axes[2].set_ylabel('True Dyad')
axes[2].set_xlabel('Position (bp)')
axes[2].set_ylim(-0.1, 1.1)
axes[2].set_title('Ground Truth Dyad Positions')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plotter.add_caption("Dyad prediction results for a single sequence")
plt.show()

## Create a Dilated Model

For better multi-scale feature extraction, use the dilated model architecture.

In [None]:
# Create dilated model with larger capacity
predictor_dilated = DyadPredictorLLM()
predictor_dilated.init_model(
    model_type='dilated',
    embedding_dim=32,
    hidden_dim=128,
    num_layers=3,
    dropout=0.3,
    conv_dilations=(1, 2, 4, 8),  # Multi-scale receptive fields
    conv_kernel_size=7
)

# Train with the same data
# predictor_dilated.train(
#     data_filename=r"data/LLM models/test.h5",
#     epochs=50,
#     batch_size=32,
# )

## Batch Prediction

Apply the model to multiple sequences.

In [None]:
# Predict for multiple samples
n_samples = 5
results = []

for i in range(n_samples):
    result = predictor.apply(
        data_filename=r"data/LLM models/test.h5",
        index=i,
        threshold=0.3,
        return_dict=True
    )
    results.append(result)
    print(f"Sample {i}: Predicted {len(result['predicted_dyads'])} dyads, "
          f"True {len(result['true_dyads'])} dyads")

In [None]:
# Calculate average metrics across samples
from sklearn.metrics import precision_score, recall_score, f1_score

precisions = []
recalls = []
f1_scores = []

for result in results:
    seq_len = len(result['dyad_probabilities'])
    
    # Create binary arrays
    y_true = np.zeros(seq_len)
    y_true[result['true_dyads']] = 1
    
    y_pred = np.zeros(seq_len)
    y_pred[result['predicted_dyads']] = 1
    
    # Calculate metrics
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)

print(f"\nAverage Metrics (n={n_samples}):")
print(f"  Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
print(f"  Recall:    {np.mean(recalls):.4f} ± {np.std(recalls):.4f}")
print(f"  F1-Score:  {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")

## Summary

The `DyadPredictorLLM` class provides:

1. **`.init_model()`** - Initialize a new model with specified architecture
2. **`.load_from_json()`** - Load a pre-trained model from JSON config
3. **`.train()`** - Train the model on HDF5 data with automatic checkpointing
4. **`.apply()`** - Predict dyad positions for a specific sequence
5. **`.visualize_model()`** - Visualize the model architecture
6. **`.plot_training_history()`** - Plot training/validation loss curves

The class handles both 'simple' and 'dilated' model architectures, manages data loading from HDF5 files, and provides a clean API for all common tasks.