In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal, kl_divergence
import logging
import time
from tqdm import tqdm
import pickle
import os
import random

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class CDVAE_TI_Generator:
    """
    Crystal Diffusion Variational Autoencoder (CDVAE) with Reinforcement Learning
    for targeted generation of Topological Insulator materials.
    """
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        
        # Initialize CDVAE model components
        self.initialize_models()
        
        # Set up optimizers
        self.setup_optimizers()
        
        # Initialize results tracking
        self.results = {
            'rewards': [],
            'z_gap': [],
            'topological_indices': [],
            'formation_energies': [],
            'best_structures': [],
            'best_rewards': [],
        }
        
        # Initialize replay buffer for experience replay
        self.replay_buffer = ReplayBuffer(config['buffer_size'])
        
    def initialize_models(self):
        """Initialize CDVAE encoder, decoder and policy networks."""
        # Import specific model classes
        try:
            from cdvae.pl_modules.decoder import GemNetTDecoder
            from cdvae.common.data_utils import ATOM_TYPES
        except ImportError:
            logger.error("Failed to import CDVAE modules. Please ensure CDVAE is installed correctly.")
            raise
            
        # Get dimensions and parameters from config
        self.latent_dim = self.config['latent_dim']
        self.n_elements = len(self.config['elements']) if 'elements' in self.config else len(ATOM_TYPES)
        
        #TODO: possibly write an encoder file addition to cdvae
        # Initialize encoder (if using pre-trained weights)
        # if self.config.get('use_encoder', False):
        #     self.encoder = GraphEncoder(
        #         hidden_dim=self.config['hidden_dim'],
        #         latent_dim=self.latent_dim,
        #         use_layer_norm=self.config.get('use_layer_norm', True)
        #     ).to(self.device)
            
        #     if self.config.get('encoder_checkpoint'):
        #         self._load_model(self.encoder, self.config['encoder_checkpoint'])
        # else:
        #     self.encoder = None

        self.encoder = None
            
        # Initialize decoder
        self.decoder = GemNetTDecoder(
            latent_dim=self.latent_dim,
            hidden_dim=self.config['hidden_dim'],
            n_elements=self.n_elements,
            cutoff=self.config.get('cutoff', 6.0),
            max_neighbors=self.config.get('max_neighbors', 20),
            use_layer_norm=self.config.get('use_layer_norm', True)
        ).to(self.device)
        
        if self.config.get('decoder_checkpoint'):
            self._load_model(self.decoder, self.config['decoder_checkpoint'])
            
        # Initialize policy network for RL
        self.policy_net = PolicyNetwork(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('policy_hidden_dims', [256, 256]),
            activation=self.config.get('policy_activation', 'relu')
        ).to(self.device)
        
        # Initialize critic network for actor-critic methods
        if self.config.get('use_critic', True):
            self.critic = CriticNetwork(
                latent_dim=self.latent_dim,
                hidden_dims=self.config.get('critic_hidden_dims', [256, 128]),
                activation=self.config.get('critic_activation', 'relu')
            ).to(self.device)
        else:
            self.critic = None
            
        # DFT surrogate models - predict quantum properties directly from latent space
        self.energy_predictor = EnergyPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('energy_predictor_dims', [128, 64])
        ).to(self.device)
        
        self.topological_predictor = TopologicalPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('topo_predictor_dims', [128, 64])
        ).to(self.device)
        
        if self.config.get('surrogate_checkpoint'):
            self._load_surrogate_models(self.config['surrogate_checkpoint'])
        
    def setup_optimizers(self):
        """Set up optimizers for different components."""
        # Policy optimizer
        self.policy_optimizer = torch.optim.Adam(
            self.policy_net.parameters(),
            lr=self.config.get('policy_lr', 1e-4),
            weight_decay=self.config.get('policy_weight_decay', 1e-6)
        )
        
        # Critic optimizer (if using actor-critic)
        if self.critic is not None:
            self.critic_optimizer = torch.optim.Adam(
                self.critic.parameters(),
                lr=self.config.get('critic_lr', 3e-4),
                weight_decay=self.config.get('critic_weight_decay', 1e-6)
            )
        
        # Surrogate model optimizers for fine-tuning
        if self.config.get('train_surrogates', False):
            self.energy_optimizer = torch.optim.Adam(
                self.energy_predictor.parameters(),
                lr=self.config.get('surrogate_lr', 1e-4)
            )
            
            self.topo_optimizer = torch.optim.Adam(
                self.topological_predictor.parameters(),
                lr=self.config.get('surrogate_lr', 1e-4)
            )
            
    def _load_model(self, model, checkpoint_path):
        """Load model weights from checkpoint."""
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            if 'state_dict' in checkpoint:
                # Handle pytorch-lightning checkpoints
                state_dict = {k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items() 
                              if k.startswith('model.')}
                model.load_state_dict(state_dict, strict=False)
            else:
                # Handle regular torch checkpoints
                model.load_state_dict(checkpoint, strict=False)
            logger.info(f"Loaded weights from {checkpoint_path}")
        except Exception as e:
            logger.error(f"Failed to load weights: {e}")
            
    def _load_surrogate_models(self, checkpoint_path):
        """Load surrogate model weights."""
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.energy_predictor.load_state_dict(checkpoint['energy_predictor'])
            self.topological_predictor.load_state_dict(checkpoint['topo_predictor'])
            logger.info(f"Loaded surrogate models from {checkpoint_path}")
        except Exception as e:
            logger.error(f"Failed to load surrogate models: {e}")
            
    def generate_structures(self, batch_size=None):
        """Generate crystal structures using the policy network and decoder."""
        if batch_size is None:
            batch_size = self.config.get('batch_size', 32)
            
        # Sample latent vectors from the policy network
        z_noise = torch.randn(batch_size, self.latent_dim).to(self.device)
        z_sampled, log_probs = self.policy_net(z_noise)
        
        # Generate structures using the decoder
        with torch.no_grad():
            # Assuming decoder outputs a dictionary with:
            # - frac_coords: fractional coordinates of atoms
            # - atom_types: types of atoms (one-hot or indices)
            # - lattice: lattice parameters for unit cells
            generated_structures = self.decoder(z_sampled)
            
        return generated_structures, z_sampled, log_probs
    
    def evaluate_structures(self, structures, z_vectors):
        """Evaluate generated structures using surrogate models."""
        # Predict formation energies
        with torch.no_grad():
            energies = self.energy_predictor(z_vectors)
            
            # Predict topological indices (Z2 invariants, Chern numbers, etc.)
            topo_indices = self.topological_predictor(z_vectors)
            
            # Calculate band gaps (can be part of the topological predictor or separate)
            band_gaps = self.estimate_band_gap(structures, z_vectors)
            
        # Combine predictions into a comprehensive evaluation
        evaluations = {
            'formation_energies': energies.cpu().numpy(),
            'topological_indices': topo_indices.cpu().numpy(),
            'band_gaps': band_gaps.cpu().numpy() if isinstance(band_gaps, torch.Tensor) else band_gaps
        }
        
        return evaluations
    
    def calculate_rewards(self, evaluations):
        """Calculate rewards based on desired material properties."""
        # Extract evaluations
        energies = evaluations['formation_energies']
        topo_indices = evaluations['topological_indices']
        band_gaps = evaluations['band_gaps']
        
        # Convert to numpy for easier manipulation
        if isinstance(energies, torch.Tensor):
            energies = energies.cpu().numpy()
        if isinstance(topo_indices, torch.Tensor):
            topo_indices = topo_indices.cpu().numpy()
        if isinstance(band_gaps, torch.Tensor):
            band_gaps = band_gaps.cpu().numpy()
            
        # Calculate stability reward component
        # Lower formation energy is better, but must be below threshold to be stable
        stability_threshold = self.config.get('stability_threshold', 0.1)
        stability_rewards = -energies * (energies < stability_threshold)
        
        # Calculate topological reward component
        # For Z2 invariants, we typically want (1;000) for 3D TIs
        # This is a simplified example - actual implementation depends on how topo_indices are represented
        topo_rewards = np.sum(topo_indices * self.config.get('topo_weights', [2.0, 1.0, 1.0, 1.0]), axis=1)
        
        # Calculate band gap reward component
        # Usually want a moderate band gap (not too small, not too large)
        target_gap = self.config.get('target_band_gap', 0.3)  # in eV
        gap_tolerance = self.config.get('gap_tolerance', 0.2)  # in eV
        gap_rewards = 1.0 - np.minimum(np.abs(band_gaps - target_gap) / gap_tolerance, 1.0)
        
        # Combine reward components with configurable weights
        w_stability = self.config.get('w_stability', 1.0)
        w_topological = self.config.get('w_topological', 2.0)
        w_gap = self.config.get('w_gap', 1.5)
        
        combined_rewards = (w_stability * stability_rewards + 
                           w_topological * topo_rewards +
                           w_gap * gap_rewards)
        
        # Create rewards dictionary
        rewards_dict = {
            'total': combined_rewards,
            'stability': stability_rewards,
            'topological': topo_rewards,
            'band_gap': gap_rewards
        }
        
        return rewards_dict
    
    def estimate_band_gap(self, structures, z_vectors):
        """Estimate band gaps of structures using a surrogate model."""
        # This would typically be a separate model or part of topological_predictor
        # For simplicity, we'll use a mock implementation
        batch_size = z_vectors.shape[0]
        
        # Mock band gap estimation (replace with actual model)
        # In practice, this would use a trained neural network or other predictor
        gaps = 0.2 + 0.3 * torch.sigmoid(z_vectors[:, 0]) + 0.1 * torch.randn(batch_size).to(self.device)
        
        return gaps
        
    def reinforce_update(self, rewards, log_probs):
        """Update policy network using REINFORCE algorithm."""
        rewards_tensor = torch.tensor(rewards, device=self.device)
        
        # Normalize rewards
        rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)
        
        # Calculate policy loss
        policy_loss = -(log_probs * rewards_normalized).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        
        # Optional gradient clipping
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.policy_net.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
            
        self.policy_optimizer.step()
        
        return policy_loss.item()
        
    def actor_critic_update(self, z_vectors, rewards, log_probs):
        """Update policy and critic networks using Actor-Critic algorithm."""
        if self.critic is None:
            return self.reinforce_update(rewards, log_probs)
            
        rewards_tensor = torch.tensor(rewards, device=self.device)
        
        # Get critic's value predictions
        value_predictions = self.critic(z_vectors).squeeze()
        
        # Calculate advantages
        advantages = rewards_tensor - value_predictions.detach()
        
        # Calculate policy (actor) loss
        policy_loss = -(log_probs * advantages).mean()
        
        # Calculate value (critic) loss
        critic_loss = F.mse_loss(value_predictions, rewards_tensor)
        
        # Update policy network
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.policy_net.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
        self.policy_optimizer.step()
        
        # Update critic network
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.critic.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
        self.critic_optimizer.step()
        
        return policy_loss.item(), critic_loss.item()
    
    def train_step(self):
        """Perform a single training step."""
        # Generate structures
        structures, z_vectors, log_probs = self.generate_structures()
        
        # Evaluate structures
        evaluations = self.evaluate_structures(structures, z_vectors)
        
        # Calculate rewards
        rewards_dict = self.calculate_rewards(evaluations)
        total_rewards = rewards_dict['total']
        
        # Store experience in replay buffer
        for i in range(len(total_rewards)):
            self.replay_buffer.add(
                z_vectors[i].detach().cpu().numpy(),
                total_rewards[i],
                log_probs[i].detach().cpu().numpy()
            )
        
        # Update policy using actor-critic or REINFORCE
        if self.critic is not None:
            loss_info = self.actor_critic_update(z_vectors, total_rewards, log_probs)
            policy_loss = loss_info[0]
        else:
            policy_loss = self.reinforce_update(total_rewards, log_probs)
            
        # Track best structures
        best_idx = np.argmax(total_rewards)
        best_reward = total_rewards[best_idx]
        
        # Update current best if this is better
        if not self.results['best_rewards'] or best_reward > max(self.results['best_rewards']):
            self.results['best_structures'].append(structures[best_idx])
            
        # Store results
        self.results['rewards'].append(np.mean(total_rewards))
        self.results['formation_energies'].append(np.mean(evaluations['formation_energies']))
        self.results['topological_indices'].append(np.mean(evaluations['topological_indices']))
        self.results['best_rewards'].append(best_reward)
        
        return {
            'mean_reward': np.mean(total_rewards),
            'max_reward': np.max(total_rewards),
            'mean_energy': np.mean(evaluations['formation_energies']),
            'policy_loss': policy_loss
        }
    
    def train(self, num_iterations=None):
        """Train the model for the specified number of iterations."""
        if num_iterations is None:
            num_iterations = self.config.get('num_iterations', 500)
            
        logger.info(f"Starting training for {num_iterations} iterations")
        
        for iteration in tqdm(range(num_iterations)):
            # Perform a training step
            step_results = self.train_step()
            
            # Log progress periodically
            if iteration % self.config.get('log_frequency', 10) == 0:
                logger.info(
                    f"Iteration {iteration} | "
                    f"Mean Reward: {step_results['mean_reward']:.4f} | "
                    f"Max Reward: {step_results['max_reward']:.4f} | "
                    f"Mean Energy: {step_results['mean_energy']:.4f} | "
                    f"Policy Loss: {step_results['policy_loss']:.4f}"
                )
                
            # Save checkpoints periodically
            if iteration % self.config.get('save_frequency', 100) == 0 and iteration > 0:
                self.save_checkpoint(f"checkpoint_iter_{iteration}.pt")
                
        logger.info("Training completed")
        self.save_checkpoint("final_checkpoint.pt")
        self.save_results("training_results.pkl")
        
    def save_checkpoint(self, filename):
        """Save model checkpoint."""
        checkpoint_dir = self.config.get('checkpoint_dir', './checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        checkpoint_path = os.path.join(checkpoint_dir, filename)
        
        checkpoint = {
            'policy_state_dict': self.policy_net.state_dict(),
            'policy_optimizer': self.policy_optimizer.state_dict(),
            'config': self.config,
            'iteration': len(self.results['rewards'])
        }
        
        if self.critic is not None:
            checkpoint['critic_state_dict'] = self.critic.state_dict()
            checkpoint['critic_optimizer'] = self.critic_optimizer.state_dict()
            
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"Saved checkpoint to {checkpoint_path}")
        
    def save_results(self, filename):
        """Save training results."""
        results_dir = self.config.get('results_dir', './results')
        os.makedirs(results_dir, exist_ok=True)
        
        results_path = os.path.join(results_dir, filename)
        
        with open(results_path, 'wb') as f:
            pickle.dump(self.results, f)
            
        logger.info(f"Saved results to {results_path}")
        
    def load_checkpoint(self, checkpoint_path):
        """Load model from checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
        self.policy_optimizer.load_state_dict(checkpoint['policy_optimizer'])
        
        if self.critic is not None and 'critic_state_dict' in checkpoint:
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
            
        logger.info(f"Loaded checkpoint from {checkpoint_path}")
        return checkpoint.get('iteration', 0)


class PolicyNetwork(nn.Module):
    """Policy network for RL-based latent space exploration."""
    
    def __init__(self, latent_dim, hidden_dims=[256, 256], activation='relu'):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Map activation function string to actual function
        act_fn = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.2),
            'tanh': nn.Tanh(),
            'silu': nn.SiLU()
        }.get(activation.lower(), nn.ReLU())
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(act_fn)
            input_dim = hidden_dim
            
        # Output layer for mean
        self.fc_layers = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(input_dim, latent_dim)
        
        # Learnable log std for exploration
        self.log_std = nn.Parameter(torch.zeros(latent_dim))
        
        # Apply weight initialization
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, z_noise):
        """
        Forward pass through the policy network.
        
        Args:
            z_noise: Random noise tensor of shape [batch_size, latent_dim]
            
        Returns:
            z_sampled: Sampled latent vectors
            log_probs: Log probabilities of the sampled vectors
        """
        x = self.fc_layers(z_noise)
        mu = self.fc_mu(x)
        
        # Get standard deviation from learnable parameter
        std = torch.exp(self.log_std.clamp(-20, 2))  # Clamp for stability
        
        # Create normal distribution
        dist = Normal(mu, std)
        
        # Sample using reparameterization trick
        z_sampled = dist.rsample()
        
        # Calculate log probabilities
        log_probs = dist.log_prob(z_sampled).sum(dim=-1)
        
        return z_sampled, log_probs


class CriticNetwork(nn.Module):
    """Critic network for actor-critic method."""
    
    def __init__(self, latent_dim, hidden_dims=[256, 128], activation='relu'):
        super().__init__()
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Map activation function string to actual function
        act_fn = {
            'relu': nn.ReLU(),
            'leaky_relu': nn.LeakyReLU(0.2),
            'tanh': nn.Tanh(),
            'silu': nn.SiLU()
        }.get(activation.lower(), nn.ReLU())
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(act_fn)
            input_dim = hidden_dim
            
        # Output layer - single value output
        layers.append(nn.Linear(input_dim, 1))
        
        self.model = nn.Sequential(*layers)
        
        # Apply weight initialization
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, z):
        """
        Forward pass through the critic network.
        
        Args:
            z: Latent vector of shape [batch_size, latent_dim]
            
        Returns:
            value: Predicted value of the state
        """
        return self.model(z)


class EnergyPredictor(nn.Module):
    """Surrogate model to predict formation energy from latent space."""
    
    def __init__(self, latent_dim, hidden_dims=[128, 64]):
        super().__init__()
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer - single value for formation energy
        layers.append(nn.Linear(input_dim, 1))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, z):
        """Predict formation energy from latent vector."""
        return self.model(z)


class TopologicalPredictor(nn.Module):
    """
    Surrogate model to predict topological invariants (Z2, Chern number) 
    from latent space.
    """
    
    def __init__(self, latent_dim, hidden_dims=[128, 64], num_invariants=4):
        super().__init__()
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer - multiple values for topological invariants
        # For 3D topological insulators, typically 4 Z2 invariants (ν₀;ν₁ν₂ν₃)
        self.feature_extractor = nn.Sequential(*layers)
        self.invariant_head = nn.Linear(input_dim, num_invariants)
        
    def forward(self, z):
        """Predict topological invariants from latent vector."""
        features = self.feature_extractor(z)
        # Apply sigmoid to constrain outputs between 0 and 1
        # In practice, these would be discretized to 0 or 1 when interpreting
        invariants = torch.sigmoid(self.invariant_head(features))
        return invariants


class ReplayBuffer:
    """Experience replay buffer for more stable training."""
    
    def __init__(self, max_size=10000):
        self.max_size = max_size
        self.buffer = []
        self.position = 0
        
    def add(self, z, reward, log_prob):
        """Add experience to buffer."""
        if len(self.buffer) < self.max_size:
            self.buffer.append(None)
        self.buffer[self.position] = (z, reward, log_prob)
        self.position = (self.position + 1) % self.max_size
        
    def sample(self, batch_size):
        """Sample a batch of experiences."""
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        z, rewards, log_probs = map(np.array, zip(*batch))
        return z, rewards, log_probs
        
    def __len__(self):
        """Return current buffer size."""
        return len(self.buffer)

In [2]:

# Example configuration
def get_default_config():
    """Get default configuration for CDVAE + RL training."""
    return {
        # Model dimensions
        "latent_dim": 64,
        "hidden_dim": 128,
        
        # Elements to consider
        "elements": ["Si", "Ge", "Sn", "Pb", "Bi", "Sb", "Te", "Se", "O"],
        
        # Training parameters
        "batch_size": 32,
        "num_iterations": 1000,
        "policy_lr": 1e-4,
        "critic_lr": 3e-4,
        "surrogate_lr": 1e-4,
        
        # RL parameters
        "use_critic": True,  # Use actor-critic instead of REINFORCE
        "clip_grad": True,
        "max_grad_norm": 1.0,
        "buffer_size": 5000,  # Replay buffer size
        
        # Reward components
        "stability_threshold": 0.1,
        "target_band_gap": 0.3,  # Target band gap in eV
        "gap_tolerance": 0.2,    # Acceptable deviation from target
        "topo_weights": [2.0, 1.0, 1.0, 1.0],  # Weights for Z2 invariants
        
        # Reward weights
        "w_stability": 1.0,
        "w_topological": 2.0,
        "w_gap": 1.5,
        
        # Logging and checkpoints
        "log_frequency": 10,
        "save_frequency": 100,
        "checkpoint_dir": "./checkpoints",
        "results_dir": "./results"
    }

In [3]:
# Example usage
if __name__ == "__main__":
    import random
    import matplotlib.pyplot as plt
    
    # Set seeds for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Get default configuration
    config = get_default_config()
    
    # Create training framework
    ti_generator = CDVAE_TI_Generator(config)
    
    # Train the model
    ti_generator.train(num_iterations=500)
    
    # Plot training results
    plt.figure(figsize=(12, 8))
    
    # Plot rewards
    plt.subplot(2, 2, 1)
    plt.plot(ti_generator.results['rewards'])
    plt.title('Average Reward')
    plt.xlabel('Iteration')
    plt.ylabel('Reward')
    
    # Plot formation energies
    plt.subplot(2, 2, 2)
    plt.plot(ti_generator.results['formation_energies'])
    plt.title('Average Formation Energy')
    plt.xlabel('Iteration')
    plt.ylabel('Energy (eV)')
    
    # Plot topological indices
    plt.subplot(2, 2, 3)
    plt.plot(ti_generator.results['topological_indices'])
    plt.title('Average Topological Index')
    plt.xlabel('Iteration')
    plt.ylabel('Index Value')
    
    # Plot best rewards
    plt.subplot(2, 2, 4)
    plt.plot(ti_generator.results['best_rewards'])
    plt.title('Best Reward')
    plt.xlabel('Iteration')
    plt.ylabel('Reward')
    
    plt.tight_layout()
    plt.savefig('training_results.png')
    plt.show()
    
    # Generate some final structures
    structures, _, _ = ti_generator.generate_structures

2025-04-05 12:01:09,502 - INFO - Using device: cpu


OSError: dlopen(/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch_scatter/_scatter_cpu.so, 0x0006): Symbol not found: __ZN2at4_ops16div__Tensor_mode4callERNS_6TensorERKS2_NSt3__18optionalIN3c1017basic_string_viewIcEEEE
  Referenced from: <4A3195B8-9E71-3AE7-AE80-DBA66ADAC535> /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch_scatter/_scatter_cpu.so
  Expected in:     <DA215AD3-6EAE-3755-B6A5-A8EB4EF952B0> /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/lib/libtorch_cpu.dylib

In [3]:
import torch
print(torch.__version__)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/abiralshakya/Library/Python/3.12/lib/python/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/abiralshakya/Library/Python/3.12/lib/python/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/abiralshakya/Library/Python/3.12/lib/python/site-packages/ipykernel/kernelapp.py", line 739, in start
    s

2.2.0


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal
import logging
import os
import time
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class SimplifiedDecoder(nn.Module):
    """
    Simplified decoder that generates crystal structures directly from latent vectors.
    Instead of relying on complex GemNetT architecture, we use a simpler MLP-based model.
    """
    def __init__(self, latent_dim, hidden_dim=128, max_atoms=32, n_elements=10):
        super().__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.max_atoms = max_atoms
        self.n_elements = n_elements
        
        # MLP for decoding
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.ReLU(),
        )
        
        # Output heads
        # - Fractional coordinates: 3 values per atom (x,y,z)
        # - Atom types: one-hot encoding for each atom
        # - Existence flags: binary value indicating if atom exists
        # - Lattice parameters: 6 values (a, b, c, alpha, beta, gamma)
        self.coords_head = nn.Linear(hidden_dim * 2, max_atoms * 3)
        self.atom_types_head = nn.Linear(hidden_dim * 2, max_atoms * n_elements)
        self.exists_head = nn.Linear(hidden_dim * 2, max_atoms)
        self.lattice_head = nn.Linear(hidden_dim * 2, 6)
        
    def forward(self, z):
        """
        Decode latent vectors into crystal structures.
        
        Args:
            z: Latent vectors [batch_size, latent_dim]
            
        Returns:
            Dictionary containing crystal structure components:
            - frac_coords: fractional coordinates [batch_size, max_atoms, 3]
            - atom_types: one-hot encoded atom types [batch_size, max_atoms, n_elements]
            - atom_mask: existence mask [batch_size, max_atoms]
            - lattice: lattice parameters [batch_size, 6]
        """
        batch_size = z.shape[0]
        
        # Pass through MLP
        h = self.mlp(z)
        
        # Decode coordinates
        coords_flat = self.coords_head(h)
        frac_coords = coords_flat.view(batch_size, self.max_atoms, 3)
        # Constrain to unit cell (0,1)
        frac_coords = torch.sigmoid(frac_coords)
        
        # Decode atom types
        atom_types_logits = self.atom_types_head(h)
        atom_types = atom_types_logits.view(batch_size, self.max_atoms, self.n_elements)
        atom_types = F.softmax(atom_types, dim=-1)
        
        # Decode existence flags
        exists_logits = self.exists_head(h)
        atom_mask = torch.sigmoid(exists_logits)
        
        # Decode lattice parameters
        # a, b, c: positive values in Angstrom
        # alpha, beta, gamma: angles in degrees, constrained to reasonable ranges
        lattice_params = self.lattice_head(h)
        
        # Split into cell lengths and angles
        cell_lengths = torch.abs(lattice_params[:, :3]) + 3.0  # Min 3 Angstrom
        cell_angles = 60 + 60 * torch.sigmoid(lattice_params[:, 3:])  # Range: 60-120 degrees
        
        # Combine into final lattice parameters
        lattice = torch.cat([cell_lengths, cell_angles], dim=-1)
        
        return {
            'frac_coords': frac_coords,
            'atom_types': atom_types,
            'atom_mask': atom_mask,
            'lattice': lattice
        }


class TopologicalInsulatorGenerator:
    """
    Simplified framework for generating topological insulator materials
    using a VAE with reinforcement learning for optimization.
    """
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {self.device}")
        
        # Initialize models
        self.initialize_models()
        
        # Setup optimizers
        self.setup_optimizers()
        
        # Results tracking
        self.results = defaultdict(list)
        
    def initialize_models(self):
        """Initialize all model components."""
        # Get dimensions
        self.latent_dim = self.config['latent_dim']
        self.hidden_dim = self.config['hidden_dim']
        self.n_elements = len(self.config['elements'])
        self.max_atoms = self.config['max_atoms']
        
        # Initialize decoder
        self.decoder = SimplifiedDecoder(
            latent_dim=self.latent_dim,
            hidden_dim=self.hidden_dim,
            max_atoms=self.max_atoms,
            n_elements=self.n_elements
        ).to(self.device)
        
        # Initialize policy network
        self.policy_net = PolicyNetwork(
            latent_dim=self.latent_dim,
            hidden_dims=self.config['policy_hidden_dims']
        ).to(self.device)
        
        # Initialize surrogate property predictors
        self.formation_energy_net = PropertyPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config['property_hidden_dims'],
            output_dim=1,
            name="Formation Energy"
        ).to(self.device)
        
        self.band_structure_net = PropertyPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config['property_hidden_dims'],
            output_dim=self.config['band_structure_dim'],
            name="Band Structure"
        ).to(self.device)
        
        self.topological_net = PropertyPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config['property_hidden_dims'],
            output_dim=self.config['topo_invariant_dim'],
            name="Topological Invariants"
        ).to(self.device)
        
        # Optional: initialize critic for actor-critic method
        if self.config.get('use_critic', True):
            self.critic = CriticNetwork(
                latent_dim=self.latent_dim,
                hidden_dims=self.config['critic_hidden_dims']
            ).to(self.device)
        else:
            self.critic = None
    
    def setup_optimizers(self):
        """Setup optimizers for all trainable components."""
        # Policy optimizer
        self.policy_optimizer = torch.optim.Adam(
            self.policy_net.parameters(),
            lr=self.config['policy_lr'],
            weight_decay=self.config.get('weight_decay', 1e-6)
        )
        
        # Optional critic optimizer
        if self.critic is not None:
            self.critic_optimizer = torch.optim.Adam(
                self.critic.parameters(),
                lr=self.config['critic_lr'],
                weight_decay=self.config.get('weight_decay', 1e-6)
            )
    
    def generate_structures(self, batch_size=None):
        """Generate crystal structures from latent space samples."""
        if batch_size is None:
            batch_size = self.config['batch_size']
        
        # Sample from policy network
        z_noise = torch.randn(batch_size, self.latent_dim).to(self.device)
        z_sampled, log_probs = self.policy_net(z_noise)
        
        # Generate structures using the decoder
        with torch.no_grad():
            structures = self.decoder(z_sampled)
        
        return structures, z_sampled, log_probs
    
    def predict_properties(self, z_vectors):
        """Predict material properties from latent vectors."""
        with torch.no_grad():
            # Predict formation energy (lower is better for stability)
            energy = self.formation_energy_net(z_vectors)
            
            # Predict band structure features
            band_features = self.band_structure_net(z_vectors)
            
            # Extract band gap from band features
            # Assuming the first element is the band gap
            band_gap = band_features[:, 0:1]
            
            # Predict topological invariants (Z2, Chern numbers)
            topo_invariants = self.topological_net(z_vectors)
            # Apply sigmoid to constrain between 0 and 1
            topo_invariants = torch.sigmoid(topo_invariants)
        
        return {
            'formation_energy': energy.squeeze(),
            'band_gap': band_gap.squeeze(),
            'topo_invariants': topo_invariants
        }
    
    def calculate_rewards(self, properties):
        """Calculate rewards based on material properties."""
        # Extract properties
        energy = properties['formation_energy']
        band_gap = properties['band_gap']
        topo_invariants = properties['topo_invariants']
        
        # Convert to numpy for easier manipulation
        if isinstance(energy, torch.Tensor):
            energy = energy.cpu().numpy()
        if isinstance(band_gap, torch.Tensor):
            band_gap = band_gap.cpu().numpy()
        if isinstance(topo_invariants, torch.Tensor):
            topo_invariants = topo_invariants.cpu().numpy()
        
        # 1. Stability reward - negative formation energy with threshold
        stability_threshold = self.config.get('stability_threshold', 0.2)
        stability_reward = -np.clip(energy, -1.0, 1.0) * (energy < stability_threshold)
        
        # 2. Band gap reward - aim for target band gap
        target_gap = self.config.get('target_band_gap', 0.3)  # in eV
        gap_tolerance = self.config.get('gap_tolerance', 0.2)  # in eV
        gap_reward = 1.0 - np.minimum(np.abs(band_gap - target_gap) / gap_tolerance, 1.0)
        
        # 3. Topological reward - prefer non-trivial topological insulators
        # For Z2 invariants, we typically want (1;000) for strong 3D TIs
        # Assuming first invariant is the strong Z2 index (ν₀)
        strong_z2_idx = 0
        topo_reward = topo_invariants[:, strong_z2_idx]
        
        # Add additional weight to other invariants if desired
        if topo_invariants.shape[1] > 1:
            weak_indices = np.mean(topo_invariants[:, 1:], axis=1) 
            # Typically want (1;000) so penalize non-zero weak indices slightly
            topo_reward = topo_reward * (1.0 - 0.2 * weak_indices)
        
        # Combine rewards with configurable weights
        w_stability = self.config.get('w_stability', 1.0)
        w_gap = self.config.get('w_gap', 1.5)
        w_topo = self.config.get('w_topo', 2.0)
        
        total_reward = (
            w_stability * stability_reward +
            w_gap * gap_reward +
            w_topo * topo_reward
        )
        
        return {
            'total': total_reward,
            'stability': stability_reward,
            'band_gap': gap_reward,
            'topological': topo_reward
        }
    
    def update_policy_reinforce(self, rewards, log_probs):
        """Update policy using REINFORCE algorithm."""
        rewards_tensor = torch.tensor(rewards, device=self.device)
        
        # Normalize rewards
        rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)
        
        # Calculate policy loss
        policy_loss = -(log_probs * rewards_normalized).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        
        # Optional gradient clipping
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.policy_net.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
            
        self.policy_optimizer.step()
        
        return policy_loss.item()
    
    def update_policy_actor_critic(self, z_vectors, rewards):
        """Update policy using Actor-Critic algorithm."""
        if self.critic is None:
            return None, None
            
        rewards_tensor = torch.tensor(rewards, device=self.device)
        
        # Get critic's value predictions
        value_predictions = self.critic(z_vectors).squeeze()
        
        # Calculate advantages
        advantages = rewards_tensor - value_predictions.detach()
        
        # Get latest policy log probabilities
        _, log_probs = self.policy_net(z_vectors)
        
        # Calculate policy (actor) loss
        policy_loss = -(log_probs * advantages).mean()
        
        # Calculate value (critic) loss
        critic_loss = F.mse_loss(value_predictions, rewards_tensor)
        
        # Update policy network
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        # Update critic network
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        return policy_loss.item(), critic_loss.item()
    
    def train_step(self):
        """Perform a single training step."""
        # Generate structures
        structures, z_vectors, log_probs = self.generate_structures()
        
        # Predict properties
        properties = self.predict_properties(z_vectors)
        
        # Calculate rewards
        rewards_dict = self.calculate_rewards(properties)
        total_rewards = rewards_dict['total']
        
        # Update policy
        if self.critic is not None:
            actor_loss, critic_loss = self.update_policy_actor_critic(z_vectors, total_rewards)
            loss_info = {'actor_loss': actor_loss, 'critic_loss': critic_loss}
        else:
            policy_loss = self.update_policy_reinforce(total_rewards, log_probs)
            loss_info = {'policy_loss': policy_loss}
        
        # Track best structures
        best_idx = np.argmax(total_rewards)
        best_reward = total_rewards[best_idx]
        best_structure = {k: v[best_idx].cpu().detach().numpy() if isinstance(v, torch.Tensor) else v[best_idx] 
                         for k, v in structures.items()}
        
        # Update results
        self.results['rewards'].append(np.mean(total_rewards))
        self.results['max_rewards'].append(best_reward)
        self.results['formation_energy'].append(np.mean(properties['formation_energy'].cpu().numpy()))
        self.results['band_gap'].append(np.mean(properties['band_gap'].cpu().numpy()))
        self.results['topo_score'].append(np.mean(properties['topo_invariants'].cpu().numpy()[:, 0]))
        
        if not self.results.get('best_reward', []) or best_reward > max(self.results.get('best_reward', [0])):
            self.results['best_structure'] = best_structure
            self.results['best_reward'] = best_reward
            self.results['best_iteration'] = len(self.results['rewards']) - 1
        
        return {
            'mean_reward': np.mean(total_rewards),
            'max_reward': best_reward,
            'mean_energy': np.mean(properties['formation_energy'].cpu().numpy()),
            'loss_info': loss_info
        }
    
    def train(self, num_iterations=None):
        """Train the model for specified iterations."""
        if num_iterations is None:
            num_iterations = self.config.get('num_iterations', 1000)
        
        logger.info(f"Starting training for {num_iterations} iterations")
        
        # Training loop
        for iter_idx in tqdm(range(num_iterations)):
            # Perform training step
            step_info = self.train_step()
            
            # Log progress
            if iter_idx % self.config.get('log_freq', 10) == 0:
                loss_str = ""
                if 'policy_loss' in step_info['loss_info']:
                    loss_str = f"Policy Loss: {step_info['loss_info']['policy_loss']:.4f}"
                elif 'actor_loss' in step_info['loss_info']:
                    loss_str = (f"Actor Loss: {step_info['loss_info']['actor_loss']:.4f}, "
                               f"Critic Loss: {step_info['loss_info']['critic_loss']:.4f}")
                
                logger.info(
                    f"Iter {iter_idx}/{num_iterations} | "
                    f"Mean Reward: {step_info['mean_reward']:.4f} | "
                    f"Max Reward: {step_info['max_reward']:.4f} | "
                    f"Mean Energy: {step_info['mean_energy']:.4f} | "
                    f"{loss_str}"
                )
            
            # Save checkpoint
            if iter_idx > 0 and iter_idx % self.config.get('save_freq', 100) == 0:
                self.save_checkpoint(f"checkpoint_iter_{iter_idx}.pt")
        
        # Save final model
        logger.info("Training completed")
        self.save_checkpoint("final_model.pt")
        self.save_results()
        
        return self.results
    
    def save_checkpoint(self, filename):
        """Save model checkpoint."""
        checkpoint_dir = self.config.get('checkpoint_dir', './checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        checkpoint = {
            'policy_state_dict': self.policy_net.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'config': self.config,
            'iteration': len(self.results['rewards']),
        }
        
        if self.critic is not None:
            checkpoint['critic_state_dict'] = self.critic.state_dict()
        
        torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
        logger.info(f"Saved checkpoint to {os.path.join(checkpoint_dir, filename)}")
    
    def save_results(self):
        """Save training results and plots."""
        results_dir = self.config.get('results_dir', './results')
        os.makedirs(results_dir, exist_ok=True)
        
        # Save metrics history
        np.save(os.path.join(results_dir, 'training_metrics.npy'), dict(self.results))
        
        # Create and save plots
        self.plot_training_results(os.path.join(results_dir, 'training_plots.png'))
        
        logger.info(f"Saved results to {results_dir}")
    
    def plot_training_results(self, filename):
        """Create plots of training metrics."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Plot average reward
        axes[0, 0].plot(self.results['rewards'])
        axes[0, 0].set_title('Average Reward')
        axes[0, 0].set_xlabel('Iteration')
        axes[0, 0].set_ylabel('Reward')
        
        # Plot formation energy
        axes[0, 1].plot(self.results['formation_energy'])
        axes[0, 1].set_title('Average Formation Energy')
        axes[0, 1].set_xlabel('Iteration')
        axes[0, 1].set_ylabel('Energy (eV)')
        
        # Plot band gap
        axes[1, 0].plot(self.results['band_gap'])
        axes[1, 0].set_title('Average Band Gap')
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Band Gap (eV)')
        
        # Plot topological score
        axes[1, 1].plot(self.results['topo_score'])
        axes[1, 1].set_title('Average Topological Score')
        axes[1, 1].set_xlabel('Iteration')
        axes[1, 1].set_ylabel('Z2 Index (ν₀)')
        
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()


