In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import generative_nmti.cdvae as cdvae
from generative_nmti.cdvae.pl_modules.decoder import GemNetTDecoder

ModuleNotFoundError: No module named 'generative_nmti'

In [10]:
import torch_scatter
print(torch_scatter.__version__)
import torch_sparse
print(torch_sparse.__version__)

2.1.2
0.6.18


In [None]:
class LatentPolicyNet(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.log_std = nn.Parameter(torch.zeros(latent_dim))  # Learnable log std

    def forward(self, z_noise):
        mu = self.fc(z_noise)
        std = torch.exp(self.log_std)
        dist = torch.distributions.Normal(mu, std)
        z_sampled = dist.rsample()  # Reparameterized sampling
        return z_sampled, dist.log_prob(z_sampled).sum(dim=-1)

In [None]:
def reinforce_update(policy_net, optimizer, rewards, log_probs):
    rewards = torch.tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)  # Normalize
    loss = -(log_probs * rewards).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
def estimate_formation_energy(structure):
    # Placeholder for formation energy estimation
    return torch.randn(1).item()  # Replace with actual implementation

def predict_magnetic_ordering(structure):
    # Placeholder for magnetic ordering prediction
    return torch.sigmoid(torch.randn(1)).item()  # Replace with actual implementation

decoder = GemNetTDecoder(latent_dim=latent_dim, n_elements=10)  # Example: Adjust n_elements

def decoder(z):
    # Use the CDVAE decoder to generate structures
    generated_structure = decoder(z) # Assuming decoder takes a latent vector z
    return [generated_structure for _ in range(batch_size)]

def reward_function(structure):
    # Custom logic based on simulated structure
    energy = estimate_formation_energy(structure)
    magnetic_score = predict_magnetic_ordering(structure)

    reward = -energy + 2.0 * magnetic_score  # Tunable trade-off
    return reward

In [None]:
# Hyperparameters
num_steps = 100
batch_size = 32
latent_dim = 16
learning_rate = 1e-3

# Initialize policy network and optimizer
policy_net = LatentPolicyNet(latent_dim)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=learning_rate)

for step in range(num_steps):
    z_noise = torch.randn(batch_size, latent_dim)
    z_sampled, log_probs = policy_net(z_noise)

    # Decode structure from CDVAE decoder
    generated_structures = decoder(z_sampled)

    # Score each generated structure
    rewards = []
    for structure in generated_structures:
        reward = reward_function(structure)
        rewards.append(reward)

    # Update policy using REINFORCE
    loss = reinforce_update(policy_net, optimizer, rewards, log_probs)

    print(f"Step {step} | Avg Reward: {sum(rewards)/len(rewards):.3f} | Policy Loss: {loss:.4f}")