# QAOA Hypothesis Clustering Demo

This notebook demonstrates QAOA-based hypothesis clustering with classical fallback.

In [None]:
import sys
sys.path.append('../..')

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

from quantum_integration.multilingual_research_agent import Hypothesis, Language
from quantum_integration.quantum_hypothesis_clusterer import QuantumHypothesisClusterer

## 1. Generate Synthetic Hypotheses

In [None]:
# Generate synthetic hypothesis embeddings
np.random.seed(42)
n_hypotheses = 30
embedding_dim = 128
n_clusters_true = 3

# Create clustered embeddings
hypotheses = []
for cluster_id in range(n_clusters_true):
    center = np.random.randn(embedding_dim) * 5
    for i in range(n_hypotheses // n_clusters_true):
        embedding = center + np.random.randn(embedding_dim) * 0.5
        hyp = Hypothesis(
            text=f"Hypothesis {len(hypotheses)}: Research finding in cluster {cluster_id}",
            language=Language.ENGLISH,
            confidence=np.random.uniform(0.6, 0.95),
            embedding=embedding
        )
        hypotheses.append(hyp)

print(f"Generated {len(hypotheses)} hypotheses with {n_clusters_true} true clusters")

## 2. Visualize Hypothesis Embeddings

In [None]:
# Reduce to 2D for visualization
embeddings = np.array([h.embedding for h in hypotheses])
pca = PCA(n_components=2)
embeddings_2d = pca.fit_transform(embeddings)

plt.figure(figsize=(10, 6))
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.6, s=100)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Hypothesis Embeddings (PCA)')
plt.grid(True, alpha=0.3)
plt.show()

## 3. Run QAOA Clustering

In [None]:
# Initialize QAOA clusterer
clusterer = QuantumHypothesisClusterer(
    num_clusters=n_clusters_true,
    qaoa_layers=2,
    shots=512
)

# Run QAOA clustering
qaoa_result = clusterer.cluster(embeddings)

print(f"\nQAOA Clustering Results:")
print(f"Method: {qaoa_result['method']}")
print(f"Number of clusters: {qaoa_result['num_clusters']}")
print(f"Clustering purity: {qaoa_result['purity']:.4f}")

## 4. Visualize QAOA Clusters

In [None]:
# Visualize QAOA clustering
plt.figure(figsize=(10, 6))
clusters = qaoa_result['cluster_assignments']
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                     c=clusters, cmap='viridis', alpha=0.6, s=100)
plt.colorbar(scatter, label='Cluster ID')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title(f'QAOA Clustering (Purity: {qaoa_result["purity"]:.3f})')
plt.grid(True, alpha=0.3)
plt.show()

## 5. Run Classical Clustering for Comparison

In [None]:
# Run classical k-means
similarity_matrix = clusterer._compute_similarity_matrix(embeddings)
classical_result = clusterer._classical_cluster(embeddings, similarity_matrix)

print(f"\nClassical Clustering Results:")
print(f"Method: {classical_result['method']}")
print(f"Number of clusters: {classical_result['num_clusters']}")
print(f"Clustering purity: {classical_result['purity']:.4f}")

## 6. Visualize Classical Clusters

In [None]:
# Visualize classical clustering
plt.figure(figsize=(10, 6))
clusters_classical = classical_result['cluster_assignments']
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                     c=clusters_classical, cmap='viridis', alpha=0.6, s=100)
plt.colorbar(scatter, label='Cluster ID')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title(f'Classical Clustering (Purity: {classical_result["purity"]:.3f})')
plt.grid(True, alpha=0.3)
plt.show()

## 7. Compare QAOA vs Classical

In [None]:
# Comparison
print(f"\nComparison:")
print(f"QAOA Purity: {qaoa_result['purity']:.4f}")
print(f"Classical Purity: {classical_result['purity']:.4f}")
print(f"Improvement: {(qaoa_result['purity'] - classical_result['purity']):.4f}")

# Bar chart
plt.figure(figsize=(8, 5))
methods = ['QAOA', 'Classical']
purities = [qaoa_result['purity'], classical_result['purity']]
plt.bar(methods, purities, color=['blue', 'orange'])
plt.ylabel('Clustering Purity')
plt.title('QAOA vs Classical Clustering')
plt.ylim([0, 1])
plt.grid(True, alpha=0.3, axis='y')
plt.show()