In [None]:
import os
import torch
import numpy as np
import logging
import glob
from torch_geometric.data import Data, Batch
from collections import defaultdict

# Import the model definition and plotting utilities
from Train_exp import ExperimentalGNN, CONFIG
from diagnostic_plots import create_diagnostic_plots  # Import the plotting function

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler()]
    )

def parse_binary_file(filepath):
    """Parse a file containing binary strings and their counts."""
    states_dict = {}
    with open(filepath, 'r') as f:
        for line in f:
            if line.strip():
                binary_str, count = line.strip().split()
                states_dict[binary_str] = int(count)
    return states_dict

def calculate_probabilities(states_dict):
    """Calculate normalized probabilities from state counts."""
    total_count = sum(states_dict.values())
    return {state: count/total_count for state, count in states_dict.items()}

def binary_to_rydberg_probabilities(binary_states, N):
    """Convert binary states to per-site Rydberg probabilities."""
    p_rydberg = np.zeros(N)
    total_prob = 0
    
    for state, prob in binary_states.items():
        state_int = int(state, 2)
        total_prob += prob
        for i in range(N):
            if (state_int & (1 << i)) != 0:
                p_rydberg[i] += prob
                
    return p_rydberg

def reshape_for_subsystem(psi, A_indices, N):
    """Reshape wavefunction for bipartition."""
    A_indices = sorted(A_indices)
    B_indices = sorted(i for i in range(N) if i not in A_indices)
    N_A = len(A_indices)
    N_B = N - N_A
    
    psi_reshaped = np.zeros((2**N_A, 2**N_B), dtype=np.complex128)
    
    for state_str, amplitude in psi.items():
        state_int = int(state_str, 2)
        i_A = 0
        i_B = 0
        
        for i in range(N):
            bit = (state_int >> i) & 1
            if i in A_indices:
                i_A |= bit << A_indices.index(i)
            else:
                i_B |= bit << B_indices.index(i)
                
        psi_reshaped[i_A, i_B] = amplitude
        
    return psi_reshaped

def calculate_entropy(binary_states, N):
    """Calculate von Neumann entropy for a random bipartition."""
    psi = {state: np.sqrt(prob) for state, prob in binary_states.items()}
    
    A_size = np.random.randint(1, N)
    A_indices = np.random.choice(N, A_size, replace=False)
    
    subsystem_mask = np.zeros(N, dtype=int)
    subsystem_mask[A_indices] = 1
    
    psi_matrix = reshape_for_subsystem(psi, A_indices, N)
    U, s, Vh = np.linalg.svd(psi_matrix, full_matrices=False)
    s_squared = s**2
    s_squared_normalized = s_squared / np.sum(s_squared)
    entropy = -np.sum(s_squared_normalized * np.log(s_squared_normalized + 1e-12))
    
    return entropy, subsystem_mask

def create_graph_data(binary_states, N, x_spacing=6.0, y_spacing=6.0):
    """Create a graph representation for the GNN."""
    Nx = N // 2  # Assuming 2 rows
    Ny = 2
    
    p_rydberg = binary_to_rydberg_probabilities(binary_states, N)
    entropy, subsystem_mask = calculate_entropy(binary_states, N)
    
    positions = np.array([
        (col * x_spacing, row * y_spacing)
        for row in range(Nx) for col in range(Ny)
    ], dtype=np.float32)
    
    boundary_dist = np.zeros(N)
    for i in range(N):
        mask_i = subsystem_mask[i]
        min_dist = N
        for j in range(N):
            if subsystem_mask[j] != mask_i:
                dist = abs(i - j)
                min_dist = min(min_dist, dist)
        boundary_dist[i] = min_dist
    
    node_features = torch.tensor(np.column_stack([
        positions,
        p_rydberg.reshape(-1, 1),
        subsystem_mask.reshape(-1, 1),
        boundary_dist.reshape(-1, 1)
    ]), dtype=torch.float)
    
    R_cut = 25.0
    edges = []
    edge_attrs = []
    
    for i in range(N):
        for j in range(i+1, N):
            dx = positions[i, 0] - positions[j, 0]
            dy = positions[i, 1] - positions[j, 1]
            dist = np.sqrt(dx*dx + dy*dy)
            
            if dist <= R_cut:
                edges.append([i, j])
                edges.append([j, i])
                
                angle = np.arctan2(dy, dx)
                correlation = p_rydberg[i] * p_rydberg[j]
                
                edge_attrs.extend([[angle, correlation, dist],
                                 [-angle, correlation, dist]])
    
    edge_index = torch.tensor(edges, dtype=torch.long).t()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
    
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        y=torch.tensor([entropy], dtype=torch.float),
        system_size=torch.tensor([[N]], dtype=torch.float),
        total_rydberg=torch.tensor([p_rydberg.sum()], dtype=torch.float),
        rydberg_density=torch.tensor([p_rydberg.sum()/N], dtype=torch.float),
        config_entropy=torch.tensor([[entropy]], dtype=torch.float),
        nA=torch.tensor([[subsystem_mask.sum()]], dtype=torch.float),
        nB=torch.tensor([[N - subsystem_mask.sum()]], dtype=torch.float)
    )
    
    return data

def process_directory(directory_path, model_path):
    """Process all txt files in a directory and make predictions with diagnostics."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the model
    sample_data = create_graph_data({'0'*18: 1}, 18)
    model = ExperimentalGNN(
        num_node_features=sample_data.x.size(1),
        edge_attr_dim=sample_data.edge_attr.size(1),
        hidden_channels=CONFIG['hidden_channels'],
        num_layers=CONFIG['num_layers'],
        dropout_p=CONFIG['dropout_p']
    ).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    predictions = []
    targets = []
    system_sizes = []
    rydberg_densities = []
    total_densities = []
    
    for filepath in glob.glob(os.path.join(directory_path, "*.txt")):
        try:
            states_dict = parse_binary_file(filepath)
            N = len(next(iter(states_dict.keys())))
            
            probabilities = calculate_probabilities(states_dict)
            data = create_graph_data(probabilities, N)
            data = data.to(device)
            
            with torch.no_grad():
                pred = model(data)
                log_s_over_n = pred[0, 0]
                s_over_n = pred[0, 1]
                abs_pred = torch.exp(log_s_over_n * data.system_size.squeeze(-1))
            
            predictions.append(float(abs_pred.cpu()))
            targets.append(float(data.y.cpu()))
            system_sizes.append(int(data.system_size.cpu()))
            rydberg_densities.append(float(data.rydberg_density.cpu()))
            total_densities.append(1.0)  # Placeholder for total density
            
        except Exception as e:
            logging.error(f"Error processing {filepath}: {str(e)}")
    
    diagnostics = {
        'predictions': np.array(predictions),
        'targets': np.array(targets),
        'sizes': np.array(system_sizes),
        'rydberg_density': np.array(rydberg_densities),
        'total_density': np.array(total_densities)
    }
    
    return diagnostics

def main():
    setup_logging()
    
    data_dir = r'C:\Users\amssa\Downloads\rydberg_test13_NR_2.0_2_6_NS_T_2_\1'
    model_path = CONFIG['best_model_path']
    
    diagnostics = process_directory(data_dir, model_path)
    create_diagnostic_plots(diagnostics, save_plots=True, save_dir="diagnostic_plots")

if __name__ == "__main__":
    main()

  model.load_state_dict(torch.load(model_path, map_location=device))


KeyboardInterrupt: 