# 1. DBBE

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics import adjusted_rand_score, v_measure_score
from tqdm import tqdm
from collections import defaultdict
import igraph as ig
import leidenalg as la
import faiss
import os
from typing import Dict
import re
import unicodedata
import pandas as pd
from collections import defaultdict
from itertools import combinations

os.makedirs('dbbe_semantic_results', exist_ok=True)

csv_path = 'paper_verses.csv'
df = pd.read_csv(csv_path)
df['text'] = df['text'].astype(str)
df['idgroup'] = df['idgroup'].astype('float32')
df = df.dropna(subset=['text', 'idgroup'])  # Drop if text OR idgroup is missing


model_name = 'kevinkrahn/shlm-grc-en'
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

def cls_pooling(model_output):
    return model_output[0][:, 0]

verses = df['text'].tolist()
embeddings = []
batch_size = 32

for i in tqdm(range(0, len(verses), batch_size), desc="Embeddings"):
    batch = verses[i:i+batch_size]
    encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    batch_embeddings = cls_pooling(model_output).cpu().numpy()
    embeddings.append(batch_embeddings)

embeddings = np.vstack(embeddings)
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

n_vectors, dimension = embeddings.shape
nlist = int(np.sqrt(n_vectors))
nlist = max(1, nlist)
nlist = min(nlist, 1024)        # cap to avoid too many lists
nlist = min(nlist, n_vectors)
k = 50
nprobe = 10

quantizer = faiss.IndexFlatIP(dimension)
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)

index.train(embeddings.astype('float32'))
index.add(embeddings.astype('float32'))
index.nprobe = nprobe
distances, indices = index.search(embeddings.astype('float32'), k)

similarity_thresholds = [0.70, 0.75, 0.80, 0.85, 0.90]
threshold_results = []
all_resolution_results = []

for sim_threshold in similarity_thresholds:
    print(f"\n{'='*80}")
    print(f"Threshold: {sim_threshold:.2f}")
    print(f"{'='*80}")
    
    candidate_pairs = set()
    
    for i in tqdm(range(len(embeddings)), desc=f"Building graph (t={sim_threshold:.2f})"):
        for j_idx, distance in zip(indices[i], distances[i]):
            if j_idx != i and j_idx != -1:
                similarity = distance
                if similarity >= sim_threshold:
                    idx1, idx2 = min(i, j_idx), max(i, j_idx)
                    candidate_pairs.add((idx1, idx2))
    
    n_pairs = len(candidate_pairs)
    avg_degree = n_pairs * 2 / len(embeddings)
    
    print(f"Candidate pairs: {n_pairs:,} (avg degree: {avg_degree:.1f})")
    
    if n_pairs == 0:
        print("No pairs found - skipping")
        threshold_results.append({
            'threshold': sim_threshold,
            'n_pairs': 0,
            'best_resolution': None,
            'best_ari': 0,
            'best_v_measure': 0,
            'n_clusters': 0
        })
        continue
    
    edges = []
    weights = []
    
    for i, j in tqdm(candidate_pairs, desc="Edge weights"):
        sim = float(np.dot(embeddings[i], embeddings[j]))
        edges.append((i, j))
        weights.append(sim)
    
    g = ig.Graph(n=len(embeddings), edges=edges, directed=False)
    g.es['weight'] = weights
    
    print(f"Graph: {g.vcount()} nodes, {g.ecount()} edges")
    
    w = np.array(weights)
    w_scaled = ((w - w.min()) / (w.max() - w.min())) ** 3
    g.es['weight'] = w_scaled.tolist()
    
    hub_thresh = 500
    for v in range(g.vcount()):
        if g.degree(v) > hub_thresh:
            for e in g.incident(v):
                g.es[e]['weight'] *= 0.5
    
    print("Leiden resolution sweep...")
    resolutions = np.logspace(-2, 1, 20)
    
    best_ari = -1
    best_labels = None
    best_res = None
    best_v = None
    
    for res in tqdm(resolutions, desc="Resolutions"):
        partition = la.find_partition(
            g,
            la.CPMVertexPartition,
            weights='weight',
            resolution_parameter=res,
            n_iterations=-1
        )
        labels = np.array(partition.membership)
        ari = adjusted_rand_score(df['idgroup'], labels)
        v_measure = v_measure_score(df['idgroup'], labels)
        n_clusters = len(set(labels))
        
        col_name = f'cluster_t{int(sim_threshold*100)}_r{res:.6f}'
        df[col_name] = labels
        
        all_resolution_results.append({
            'threshold': sim_threshold,
            'resolution': res,
            'ari': ari,
            'v_measure': v_measure,
            'n_clusters': n_clusters,
            'column_name': col_name
        })
        
        if ari > best_ari:
            best_ari = ari
            best_labels = labels
            best_res = res
            best_v = v_measure
    
    n_clusters = len(set(best_labels))
    
    print(f"Best resolution: {best_res:.4f}")
    print(f"Best ARI: {best_ari:.4f}")
    print(f"Best V-measure: {best_v:.4f}")
    print(f"Clusters: {n_clusters}")
    
    threshold_results.append({
        'threshold': sim_threshold,
        'n_pairs': n_pairs,
        'avg_degree': avg_degree,
        'best_resolution': best_res,
        'best_ari': best_ari,
        'best_v_measure': best_v,
        'n_clusters': n_clusters
    })
    
    df[f'cluster_t{int(sim_threshold*100)}_best'] = best_labels

print("\n" + "="*80)
print("GRID SEARCH RESULTS (BEST PER THRESHOLD)")
print("="*80)

results_df = pd.DataFrame(threshold_results)
print(results_df.to_string(index=False))

best_threshold_row = results_df.loc[results_df['best_ari'].idxmax()]
best_threshold = best_threshold_row['threshold']

print(f"\n{'='*80}")
print(f"OVERALL BEST THRESHOLD: {best_threshold:.2f}")
print(f"{'='*80}")
print(f"ARI: {best_threshold_row['best_ari']:.4f}")
print(f"V-measure: {best_threshold_row['best_v_measure']:.4f}")
print(f"Resolution: {best_threshold_row['best_resolution']:.4f}")
print(f"Clusters: {int(best_threshold_row['n_clusters'])}")

df['cluster_best'] = df[f'cluster_t{int(best_threshold*100)}_best']

df.to_csv("dbbe_semantic_results/faiss_leiden_gridsearch_results.csv", index=False)
results_df.to_csv("dbbe_semantic_results/threshold_gridsearch_summary.csv", index=False)

all_resolution_df = pd.DataFrame(all_resolution_results)
all_resolution_df.to_csv("dbbe_semantic_results/all_threshold_resolution_results.csv", index=False)

print("\nFiles saved:")
print("  - dbbe_semantic_results/faiss_leiden_gridsearch_results.csv")
print("  - dbbe_semantic_results/threshold_gridsearch_summary.csv")
print("  - dbbe_semantic_results/all_threshold_resolution_results.csv")

print(f"\nTotal clustering columns created: {len(all_resolution_results)}")
print(f"Thresholds tested: {len(similarity_thresholds)}")
print(f"Resolutions per threshold: {len(resolutions)}")

def calculate_jaccard_similarity(clusters_a, clusters_b):
    if not clusters_a or not clusters_b:
        return 0.0
    intersection = len(clusters_a & clusters_b)
    union = len(clusters_a | clusters_b)
    return intersection / union

def reconstruct_poems(df):
    poem_to_clusters = defaultdict(set)
    poem_verse_counts = defaultdict(int)
    all_poem_ids = set()

    for _, row in df.iterrows():
        poem_id = row['idoriginal_poem']
        cluster_id = row['cluster_leiden_fixed']
        all_poem_ids.add(poem_id)
        poem_verse_counts[poem_id] += 1
        if cluster_id != -1:
            poem_to_clusters[poem_id].add(cluster_id)

    for poem_id in all_poem_ids:
        if poem_id not in poem_to_clusters:
            poem_to_clusters[poem_id] = set()  # Empty set for poems with no verse clusters

    print(f"\nReconstructed {len(poem_to_clusters)} poems")
    print(f"  - Poems with verse clusters: {sum(1 for clusters in poem_to_clusters.values() if clusters)}")
    print(f"  - Poems without verse clusters: {sum(1 for clusters in poem_to_clusters.values() if not clusters)}")
    return poem_to_clusters, poem_verse_counts

def evaluate_against_ground_truth(df, poem_clusters):
    poem_to_type = df.groupby('idoriginal_poem')['type_id'].first().to_dict()

    y_true = []
    y_pred = []
    for poem_id, predicted_cluster in poem_clusters.items():
        if poem_id in poem_to_type:
            y_true.append(poem_to_type[poem_id])
            y_pred.append(predicted_cluster)

    ari = adjusted_rand_score(y_true, y_pred)
    v_measure = v_measure_score(y_true, y_pred)

    return ari, v_measure, y_true, y_pred

def cluster_poems_jaccard(poem_to_clusters, similarity_threshold=0.66):
    poem_ids = list(poem_to_clusters.keys())
    n_poems = len(poem_ids)

    edges = []
    for i in range(n_poems):
        for j in range(i + 1, n_poems):
            pid_a, pid_b = poem_ids[i], poem_ids[j]
            sim = calculate_jaccard_similarity(poem_to_clusters[pid_a], poem_to_clusters[pid_b])
            if sim >= similarity_threshold:
                edges.append((pid_a, pid_b))

    class UF:
        def __init__(self, elements):
            self.parent = {e: e for e in elements}
            self.rank = {e: 0 for e in elements}

        def find(self, x):
            if self.parent[x] != x:
                self.parent[x] = self.find(self.parent[x])
            return self.parent[x]

        def union(self, x, y):
            px, py = self.find(x), self.find(y)
            if px == py: return
            if self.rank[px] < self.rank[py]: px, py = py, px
            self.parent[py] = px
            if self.rank[px] == self.rank[py]: self.rank[px] += 1

    uf = UF(poem_ids)
    for a, b in edges:
        uf.union(a, b)

    poem_clusters = {pid: uf.find(pid) for pid in poem_ids}
    return poem_clusters, edges

def calculate_perfect_reconstruction_rate(df, poem_clusters):
    poem_to_type = df.groupby('idoriginal_poem')['type_id'].first().to_dict()

    gt_to_poems = defaultdict(set)
    for poem_id, gt_type in poem_to_type.items():
        gt_to_poems[gt_type].add(poem_id)

    pred_to_poems = defaultdict(set)
    for poem_id, pred_cluster in poem_clusters.items():
        pred_to_poems[pred_cluster].add(poem_id)

    perfectly_reconstructed = 0
    total_gt_clusters = len(gt_to_poems)

    for gt_type, gt_poems in gt_to_poems.items():
        for pred_cluster, pred_poems in pred_to_poems.items():
            if gt_poems == pred_poems:
                perfectly_reconstructed += 1
                break

    reconstruction_rate = perfectly_reconstructed / total_gt_clusters if total_gt_clusters > 0 else 0
    return reconstruction_rate, perfectly_reconstructed, total_gt_clusters

df = pd.read_csv("dbbe_semantic_results/faiss_leiden_gridsearch_results.csv")

if 'cluster_best' in df.columns:
    df['cluster_leiden_fixed'] = df['cluster_best']