class PolicyNetwork(nn.Module):
    """Policy network for RL-based latent space exploration."""
    
    def __init__(self, latent_dim, hidden_dims=[256, 256]):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer for mean
        self.fc_layers = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(input_dim, latent_dim)
        
        # Learnable log std for exploration
        self.log_std = nn.Parameter(torch.zeros(latent_dim))
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, z_noise):
        """Forward pass through the policy network."""
        x = self.fc_layers(z_noise)
        mu = self.fc_mu(x)
        
        # Get standard deviation from learnable parameter
        std = torch.exp(self.log_std.clamp(-20, 2))  # Clamp for stability
        
        # Create normal distribution
        dist = Normal(mu, std)
        
        # Sample using reparameterization trick
        z_sampled = dist.rsample()
        
        # Calculate log probabilities
        log_probs = dist.log_prob(z_sampled).sum(dim=-1)
        
        return z_sampled, log_probs


class CriticNetwork(nn.Module):
    """Critic network for actor-critic method."""
    
    def __init__(self, latent_dim, hidden_dims=[256, 128]):
        super().__init__()
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer - single value output
        layers.append(nn.Linear(input_dim, 1))
        
        self.model = nn.Sequential(*layers)
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, z):
        """Forward pass through the critic network."""
        return self.model(z)


