In [1]:
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
import logging
import random
import warnings
import joblib

# Configuration settings
CONFIG = {
    'data_path': r'C:\Users\amssa\Documents\Codes\New\Von-Neumann-Entropy-GNN\Random sub sets\spin_system_properties_gpu1-7.parquet',
    'processed_dir': './processed',
    'processed_file': './processed/data.pt',
    'batch_size': 1024,
    'random_seed': 42,
    'distance_threshold': 25,
    'scalers_path': 'scalers.pkl',
}

def setup_logging():
    """Configure logging settings."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.StreamHandler()]
    )

def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

class SpinSystemDataset(InMemoryDataset):
    """
    Custom dataset class for quantum spin systems.
    
    Features:
    ---------
    Node Features (8-dimensional):
    - Normalized Positions (2D)
    - Rydberg Probability (1D)
    - Partition Feature (1D)
    - Neighbor Density (1D)
    - Boundary Distance (1D)
    - Angular Position (1D)
    - Interaction Energy (1D)
    - Configuration Entropy (1D)
    
    Edge Features (4-dimensional):
    - Van der Waals Interaction (1D)
    - Quantum Correlation (1D)
    - Edge Orientation (1D)
    - Relative Position (1D)
    """
    
    def __init__(self, dataframe, root='.', transform=None, pre_transform=None):
        self.df = dataframe
        super(SpinSystemDataset, self).__init__(root, transform, pre_transform)
        if os.path.exists(self.processed_paths[0]):
            self.data, self.slices = torch.load(self.processed_paths[0])
        else:
            self.process()

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        """Process raw quantum data into graph structures."""
        data_list = []
        for idx, row in self.df.iterrows():
            # Basic system setup
            Nx = row['Nx']
            Ny = 2
            N = Nx * Ny

            # Create atomic positions
            x_spacing = row['x_spacing']
            y_spacing = row['y_spacing']
            positions = np.array([
                (col * x_spacing, row_idx * y_spacing)
                for row_idx in range(Nx) for col in range(Ny)
            ])
            positions = torch.tensor(positions, dtype=torch.float)

            # Normalize positions to [0,1] range
            pos_min = positions.min(dim=0).values
            pos_max = positions.max(dim=0).values
            normalized_positions = (positions - pos_min) / (pos_max - pos_min + 1e-8)

            # Calculate Rydberg probabilities
            state_indices = row['Top_50_Indices']
            state_probs = row['Top_50_Probabilities']
            p_rydberg = torch.zeros(N, dtype=torch.float)
            for state, prob in zip(state_indices, state_probs):
                state = int(state)
                for i in range(N):
                    if state & (1 << i):
                        p_rydberg[i] += prob
            p_rydberg = p_rydberg.unsqueeze(1)

            # Create subsystem partition feature
            subsystem_mask = torch.tensor([int(bit) for bit in row['Subsystem_Mask']], 
                                        dtype=torch.float).unsqueeze(1)

            # Calculate nearest neighbors
            distance_threshold = CONFIG['distance_threshold']
            nbrs = NearestNeighbors(radius=distance_threshold, algorithm='ball_tree').fit(positions.numpy())
            indices = nbrs.radius_neighbors(positions.numpy(), return_distance=False)

            # Calculate local density of Rydberg excitations
            neighbor_density = torch.zeros(N, 1)
            for i in range(N):
                neighbors = indices[i]
                if len(neighbors) > 0:
                    neighbor_density[i] = torch.mean(p_rydberg[neighbors])

            # Calculate distance to system boundaries
            boundary_distance = torch.zeros(N, 1)
            for i in range(N):
                x, y = positions[i]
                dx = min(x, pos_max[0] - x)
                dy = min(y, pos_max[1] - y)
                boundary_distance[i] = min(dx, dy)

            # Calculate angular positions
            angles = torch.atan2(normalized_positions[:, 1], 
                               normalized_positions[:, 0]).unsqueeze(1)

            # Calculate local interaction energies
            nn_interaction = torch.zeros(N, 1)
            for i in range(N):
                neighbors = indices[i]
                if len(neighbors) > 0:
                    nn_interaction[i] = torch.sum(p_rydberg[i] * p_rydberg[neighbors])

            # Calculate local configuration entropy
            probs = p_rydberg.squeeze()
            config_entropy = -torch.sum(probs * torch.log(probs + 1e-10) + 
                                      (1-probs) * torch.log(1-probs + 1e-10))
            local_entropy = config_entropy.repeat(N).unsqueeze(1)

            # Combine node features
            node_features = torch.cat([
                normalized_positions,     # [N, 2]
                p_rydberg,               # [N, 1]
                subsystem_mask,          # [N, 1]
                neighbor_density,        # [N, 1]
                boundary_distance,       # [N, 1]
                angles,                  # [N, 1]
                nn_interaction,         # [N, 1]
                local_entropy           # [N, 1]
            ], dim=1)

            # Create edge connections
            edge_index = []
            for i in range(N):
                for j in indices[i]:
                    if i < j:
                        edge_index.append([i, j])
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

            # Calculate edge features
            if edge_index.size(1) > 0:
                pos_i = positions[edge_index[0]]
                pos_j = positions[edge_index[1]]
                edge_vectors = pos_j - pos_i
                distances = torch.norm(edge_vectors, dim=1, keepdim=True)
                
                # Van der Waals interaction
                epsilon = 1e-8
                inv_r6 = 1.0 / (distances.pow(6) + epsilon)

                # Edge geometry features
                edge_angles = torch.atan2(edge_vectors[:, 1], 
                                        edge_vectors[:, 0]).unsqueeze(1)
                relative_pos = (pos_i + pos_j) / 2 - positions.mean(dim=0)
                relative_pos_norm = torch.norm(relative_pos, dim=1, keepdim=True)

                # Calculate quantum correlations
                edge_correlation = torch.zeros(edge_index.size(1), 1)
                for k in range(edge_index.size(1)):
                    i, j = edge_index[0, k], edge_index[1, k]
                    joint_prob = 0.0
                    for state, prob in zip(state_indices, state_probs):
                        state = int(state)
                        if (state & (1 << i)) and (state & (1 << j)):
                            joint_prob += prob
                    edge_correlation[k] = joint_prob - p_rydberg[i] * p_rydberg[j]

                edge_attr = torch.cat([
                    inv_r6,             # [E, 1]
                    edge_correlation,   # [E, 1]
                    edge_angles,        # [E, 1]
                    relative_pos_norm   # [E, 1]
                ], dim=1)
            else:
                edge_attr = torch.empty((0, 4), dtype=torch.float)

            # Create graph with target (Von Neumann Entropy)
            entropy = torch.tensor([row['Von_Neumann_Entropy']], dtype=torch.float)
            data = Data(x=node_features, 
                       edge_index=edge_index, 
                       edge_attr=edge_attr, 
                       y=entropy)
            
            # Add global features
            data.Omega = torch.tensor([[row['Omega']]], dtype=torch.float)
            data.Delta = torch.tensor([[row['Delta']]], dtype=torch.float)
            data.Energy = torch.tensor([[row['Energy']]], dtype=torch.float)
            data.total_rydberg = torch.sum(p_rydberg)
            data.rydberg_density = data.total_rydberg / N
            data.system_size = torch.tensor([[N]], dtype=torch.float)
            data.config_entropy = config_entropy.unsqueeze(0)

            data_list.append(data)

            if idx % 1000 == 0:
                logging.info(f"Processed {idx} graphs")
                logging.info(f"Node features shape: {node_features.shape}")
                logging.info(f"Edge features shape: {edge_attr.shape}")

        if self.pre_transform:
            data_list = [self.pre_transform(d) for d in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

def load_data(config):
    """Load and prepare the dataset from parquet file."""
    warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
    if not os.path.exists(config['data_path']):
        logging.error(f"Data file not found at {config['data_path']}")
        raise FileNotFoundError(f"Data file not found at {config['data_path']}")

    df = pd.read_parquet(config['data_path'])

    logging.info("First few rows of the dataset:")
    logging.info(df.head())
    logging.info("\nDataset Information:")
    logging.info(df.info())

    df_shuffled = df.sample(frac=1, random_state=config['random_seed']).reset_index(drop=True)
    dataset = SpinSystemDataset(df_shuffled, root=config['processed_dir'])

    logging.info(f'\nTotal graphs in dataset: {len(dataset)}')
    logging.info(f'\nSample Data Object:')
    logging.info(dataset[0])

    return dataset

def split_dataset(dataset, config):
    """Split dataset into training, validation, and test sets."""
    total = len(dataset)
    train_end = int(0.8 * total)
    val_end = int(0.9 * total)

    train_dataset = dataset[:train_end]
    val_dataset = dataset[train_end:val_end]
    test_dataset = dataset[val_end:]

    logging.info(f'\nTraining graphs: {len(train_dataset)}')
    logging.info(f'Validation graphs: {len(val_dataset)}')
    logging.info(f'Test graphs: {len(test_dataset)}')
    return train_dataset, val_dataset, test_dataset

def create_dataloaders(train_dataset, val_dataset, test_dataset, config):
    """Create PyTorch Geometric DataLoaders for all datasets."""
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    return train_loader, val_loader, test_loader

def main():
    """Main execution function."""
    setup_logging()
    set_seed(CONFIG['random_seed'])
    
    dataset = load_data(CONFIG)
    train_dataset, val_dataset, test_dataset = split_dataset(dataset, CONFIG)
    train_loader, val_loader, test_loader = create_dataloaders(
        train_dataset, val_dataset, test_dataset, CONFIG
    )

    logging.info("Data processing and loading completed successfully.")

if __name__ == "__main__":
    main()

2024-12-24 21:52:43,052 [INFO] First few rows of the dataset:
2024-12-24 21:52:43,060 [INFO]    Nx      Delta      Omega  x_spacing  y_spacing      Energy  \
0   7  40.030606   9.988305   7.118764   6.387401 -239.252283   
1   7  23.347429  39.473622   4.798695   6.056938 -221.844105   
2   7   9.494868  20.161970   4.728944   7.021446 -105.435782   
3   7   5.336133  10.603160   4.180909   5.301321  -44.204418   
4   7  16.238134  16.933983   6.918425   6.550230 -122.907435   

                                      Top_50_Indices  \
0  [6553, 9830, 6425, 9766, 6297, 6537, 9798, 931...   
1  [0, 1, 8192, 4096, 2, 4097, 8194, 8193, 4098, ...   
2  [0, 8192, 2, 1, 4096, 8194, 4097, 8193, 4098, ...   
3  [0, 1, 4096, 8192, 2, 8193, 4098, 4097, 8194, ...   
4  [8738, 4369, 8802, 4497, 9762, 8742, 4377, 641...   

                                Top_50_Probabilities  Von_Neumann_Entropy  \
0  [0.38220482824683005, 0.38220477448728546, 0.0...             0.666110   
1  [0.007107020650037715,

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1500000 entries, 0 to 1499999
Data columns (total 11 columns):
 #   Column                Non-Null Count    Dtype  
---  ------                --------------    -----  
 0   Nx                    1500000 non-null  int64  
 1   Delta                 1500000 non-null  float64
 2   Omega                 1500000 non-null  float64
 3   x_spacing             1500000 non-null  float64
 4   y_spacing             1500000 non-null  float64
 5   Energy                1500000 non-null  float64
 6   Top_50_Indices        1500000 non-null  object 
 7   Top_50_Probabilities  1500000 non-null  object 
 8   Von_Neumann_Entropy   1500000 non-null  float64
 9   N_A                   1500000 non-null  int64  
 10  Subsystem_Mask        1500000 non-null  object 
dtypes: float64(6), int64(2), object(3)
memory usage: 125.9+ MB


Processing...
2024-12-24 21:52:44,104 [INFO] Processed 0 graphs
2024-12-24 21:52:44,105 [INFO] Node features shape: torch.Size([6, 9])
2024-12-24 21:52:44,105 [INFO] Edge features shape: torch.Size([15, 4])
2024-12-24 21:53:01,833 [INFO] Processed 1000 graphs
2024-12-24 21:53:01,835 [INFO] Node features shape: torch.Size([10, 9])
2024-12-24 21:53:01,835 [INFO] Edge features shape: torch.Size([45, 4])
2024-12-24 21:53:19,496 [INFO] Processed 2000 graphs
2024-12-24 21:53:19,496 [INFO] Node features shape: torch.Size([8, 9])
2024-12-24 21:53:19,497 [INFO] Edge features shape: torch.Size([28, 4])
2024-12-24 21:53:37,785 [INFO] Processed 3000 graphs
2024-12-24 21:53:37,786 [INFO] Node features shape: torch.Size([4, 9])
2024-12-24 21:53:37,786 [INFO] Edge features shape: torch.Size([6, 4])
2024-12-24 21:53:56,619 [INFO] Processed 4000 graphs
2024-12-24 21:53:56,620 [INFO] Node features shape: torch.Size([12, 9])
2024-12-24 21:53:56,620 [INFO] Edge features shape: torch.Size([54, 4])
2024-12-