else:
    raise ValueError("Column 'cluster_best' not found in CSV")

poem_to_clusters, _ = reconstruct_poems(df)

thresholds = [0.50, 0.60, 0.70, 0.8]
results = []

for thresh in thresholds:
    print(f"\nThreshold {thresh:.0%}...")
    poem_clusters, edges = cluster_poems_jaccard(poem_to_clusters, thresh)
    df['poem_cluster_id'] = df['idoriginal_poem'].map(poem_clusters)

    ari, v_measure, _, _ = evaluate_against_ground_truth(df, poem_clusters)
    reconstruction_rate, n_perfect, n_total_gt = calculate_perfect_reconstruction_rate(df, poem_clusters)

    results.append({
        'threshold': thresh,
        'n_poem_clusters': len(set(poem_clusters.values())),
        'n_edges': len(edges),
        'ari': ari,
        'v_measure': v_measure,
        'perfect_reconstruction_rate': reconstruction_rate,
        'n_perfect_clusters': n_perfect,
        'n_total_gt_clusters': n_total_gt
    })

results_df = pd.DataFrame(results)
df.to_csv('dbbe_semantic_results/dbbe_poems_semantic.csv')
results_df.to_csv('dbbe_semantic_results/dbbe_poems_semantic_stats.csv')

results_df = results_df.sort_values('threshold')

plt.figure(figsize=(6,4))
plt.plot(results_df['threshold'], results_df['ari'], marker='o', linestyle='-')
plt.xticks(results_df['threshold'])
plt.xlabel("Jaccard Similarity Threshold")
plt.ylabel("Adjusted Rand Index (ARI)")
plt.title("ARI of Poem Clustering vs Threshold")
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig("dbbe_semantic_results/ari_poemlevel_sem_dbbe.png", dpi=300)
plt.close()

df = pd.read_csv("dbbe_semantic_results/all_threshold_resolution_results.csv")

df = df[~df['threshold'].isin([0.85, 0.75])]
df = df[df['resolution'] <= 1.0]

df['resolution'] = df['resolution'].round(4)
df['threshold'] = df['threshold'].round(4)
df['ari'] = df['ari'].round(4)

heatmap_data = df.pivot(index='resolution', columns='threshold', values='ari')

plt.figure(figsize=(8, 6))
sns.heatmap(
    heatmap_data, 
    annot=True, 
    fmt=".4f", 
    cmap="viridis", 
    cbar_kws={'label': 'ARI'},
    annot_kws={"fontsize":14}
)
plt.ylabel("Resolution")
plt.xlabel("Similarity Threshold")
plt.title("ARI heatmap across Threshold and Resolution", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()

plt.savefig("dbbe_semantic_results/ari_verselevel_sem_dbbe.png", dpi=300)
plt.close()

print("\nAll outputs saved to dbbe_semantic_results/")

In [None]:
# Helper printing all verses that were clustered together with a given target idgroup
df = pd.read_csv("dbbe_semantic_results/dbbe_poems_semantic.csv")

target_idgroup = 831
threshold_cols = ['cluster_t70', 'cluster_t75', 'cluster_t80', 'cluster_t85', 'cluster_t90', 'cluster_leiden_fixed']
poem_to_cluster = df.groupby('idoriginal_poem')['poem_cluster_id'].first().to_dict()


verse_row = df[df['idgroup'] == target_idgroup]
if verse_row.empty:
    print("idgroup not found")
else:
    print(f"Verse {target_idgroup}: {verse_row['text'].iloc[0]}\n")

    for col in threshold_cols:
        cluster_id = verse_row[col].iloc[0]
        same_cluster = df[df[col] == cluster_id]
        same_cluster_sorted = same_cluster.sort_values(['idoriginal_poem', 'order'])

        poem_ids_in_cluster = same_cluster_sorted['idoriginal_poem'].unique().tolist()

        print(f"=== {col} | Cluster {cluster_id} | {len(poem_ids_in_cluster)} poems ===")

        for _, row in same_cluster_sorted.iterrows():
            poem_id = row['idoriginal_poem']
            verse_text = row['text']
            type_id = row['type_id']
            verse_group = row['idgroup']

            poem_cluster = poem_to_cluster.get(poem_id, "N/A")
            print(f"  Poem ID: {poem_id} | Verse: {verse_text} | Poem Cluster: {poem_cluster} | Type ID: {type_id} | Verse Group {verse_group}")

        print("\n")


# 2. Full dataset

In [None]:
import re
import unicodedata
from typing import Dict, List, Tuple
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict, Counter
import time
import pickle
import gzip
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import igraph as ig
import leidenalg as la
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
import faiss
from functools import partial
import os
import multiprocessing as mp
import psutil
import platform
import socket
from datetime import datetime
import threading
import sys

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

csv_path = 'concatenated.csv'
RESULTS_DIR = Path("full_semantic_results")
CHECKPOINT_DIR = Path("/scratch/gent/vo/000/gvo00042/vsc48660/full_semantic_clustering_checkpoints_tmp")

RANDOM_SEED = 42

class ResourceMonitor:
    def __init__(self):
        self.monitoring = False
        self.thread = None
        self.peak_ram_gb = 0
        self.peak_gpu_mem_gb = 0
        self.ram_samples = []
        self.gpu_samples = []
        self.process = psutil.Process()

    def _monitor_loop(self):
        while self.monitoring:
            ram_gb = self.process.memory_info().rss / (1024**3)
            self.ram_samples.append(ram_gb)
            self.peak_ram_gb = max(self.peak_ram_gb, ram_gb)

            try:
                if torch.cuda.is_available():
                    gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
                    self.gpu_samples.append(gpu_mem_gb)
                    self.peak_gpu_mem_gb = max(self.peak_gpu_mem_gb, gpu_mem_gb)
            except:
                pass

            time.sleep(1)

    def start(self):
        self.monitoring = True
        self.thread = threading.Thread(target=self._monitor_loop, daemon=True)
        self.thread.start()

    def stop(self):
        self.monitoring = False
        if self.thread:
            self.thread.join(timeout=2)

    def get_stats(self):
        return {
            'peak_ram_gb': self.peak_ram_gb,
            'avg_ram_gb': np.mean(self.ram_samples) if self.ram_samples else 0,
            'peak_gpu_mem_gb': self.peak_gpu_mem_gb,
            'avg_gpu_mem_gb': np.mean(self.gpu_samples) if self.gpu_samples else 0
        }

class TimingLogger:
    def __init__(self):
        self.stages = {}
        self.current_stage = None
        self.stage_start = None

    def start_stage(self, name):
        self.current_stage = name
        self.stage_start = time.time()

    def end_stage(self):
        if self.current_stage and self.stage_start:
            duration = time.time() - self.stage_start
            self.stages[self.current_stage] = duration
            self.current_stage = None
            self.stage_start = None

    def get_summary(self):
        return self.stages.copy()

resource_monitor = ResourceMonitor()
timing_logger = TimingLogger()

def get_system_info():
    info = {
        'hostname': socket.gethostname(),
        'platform': platform.platform(),
        'python_version': platform.python_version(),
        'processor': platform.processor(),
        'cpu_count_physical': psutil.cpu_count(logical=False),
        'cpu_count_logical': psutil.cpu_count(logical=True),
        'total_ram_gb': psutil.virtual_memory().total / (1024**3),
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }

    if torch.cuda.is_available():
        info['gpu_available'] = True
        info['gpu_count'] = torch.cuda.device_count()
        info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
        info['gpu_total_memory_gb'] = [torch.cuda.get_device_properties(i).total_memory / (1024**3)
                                       for i in range(torch.cuda.device_count())]
    else:
        info['gpu_available'] = False
        info['gpu_count'] = 0
        info['gpu_names'] = []
        info['gpu_total_memory_gb'] = []

    return info

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)

CHECKPOINT_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)

