-
Notifications
You must be signed in to change notification settings - Fork 227
/
cgcnn.py
230 lines (201 loc) · 7.93 KB
/
cgcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.nn.models.schnet import GaussianSmearing
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import conditional_grad
from ocpmodels.datasets.embeddings import KHOT_EMBEDDINGS, QMOF_KHOT_EMBEDDINGS
from ocpmodels.models.base import BaseModel
@registry.register_model("cgcnn")
class CGCNN(BaseModel):
r"""Implementation of the Crystal Graph CNN model from the
`"Crystal Graph Convolutional Neural Networks for an Accurate
and Interpretable Prediction of Material Properties"
<https://arxiv.org/abs/1710.10324>`_ paper.
Args:
num_atoms (int): Number of atoms.
bond_feat_dim (int): Dimension of bond features.
num_targets (int): Number of targets to predict.
use_pbc (bool, optional): If set to :obj:`True`, account for periodic boundary conditions.
(default: :obj:`True`)
regress_forces (bool, optional): If set to :obj:`True`, predict forces by differentiating
energy with respect to positions.
(default: :obj:`True`)
atom_embedding_size (int, optional): Size of atom embeddings.
(default: :obj:`64`)
num_graph_conv_layers (int, optional): Number of graph convolutional layers.
(default: :obj:`6`)
fc_feat_size (int, optional): Size of fully connected layers.
(default: :obj:`128`)
num_fc_layers (int, optional): Number of fully connected layers.
(default: :obj:`4`)
otf_graph (bool, optional): If set to :obj:`True`, compute graph edges on the fly.
(default: :obj:`False`)
cutoff (float, optional): Cutoff distance for interatomic interactions.
(default: :obj:`10.0`)
num_gaussians (int, optional): Number of Gaussians used for smearing.
(default: :obj:`50.0`)
"""
def __init__(
self,
num_atoms: int,
bond_feat_dim: int,
num_targets: int,
use_pbc: bool = True,
regress_forces: bool = True,
atom_embedding_size: int = 64,
num_graph_conv_layers: int = 6,
fc_feat_size: int = 128,
num_fc_layers: int = 4,
otf_graph: bool = False,
cutoff: float = 6.0,
num_gaussians: int = 50,
embeddings: str = "khot",
) -> None:
super(CGCNN, self).__init__(num_atoms, bond_feat_dim, num_targets)
self.regress_forces = regress_forces
self.use_pbc = use_pbc
self.cutoff = cutoff
self.otf_graph = otf_graph
self.max_neighbors = 50
# Get CGCNN atom embeddings
if embeddings == "khot":
embeddings = KHOT_EMBEDDINGS
elif embeddings == "qmof":
embeddings = QMOF_KHOT_EMBEDDINGS
else:
raise ValueError(
'embedding mnust be either "khot" for original CGCNN K-hot elemental embeddings or "qmof" for QMOF K-hot elemental embeddings'
)
self.embedding = torch.zeros(100, len(embeddings[1]))
for i in range(100):
self.embedding[i] = torch.tensor(embeddings[i + 1])
self.embedding_fc = nn.Linear(len(embeddings[1]), atom_embedding_size)
self.convs = nn.ModuleList(
[
CGCNNConv(
node_dim=atom_embedding_size,
edge_dim=bond_feat_dim,
cutoff=cutoff,
)
for _ in range(num_graph_conv_layers)
]
)
self.conv_to_fc = nn.Sequential(
nn.Linear(atom_embedding_size, fc_feat_size), nn.Softplus()
)
if num_fc_layers > 1:
layers = []
for _ in range(num_fc_layers - 1):
layers.append(nn.Linear(fc_feat_size, fc_feat_size))
layers.append(nn.Softplus())
self.fcs = nn.Sequential(*layers)
self.fc_out = nn.Linear(fc_feat_size, self.num_targets)
self.cutoff = cutoff
self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)
@conditional_grad(torch.enable_grad())
def _forward(self, data):
# Get node features
if self.embedding.device != data.atomic_numbers.device:
self.embedding = self.embedding.to(data.atomic_numbers.device)
data.x = self.embedding[data.atomic_numbers.long() - 1]
(
edge_index,
distances,
distance_vec,
cell_offsets,
_, # cell offset distances
neighbors,
) = self.generate_graph(data)
data.edge_index = edge_index
data.edge_attr = self.distance_expansion(distances)
# Forward pass through the network
mol_feats = self._convolve(data)
mol_feats = self.conv_to_fc(mol_feats)
if hasattr(self, "fcs"):
mol_feats = self.fcs(mol_feats)
energy = self.fc_out(mol_feats)
return energy
def forward(self, data):
if self.regress_forces:
data.pos.requires_grad_(True)
energy = self._forward(data)
if self.regress_forces:
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
return energy, forces
else:
return energy
def _convolve(self, data):
"""
Returns the output of the convolution layers before they are passed
into the dense layers.
"""
node_feats = self.embedding_fc(data.x)
for f in self.convs:
node_feats = f(node_feats, data.edge_index, data.edge_attr)
mol_feats = global_mean_pool(node_feats, data.batch)
return mol_feats
class CGCNNConv(MessagePassing):
"""Implements the message passing layer from
`"Crystal Graph Convolutional Neural Networks for an
Accurate and Interpretable Prediction of Material Properties"
<https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`.
"""
def __init__(
self, node_dim, edge_dim, cutoff: float = 6.0, **kwargs
) -> None:
super(CGCNNConv, self).__init__(aggr="add")
self.node_feat_size = node_dim
self.edge_feat_size = edge_dim
self.cutoff = cutoff
self.lin1 = nn.Linear(
2 * self.node_feat_size + self.edge_feat_size,
2 * self.node_feat_size,
)
self.bn1 = nn.BatchNorm1d(2 * self.node_feat_size)
self.ln1 = nn.LayerNorm(self.node_feat_size)
self.reset_parameters()
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.lin1.weight)
self.lin1.bias.data.fill_(0)
self.bn1.reset_parameters()
self.ln1.reset_parameters()
def forward(self, x, edge_index, edge_attr):
"""
Arguments:
x has shape [num_nodes, node_feat_size]
edge_index has shape [2, num_edges]
edge_attr is [num_edges, edge_feat_size]
"""
out = self.propagate(
edge_index, x=x, edge_attr=edge_attr, size=(x.size(0), x.size(0))
)
out = nn.Softplus()(self.ln1(out) + x)
return out
def message(self, x_i, x_j, edge_attr):
"""
Arguments:
x_i has shape [num_edges, node_feat_size]
x_j has shape [num_edges, node_feat_size]
edge_attr has shape [num_edges, edge_feat_size]
Returns:
tensor of shape [num_edges, node_feat_size]
"""
z = self.lin1(torch.cat([x_i, x_j, edge_attr], dim=1))
z = self.bn1(z)
z1, z2 = z.chunk(2, dim=1)
z1 = nn.Sigmoid()(z1)
z2 = nn.Softplus()(z2)
return z1 * z2