In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import Normal, kl_divergence
import logging
import time
import tqdm
import pickle
import os
import random

import sys


sys.path.insert(0, '/Users/abiralshakya/Documents/Research/Topological_Insulators_OnGithub/generative_nmti/cdvae')
import cdvae

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class CDVAE_TI_Generator:
    """
    Crystal Diffusion Variational Autoencoder (CDVAE) with Reinforcement Learning
    for targeted generation of Topological Insulator materials.
    """
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch.float32 #set default tensor type to float32
        logger.info(f"Using device: {self.device}")
        
        # Initialize CDVAE model components
        self.initialize_models()
        
        # Set up optimizers
        self.setup_optimizers()
        
        # Initialize results tracking
        self.results = {
            'rewards': [],
            'z_gap': [],
            'topological_indices': [],
            'formation_energies': [],
            'best_structures': [],
            'best_rewards': [],
        }
        
        # Initialize replay buffer for experience replay
        self.replay_buffer = ReplayBuffer(config['buffer_size'])
        
    def initialize_models(self):
        """Initialize CDVAE encoder, decoder and policy networks."""
        # Import specific model classes
        try:
            from cdvae.pl_modules.decoder import GemNetTDecoder
            from cdvae.common.data_utils import StandardScalerTorch
            from cdvae.common.data_utils import ATOM_TYPES
     
        except ImportError:
            logger.error("Failed to import CDVAE modules. Please ensure CDVAE is installed correctly.")
            raise
            
        # Get dimensions and parameters from config
        self.latent_dim = self.config['latent_dim']
        self.n_elements = len(self.config['elements']) if 'elements' in self.config else len(ATOM_TYPES.get())
        
        #TODO: possibly write an encoder file addition to cdvae
        # Initialize encoder (if using pre-trained weights)
        # if self.config.get('use_encoder', False):
        #     self.encoder = GraphEncoder(
        #         hidden_dim=self.config['hidden_dim'],
        #         latent_dim=self.latent_dim,
        #         use_layer_norm=self.config.get('use_layer_norm', True)
        #     ).to(self.device)
            
        #     if self.config.get('encoder_checkpoint'):
        #         self._load_model(self.encoder, self.config['encoder_checkpoint'])
        # else:
        #     self.encoder = None

        self.encoder = None
            
        # Initialize decoder
        self.decoder = GemNetTDecoder(
            latent_dim=self.latent_dim,
            hidden_dim=self.config['hidden_dim'],
            #cutoff=self.config.get('cutoff', 6.0),
            max_neighbors=self.config.get('max_neighbors', 20),
            #use_layer_norm=self.config.get('use_layer_norm', True)
        ).to(self.device)
        
        if self.config.get('decoder_checkpoint'):
            self._load_model(self.decoder, self.config['decoder_checkpoint'])
            
        # Initialize policy network for RL
        self.policy_net = PolicyNetwork(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('policy_hidden_dims', [256, 256]),
            activation=self.config.get('policy_activation', 'relu')
        ).to(self.device)
        
        # Initialize critic network for actor-critic methods
        if self.config.get('use_critic', True):
            self.critic = CriticNetwork(
                latent_dim=self.latent_dim,
                hidden_dims=self.config.get('critic_hidden_dims', [256, 128]),
                activation=self.config.get('critic_activation', 'relu')
            ).to(self.device)
        else:
            self.critic = None
            
        # DFT surrogate models - predict quantum properties directly from latent space
        self.energy_predictor = EnergyPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('energy_predictor_dims', [128, 64])
        ).to(self.device)
        
        self.topological_predictor = TopologicalPredictor(
            latent_dim=self.latent_dim,
            hidden_dims=self.config.get('topo_predictor_dims', [128, 64])
        ).to(self.device)
        
        if self.config.get('surrogate_checkpoint'):
            self._load_surrogate_models(self.config['surrogate_checkpoint'])
        
    def setup_optimizers(self):
        """Set up optimizers for different components."""
        # Policy optimizer
        self.policy_optimizer = torch.optim.Adam(
            self.policy_net.parameters(),
            lr=self.config.get('policy_lr', 1e-4),
            weight_decay=self.config.get('policy_weight_decay', 1e-6)
        )
        
        # Critic optimizer (if using actor-critic)
        if self.critic is not None:
            self.critic_optimizer = torch.optim.Adam(
                self.critic.parameters(),
                lr=self.config.get('critic_lr', 3e-4),
                weight_decay=self.config.get('critic_weight_decay', 1e-6)
            )
        
        # Surrogate model optimizers for fine-tuning
        if self.config.get('train_surrogates', False):
            self.energy_optimizer = torch.optim.Adam(
                self.energy_predictor.parameters(),
                lr=self.config.get('surrogate_lr', 1e-4)
            )
            
            self.topo_optimizer = torch.optim.Adam(
                self.topological_predictor.parameters(),
                lr=self.config.get('surrogate_lr', 1e-4)
            )
            
    def _load_model(self, model, checkpoint_path):
        """Load model weights from checkpoint."""
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            if 'state_dict' in checkpoint:
                # Handle pytorch-lightning checkpoints
                state_dict = {k.replace('model.', ''): v for k, v in checkpoint['state_dict'].items() 
                              if k.startswith('model.')}
                model.load_state_dict(state_dict, strict=False)
            else:
                # Handle regular torch checkpoints
                model.load_state_dict(checkpoint, strict=False)
            logger.info(f"Loaded weights from {checkpoint_path}")
        except Exception as e:
            logger.error(f"Failed to load weights: {e}")
            
    def _load_surrogate_models(self, checkpoint_path):
        """Load surrogate model weights."""
        try:
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            self.energy_predictor.load_state_dict(checkpoint['energy_predictor'])
            self.topological_predictor.load_state_dict(checkpoint['topo_predictor'])
            logger.info(f"Loaded surrogate models from {checkpoint_path}")
        except Exception as e:
            logger.error(f"Failed to load surrogate models: {e}")

    
    def generate_structures(self, batch_size=None):
        """Generate crystal structures using the policy network and decoder."""
        if batch_size is None:
            batch_size = self.config.get('batch_size', 32)
        
        # Sample latent vectors from the policy network
        z_noise = torch.randn(batch_size, self.latent_dim, device=self.device, dtype=self.dtype)
        z_sampled, log_probs = self.policy_net(z_noise)
        
        # For testing, let's use a smaller batch size and fewer atoms per crystal
        
        batch_size = z_sampled.shape[0]
        max_atoms  = self.config.get('max_atoms', 10)

        # Start with zeros (or any padding value)
        pred_frac_coords = torch.zeros(batch_size, max_atoms, 3,
                                    device=self.device, dtype=self.dtype)
        pred_atom_types  = torch.zeros(batch_size, max_atoms,
                                    device=self.device, dtype=torch.long)

        
        # Create a batch where each structure has a different number of atoms
        num_atoms = torch.randint(2, max_atoms+1, (batch_size,), device=self.device)
        
        # Lists to store fractional coordinates and atom types for each structure
        frac_coords_list = []
        atom_types_list = []

        for i, n_i in enumerate(num_atoms):
            # Random fractional coordinates for this structure
            frac_coords = torch.rand(n_i, 3, device=self.device, dtype=self.dtype)
            frac_coords_list.append(frac_coords)

            # Random atom types for this structure
            atom_types = torch.randint(0, self.n_elements, (n_i,), device=self.device)
            atom_types_list.append(atom_types)

        # Concatenate the lists to create the final tensors
        frac_coords = torch.cat(frac_coords_list, dim=0)
        atom_types = torch.cat(atom_types_list, dim=0)
        
        # Random unit cell parameters
        lengths = torch.rand(batch_size, 3, device=self.device, dtype=self.dtype) * 5 + 5  # Between 5-10 Å
        angles = torch.rand(batch_size, 3, device=self.device, dtype=self.dtype) * 30 + 90  # Between 90-120°

        for i, n_i in enumerate(num_atoms):
            # random coords for this sample
            pred_frac_coords[i, :n_i] = torch.rand(n_i, 3,
                                                device=self.device,
                                                dtype=self.dtype)
            # random atom types
            pred_atom_types[i, :n_i] = torch.randint(
                0, self.n_elements, (n_i,),
                device=self.device
            )
    
        
        # Generate structures using the decoder
        with torch.no_grad():
            try:
                pred_cart_coord_diff, pred_atom_types = self.decoder(
                    z_sampled,
                    frac_coords,
                    atom_types,
                    num_atoms,
                    lengths,
                    angles
                )
                
                # Combine the results
                generated_structures = {
                    'frac_coords': frac_coords,
                    'atom_types': atom_types,
                    'num_atoms': num_atoms,
                    'lengths': lengths,
                    'angles': angles,
                    'pred_cart_coord_diff': pred_cart_coord_diff,
                    'pred_atom_types': pred_atom_types
                }
                
            except Exception as e:
                print(f"Error in decoder: {e}")
                generated_structures = {
                    'frac_coords': frac_coords,
                    'atom_types': atom_types,
                    'num_atoms': num_atoms,
                    'lengths': lengths,
                    'angles': angles
                }
    
        return generated_structures, z_sampled, log_probs

    def evaluate_structures(self, structures, z_vectors):
        """Evaluate generated structures using surrogate models."""
        # Predict formation energies
        with torch.no_grad():
            energies = self.energy_predictor(z_vectors)
            
            # Predict topological indices (Z2 invariants, Chern numbers, etc.)
            topo_indices = self.topological_predictor(z_vectors)
            
            # Calculate band gaps (can be part of the topological predictor or separate)
            band_gaps = self.estimate_band_gap(structures, z_vectors)
            
        # Combine predictions into a comprehensive evaluation
        evaluations = {
            'formation_energies': energies.cpu().numpy(),
            'topological_indices': topo_indices.cpu().numpy(),
            'band_gaps': band_gaps.cpu().numpy() if isinstance(band_gaps, torch.Tensor) else band_gaps
        }
        
        return evaluations

    def calculate_rewards(self, evaluations):
        """Calculate rewards based on desired material properties."""
        # Extract evaluations
        energies = evaluations['formation_energies']
        topo_indices = evaluations['topological_indices']
        band_gaps = evaluations['band_gaps']
        
        # Convert to numpy for easier manipulation
        if isinstance(energies, torch.Tensor):
            energies = energies.cpu().numpy().squeeze()
        if isinstance(topo_indices, torch.Tensor):
            topo_indices = topo_indices.cpu().numpy()
        if isinstance(band_gaps, torch.Tensor):
            band_gaps = band_gaps.cpu().numpy().squeeze()
            
        # Calculate stability reward component
        # Lower formation energy is better, but must be below threshold to be stable
        stability_threshold = self.config.get('stability_threshold', 0.1)
        stability_rewards = -energies * (energies < stability_threshold)
        
        # Calculate topological reward component
        # For Z2 invariants, we typically want (1;000) for 3D TIs
        # This is a simplified example - actual implementation depends on how topo_indices are represented
        topo_rewards = np.sum(topo_indices * self.config.get('topo_weights', [2.0, 1.0, 1.0, 1.0]), axis=1)
        
        # Calculate band gap reward component
        # Usually want a moderate band gap (not too small, not too large)
        target_gap = self.config.get('target_band_gap', 0.3)  # in eV
        gap_tolerance = self.config.get('gap_tolerance', 0.2)  # in eV
        gap_rewards = 1.0 - np.minimum(np.abs(band_gaps - target_gap) / gap_tolerance, 1.0)
        
        # Combine reward components with configurable weights
        w_stability = self.config.get('w_stability', 1.0)
        w_topological = self.config.get('w_topological', 2.0)
        w_gap = self.config.get('w_gap', 1.5)
        
        combined_rewards = (w_stability * stability_rewards + 
                           w_topological * topo_rewards +
                           w_gap * gap_rewards)
        
        # Create rewards dictionary
        rewards_dict = {
            'total': combined_rewards,
            'stability': stability_rewards,
            'topological': topo_rewards,
            'band_gap': gap_rewards
        }
        
        return rewards_dict

    def estimate_band_gap(self, structures, z_vectors):
        """Estimate band gaps of structures using a surrogate model."""
        # This would typically be a separate model or part of topological_predictor
        # For simplicity, we'll use a mock implementation
        batch_size = z_vectors.shape[0]
        
        # Mock band gap estimation (replace with actual model)
        # In practice, this would use a trained neural network or other predictor
        gaps = 0.2 + 0.3 * torch.sigmoid(z_vectors[:, 0]) + 0.1 * torch.randn(batch_size).to(self.device)
        
        return gaps

    def reinforce_update(self, rewards, log_probs):
        """Update policy network using REINFORCE algorithm."""
        # Convert to tensor with the right dtype
        rewards_tensor = torch.tensor(rewards, device=self.device, dtype=self.dtype)
        
        # Normalize rewards
        rewards_normalized = (rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)
        
        # Calculate policy loss
        policy_loss = -(log_probs * rewards_normalized).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        
        # Optional gradient clipping
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.policy_net.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
            
        self.policy_optimizer.step()
        
        return policy_loss.item()

    def actor_critic_update(self, z_vectors, rewards, log_probs):
        """Update policy and critic networks using Actor-Critic algorithm."""
        if self.critic is None:
            return self.reinforce_update(rewards, log_probs)
            
        # Convert rewards to tensor with proper dtype
        rewards_tensor = torch.tensor(rewards, device=self.device, dtype=self.dtype)
        
        # Get critic's value predictions
        value_predictions = self.critic(z_vectors).squeeze()
        
        # Calculate advantages
        advantages = rewards_tensor - value_predictions.detach()
        
        # Calculate policy (actor) loss
        policy_loss = -(log_probs * advantages).mean()
        
        # Calculate value (critic) loss
        critic_loss = F.mse_loss(value_predictions, rewards_tensor)
        
        # Option 1: Combine losses and do a single backward pass
        total_loss = policy_loss + critic_loss
        
        # Zero all gradients
        self.policy_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        
        # Single backward pass
        total_loss.backward()
        
        # Apply gradient clipping if needed
        if self.config.get('clip_grad', False):
            torch.nn.utils.clip_grad_norm_(
                self.policy_net.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
            torch.nn.utils.clip_grad_norm_(
                self.critic.parameters(), 
                self.config.get('max_grad_norm', 1.0)
            )
        
        # Update both networks
        self.policy_optimizer.step()
        self.critic_optimizer.step()
        
        return policy_loss.item(), critic_loss.item()
    
  
    def train_step(self):
        """Perform a single training step."""
        # Generate structures
        structures, z_vectors, log_probs = self.generate_structures()
        print(len(structures))

        if not structures: 
            logger.error("no strucutres generated in this step")
            return  {'mean_reward': 0, 'max_reward': 0, 'mean_energy': 0, 'policy_loss': 0}
        
        # Evaluate structures
        evaluations = self.evaluate_structures(structures, z_vectors)
        
        # Calculate rewards
        rewards_dict = self.calculate_rewards(evaluations)
        total_rewards = rewards_dict['total']
        
        # Store experience in replay buffer
        for i in range(len(total_rewards)):
            self.replay_buffer.add(
                z_vectors[i].detach().cpu().numpy(),
                total_rewards[i],
                log_probs[i].detach().cpu().numpy()
            )
        
        # Update policy using actor-critic or REINFORCE
        if self.critic is not None:
            loss_info = self.actor_critic_update(z_vectors, total_rewards, log_probs)
            policy_loss = loss_info[0]
        else:
            policy_loss = self.reinforce_update(total_rewards, log_probs)
            
        # Track best structures
        best_idx = np.argmax(total_rewards)

        if best_idx >= len(total_rewards):
            logger.warning(f"best_idx {best_idx} out of bounds for total_rewards with length {len(total_rewards)}")
            best_idx = 0  # Fallback to first item

        #print(f"Type of total_rewards: {type(total_rewards)}")
        #print(f"Shape of total_rewards: {total_rewards.shape if hasattr(total_rewards, 'shape') else 'no shape attribute'}")
        #print(f"Type of total_rewards[best_idx]: {type(total_rewards[best_idx])}")
            
        # Fix: Access the scalar value directly without trying to index further
        #best_reward = float(total_rewards[best_idx])  # Remove the [0] indexing
        reward_value = total_rewards[best_idx]
        if isinstance(reward_value, np.ndarray):
            if reward_value.size == 1:
                best_reward = float(reward_value.item())
            else:
                # If it's an array with multiple values, take the first one
                best_reward = float(reward_value[0])
        else:
            # If it's already a scalar type (int, float)
            best_reward = float(reward_value)

        # Determine if this iteration’s best is a new overall best
        if len(self.results['best_rewards']) == 0:
            is_new_best = True
        else:
            prev_best = max(self.results['best_rewards'])
            # force a Python bool
            is_new_best = bool(best_reward > prev_best)

        if is_new_best:
            best_structure = {
                'frac_coords': structures['frac_coords'][best_idx].tolist(),
                'atom_types': structures['atom_types'][best_idx].tolist(),
                'lengths': structures['lengths'][best_idx].tolist(),
                'angles': structures['angles'][best_idx].tolist()
            }
            self.results['best_structures'].append(best_structure)
            self.results['best_rewards'].append(best_reward)
            logger.info(f"New best structure found with reward: {best_reward:.4f}")

        # Log results
        self.results['rewards'].append(total_rewards)
        self.results['z_gap'].append(evaluations['band_gaps'])
        self.results['topological_indices'].append(evaluations['topological_indices'])
        self.results['formation_energies'].append(evaluations['formation_energies'])
    # Surrogate model to predict topological invariants (Z2, Chern number) 
    # from latent space.
    

    def __init__(self, latent_dim, hidden_dims=[128, 64], num_invariants=4):
        super().__init__()
        
        # Build network layers
        layers = []
        input_dim = latent_dim
        
        # Build hidden layers
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
            
        # Output layer - multiple values for topological invariants
        # For 3D topological insulators, typically 4 Z2 invariants (ν₀;ν₁ν₂ν₃)
        self.feature_extractor = nn.Sequential(*layers)
        self.invariant_head = nn.Linear(input_dim, num_invariants)
        
    def forward(self, z):
        """Predict topological invariants from latent vector."""
        features = self.feature_extractor(z)
        # Apply sigmoid to constrain outputs between 0 and 1
        # In practice, these would be discretized to 0 or 1 when interpreting
        invariants = torch.sigmoid(self.invariant_head(features))
        return invariants


class ReplayBuffer:
    """Experience replay buffer for more stable training."""
    
    def __init__(self, max_size=10000):
        self.max_size = max_size
        self.buffer = []
        self.position = 0
        
    def add(self, z, reward, log_prob):
        """Add experience to buffer."""
        if len(self.buffer) < self.max_size:
            self.buffer.append(None)
        self.buffer[self.position] = (z, reward, log_prob)
        self.position = (self.position + 1) % self.max_size
        
    def sample(self, batch_size):
        """Sample a batch of experiences."""
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        z, rewards, log_probs = map(np.array, zip(*batch))
        return z, rewards, log_probs
        
    def __len__(self):
        """Return current buffer size."""
        return len(self.buffer)

In [18]:
import cdvae

# Example configuration
def get_default_config():
    """Get default configuration for CDVAE + RL training."""
    return {
        # Model dimensions
        "latent_dim": 64,
        "hidden_dim": 128,
        
        # Elements to consider
        "elements": ["Si", "Ge", "Sn", "Pb", "Bi", "Sb", "Te", "Se", "O"],
        
        # Training parameters
        "batch_size": 32,
        "num_iterations": 1000,
        "policy_lr": 1e-4,
        "critic_lr": 3e-4,
        "surrogate_lr": 1e-4,
        
        # RL parameters
        "use_critic": True,  # Use actor-critic instead of REINFORCE
        "clip_grad": True,
        "max_grad_norm": 1.0,
        "buffer_size": 5000,  # Replay buffer size
        
        # Reward components
        "stability_threshold": 0.1,
        "target_band_gap": 0.3,  # Target band gap in eV
        "gap_tolerance": 0.2,    # Acceptable deviation from target
        "topo_weights": [2.0, 1.0, 1.0, 1.0],  # Weights for Z2 invariants
        
        # Reward weights
        "w_stability": 1.0,
        "w_topological": 2.0,
        "w_gap": 1.5,
        
        # Logging and checkpoints
        "log_frequency": 10,
        "save_frequency": 100,
        "checkpoint_dir": "./checkpoints",
        "results_dir": "./results"
    }

In [19]:
# Example usage
if __name__ == "__main__":
    import random
    import matplotlib.pyplot as plt
    
    # Set seeds for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.set_default_tensor_type(torch.FloatTensor)

    # Get default configuration
    config = get_default_config()
    
    # Create training framework
    # Set default tensor type
    ti_generator = CDVAE_TI_Generator(config)
    
    # Train the model
    ti_generator.train(num_iterations=500)
    
    # Plot training results
    plt.figure(figsize=(12, 8))
    
    # Plot rewards
    plt.subplot(2, 2, 1)
    plt.plot(ti_generator.results['rewards'])
    plt.title('Average Reward')
    plt.xlabel('Iteration')
    plt.ylabel('Reward')
    
    # Plot formation energies
    plt.subplot(2, 2, 2)
    plt.plot(ti_generator.results['formation_energies'])
    plt.title('Average Formation Energy')
    plt.xlabel('Iteration')
    plt.ylabel('Energy (eV)')
    
    # Plot topological indices
    plt.subplot(2, 2, 3)
    plt.plot(ti_generator.results['topological_indices'])
    plt.title('Average Topological Index')
    plt.xlabel('Iteration')
    plt.ylabel('Index Value')
    
    # Plot best rewards
    plt.subplot(2, 2, 4)
    plt.plot(ti_generator.results['best_rewards'])
    plt.title('Best Reward')
    plt.xlabel('Iteration')
    plt.ylabel('Reward')
    
    plt.tight_layout()
    plt.savefig('training_results.png')
    plt.show()
    
    # Generate some final structures
    structures, _, _ = ti_generator.generate_structures()

TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got dict"