# Hybrid Quantum-Classical Protein Folding Tutorial

This notebook demonstrates the key components of the hybrid model and how to use them.

## Contents
1. Setup and imports
2. Data preparation
3. Model architecture
4. Training loop
5. Inference and visualization


In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from hqpf.models import HybridModel, StructureGenerator, SurrogateModel
from hqpf.data import ProteinDataset

print('Imports successful!')


## 1. Create a simple protein sequence


In [None]:
# Define sequence
sequence_str = "ACDEFGHIKLMNPQRSTVWY"

# Convert to indices
aa_to_idx = {
    'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4,
    'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9,
    'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14,
    'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19
}

sequence = torch.tensor([aa_to_idx[aa] for aa in sequence_str], dtype=torch.long)
print(f"Sequence: {sequence_str}")
print(f"Length: {len(sequence)}")


## 2. Initialize the hybrid model


In [None]:
model = HybridModel(
    n_residues=len(sequence),
    n_qubits=20,
    embedding_dim=64,
    hidden_dim=128,
    use_quantum=False,  # Use simulator
    device='cpu'
)

print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")


## 3. Generate structure predictions


In [None]:
model.eval()

with torch.no_grad():
    outputs = model(
        sequence,
        n_candidates=10,
        use_surrogate=True,
        temperature=1.0
    )

best_structure = outputs['best_structure']
best_energy = outputs['best_energy']

print(f"Best energy: {best_energy.item():.4f}")
print(f"Structure shape: {best_structure.shape}")


## 4. Visualize the predicted structure


In [None]:
coords = best_structure.numpy()

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Plot backbone
ax.plot(coords[:, 0], coords[:, 1], coords[:, 2], 'b-', linewidth=2, alpha=0.7, label='Backbone')

# Plot residues
scatter = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], 
                     c=range(len(coords)), cmap='viridis', s=100, label='Residues')

ax.set_xlabel('X (\u00c5)')
ax.set_ylabel('Y (\u00c5)')
ax.set_zlabel('Z (\u00c5)')
ax.set_title(f'Predicted Structure: {sequence_str}')
ax.legend()

plt.colorbar(scatter, label='Residue Index')
plt.tight_layout()
plt.show()


## 5. Analyze the energy landscape


In [None]:
energies = outputs['energies'].numpy()

plt.figure(figsize=(10, 6))
plt.hist(energies, bins=20, edgecolor='black', alpha=0.7)
plt.axvline(best_energy.item(), color='r', linestyle='--', linewidth=2, label='Best')
plt.xlabel('Energy (a.u.)')
plt.ylabel('Count')
plt.title('Energy Distribution of Candidate Structures')
plt.legend()
plt.grid(alpha=0.3)
plt.show()


## 6. Compute structural properties


In [None]:
# Radius of gyration
center = coords.mean(axis=0)
rg = np.sqrt(np.mean(np.sum((coords - center) ** 2, axis=1)))

# End-to-end distance
end_to_end = np.linalg.norm(coords[-1] - coords[0])

# Contact map
from scipy.spatial.distance import cdist
dist_matrix = cdist(coords, coords)
contact_map = (dist_matrix < 8.0).astype(float)
np.fill_diagonal(contact_map, 0)

print(f"Radius of gyration: {rg:.2f} \u00c5")
print(f"End-to-end distance: {end_to_end:.2f} \u00c5")
print(f"Number of contacts: {int(contact_map.sum() / 2)}")

# Plot contact map
plt.figure(figsize=(8, 8))
plt.imshow(contact_map, cmap='Blues')
plt.colorbar(label='Contact')
plt.xlabel('Residue Index')
plt.ylabel('Residue Index')
plt.title('Contact Map (8 \u00c5 cutoff)')
plt.tight_layout()
plt.show()