def detect_optimal_resources():
    """Detect and configure optimal parameters based on available hardware."""
    cpu_count_physical = psutil.cpu_count(logical=False) or 1
    cpu_count_logical = psutil.cpu_count(logical=True) or 1
    total_ram_gb = psutil.virtual_memory().total / (1024**3)

    gpu_available = torch.cuda.is_available()
    gpu_count = torch.cuda.device_count() if gpu_available else 0
    gpu_memory_gb = []
    gpu_names = []
    if gpu_available and gpu_count > 0:
        for i in range(gpu_count):
            gpu_memory_gb.append(torch.cuda.get_device_properties(i).total_memory / (1024**3))
            gpu_names.append(torch.cuda.get_device_name(i))

    logger.info("\n" + "="*80)
    logger.info("DYNAMIC RESOURCE DETECTION FOR HPC")
    logger.info("="*80)
    logger.info(f"Physical CPU cores: {cpu_count_physical}")
    logger.info(f"Logical CPU cores:  {cpu_count_logical}")
    logger.info(f"Total RAM:          {total_ram_gb:.2f} GB")
    logger.info(f"GPU available:      {gpu_available}")
    if gpu_available:
        logger.info(f"GPU count:          {gpu_count}")
        for i, (name, mem) in enumerate(zip(gpu_names, gpu_memory_gb)):
            logger.info(f"  GPU {i}: {name} ({mem:.2f} GB)")
    logger.info("="*80 + "\n")

    # Scale parameters based on resources - AGGRESSIVE HPC SCALING
    # Use up to 90% of logical cores for parallel processing on HPC
    max_workers = max(8, int(cpu_count_logical * 0.90))
    max_workers = min(max_workers, 128)  # Cap at 128 for safety

    # Scale FAISS threads - use more on HPC
    n_threads = max(8, int(cpu_count_physical * 0.8))
    n_threads = min(n_threads, 64)  # Cap at 64

    # Scale sample size based on RAM - aggressive for HPC
    # Assume ~0.5GB per 10k samples (embeddings are small)
    sample_size = min(50000, int(total_ram_gb * 250))
    sample_size = max(15000, sample_size)

    # Scale FAISS nprobe based on CPU power - more thorough search on HPC
    faiss_nprobe = min(64, max(16, cpu_count_physical))

    # Scale N_NEIGHBORS based on RAM and CPU
    n_neighbors = min(300, max(100, int(total_ram_gb * 3)))

    # Bootstrap and stability pairs scale with workers
    n_bootstrap = max(2, min(10, max_workers // 12))
    stability_pairs = min(10000, max(2000, max_workers * 75))

    # Batch sizes scale with RAM - larger on HPC
    batch_size = min(50000, max(10000, int(total_ram_gb * 200)))

    # Search batch size for FAISS
    search_batch_size = min(100000, max(30000, int(total_ram_gb * 300)))

    # Edge batch size
    edge_batch_size = min(50000, max(20000, int(total_ram_gb * 150)))

    # Poem-level workers
    poem_max_workers = max(16, int(cpu_count_logical * 0.85))
    poem_max_workers = min(poem_max_workers, 128)

    # Poem batch size
    poem_batch_size = min(100000, max(50000, int(total_ram_gb * 300)))

    config = {
        'max_workers': max_workers,
        'n_threads': n_threads,
        'sample_size': sample_size,
        'faiss_nprobe': faiss_nprobe,
        'n_neighbors': n_neighbors,
        'n_bootstrap': n_bootstrap,
        'stability_pairs': stability_pairs,
        'batch_size': batch_size,
        'search_batch_size': search_batch_size,
        'edge_batch_size': edge_batch_size,
        'poem_max_workers': poem_max_workers,
        'poem_batch_size': poem_batch_size,
        'cpu_physical': cpu_count_physical,
        'cpu_logical': cpu_count_logical,
        'ram_gb': total_ram_gb,
        'gpu_available': gpu_available,
        'gpu_count': gpu_count,
        'gpu_memory_gb': gpu_memory_gb,
        'gpu_names': gpu_names
    }

    logger.info("OPTIMIZED CONFIGURATION FOR HPC")
    logger.info("="*80)
    logger.info(f"Max workers (parallel):   {max_workers}")
    logger.info(f"FAISS threads:            {n_threads}")
    logger.info(f"Sample size:              {sample_size:,}")
    logger.info(f"FAISS nprobe:             {faiss_nprobe}")
    logger.info(f"N neighbors:              {n_neighbors}")
    logger.info(f"N bootstrap:              {n_bootstrap}")
    logger.info(f"Stability pairs:          {stability_pairs:,}")
    logger.info(f"Batch size:               {batch_size:,}")
    logger.info(f"Search batch size:        {search_batch_size:,}")
    logger.info(f"Edge batch size:          {edge_batch_size:,}")
    logger.info(f"Poem max workers:         {poem_max_workers}")
    logger.info(f"Poem batch size:          {poem_batch_size:,}")
    logger.info("="*80 + "\n")

    # Save resource config for reference
    resource_info_path = CHECKPOINT_DIR / 'resource_config.pkl'
    with open(resource_info_path, 'wb') as f:
        pickle.dump(config, f)
    logger.info(f"Resource configuration saved to: {resource_info_path}\n")

    return config

# Get optimal configuration
resource_config = detect_optimal_resources()

N_NEIGHBORS = resource_config['n_neighbors']
BATCH_SIZE = resource_config['batch_size']
N_THREADS = resource_config['n_threads']
FAISS_NPROBE = resource_config['faiss_nprobe']

SAMPLE_SIZE = resource_config['sample_size']
N_BOOTSTRAP = resource_config['n_bootstrap']
STABILITY_PAIRS = resource_config['stability_pairs']
MAX_WORKERS = resource_config['max_workers']

SEARCH_BATCH_SIZE = resource_config['search_batch_size']
EDGE_BATCH_SIZE = resource_config['edge_batch_size']
POEM_MAX_WORKERS = resource_config['poem_max_workers']
POEM_BATCH_SIZE = resource_config['poem_batch_size']

logger.info("="*80)
logger.info("HPC-OPTIMIZED SEMANTIC CLUSTERING PIPELINE")
logger.info(f"Sample: {SAMPLE_SIZE:,}, Bootstrap: {N_BOOTSTRAP}, Workers: {MAX_WORKERS}")
logger.info("="*80)

resource_monitor.start()
script_start_time = time.time()

CLEAN_PATTERN = re.compile(r'[^\w\s]')
WHITESPACE_PATTERN = re.compile(r'\s+')

def preprocess_text(text: str, options: Dict[str, bool] = None) -> str:
    text = unicodedata.normalize('NFC', text)
    return text.strip()

def cls_pooling(model_output):
    return model_output[0][:, 0]

timing_logger.start_stage("01_loading_embeddings")

embeddings_file = CHECKPOINT_DIR / 'embeddings.npz'
metadata_file = CHECKPOINT_DIR / 'metadata.pkl.gz'

if embeddings_file.exists() and metadata_file.exists():
    logger.info("\n" + "="*80)
    logger.info("Loading checkpoint")
    logger.info("="*80)

    start_time = time.time()

    embeddings_data = np.load(embeddings_file)
    embeddings = embeddings_data['embeddings'].astype('float32')

    with gzip.open(metadata_file, 'rb') as f:
        metadata = pickle.load(f)

    df = pd.read_parquet(CHECKPOINT_DIR / 'df_minimal.parquet')
    source_datasets = metadata['source_datasets']
    dataset_to_indices = metadata['dataset_to_indices']

    logger.info(f"Loaded {len(embeddings):,} embeddings in {time.time()-start_time:.1f}s")

    # Load previous embedding timing if available
    timing_file = CHECKPOINT_DIR / 'timing_metadata.pkl'
    if timing_file.exists():
        with open(timing_file, 'rb') as f:
            embed_timing = pickle.load(f)
            if 'gpu_used' in embed_timing:
                logger.info(f"Previous embedding GPU: {embed_timing.get('gpu_used', 'N/A')}")

    LOADED_FROM_CHECKPOINT = True

else:
    logger.info("\n" + "="*80)
    logger.info("No checkpoint found - creating embeddings")
    logger.info("="*80)
    embedding_start_time = time.time()

    df = pd.read_csv(csv_path)
    df = df.dropna(subset=['verse', 'source_dataset'])
    df = df[df['verse'].fillna('').astype(str).str.len() >= 20]
    df = df[df['verse'].fillna('').astype(str).str.len() < 256]
    df['verse'] = df['verse'].apply(preprocess_text)
    df = df.reset_index(drop=True)

    logger.info(f"Total verses: {len(df):,}")

    from transformers import AutoTokenizer, AutoModel

    model_name = 'kevinkrahn/shlm-grc-en'
    logger.info(f"Loading model: {model_name}")
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")

    # Log GPU details for embeddings
    gpu_info_str = "CPU only"
    if torch.cuda.is_available():
        gpu_idx = torch.cuda.current_device()
        gpu_name = torch.cuda.get_device_name(gpu_idx)
        gpu_mem = torch.cuda.get_device_properties(gpu_idx).total_memory / (1024**3)
        gpu_info_str = f"GPU {gpu_idx}: {gpu_name} ({gpu_mem:.2f} GB)"
        logger.info(f"GPU for embeddings: {gpu_info_str}")

    model = model.to(device)
    model.eval()

    logger.info("Computing embeddings...")
    verses = df['verse'].tolist()
    embeddings = []

    # Scale embedding batch size based on GPU memory
    if torch.cuda.is_available():
        gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        embed_batch_size = min(128, max(32, int(gpu_mem_gb * 8)))
    else:
        embed_batch_size = 16

    logger.info(f"Embedding batch size: {embed_batch_size}")

    for i in tqdm(range(0, len(verses), embed_batch_size), desc="Embedding"):
        try:
            batch = verses[i:i+embed_batch_size]
            encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
            encoded_input = {k: v.to(device) for k, v in encoded_input.items()}

            with torch.no_grad():
                model_output = model(**encoded_input)

            batch_embeddings = cls_pooling(model_output).cpu().numpy()
            embeddings.append(batch_embeddings)
        except Exception as e:
            logger.warning(f"Error at batch {i}: {e}")
            embeddings.append(np.zeros((len(batch), model.config.hidden_size), dtype=np.float32))

    embeddings = np.vstack(embeddings).astype('float32')
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

    source_datasets = df['source_dataset'].values
    dataset_to_indices = defaultdict(list)
    for idx, dataset in enumerate(source_datasets):
        dataset_to_indices[dataset].append(idx)

    logger.info("\nSaving checkpoint...")

    np.savez_compressed(CHECKPOINT_DIR / 'embeddings.npz', embeddings=embeddings)
    embedding_total_time = time.time() - embedding_start_time

    resource_monitor.stop()
    resource_stats = resource_monitor.get_stats()
    resource_monitor.start()  # Restart monitoring

    checkpoint_timing = {
        'embedding_generation_time': embedding_total_time,
        'total_verses_processed': len(df),
        'embedding_batch_size': embed_batch_size,
        'device_used': str(device),
        'gpu_used': gpu_info_str,
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'peak_ram_gb': resource_stats['peak_ram_gb'],
        'avg_ram_gb': resource_stats['avg_ram_gb'],
        'peak_gpu_mem_gb': resource_stats['peak_gpu_mem_gb'],
        'avg_gpu_mem_gb': resource_stats['avg_gpu_mem_gb']
    }

    with open(CHECKPOINT_DIR / 'timing_metadata.pkl', 'wb') as f:
        pickle.dump(checkpoint_timing, f)

    essential_cols = ['verse', 'source_dataset']
    for col in ['idoriginal_poem', 'idgroup', 'order']:
        if col in df.columns:
            essential_cols.append(col)

    df_minimal = df[essential_cols].copy()
    for col in df_minimal.columns:
        if df_minimal[col].dtype == 'object':
            df_minimal[col] = df_minimal[col].astype(str)
            df_minimal[col] = df_minimal[col].replace('nan', None).replace('None', None)

    df_minimal.to_parquet(CHECKPOINT_DIR / 'df_minimal.parquet',
                          compression='gzip', index=True)

    metadata = {
        'source_datasets': source_datasets,
        'dataset_to_indices': dataset_to_indices
    }
    with gzip.open(CHECKPOINT_DIR / 'metadata.pkl.gz', 'wb') as f:
        pickle.dump(metadata, f, protocol=pickle.HIGHEST_PROTOCOL)

    logger.info("Checkpoint saved")
    logger.info(f"GPU used for embeddings: {gpu_info_str}")
    LOADED_FROM_CHECKPOINT = False

timing_logger.end_stage()

# [REST OF THE SCRIPT CONTINUES WITH THE SAME LOGIC BUT USING DYNAMIC PARAMETERS]
# I'll include key sections with modifications...

timing_logger.start_stage("02_sample_preparation")

logger.info("\n" + "="*80)
logger.info("Fast sample preparation")
logger.info("="*80)

def stratified_sample(df, n_sample=15000):
    datasets = df['source_dataset'].unique()
    total_size = len(df)
    sample_indices = []

    for dataset in datasets:
        dataset_indices = df[df['source_dataset'] == dataset].index.tolist()
        dataset_size = len(dataset_indices)
        proportion = dataset_size / total_size
        n_from_dataset = int(n_sample * proportion)
        n_from_dataset = min(n_from_dataset, dataset_size)
        if n_from_dataset > 0:
            sampled = np.random.choice(dataset_indices, size=n_from_dataset, replace=False)
            sample_indices.extend(sampled)

    return np.array(sorted(sample_indices))

sample_indices = stratified_sample(df, n_sample=SAMPLE_SIZE)
sample_embeddings = embeddings[sample_indices].copy()
sample_dataset_map = np.array([source_datasets[i] for i in sample_indices])

logger.info(f"Sample size: {len(sample_indices):,}")

logger.info("Building FAISS index...")
logger.info(f"Using {N_THREADS} threads for FAISS operations")
faiss.omp_set_num_threads(N_THREADS)

start_time = time.time()
dimension = embeddings.shape[1]
index_sample = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(sample_embeddings)
index_sample.add(sample_embeddings)

k = min(200, len(sample_embeddings) - 1)
similarities, indices = index_sample.search(sample_embeddings, k)

logger.info(f"Neighbor search complete in {time.time()-start_time:.1f}s")

timing_logger.end_stage()


timing_logger.start_stage("03_parameter_grid_setup")

threshold_percentiles_coarse = [96, 97, 98, 99]
resolutions_coarse = np.logspace(-6, -1, 8)

logger.info(f"\nCoarse parameter grid:")
logger.info(f"  Thresholds: {threshold_percentiles_coarse}")
logger.info(f"  Resolutions: {len(resolutions_coarse)} values")
logger.info(f"  Total: {len(threshold_percentiles_coarse) * len(resolutions_coarse)} combinations")

logger.info("\nPrecomputing cross-dataset similarities...")
start_time = time.time()

cross_similarities = []
for i in range(len(sample_embeddings)):
    neighbor_datasets = sample_dataset_map[indices[i, 1:]]
    cross_mask = neighbor_datasets != sample_dataset_map[i]
    cross_similarities.extend(similarities[i, 1:][cross_mask])

cross_similarities = np.array(cross_similarities)
logger.info(f"Collected {len(cross_similarities):,} pairs in {time.time()-start_time:.1f}s")

all_percentiles = list(range(50, 100, 5))
threshold_lookup = {p: np.percentile(cross_similarities, p) for p in all_percentiles}

logger.info(f"Threshold range: {threshold_lookup[50]:.4f} (P50) to {threshold_lookup[95]:.4f} (P95)")

timing_logger.end_stage()

def compute_stability_fast(partitions_list, n_nodes, sample_size=STABILITY_PAIRS):
    if n_nodes < 100:
        return 0.0

    n_partitions = len(partitions_list)
    n_sample_pairs = min(sample_size, n_nodes * (n_nodes - 1) // 2)

    pairs_i = np.random.randint(0, n_nodes, n_sample_pairs)
    pairs_j = np.random.randint(0, n_nodes, n_sample_pairs)
    valid = pairs_i != pairs_j
    pairs_i = pairs_i[valid][:n_sample_pairs]
    pairs_j = pairs_j[valid][:n_sample_pairs]

    coclustering = 0
    for membership in partitions_list:
        membership_arr = np.array(membership)
        matches = membership_arr[pairs_i] == membership_arr[pairs_j]
        coclustering += np.sum(matches)

    stability = coclustering / (len(pairs_i) * n_partitions)
    return stability

def evaluate_single_combination(args):
    threshold_pct, threshold, resolution, edges_data, weights_data, dataset_map, n_nodes = args

    try:
        edge_mask = weights_data >= threshold
        edges = edges_data[edge_mask]
        weights = weights_data[edge_mask]

        if len(edges) == 0:
            return None

        g = ig.Graph(n=n_nodes, edges=edges.tolist(), directed=False)
        g.es['weight'] = weights.tolist()

        bootstrap_memberships = []
        bootstrap_qualities = []

        for seed in range(N_BOOTSTRAP):
            partition = la.find_partition(
                g,
                la.CPMVertexPartition,
                weights='weight',
                resolution_parameter=resolution,
                n_iterations=5,
                seed=seed
            )
            bootstrap_memberships.append(partition.membership)
            bootstrap_qualities.append(partition.quality())

        membership = bootstrap_memberships[0]
        n_clusters = len(set(membership))

        if n_clusters == 0 or n_clusters == n_nodes:
            return None

        stability = compute_stability_fast(bootstrap_memberships, n_nodes)
        avg_size = n_nodes / n_clusters

        cluster_datasets = defaultdict(set)
        cluster_sizes = defaultdict(int)

        for idx, cid in enumerate(membership):
            cluster_datasets[cid].add(dataset_map[idx])
            cluster_sizes[cid] += 1

        n_cross_clusters = sum(1 for datasets in cluster_datasets.values()
                              if len(datasets) > 1)
        pct_cross_clusters = (n_cross_clusters / n_clusters * 100) if n_clusters > 0 else 0
        n_singleton = sum(1 for size in cluster_sizes.values() if size == 1)

        avg_quality = np.mean(bootstrap_qualities)
        modularity = g.modularity(membership, weights='weight')

        return {
            'threshold_percentile': threshold_pct,
            'threshold_value': threshold,
            'resolution': resolution,
            'n_edges': g.ecount(),
            'graph_density': 2*g.ecount()/(n_nodes*(n_nodes-1)) if n_nodes > 1 else 0,
            'stability': stability,
            'n_clusters': n_clusters,
            'n_singleton': n_singleton,
            'avg_cluster_size': avg_size,
            'n_cross_clusters': n_cross_clusters,
            'pct_cross_clusters': pct_cross_clusters,
            'avg_quality': avg_quality,
            'modularity': modularity
        }

    except Exception as e:
        logger.warning(f"Error at P{threshold_pct}, res={resolution:.2e}: {e}")
        return None

timing_logger.start_stage("04_edge_precomputation")

logger.info("\nPrecomputing edge structures...")
start_time = time.time()

all_edges_list = []
all_weights_list = []

for i in range(len(sample_embeddings)):
    dataset_i = sample_dataset_map[i]
    neighbors = indices[i, 1:]
    sims = similarities[i, 1:]

    valid_mask = neighbors > i
    valid_neighbors = neighbors[valid_mask]
    valid_sims = sims[valid_mask]

    if len(valid_neighbors) > 0:
        neighbor_datasets = sample_dataset_map[valid_neighbors]
        cross_mask = neighbor_datasets != dataset_i

        final_neighbors = valid_neighbors[cross_mask]
        final_sims = valid_sims[cross_mask]

        for j, sim in zip(final_neighbors, final_sims):
            all_edges_list.append([i, j])
            all_weights_list.append(sim)

all_edges = np.array(all_edges_list, dtype=np.int32)
all_weights = np.array(all_weights_list, dtype=np.float32)

logger.info(f"Precomputed {len(all_edges):,} edges in {time.time()-start_time:.1f}s")

timing_logger.end_stage()

timing_logger.start_stage("05_coarse_sweep")

logger.info("\n" + "="*80)
logger.info("Stage 1: Coarse parallel sweep")
logger.info("="*80)

coarse_args = []
for threshold_pct in threshold_percentiles_coarse:
    if threshold_pct in threshold_lookup:
        threshold = threshold_lookup[threshold_pct]
    else:
        threshold = np.percentile(cross_similarities, threshold_pct)
        threshold_lookup[threshold_pct] = threshold

    for resolution in resolutions_coarse:
        coarse_args.append((
            threshold_pct,
            threshold,
            resolution,
            all_edges,
            all_weights,
            sample_dataset_map,
            len(sample_embeddings)
        ))

logger.info(f"Testing {len(coarse_args)} combinations with {MAX_WORKERS} workers...")
start_time = time.time()

coarse_results = []
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {executor.submit(evaluate_single_combination, args): args
               for args in coarse_args}

    with tqdm(total=len(futures), desc="Coarse sweep") as pbar:
        for future in as_completed(futures):
            result = future.result()
            if result is not None:
                coarse_results.append(result)
            pbar.update(1)

logger.info(f"Coarse sweep complete in {time.time()-start_time:.1f}s")
logger.info(f"  Valid results: {len(coarse_results)} / {len(coarse_args)}")

timing_logger.end_stage()

timing_logger.start_stage("06_fine_sweep")

logger.info("\n" + "="*80)
logger.info("Stage 2: Fine sweep around best region")
logger.info("="*80)

if len(coarse_results) == 0:
    logger.error("No valid coarse results! Cannot proceed with fine sweep.")
    sys.exit(1)

coarse_df = pd.DataFrame(coarse_results)

best_coarse = coarse_df.loc[coarse_df['stability'].idxmax()]
best_thresh_pct = best_coarse['threshold_percentile']
best_res_coarse = best_coarse['resolution']

logger.info(f"Best coarse result (by stability): P{best_thresh_pct}, res={best_res_coarse:.2e}")
logger.info(f"  Stability: {best_coarse['stability']:.3f}")
logger.info(f"  Clusters: {best_coarse['n_clusters']:,.0f}")
logger.info(f"  Cross-dataset: {best_coarse['pct_cross_clusters']:.1f}%")

thresh_fine = [max(50, best_thresh_pct - 10),
               max(50, best_thresh_pct - 5),
               best_thresh_pct,
               min(95, best_thresh_pct + 5),
               min(95, best_thresh_pct + 10)]
thresh_fine = sorted(list(set(thresh_fine)))

log_res = np.log10(best_res_coarse)
res_fine = np.logspace(log_res - 0.5, log_res + 0.5, 7)

logger.info(f"\nFine grid:")
logger.info(f"  Thresholds: {thresh_fine}")
logger.info(f"  Resolutions: {len(res_fine)} values around {best_res_coarse:.2e}")
logger.info(f"  Total: {len(thresh_fine) * len(res_fine)} combinations")

fine_args = []
for threshold_pct in thresh_fine:
    if threshold_pct in threshold_lookup:
        threshold = threshold_lookup[threshold_pct]
    else:
        threshold = np.percentile(cross_similarities, threshold_pct)
        threshold_lookup[threshold_pct] = threshold

    for resolution in res_fine:
        already_tested = any(
            np.isclose(r['threshold_percentile'], threshold_pct) and
            np.isclose(r['resolution'], resolution, rtol=0.1)
            for r in coarse_results
        )
        if not already_tested:
            fine_args.append((
                threshold_pct,
                threshold,
                resolution,
                all_edges,
                all_weights,
                sample_dataset_map,
                len(sample_embeddings)
            ))

logger.info(f"Testing {len(fine_args)} new combinations...")
start_time = time.time()

fine_results = []
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {executor.submit(evaluate_single_combination, args): args
               for args in fine_args}

    with tqdm(total=len(futures), desc="Fine sweep") as pbar:
        for future in as_completed(futures):
            result = future.result()
            if result is not None:
                fine_results.append(result)
            pbar.update(1)

logger.info(f"Fine sweep complete in {time.time()-start_time:.1f}s")
logger.info(f"  Valid results: {len(fine_results)} / {len(fine_args)}")

timing_logger.end_stage()

timing_logger.start_stage("07_parameter_analysis")

logger.info("\n" + "="*80)
logger.info("Final analysis - selection by stability")
logger.info("="*80)

all_results = coarse_results + fine_results
sweep_df = pd.DataFrame(all_results)

sweep_df = sweep_df.sort_values('stability', ascending=False)
sweep_df.to_csv(RESULTS_DIR / 'joint_parameter_sweep_results.csv', index=False)

best_params = sweep_df.iloc[0]
best_threshold = best_params['threshold_value']
best_resolution = best_params['resolution']

logger.info("\n" + "="*80)
logger.info("Top 5 parameter combinations (by stability)")
logger.info("="*80)

for idx, (i, row) in enumerate(sweep_df.head(10).iterrows(), 1):
    logger.info(f"\n#{idx}. Threshold: P{row['threshold_percentile']} ({row['threshold_value']:.4f}), "
               f"Resolution: {row['resolution']:.2e}")
    logger.info(f"     Stability: {row['stability']:.3f}")
    logger.info(f"     Clusters: {row['n_clusters']:,.0f}, Singletons: {row['n_singleton']:,.0f}")
    logger.info(f"     Cross-dataset: {row['n_cross_clusters']:,.0f} ({row['pct_cross_clusters']:.1f}%)")
    logger.info(f"     Modularity: {row['modularity']:.3f}, Quality: {row['avg_quality']:.3f}")

logger.info("\n" + "="*80)
logger.info("Selected parameters (highest stability)")
logger.info("="*80)
logger.info(f"Threshold: P{best_params['threshold_percentile']} = {best_threshold:.4f}")
logger.info(f"Resolution: {best_resolution:.6e}")
logger.info(f"Stability: {best_params['stability']:.3f}")
logger.info(f"Clusters: {best_params['n_clusters']:,.0f}")
logger.info(f"Cross-dataset: {best_params['pct_cross_clusters']:.1f}%")
logger.info("="*80)

timing_logger.end_stage()

timing_logger.start_stage("08_visualization")

logger.info("\nCreating visualization...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

pivot_data = sweep_df.pivot_table(
    values='stability',
    index='resolution',
    columns='threshold_percentile',
    aggfunc='first'
)

pivot_data.index = [f"{res:.3e}" for res in pivot_data.index]

ax = axes[0, 0]
sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='RdYlGn', ax=ax,
           cbar_kws={'label': 'Stability'})
ax.set_ylabel('Resolution', fontweight='bold')
ax.set_xlabel('Threshold Percentile', fontweight='bold')
ax.set_title('Stability Heatmap', fontweight='bold', fontsize=14)

ax = axes[0, 1]
ax.hist(sweep_df['stability'], bins=30, color='#0173B2', alpha=0.7, edgecolor='black')
ax.axvline(best_params['stability'], color='red', linestyle='--', linewidth=2,
          label=f"Best: {best_params['stability']:.3f}")
ax.set_xlabel('Stability', fontweight='bold')
ax.set_ylabel('Frequency', fontweight='bold')
ax.set_title('Stability Distribution', fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

ax = axes[1, 0]
for thresh_pct in sorted(sweep_df['threshold_percentile'].unique()):
    data = sweep_df[sweep_df['threshold_percentile'] == thresh_pct]
    ax.plot(data['resolution'], data['stability'], 'o-',
           label=f'P{thresh_pct}', alpha=0.7, markersize=4)
ax.axhline(best_params['stability'], color='red', linestyle='--',
          linewidth=1, alpha=0.5, label=f'Best: {best_params["stability"]:.3f}')
ax.set_xlabel('Resolution', fontweight='bold')
ax.set_ylabel('Stability', fontweight='bold')
ax.set_title('Stability vs Resolution', fontweight='bold')
ax.set_xscale('log')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

ax = axes[1, 1]
scatter = ax.scatter(sweep_df['n_clusters'], sweep_df['stability'],
                    c=sweep_df['threshold_value'], cmap='viridis',
                    s=50, alpha=0.6, edgecolors='black', linewidth=0.5)
ax.scatter(best_params['n_clusters'], best_params['stability'],
          color='red', s=200, marker='*', edgecolors='black', linewidth=2,
          label='Best', zorder=10)
ax.set_xlabel('Number of Clusters', fontweight='bold')
ax.set_ylabel('Stability', fontweight='bold')
ax.set_title('Stability vs Cluster Count (colored by threshold)', fontweight='bold')
ax.set_xscale('log')
plt.colorbar(scatter, ax=ax, label='Threshold Value')
ax.legend()
ax.grid(alpha=0.3)

fig.suptitle(f'Fast Joint Parameter Sweep - Stability-Based Selection (n={len(sweep_df)} combinations)',
            fontsize=16, fontweight='bold', y=0.995)

plt.tight_layout()
plot_path = RESULTS_DIR / 'fast_parameter_sweep_summary.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
logger.info(f"Plot saved: {plot_path}")
plt.close()

timing_logger.end_stage()

summary = {
    'best_threshold_percentile': best_params['threshold_percentile'],
    'best_threshold_value': best_threshold,
    'best_resolution': best_resolution,
    'stability': best_params['stability'],
    'n_clusters': best_params['n_clusters'],
    'n_singleton': best_params['n_singleton'],
    'avg_cluster_size': best_params['avg_cluster_size'],
    'n_cross_clusters': best_params['n_cross_clusters'],
    'pct_cross_clusters': best_params['pct_cross_clusters'],
    'modularity': best_params['modularity'],
    'n_combinations_tested': len(sweep_df),
    'sample_size': SAMPLE_SIZE,
    'n_bootstrap': N_BOOTSTRAP,
    'selection_criterion': 'stability'
}

pd.DataFrame([summary]).to_csv(RESULTS_DIR / 'best_parameters_summary.csv', index=False)

params_for_clustering = {
    'threshold': best_threshold,
    'resolution': best_resolution
}

with open(RESULTS_DIR / 'optimal_parameters.pkl', 'wb') as f:
    pickle.dump(params_for_clustering, f)

logger.info(f"\nResults saved to: {RESULTS_DIR}")
logger.info("="*80)
logger.info("Fast joint parameter sweep complete")
logger.info("="*80)

timing_logger.start_stage("09_graph_construction")

logger.info("\n" + "="*80)
logger.info("Stage 3: Fast approximate graph construction")
logger.info("="*80)

logger.info("Normalizing embeddings...")
faiss.normalize_L2(embeddings)

logger.info("Building FAISS IVF index...")
logger.info(f"Using {N_THREADS} threads for FAISS")
faiss.omp_set_num_threads(N_THREADS)

start_time = time.time()

nlist = min(int(4 * np.sqrt(len(embeddings))), 16384)
nprobe = FAISS_NPROBE

logger.info(f"FAISS params: nlist={nlist}, nprobe={nprobe}")
logger.info(f"Targeting {N_NEIGHBORS} neighbors per node")


on_gpu = False

if torch.cuda.is_available():
    try:
        logger.info("Attempting GPU setup for FAISS...")

        res = faiss.StandardGpuResources()

        cpu_quantizer = faiss.IndexFlatIP(dimension)
        cpu_index = faiss.IndexIVFFlat(cpu_quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)

        index_full = faiss.index_cpu_to_gpu(res, 0, cpu_index)

        logger.info("Successfully created GPU index")
        on_gpu = True

    except Exception as e:
        logger.warning(f"GPU setup failed: {str(e)}")
        logger.info("Falling back to CPU (this is fine, just slower)")

        quantizer = faiss.IndexFlatIP(dimension)
        index_full = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
        on_gpu = False
else:
    logger.info("CUDA not available, using CPU")
    quantizer = faiss.IndexFlatIP(dimension)
    index_full = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
    on_gpu = False

logger.info(f"Training index on {'GPU' if on_gpu else 'CPU'}...")

if len(embeddings) > 1000000:
    train_sample_size = min(500000, len(embeddings))
    train_indices = np.random.choice(len(embeddings), train_sample_size, replace=False)
    train_data = embeddings[train_indices].copy()
    logger.info(f"Training on sample of {train_sample_size:,} vectors...")
    index_full.train(train_data)
else:
    index_full.train(embeddings)

logger.info("Adding vectors to index...")
index_full.add(embeddings)
index_full.nprobe = nprobe

logger.info(f"FAISS index built in {time.time()-start_time:.1f}s (mode: {'GPU' if on_gpu else 'CPU'})")

logger.info(f"Searching for {N_NEIGHBORS} nearest neighbors...")
SEARCH_BATCH_SIZE = 50000

start_time = time.time()
all_similarities = []
all_indices = []

n_search_batches = (len(embeddings) + SEARCH_BATCH_SIZE - 1) // SEARCH_BATCH_SIZE
logger.info(f"Processing {n_search_batches} search batches of size {SEARCH_BATCH_SIZE:,}...")

for i in tqdm(range(0, len(embeddings), SEARCH_BATCH_SIZE), desc="Neighbor search", total=n_search_batches):
    batch_end = min(i + SEARCH_BATCH_SIZE, len(embeddings))
    batch_emb = embeddings[i:batch_end]

    D, I = index_full.search(batch_emb, N_NEIGHBORS)
    all_similarities.append(D)
    all_indices.append(I)

all_similarities = np.vstack(all_similarities)
all_indices = np.vstack(all_indices)

logger.info(f"Neighbor search complete in {time.time()-start_time:.1f}s")

timing_logger.end_stage()

timing_logger.start_stage("10_edge_construction")

logger.info("\n" + "="*80)
logger.info("Stage 4: Memory-efficient edge construction")
logger.info("="*80)

edge_checkpoint = CHECKPOINT_DIR / 'edges_checkpoint.npz'
threshold = best_threshold

if edge_checkpoint.exists():
    logger.info("Found edge checkpoint - loading...")
    edge_data = np.load(edge_checkpoint)
    all_edges = [(int(i), int(j)) for i, j in edge_data['edges']]
    all_weights = edge_data['weights'].tolist()
    logger.info(f"Loaded {len(all_edges):,} edges from checkpoint")

else:
    logger.info("Building edge list with cross-dataset filtering...")
    start_time = time.time()

    all_edges = []
    all_weights = []

    EDGE_BATCH_SIZE = 20000

    for batch_start in tqdm(range(0, len(embeddings), EDGE_BATCH_SIZE), desc="Building edges"):
        batch_end = min(batch_start + EDGE_BATCH_SIZE, len(embeddings))

        batch_size = batch_end - batch_start

        for local_idx in range(batch_size):
            node_idx = batch_start + local_idx
            dataset_i = source_datasets[node_idx]

            neighbors = all_indices[node_idx, 1:]
            sims = all_similarities[node_idx, 1:]

            valid_mask = (neighbors > node_idx) & (sims >= threshold)
            valid_neighbors = neighbors[valid_mask]
            valid_sims = sims[valid_mask]

            if len(valid_neighbors) > 0:
                neighbor_datasets = np.array([source_datasets[n] for n in valid_neighbors])
                cross_dataset_mask = neighbor_datasets != dataset_i

                final_neighbors = valid_neighbors[cross_dataset_mask]
                final_sims = valid_sims[cross_dataset_mask]

                for neighbor, sim in zip(final_neighbors, final_sims):
                    all_edges.append((node_idx, int(neighbor)))
                    all_weights.append(float(sim))

        if (batch_start // EDGE_BATCH_SIZE) % 10 == 0 and batch_start > 0:
            import gc
            gc.collect()

    logger.info(f"{len(all_edges):,} cross-dataset edges in {time.time()-start_time:.1f}s")

    logger.info("Saving edge checkpoint...")
    np.savez_compressed(
        CHECKPOINT_DIR / 'edges_checkpoint.npz',
        edges=np.array(all_edges, dtype=np.int32),
        weights=np.array(all_weights, dtype=np.float32)
    )
    logger.info("Edge checkpoint saved")

timing_logger.end_stage()

timing_logger.start_stage("11_verse_clustering")

logger.info("\n" + "="*80)
logger.info("Stage 5: Hierarchical Leiden clustering")
logger.info("="*80)

logger.info("Building graph...")
g = ig.Graph(n=len(embeddings), edges=all_edges, directed=False)
g.es['weight'] = all_weights

logger.info(f"Full graph: {g.vcount():,} nodes, {g.ecount():,} edges")

logger.info("\nStep 1: Coarse clustering (fast)...")
start_time = time.time()

coarse_partition = la.find_partition(
    g,
    la.CPMVertexPartition,
    weights='weight',
    resolution_parameter=0.01,
    n_iterations=3,
    seed=42
)

coarse_labels = np.array(coarse_partition.membership)
n_coarse = len(set(coarse_labels))

logger.info(f"{n_coarse:,} coarse clusters in {time.time()-start_time:.1f}s")

coarse_cluster_info = defaultdict(lambda: {'datasets': set(), 'nodes': []})

for idx, cid in enumerate(coarse_labels):
    coarse_cluster_info[cid]['datasets'].add(source_datasets[idx])
    coarse_cluster_info[cid]['nodes'].append(idx)

cross_dataset_clusters = [cid for cid, info in coarse_cluster_info.items()
                          if len(info['datasets']) > 1]

logger.info(f"Cross-dataset coarse clusters: {len(cross_dataset_clusters):,}")

logger.info(f"\nStep 2: Refining with resolution {best_resolution:.4f}...")

final_labels = coarse_labels.copy()
next_cluster_id = n_coarse

cluster_sizes = [(cid, len(coarse_cluster_info[cid]['nodes']))
                 for cid in cross_dataset_clusters]
cluster_sizes.sort(key=lambda x: x[1])

refined_count = 0
skipped_small = 0
skipped_large = 0

for coarse_cid, size in tqdm(cluster_sizes, desc="Refining"):
    if size < 10:
        skipped_small += 1
        continue

    if size > 100000:
        skipped_large += 1
        continue

    try:
        cluster_nodes = coarse_cluster_info[coarse_cid]['nodes']
        subg = g.subgraph(cluster_nodes)

        if subg.ecount() > 0:
            sub_partition = la.find_partition(
                subg,
                la.CPMVertexPartition,
                weights='weight',
                resolution_parameter=best_resolution,
                n_iterations=10,
                seed=42
            )

            sub_labels = np.array(sub_partition.membership)

            if len(set(sub_labels)) > 1:
                for sub_idx, sub_cid in enumerate(sub_labels):
                    global_idx = cluster_nodes[sub_idx]
                    final_labels[global_idx] = next_cluster_id + sub_cid

                next_cluster_id += len(set(sub_labels))
                refined_count += 1

    except Exception as e:
        logger.warning(f"Error refining cluster {coarse_cid}: {e}")
        continue

logger.info(f"Refined: {refined_count:,}, Skipped small: {skipped_small:,}, "
            f"Skipped large: {skipped_large:,}")

timing_logger.end_stage()

timing_logger.start_stage("12_verse_results")

logger.info("\n" + "="*80)
logger.info("Analyzing verse clustering results")
logger.info("="*80)

cluster_datasets = defaultdict(set)
cluster_verses = defaultdict(list)

for idx, cid in enumerate(final_labels):
    cluster_datasets[cid].add(source_datasets[idx])
    cluster_verses[cid].append(idx)

n_total_clusters = len(set(final_labels))
n_cross = sum(1 for ds in cluster_datasets.values() if len(ds) > 1)
cross_verses = sum(len(cluster_verses[cid]) for cid in set(final_labels)
                  if len(cluster_datasets[cid]) > 1)

logger.info(f"\nVerse-level clustering results:")
logger.info(f"  Total clusters: {n_total_clusters:,}")
logger.info(f"  Cross-dataset clusters: {n_cross:,} ({n_cross/n_total_clusters*100:.1f}%)")
logger.info(f"  Cross-dataset verses: {cross_verses:,} ({cross_verses/len(final_labels)*100:.1f}%)")

df['cluster_id'] = final_labels

cluster_info = []
for cid, datasets in cluster_datasets.items():
    if len(datasets) > 1:
        verses = cluster_verses[cid]
        counts = Counter(source_datasets[i] for i in verses)
        cluster_info.append({
            'cluster_id': cid,
            'size': len(verses),
            'n_datasets': len(datasets),
            'datasets': ', '.join(sorted(datasets)),
            'dataset_counts': dict(counts)
        })

cluster_info_df = pd.DataFrame(cluster_info).sort_values('size', ascending=False)

logger.info(f"\nTop 20 cross-dataset clusters:")
if len(cluster_info_df) > 0:
    print(cluster_info_df.head(20).to_string(index=False))

logger.info("\n" + "="*80)
logger.info("Saving verse clustering results")
logger.info("="*80)

df.to_csv(RESULTS_DIR / "concatenated_cross_dataset_clusters.csv", index=False)
logger.info(f"Full results: {RESULTS_DIR / 'concatenated_cross_dataset_clusters.csv'}")

cross_mask = df['cluster_id'].map(lambda cid: len(cluster_datasets[cid]) > 1)
df_cross = df[cross_mask].copy()

df_cross = df_cross[['verse', 'source_dataset', 'cluster_id']]
df_cross.to_csv(RESULTS_DIR / "cross_dataset_verses_only.csv", index=False)
logger.info(f"Cross-dataset only: {RESULTS_DIR / 'cross_dataset_verses_only.csv'} ({len(df_cross):,} verses)")

cluster_info_df.to_csv(RESULTS_DIR / 'cross_dataset_clusters_summary.csv', index=False)
logger.info(f"Cluster summary: {RESULTS_DIR / 'cross_dataset_clusters_summary.csv'}")

verse_summary = {
    'n_verses': len(df),
    'n_datasets': len(set(source_datasets)),
    'threshold': threshold,
    'resolution': best_resolution,
    'n_total_clusters': n_total_clusters,
    'n_cross_clusters': n_cross,
    'pct_cross_clusters': n_cross/n_total_clusters*100,
    'n_cross_verses': cross_verses,
    'pct_cross_verses': cross_verses/len(df)*100
}

pd.DataFrame([verse_summary]).to_csv(RESULTS_DIR / 'verse_clustering_summary.csv', index=False)
logger.info(f"Summary: {RESULTS_DIR / 'verse_clustering_summary.csv'}")

timing_logger.end_stage()

from numba import njit

INPUT_CSV = RESULTS_DIR / "concatenated_cross_dataset_clusters.csv"

POEM_SAMPLE_SIZE = 15000
JACCARD_THRESHOLD_RANGE = (0.8, 0.9, 2)
POEM_MAX_WORKERS = 32
POEM_BATCH_SIZE = 50000

logger.info("\n" + "="*80)
logger.info("Poem-level clustering based on verse cluster Jaccard similarity")
logger.info("="*80)
logger.info(f"Input: {INPUT_CSV}")
logger.info(f"Sample size: {POEM_SAMPLE_SIZE:,}")
logger.info(f"Jaccard threshold range: {JACCARD_THRESHOLD_RANGE}")
logger.info(f"Workers: {POEM_MAX_WORKERS}")
logger.info("="*80)

@njit
def compute_jaccard_similarity(a_clusters, b_clusters):
    a_set = set(a_clusters)
    b_set = set(b_clusters)

    intersection = len(a_set & b_set)
    union = len(a_set | b_set)

    if union == 0:
        return 0.0

    return intersection / union

class UnionFind:
    __slots__ = ['parent', 'rank']

    def __init__(self, elements):
        self.parent = {e: e for e in elements}
        self.rank = {e: 0 for e in elements}

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
        return True

    def get_clusters(self):
        clusters = defaultdict(set)
        for elem in self.parent.keys():
            clusters[self.find(elem)].add(elem)
        return dict(clusters)

def reconstruct_poems_from_verses(df):
    timing_logger.start_stage("13_reconstruct_poems")

    logger.info("\nReconstructing poems from verses...")

    required_cols = ['idoriginal_poem', 'cluster_id', 'source_dataset']
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"Error: '{col}' column not found in input CSV.")

    valid_mask = df['cluster_id'] != -1
    df_valid = df[valid_mask].copy()

    logger.info(f"  Total verses: {len(df):,}")
    logger.info(f"  Valid verses (cluster_id != -1): {len(df_valid):,}")

    df_valid['poem_composite_id'] = df_valid['idoriginal_poem'].astype(str) + '___' + df_valid['source_dataset'].astype(str)

    if 'order' in df_valid.columns:
        df_valid = df_valid.sort_values(['poem_composite_id', 'order'])
    else:
        df_valid = df_valid.sort_values('poem_composite_id')

    poem_to_clusters = {}
    poem_to_dataset = {}
    poem_to_size = {}

    for composite_id, group in df_valid.groupby('poem_composite_id'):
        cluster_sequence = group['cluster_id'].values
        poem_to_clusters[composite_id] = np.unique(cluster_sequence).astype(np.int32)
        poem_to_dataset[composite_id] = group['source_dataset'].iloc[0]
        poem_to_size[composite_id] = len(poem_to_clusters[composite_id])

    logger.info(f"  Reconstructed {len(poem_to_clusters):,} poems")
    logger.info(f"  Average unique clusters per poem: {np.mean(list(poem_to_size.values())):.1f}")

    dataset_counts = defaultdict(int)
    for dataset in poem_to_dataset.values():
        dataset_counts[dataset] += 1

    logger.info(f"  Poems by dataset:")
    for dataset, count in sorted(dataset_counts.items()):
        logger.info(f"      {dataset}: {count:,}")

    poems_by_dataset = defaultdict(list)
    for poem_id, dataset in poem_to_dataset.items():
        poems_by_dataset[dataset].append(poem_id)

    poem_metadata = {
        'poem_to_size': poem_to_size,
        'poems_by_dataset': dict(poems_by_dataset)
    }

    timing_logger.end_stage()
    return poem_to_clusters, poem_to_dataset, poem_metadata

def build_cluster_to_poems_index(poem_to_clusters):
    logger.info("\nBuilding cluster-to-poems inverted index...")
    cluster_to_poems = defaultdict(list)
    for poem_id, cluster_array in tqdm(poem_to_clusters.items(), desc="Indexing"):
        for cluster_id in cluster_array:
            cluster_to_poems[int(cluster_id)].append(poem_id)

    for cluster_id in cluster_to_poems:
        cluster_to_poems[cluster_id] = np.array(cluster_to_poems[cluster_id])

    return dict(cluster_to_poems)

def stratified_sample_poems(poem_to_clusters, poem_to_dataset, poem_to_size, n_sample=10000):
    logger.info(f"\nStratified sampling of {n_sample:,} poems...")

    np.random.seed(RANDOM_SEED)

    poem_metadata = []
    for poem_id in poem_to_clusters.keys():
        metadata = {
            'poem_id': poem_id,
            'n_verses': poem_to_size[poem_id],
            'source': poem_to_dataset[poem_id]
        }
        poem_metadata.append(metadata)

    poem_df = pd.DataFrame(poem_metadata)

    poem_df['size_bin'] = pd.cut(poem_df['n_verses'],
                                  bins=[0, 5, 10, 20, 50, np.inf],
                                  labels=['tiny', 'small', 'medium', 'large', 'huge'])

    sample_indices = []

    for (source, size_bin), group in poem_df.groupby(['source', 'size_bin']):
        n_in_group = len(group)
        proportion = n_in_group / len(poem_df)
        n_sample_group = max(1, int(n_sample * proportion))
        n_sample_group = min(n_sample_group, n_in_group)

        sampled = group.sample(n=n_sample_group, random_state=RANDOM_SEED)
        sample_indices.extend(sampled['poem_id'].tolist())

    if len(sample_indices) < n_sample:
        remaining = n_sample - len(sample_indices)
        available = set(poem_to_clusters.keys()) - set(sample_indices)
        if available:
            additional = np.random.choice(list(available),
                                         size=min(remaining, len(available)),
                                         replace=False)
            sample_indices.extend(additional)

    sample_indices = sample_indices[:n_sample]
    logger.info(f"  Sampled {len(sample_indices):,} poems")

    return sample_indices

def find_cross_dataset_candidate_pairs_batched(poem_to_clusters, poem_to_dataset, cluster_to_poems, poems_by_dataset):
    timing_logger.start_stage("14_find_poem_pairs")

    logger.info("\nFinding cross-dataset candidate pairs (batched)...")

    datasets = list(poems_by_dataset.keys())
    pair_file = CHECKPOINT_DIR / "candidate_pairs.npz"

    if pair_file.exists():
        logger.info("  Loading cached candidate pairs...")
        data = np.load(pair_file)
        timing_logger.end_stage()
        return set(zip(data['p1'], data['p2']))

    all_pairs = set()

    for i, dataset1 in enumerate(datasets):
        for dataset2 in datasets[i+1:]:
            logger.info(f"  Processing {dataset1} x {dataset2}...")
            poems1 = poems_by_dataset[dataset1]
            poems2 = poems_by_dataset[dataset2]
            poems2_set = set(poems2)

            batch_pairs = set()

            for poem_id in tqdm(poems1, desc=f"  {dataset1}"):
                clusters = poem_to_clusters[poem_id]

                candidates = set()
                for cluster_id in clusters:
                    if int(cluster_id) in cluster_to_poems:
                        candidates.update(cluster_to_poems[int(cluster_id)])

                candidates = candidates & poems2_set

                for other_poem in candidates:
                    pair = tuple(sorted([poem_id, other_poem]))
                    batch_pairs.add(pair)

            all_pairs.update(batch_pairs)
            logger.info(f"    Found {len(batch_pairs):,} pairs")

    logger.info(f"  Total candidate pairs: {len(all_pairs):,}")

    if all_pairs:
        p1_list, p2_list = zip(*all_pairs)
        np.savez_compressed(pair_file, p1=p1_list, p2=p2_list)
        logger.info(f"  Cached pairs to {pair_file}")

    timing_logger.end_stage()
    return all_pairs

def compute_cluster_cohesion(poem_to_clusters, cluster_assignments, max_sample=500):
    poem_ids = list(poem_to_clusters.keys())
    cohesions = []

    clusters = defaultdict(list)
    for poem_id in poem_ids:
        cluster_id = cluster_assignments.get(poem_id)
        if cluster_id is not None:
            clusters[cluster_id].append(poem_id)

    for cluster_id, cluster_poems in clusters.items():
        if len(cluster_poems) < 2:
            continue

        if len(cluster_poems) > 30:
            sampled = np.random.choice(cluster_poems, 30, replace=False)
        else:
            sampled = cluster_poems

        overlaps = []
        for i in range(len(sampled)):
            for j in range(i+1, min(i+10, len(sampled))):
                jaccard = compute_jaccard_similarity(
                    poem_to_clusters[sampled[i]],
                    poem_to_clusters[sampled[j]]
                )
                overlaps.append(jaccard)

        if overlaps:
            cohesions.append(np.mean(overlaps))

        if len(cohesions) >= max_sample:
            break

    return np.mean(cohesions) if cohesions else 0.0

def evaluate_single_poem_config(args):
    jaccard_thresh, sample_poems, poem_to_clusters, candidate_pairs, poem_to_dataset = args

    try:
        sample_set = set(sample_poems)
        cross_dataset_pairs = set()

        for p1, p2 in candidate_pairs:
            if p1 in sample_set and p2 in sample_set:
                cross_dataset_pairs.add((p1, p2))

        uf = UnionFind(sample_poems)
        merges = 0

        for p1, p2 in cross_dataset_pairs:
            jaccard = compute_jaccard_similarity(
                poem_to_clusters[p1],
                poem_to_clusters[p2]
            )

            if jaccard >= jaccard_thresh:
                if uf.union(p1, p2):
                    merges += 1

        poem_clusters = uf.get_clusters()
        cluster_assignments = {}
        for cluster_id, poems in poem_clusters.items():
            for poem in poems:
                cluster_assignments[poem] = cluster_id

        n_clusters = len(poem_clusters)
        cluster_sizes = [len(poems) for poems in poem_clusters.values()]
        n_singletons = sum(1 for size in cluster_sizes if size == 1)
        avg_size = np.mean(cluster_sizes)
        max_size = max(cluster_sizes) if cluster_sizes else 0

        n_cross_dataset_clusters = 0
        for cluster_id, poems in poem_clusters.items():
            datasets = set(poem_to_dataset.get(p) for p in poems)
            if len(datasets) > 1:
                n_cross_dataset_clusters += 1

        cohesion = compute_cluster_cohesion(poem_to_clusters, cluster_assignments)

        return {
            'jaccard_threshold': jaccard_thresh,
            'n_clusters': n_clusters,
            'n_singletons': n_singletons,
            'n_cross_dataset_clusters': n_cross_dataset_clusters,
            'avg_cluster_size': avg_size,
            'max_cluster_size': max_size,
            'cohesion': cohesion,
            'merges': merges,
            'n_cross_dataset_pairs': len(cross_dataset_pairs)
        }

    except Exception as e:
        logger.warning(f"Error at jaccard={jaccard_thresh:.2f}: {e}")
        return None

def grid_search_poem_parameters(poem_to_clusters, poem_to_dataset, poem_to_size, poems_by_dataset):
    timing_logger.start_stage("15_poem_parameter_search")

    logger.info("\n" + "="*80)
    logger.info("Grid search: Jaccard threshold")
    logger.info("="*80)

    sample_poems = stratified_sample_poems(poem_to_clusters, poem_to_dataset,
                                          poem_to_size, POEM_SAMPLE_SIZE)

    sample_poems_set = set(sample_poems)
    sample_poems_by_dataset = defaultdict(list)
    for poem_id in sample_poems:
        sample_poems_by_dataset[poem_to_dataset[poem_id]].append(poem_id)

    cluster_to_poems = build_cluster_to_poems_index({p: poem_to_clusters[p] for p in sample_poems})

    candidate_pairs = find_cross_dataset_candidate_pairs_batched(
        {p: poem_to_clusters[p] for p in sample_poems},
        poem_to_dataset,
        cluster_to_poems,
        sample_poems_by_dataset
    )

    jaccard_thresholds = np.linspace(JACCARD_THRESHOLD_RANGE[0],
                                     JACCARD_THRESHOLD_RANGE[1],
                                     int(JACCARD_THRESHOLD_RANGE[2]))

    logger.info(f"\nParameter grid:")
    logger.info(f"  Jaccard thresholds: {len(jaccard_thresholds)} values from {jaccard_thresholds[0]:.2f} to {jaccard_thresholds[-1]:.2f}")

    args_list = []
    for jaccard_thresh in jaccard_thresholds:
        args_list.append((
            jaccard_thresh, sample_poems, poem_to_clusters,
            candidate_pairs, poem_to_dataset
        ))

    logger.info(f"\nRunning grid search with {POEM_MAX_WORKERS} workers...")
    start_time = time.time()
    results = []

    with ProcessPoolExecutor(max_workers=POEM_MAX_WORKERS) as executor:
        futures = {executor.submit(evaluate_single_poem_config, args): args for args in args_list}
        with tqdm(total=len(futures), desc="Grid search") as pbar:
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    results.append(result)
                pbar.update(1)

    logger.info(f"Grid search complete in {time.time() - start_time:.1f}s")

    results_df = pd.DataFrame(results)

    def normalize(series):
        min_val = series.min()
        max_val = series.max()
        if max_val - min_val < 1e-10:
            return pd.Series(0.5, index=series.index)
        return (series - min_val) / (max_val - min_val)

    cohesion_score = normalize(results_df['cohesion'])

    cross_dataset_ratio = results_df['n_cross_dataset_clusters'] / (results_df['n_clusters'] + 1e-10)
    cross_dataset_score = normalize(cross_dataset_ratio)

    singleton_ratio = results_df['n_singletons'] / len(sample_poems)
    balance_score = np.clip(1 - singleton_ratio, 0, 1)

    results_df['quality_score'] = (
        cohesion_score * 0.50 +
        cross_dataset_score * 0.30 +
        balance_score * 0.20
    )

    results_df = results_df.sort_values('quality_score', ascending=False)
    results_csv = RESULTS_DIR / 'poem_parameter_grid_search.csv'
    results_df.to_csv(results_csv, index=False)
    logger.info(f"Results saved: {results_csv}")

    logger.info("\nCreating poem parameter visualizations...")

    sns.set_palette("colorblind")
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('Poem-level cross-dataset clustering: Grid search', fontsize=16, fontweight='bold')

    best = results_df.iloc[0]

    ax = axes[0, 0]
    ax.plot(results_df['jaccard_threshold'], results_df['quality_score'], 'o-', linewidth=2, markersize=6)
    ax.axvline(best['jaccard_threshold'], color='red', linestyle='--', linewidth=2, label='Best')
    ax.set_xlabel('Jaccard Threshold', fontweight='bold')
    ax.set_ylabel('Quality Score', fontweight='bold')
    ax.set_title('Quality vs Jaccard', fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    ax = axes[0, 1]
    ax.plot(results_df['jaccard_threshold'], results_df['cohesion'], 'o-', linewidth=2, markersize=6, color='orange')
    ax.axvline(best['jaccard_threshold'], color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Jaccard Threshold', fontweight='bold')
    ax.set_ylabel('Cohesion', fontweight='bold')
    ax.set_title('Cohesion vs Jaccard', fontweight='bold')
    ax.grid(alpha=0.3)

    ax = axes[0, 2]
    ax.plot(results_df['jaccard_threshold'], results_df['n_cross_dataset_clusters'], 'o-', linewidth=2, markersize=6, color='green')
    ax.axvline(best['jaccard_threshold'], color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Jaccard Threshold', fontweight='bold')
    ax.set_ylabel('Cross-Dataset Clusters', fontweight='bold')
    ax.set_title('Cross-Dataset Clusters vs Jaccard', fontweight='bold')
    ax.grid(alpha=0.3)

    ax = axes[1, 0]
    ax.plot(results_df['jaccard_threshold'], results_df['n_clusters'], 'o-', linewidth=2, markersize=6, color='purple')
    ax.axvline(best['jaccard_threshold'], color='red', linestyle='--', linewidth=2)
    ax.set_xlabel('Jaccard Threshold', fontweight='bold')
    ax.set_ylabel('Total Clusters', fontweight='bold')
    ax.set_title('Total Clusters vs Jaccard', fontweight='bold')
    ax.grid(alpha=0.3)

    ax = axes[1, 1]
    ax.scatter(results_df['cohesion'], results_df['n_cross_dataset_clusters'],
              c=results_df['quality_score'], cmap='RdYlGn', s=100, edgecolors='black')
    ax.scatter(best['cohesion'], best['n_cross_dataset_clusters'],
              color='red', s=300, marker='*', edgecolors='black', linewidth=2, zorder=10)
    ax.set_xlabel('Cohesion', fontweight='bold')
    ax.set_ylabel('Cross-Dataset Clusters', fontweight='bold')
    ax.set_title('Cohesion vs Cross-Dataset', fontweight='bold')
    ax.grid(alpha=0.3)

    ax = axes[1, 2]
    ax.hist(results_df['quality_score'], bins=15, color='#0173B2', alpha=0.7, edgecolor='black')
    ax.axvline(best['quality_score'], color='red', linestyle='--', linewidth=2, label=f"Best: {best['quality_score']:.3f}")
    ax.set_xlabel('Quality Score', fontweight='bold')
    ax.set_ylabel('Frequency', fontweight='bold')
    ax.set_title('Score Distribution', fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plot_path = RESULTS_DIR / 'poem_grid_search_comprehensive.png'
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    logger.info(f"Visualization saved: {plot_path}")
    plt.close()

    best_jaccard = float(best['jaccard_threshold'])

    logger.info("\n" + "="*80)
    logger.info("Selected configuration (highest quality)")
    logger.info("="*80)
    logger.info(f"Jaccard threshold: {best_jaccard:.3f}")
    logger.info(f"Quality score: {best['quality_score']:.3f}")
    logger.info(f"Cohesion: {best['cohesion']:.3f}")
    logger.info(f"Cross-dataset clusters: {best['n_cross_dataset_clusters']}")

    timing_logger.end_stage()
    return best_jaccard, results_df

def cluster_all_poems_batched(poem_to_clusters, poem_to_dataset, poems_by_dataset, jaccard_threshold):
    timing_logger.start_stage("16_poem_clustering")

    logger.info("\n" + "="*80)
    logger.info("Clustering all poems (batched, cross-dataset only)")
    logger.info("="*80)
    logger.info(f"Jaccard threshold: {jaccard_threshold:.3f}")

    cluster_to_poems = build_cluster_to_poems_index(poem_to_clusters)

    candidate_pairs = find_cross_dataset_candidate_pairs_batched(
        poem_to_clusters, poem_to_dataset, cluster_to_poems, poems_by_dataset
    )

    logger.info(f"\nClustering {len(poem_to_clusters):,} poems...")
    poem_ids = list(poem_to_clusters.keys())
    uf = UnionFind(poem_ids)

    merges = 0
    batches = [list(candidate_pairs)[i:i+POEM_BATCH_SIZE] for i in range(0, len(candidate_pairs), POEM_BATCH_SIZE)]

    for batch in tqdm(batches, desc="Processing batches"):
        for p1, p2 in batch:
            jaccard = compute_jaccard_similarity(
                poem_to_clusters[p1],
                poem_to_clusters[p2]
            )

            if jaccard >= jaccard_threshold:
                if uf.union(p1, p2):
                    merges += 1

    logger.info(f"  Performed {merges:,} merges")

    poem_clusters = uf.get_clusters()
    cluster_assignments = {}
    for cluster_id, poems in poem_clusters.items():
        for poem in poems:
            cluster_assignments[poem] = cluster_id

    n_clusters = len(poem_clusters)
    cluster_sizes = [len(poems) for poems in poem_clusters.values()]
    n_singletons = sum(1 for size in cluster_sizes if size == 1)

    n_cross_dataset_clusters = 0
    for cluster_id, poems in poem_clusters.items():
        datasets = set(poem_to_dataset.get(p) for p in poems)
        if len(datasets) > 1:
            n_cross_dataset_clusters += 1

    logger.info(f"\n  Total poem clusters: {n_clusters:,}")
    logger.info(f"  Cross-dataset clusters: {n_cross_dataset_clusters:,}")
    logger.info(f"  Singleton poems: {n_singletons:,}")
    logger.info(f"  Avg cluster size: {np.mean(cluster_sizes):.2f}")
    logger.info(f"  Max cluster size: {max(cluster_sizes)}")

    timing_logger.end_stage()
    return cluster_assignments, poem_clusters

poem_to_clusters, poem_to_dataset, poem_metadata = reconstruct_poems_from_verses(df)
poem_to_size = poem_metadata['poem_to_size']
poems_by_dataset = poem_metadata['poems_by_dataset']

best_jaccard, grid_results = grid_search_poem_parameters(
    poem_to_clusters, poem_to_dataset, poem_to_size, poems_by_dataset
)

del grid_results
import gc
gc.collect()

cluster_assignments, poem_clusters = cluster_all_poems_batched(
    poem_to_clusters, poem_to_dataset, poems_by_dataset, best_jaccard
)

timing_logger.start_stage("17_poem_results")

df['poem_composite_id'] = df['idoriginal_poem'].astype(str) + '___' + df['source_dataset'].astype(str)
df['poem_cluster_id'] = df['poem_composite_id'].map(cluster_assignments)

cross_dataset_cluster_ids = set()
for cluster_id, poems in poem_clusters.items():
    datasets = set(poem_to_dataset.get(p) for p in poems)
    if len(datasets) > 1:
        cross_dataset_cluster_ids.add(cluster_id)

df['is_cross_dataset_poem_cluster'] = df['poem_cluster_id'].isin(cross_dataset_cluster_ids)

output_csv = RESULTS_DIR / 'poems_clustered_by_verse_jaccard.csv'
df.to_csv(output_csv, index=False)
logger.info(f"\nResults saved: {output_csv}")

poem_summary = {
    'n_verses': len(df),
    'n_poems': len(poem_to_clusters),
    'n_datasets': len(set(poem_to_dataset.values())),
    'best_jaccard_threshold': best_jaccard,
    'n_poem_clusters': len(set(cluster_assignments.values())),
    'n_cross_dataset_clusters': len(cross_dataset_cluster_ids),
    'n_poems_in_cross_dataset_clusters': sum(df['is_cross_dataset_poem_cluster'])
}

pd.DataFrame([poem_summary]).to_csv(RESULTS_DIR / 'poem_clustering_summary.csv', index=False)

logger.info("\n" + "="*80)
logger.info("Poem-level clustering complete")
logger.info("="*80)
logger.info(f"Cross-dataset poem clusters: {poem_summary['n_cross_dataset_clusters']:,}")
logger.info(f"Poems in cross-dataset clusters: {poem_summary['n_poems_in_cross_dataset_clusters']:,}")
logger.info(f"All results saved to: {RESULTS_DIR}/")
logger.info("="*80)

timing_logger.end_stage()

resource_monitor.stop()
total_time = time.time() - script_start_time

system_info = get_system_info()
resource_stats = resource_monitor.get_stats()
timing_summary = timing_logger.get_summary()

report_lines = []
report_lines.append("="*80)
report_lines.append("Comprehensive clustering performance report")
report_lines.append("="*80)
report_lines.append("")

report_lines.append("System information")
report_lines.append("-" * 80)
report_lines.append(f"Hostname:            {system_info['hostname']}")
report_lines.append(f"Platform:            {system_info['platform']}")
report_lines.append(f"Python Version:      {system_info['python_version']}")
report_lines.append(f"Processor:           {system_info['processor']}")
report_lines.append(f"CPU Cores (Physical):{system_info['cpu_count_physical']}")
report_lines.append(f"CPU Cores (Logical): {system_info['cpu_count_logical']}")
report_lines.append(f"Total RAM:           {system_info['total_ram_gb']:.2f} GB")
report_lines.append(f"GPU Available:       {system_info['gpu_available']}")
# if system_info['gpu_available']:
#     report_lines.append(f"GPU Name:            {system_info['gpu_name']}")
#     report_lines.append(f"GPU Memory:          {system_info['gpu_total_memory_gb']:.2f} GB")
# report_lines.append(f"Timestamp:           {system_info['timestamp']}")
report_lines.append("")

report_lines.append("Peak resource usage")
report_lines.append("-" * 80)
report_lines.append(f"Peak RAM Usage:      {resource_stats['peak_ram_gb']:.2f} GB")
report_lines.append(f"Average RAM Usage:   {resource_stats['avg_ram_gb']:.2f} GB")
if system_info['gpu_available']:
    report_lines.append(f"Peak GPU Memory:     {resource_stats['peak_gpu_mem_gb']:.2f} GB")
    report_lines.append(f"Average GPU Memory:  {resource_stats['avg_gpu_mem_gb']:.2f} GB")
report_lines.append("")

report_lines.append("Timing breakdown (by stage)")
report_lines.append("-" * 80)

total_measured = sum(timing_summary.values())
for stage_name, duration in timing_summary.items():
    pct = (duration / total_measured * 100) if total_measured > 0 else 0
    report_lines.append(f"{stage_name:.<50} {duration:>8.1f}s ({pct:>5.1f}%)")

report_lines.append(f"{'Total measured time':.<50} {total_measured:>8.1f}s")
report_lines.append(f"{'Total wall clock time':.<50} {total_time:>8.1f}s ({total_time/60:>6.1f} min)")
report_lines.append("")

report_lines.append("Detailed timing analysis")
report_lines.append("-" * 80)

verse_stages = [k for k in timing_summary.keys() if k.startswith(('01_', '02_', '03_', '04_', '05_', '06_', '07_', '08_', '09_', '10_', '11_', '12_'))]
verse_time = sum(timing_summary.get(k, 0) for k in verse_stages)

poem_stages = [k for k in timing_summary.keys() if k.startswith(('13_', '14_', '15_', '16_', '17_'))]
poem_time = sum(timing_summary.get(k, 0) for k in poem_stages)

report_lines.append(f"Verse-level clustering:  {verse_time:>8.1f}s ({verse_time/60:>6.1f} min)")
report_lines.append(f"Poem-level clustering:   {poem_time:>8.1f}s ({poem_time/60:>6.1f} min)")
report_lines.append("")

report_lines.append("Clustering results summary")
report_lines.append("-" * 80)
report_lines.append("Verse-level:")
report_lines.append(f"  Total verses:             {verse_summary['n_verses']:,}")
report_lines.append(f"  Total clusters:           {verse_summary['n_total_clusters']:,}")
report_lines.append(f"  Cross-dataset clusters:   {verse_summary['n_cross_clusters']:,} ({verse_summary['pct_cross_clusters']:.1f}%)")
report_lines.append(f"  Cross-dataset verses:     {verse_summary['n_cross_verses']:,} ({verse_summary['pct_cross_verses']:.1f}%)")
report_lines.append("")
report_lines.append("Poem-level:")
report_lines.append(f"  Total poems:              {poem_summary['n_poems']:,}")
report_lines.append(f"  Total clusters:           {poem_summary['n_poem_clusters']:,}")
report_lines.append(f"  Cross-dataset clusters:   {poem_summary['n_cross_dataset_clusters']:,}")
report_lines.append(f"  Poems in cross-clusters:  {poem_summary['n_poems_in_cross_dataset_clusters']:,}")
report_lines.append("")

report_lines.append("Performance metrics")
report_lines.append("-" * 80)
if verse_time > 0:
    report_lines.append(f"Verse clustering throughput:  {verse_summary['n_verses'] / verse_time:.1f} verses/sec")
if poem_time > 0:
    report_lines.append(f"Poem clustering throughput:   {poem_summary['n_poems'] / poem_time:.1f} poems/sec")
report_lines.append(f"Overall processing rate:      {verse_summary['n_verses'] / total_time:.1f} verses/sec")
report_lines.append("")

report_lines.append("="*80)
report_lines.append("End of report")
report_lines.append("="*80)

report_path = RESULTS_DIR / 'clustering_performance_report.txt'
with open(report_path, 'w') as f:
    f.write('\n'.join(report_lines))

for line in report_lines:
    logger.info(line)

logger.info(f"\nPerformance report saved to: {report_path}")
logger.info("="*80)
logger.info("All processing complete")
logger.info("="*80)