# Implementing a Minimalist SchNet Model in PyTorch

This code provides a step-by-step guide to implementing a minimalist version of **SchNet**, 
a deep learning architecture designed for modeling molecular properties, including the prediction 
of molecular force fields. The implementation focuses on predicting molecular energies 
and deriving forces from these predictions.

---

## **Cell 2: Markdown**

```markdown
## 1. Setting Up the Environment

First, ensure you have the necessary libraries installed. We will use **PyTorch** for building 
the neural network and **PyTorch Geometric (PyG)** for handling graph-based data structures.


In [None]:
# Install PyTorch
# You can install PyTorch by following the instructions on https://pytorch.org/get-started/locally/
# Here's a generic installation command. Please adjust it based on your CUDA version.

# !pip install torch torchvision torchaudio

# Install PyTorch Geometric and its dependencies
# PyTorch Geometric has specific installation requirements. Visit https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html
# For simplicity, here's a common installation command for CPU:

# !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric

# For this minimalist implementation, we'll proceed assuming PyTorch and PyTorch Geometric are installed.


## 2. Importing Necessary Libraries

We'll import the required libraries for building the SchNet model, handling graph data, and performing mathematical operations.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import radius_graph
from torch_scatter import scatter_add
from torch_geometric.data import Data#, DataLoader
from torch_geometric.loader import DataLoader


  Referenced from: <1A601F96-0008-31B3-901F-1052916C6247> /opt/anaconda3/envs/jax+torch_env/lib/python3.12/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <2249AA27-8E80-3F21-8216-1A25FE6B42A4> /opt/anaconda3/envs/jax+torch_env/lib/python3.12/site-packages/torch/lib/libtorch_cpu.dylib


## 3. Defining the SchNet Model Components

We'll define the core components of the SchNet model, including atom embeddings, interaction blocks, and the overall SchNet architecture.


In [None]:
class AtomEmbedding(nn.Module):
    def __init__(self, num_atom_types, embedding_dim):
        super(AtomEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_atom_types, embedding_dim)

    def forward(self, atom_types):
        return self.embedding(atom_types)


In [None]:
class InteractionBlock(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_filters):
        super(InteractionBlock, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        # Radial Basis Functions for distance encoding
        self.rbf = nn.Linear(1, num_filters)

        # MLP to generate filters based on distances
        self.mlp = nn.Sequential(
            nn.Linear(num_filters, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

        # MLP for updating atom embeddings
        self.update_mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )

    def forward(self, x, edge_index, edge_distance):
        # Encode distances using radial basis functions
        rbf = self.rbf(edge_distance.unsqueeze(1))  # Shape: [num_edges, num_filters]
        filters = self.mlp(rbf)  # Shape: [num_edges, embedding_dim]

        # Message passing: multiply neighbor embeddings by filters
        messages = x[edge_index[0]] * filters  # Shape: [num_edges, embedding_dim]

        # Aggregate messages
        out = scatter_add(messages, edge_index[1], dim=0)  # Shape: [num_nodes, embedding_dim]

        # Update atom embeddings
        out = self.update_mlp(out)
        return x + out  # Residual connection


In [None]:
class SchNet(nn.Module):
    def __init__(self, num_atom_types, embedding_dim=128, hidden_dim=128, num_filters=64, num_interactions=3):
        super(SchNet, self).__init__()
        self.atom_embedding = AtomEmbedding(num_atom_types, embedding_dim)
        self.interactions = nn.ModuleList([
            InteractionBlock(embedding_dim, hidden_dim, num_filters) for _ in range(num_interactions)
        ])
        self.output_mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Predict scalar energy
        )

    def forward(self, data):
        """
        data should contain:
            - x: Atom types [num_atoms]
            - pos: Atom positions [num_atoms, 3]
            - batch: Batch indices [num_atoms]
        """
        atom_types = data.x
        pos = data.pos  # [num_atoms, 3]
        batch = data.batch  # [num_atoms]

        # Initial atom embeddings
        x = self.atom_embedding(atom_types)  # [num_atoms, embedding_dim]

        # Create a radius graph (edges based on distance)
        edge_index = radius_graph(pos, r=5.0, loop=False)  # Adjust radius as needed
        # Compute distances for edges
        edge_distance = (pos[edge_index[0]] - pos[edge_index[1]]).norm(p=2, dim=1)

        # Interaction blocks
        for interaction in self.interactions:
            x = interaction(x, edge_index, edge_distance)

        # Aggregate atom embeddings to predict total energy
        energy = self.output_mlp(x)  # [num_atoms, 1]
        energy = scatter_add(energy, batch, dim=0)  # [batch_size, 1]
        return energy.squeeze()


## 4. Preparing the Dataset

For demonstration purposes, we'll create a synthetic dataset consisting of multiple identical water molecules. In practice, you'd use real molecular datasets like QM9.


In [13]:
# Example: Create a single water molecule
def create_water_molecule():
    # Atom types: Oxygen (0), Hydrogen (1)
    atom_types = torch.tensor([0, 1, 1], dtype=torch.long)
    # Positions in angstroms
    pos = torch.tensor([
        [0.0, 0.0, 0.0],        # Oxygen
        [0.96, 0.0, 0.0],       # Hydrogen 1
        [-0.24, 0.93, 0.0]      # Hydrogen 2
    ], dtype=torch.float)
    # Assume total energy is -76.0 eV (for example)
    energy = torch.tensor([-76.0], dtype=torch.float)
    data = Data(x=atom_types, pos=pos, y=energy)
    return data

# Create a dataset with multiple identical water molecules
dataset = [create_water_molecule() for _ in range(100)]
loader = DataLoader(dataset, batch_size=10, shuffle=True)


In [22]:
dataset[0]["pos"]

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.9600,  0.0000,  0.0000],
        [-0.2400,  0.9300,  0.0000]])

In [14]:
batch = next(iter(loader)); batch

DataBatch(x=[30], y=[10], pos=[30, 3], batch=[30], ptr=[11])

In [16]:
batch["x"]

tensor([0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 1, 1, 0, 1, 1])

In [15]:
batch["ptr"]

tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27, 30])

## 5. Training the SchNet Model

We'll define a simple training loop to train the model on the synthetic dataset. The model will learn to predict the total energy of the molecules.


In [None]:
# Define the model
num_atom_types = 2  # Oxygen and Hydrogen
model = SchNet(num_atom_types=num_atom_types)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        energy_pred = model(batch)  # [batch_size]
        energy_true = batch.y  # [batch_size]
        loss = criterion(energy_pred, energy_true)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    average_loss = total_loss / len(dataset)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")


Epoch 10/1000, Loss: 0.0280
Epoch 20/1000, Loss: 0.0000
Epoch 30/1000, Loss: 0.0000
Epoch 40/1000, Loss: 0.0000
Epoch 50/1000, Loss: 0.0000
Epoch 60/1000, Loss: 0.0000
Epoch 70/1000, Loss: 0.0000
Epoch 80/1000, Loss: 0.0000
Epoch 90/1000, Loss: 0.0000
Epoch 100/1000, Loss: 0.0000
Epoch 110/1000, Loss: 0.0000
Epoch 120/1000, Loss: 0.0000
Epoch 130/1000, Loss: 0.0000
Epoch 140/1000, Loss: 0.0000
Epoch 150/1000, Loss: 0.0000
Epoch 160/1000, Loss: 0.0000
Epoch 170/1000, Loss: 0.0000
Epoch 180/1000, Loss: 0.0000
Epoch 190/1000, Loss: 0.0000
Epoch 200/1000, Loss: 0.0000
Epoch 210/1000, Loss: 0.0000
Epoch 220/1000, Loss: 0.0000
Epoch 230/1000, Loss: 0.0000
Epoch 240/1000, Loss: 0.0000
Epoch 250/1000, Loss: 0.0000
Epoch 260/1000, Loss: 0.0000
Epoch 270/1000, Loss: 0.0000
Epoch 280/1000, Loss: 0.0000
Epoch 290/1000, Loss: 0.0000
Epoch 300/1000, Loss: 0.0000
Epoch 310/1000, Loss: 0.0000
Epoch 320/1000, Loss: 0.0000
Epoch 330/1000, Loss: 0.0000
Epoch 340/1000, Loss: 0.0000
Epoch 350/1000, Loss: 0

## 6. Predicting Forces

To compute forces, we'll take the negative gradient of the predicted energy with respect to atomic positions. This leverages PyTorch's automatic differentiation capabilities.


In [16]:
def predict_forces(model, data):
    model.eval()
    pos = data.pos.clone().detach().requires_grad_(True)  # Enable gradient computation
    batch = torch.tensor([0]*data.num_nodes, dtype=torch.long)  # Single molecule batch
    data_with_grad = Data(x=data.x, pos=pos, batch=batch)
    energy = model(data_with_grad)
    print(f"energy {energy.detach()}")
    energy.backward()
    forces = -pos.grad  # Forces are negative gradients of energy
    return forces.detach()

# Example usage
test_data = create_water_molecule()
predicted_forces = predict_forces(model, test_data)
print("Predicted Forces (eV/Å):")
print(predicted_forces)


energy -1.7082438468933105
Predicted Forces (eV/Å):
tensor([[-0.2111, -0.2725, -0.0000],
        [ 0.5674, -0.2217, -0.0000],
        [-0.3564,  0.4942, -0.0000]])


## 7. Explanation of the Implementation

### a. Atom Embeddings

- **Purpose:** Convert discrete atom types into continuous vector representations.
- **Implementation:** `nn.Embedding` layer maps each atom type to a learnable embedding vector.

### b. Interaction Blocks

- **Purpose:** Capture the interactions between atoms based on their spatial relationships.
- **Components:**
  - **Radial Basis Functions (RBF):** Encode interatomic distances into a higher-dimensional space.
  - **Filter MLP:** Generates continuous filters based on the encoded distances.
  - **Message Passing:** Each atom receives messages from its neighbors, weighted by the filters.
  - **Update MLP:** Updates atom embeddings based on the aggregated messages.
  - **Residual Connection:** Adds the updated embeddings back to the original embeddings to facilitate training.

### c. SchNet Model

- **Flow:**
  1. **Embedding:** Convert atom types to embeddings.
  2. **Edge Creation:** Construct a graph where edges represent interatomic distances within a specified radius.
  3. **Interaction Blocks:** Apply multiple interaction blocks to refine atom embeddings.
  4. **Energy Prediction:** Use an MLP to predict the total energy from atom embeddings by aggregating them (e.g., summing).

### d. Force Calculation

- **Mechanism:** Forces are derived by differentiating the energy with respect to atomic positions.
- **Implementation:** By setting `requires_grad=True` for positions and calling `backward()` on the energy, PyTorch computes the gradients automatically.


## 8. Extending the Implementation

This minimalist implementation can be extended in various ways to better capture the complexities of molecular systems:

1. **Distance Encoding:** Implement more sophisticated radial basis functions or include angular information.
2. **Many-Body Interactions:** Incorporate higher-order interactions beyond pairwise.
3. **Layer Normalization:** Add normalization layers to stabilize training.
4. **Batch Handling:** Improve batching mechanisms for larger and more diverse datasets.
5. **Loss Functions:** Incorporate additional loss terms for force predictions to train the model end-to-end.


## 9. Resources for Further Reference

- **Original SchNet Paper:** [Schütt et al., "SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions", 2017](https://arxiv.org/abs/1706.08566)
- **PyTorch Geometric:** [https://pytorch-geometric.readthedocs.io/](https://pytorch-geometric.readthedocs.io/)
- **SchNetPack:** A PyTorch-based package for atomistic simulations using SchNet and related models. [SchNetPack GitHub](https://github.com/atomistic-machine-learning/schnetpack)


## 10. Conclusion

This minimalist SchNet implementation provides a foundational understanding of how molecular properties can be predicted using graph-based neural networks in PyTorch. While simplified, it captures the essence of SchNet's approach to embedding atom types, modeling interactions, and predicting energies and forces. From here, you can enhance the model's complexity and adapt it to real-world datasets and applications in computational chemistry and materials science.

Feel free to ask if you have any questions or need further assistance with specific parts of the implementation!