class PropertyPredictor(nn.Module):
    """Neural network for predicting material properties from latent space."""
    
    def __init__(self, latent_dim, hidden_dims=[128, 64], output_dim=1, name=None):
        super().__init__()
        
        self.name = name
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer
        layers.append(nn.Linear(input_dim, output_dim))
        
        self.model = nn.Sequential(*layers)
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, z):
        """Predict property from latent vector."""
        return self.model(z)


# Default configuration
def get_default_config():
    return {
        # Model dimensions
        'latent_dim': 32,
        'hidden_dim': 128,
        'max_atoms': 24,
        
        # Elements to consider - common in topological insulators
        'elements': ['Bi', 'Sb', 'Te', 'Se', 'Sn', 'Ge', 'Pb', 'O', 'S'],
        
        # Network architectures
        'policy_hidden_dims': [256, 256],
        'critic_hidden_dims': [256, 128],
        'property_hidden_dims': [128, 64],
        
        # Property prediction dimensions
        'band_structure_dim': 5,  # Band gap + other band structure features
        'topo_invariant_dim': 4,  # Z2 invariants (ν₀;ν₁ν₂ν₃)
        
        # Training parameters
        'batch_size': 32,
        'num_iterations': 500,
        'policy_lr': 1e-4,
        'critic_lr': 3e-4,
        'weight_decay': 1e-6,
        'use_critic': True,
        'clip_grad': True,
        'max_grad_norm': 1.0,
        
        # Reward components
        'stability_threshold': 0.2,
        'target_band_gap': 0.3,  # Target band gap in eV
        'gap_tolerance': 0.2,    # Acceptable deviation from target
        
        # Reward weights
        'w_stability': 1.0,
        'w_gap': 1.5,
        'w_topo': 2.0,
        
        # Logging and checkpoints
        'log_freq': 10,
        'save_freq': 100,
        'checkpoint_dir': './checkpoints',
        'results_dir': './results'
    }


