In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

# Simulating a differentiable feature computation
def differentiable_feature_computation(atomic_numbers, positions):
    # Example F, P, S, H are now matrices, derived from the positions
    F = torch.matmul(positions, positions.T)  # Example F feature (matrix of dot products)
    P = positions @ positions.T               # Example P feature (same as F in this case)
    S = positions * positions                 # Example S feature (element-wise product)
    H = torch.mean(positions, dim=0, keepdim=True) @ positions.T  # Example H feature (mean with matmul)
    return F, P, S, H

# Class to cache FPSH features
class MolFeatureWithCache:
    def __init__(self, atomic_numbers, positions):
        self.atomic_numbers = atomic_numbers
        self.positions = positions.requires_grad_()  # Positions need gradients
        self._cached_fpsh = None  # This will store the cached FPSH matrices

    def compute_or_cache_features(self, cache_enabled=True):
        if self._cached_fpsh is None or not cache_enabled:
            # Compute FPSH features for the first time or if caching is disabled
            for _ in range(10000):
                F, P, S, H = differentiable_feature_computation(self.atomic_numbers, self.positions)
            self._cached_fpsh = (F, P, S, H)  # Cache the computed features
        return self._cached_fpsh

    def clear_cache(self):
        # Clears the cache (only if positions change)
        self._cached_fpsh = None

# A simple ML model that takes FPSH features as input and predicts scalar energy
class SimpleMLModel(nn.Module):
    def __init__(self):
        super(SimpleMLModel, self).__init__()
        self.fc = nn.Linear(27, 1)  # Assume flattened FPSH concatenates into a vector of size 9 * 4

    def forward(self, F, P, S, H):
        F_flat = F.view(-1)  # Flatten the F matrix
        P_flat = P.view(-1)  # Flatten the P matrix
        S_flat = S.view(-1)  # Flatten the S matrix
        H_flat = H.view(-1)  # Flatten the H matrix

        # Concatenate all features into one long vector
        features = torch.cat([F_flat, P_flat, S_flat, H_flat], dim=0)
        energy = self.fc(features.unsqueeze(0))  # Add batch dimension for the Linear layer
        return energy

# Create a toy dataset of atomic positions and atomic numbers
def create_toy_dataset(n_molecules):
    dataset = []
    for i in range(n_molecules):
        # Random atomic positions for 3 atoms in 2D space
        positions = torch.randn(3, 2)
        # Random atomic numbers (just for the sake of it, but not used here)
        atomic_numbers = torch.randint(1, 10, (3,))
        mol_feature = MolFeatureWithCache(atomic_numbers, positions)
        dataset.append(mol_feature)
    return dataset

# Training loop
def train(model, dataset, epochs=1, learning_rate=1e-3, cache_enabled=True):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    positions_grads = []
    for epoch in range(epochs):
        for mol_feature in dataset:
            optimizer.zero_grad()

            # Get cached or compute FPSH features
            F, P, S, H = mol_feature.compute_or_cache_features(cache_enabled=cache_enabled)

            # Forward pass through the ML model to compute energy
            energy = model(F, P, S, H)

            # Loss function: Minimize energy
            loss = torch.sum(energy ** 2)

            # Backpropagate energy to compute dE/dW
            loss.backward(retain_graph=True)

            # Store the gradients of the positions
            positions_grads.append(copy.deepcopy(mol_feature.positions.grad.clone()))

            # Update the model parameters
            optimizer.step()

    return positions_grads

# Compare gradients with and without caching
def compare_grads(grads_with_cache, grads_without_cache):
    for i, (grad_cache, grad_no_cache) in enumerate(zip(grads_with_cache, grads_without_cache)):
        if torch.allclose(grad_cache, grad_no_cache):
            print(f"Gradients for molecule {i} are the same with and without caching.")
        else:
            print(f"Gradients for molecule {i} differ with and without caching.")

# Main function to test the example
def main():
    # Create toy dataset with 5 molecules
    dataset = create_toy_dataset(n_molecules=2)

    # Initialize the ML model
    model = SimpleMLModel()

    grads_with_cache = train(model, dataset, epochs=1, cache_enabled=True)

    # Train the model without caching
    grads_without_cache = train(model, dataset, epochs=1, cache_enabled=False)

    print(f"grads_with_cache: {grads_with_cache}")
    print(f"grads_without_cache: {grads_without_cache}")

    # Compare the gradients
    compare_grads(grads_with_cache, grads_without_cache)

if __name__ == "__main__":
    main()


grads_with_cache: [tensor([[-0.0096, -0.5187],
        [ 0.1414, -0.2812],
        [-0.0567,  0.2288]]), tensor([[ 0.0724, -0.2029],
        [ 0.2140, -0.0736],
        [-0.0720,  0.1057]])]
grads_without_cache: [tensor([[-0.0316, -0.9845],
        [ 0.2710, -0.5067],
        [-0.1156,  0.4451]]), tensor([[ 0.1361, -0.3991],
        [ 0.4105, -0.1482],
        [-0.1392,  0.1871]])]
Gradients for molecule 0 differ with and without caching.
Gradients for molecule 1 differ with and without caching.