# Example usage
if __name__ == "__main__":
    import random
    
    # Set seeds for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    # Get configuration
    config = get_default_config()
    
    # Create generator
    ti_generator = TopologicalInsulatorGenerator(config)
    
    # Train model
    results = ti_generator.train(num_iterations=500)
    
    # Generate and evaluate final structures
    final_structures, z_vectors, _ = ti_generator.generate_structures(batch_size=10)
    properties = ti_generator.predict_properties(z_vectors)
    rewards = ti_generator.calculate_rewards(properties)
    
    # Print best structure information
    best_idx = np.argmax(rewards['total'])
    print("\nBest Generated Structure:")
    print(f"- Formation Energy: {properties['formation_energy'][best_idx].item():.4f} eV")
    print(f"- Band Gap: {properties['band_gap'][best_idx].item():.4f} eV")
    print(f"- Z2 Invariants: {properties['topo_invariants'][best_idx].cpu().numpy()}")
    print(f"- Total Reward: {rewards['total'][best_idx]:.4f}")

2025-04-04 17:19:32,548 - INFO - Using device: cpu


ModuleNotFoundError: No module named 'torch._C._dynamo.guards'; 'torch._C._dynamo' is not a package

In [4]:
import torch
print(torch.__version__


2.6.0


<function torch._VariableFunctionsClass.fake_quantize_per_tensor_affine>