-
Notifications
You must be signed in to change notification settings - Fork 16
/
schnet.py
148 lines (130 loc) · 4.95 KB
/
schnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional, Set, Union
import torch
import torch_scatter
from graphein.protein.tensor.data import ProteinBatch
from torch_geometric.data import Batch
from torch_geometric.nn.models import SchNet
from proteinworkshop.types import EncoderOutput
class SchNetModel(SchNet):
def __init__(
self,
hidden_channels: int = 128,
out_dim: int = 1,
num_filters: int = 128,
num_layers: int = 6,
num_gaussians: int = 50,
cutoff: float = 10,
max_num_neighbors: int = 32,
readout: str = "add",
dipole: bool = False,
mean: Optional[float] = None,
std: Optional[float] = None,
atomref: Optional[torch.Tensor] = None,
):
"""
Initializes an instance of the SchNetModel class with the provided
parameters.
:param hidden_channels: Number of channels in the hidden layers
(default: ``128``)
:type hidden_channels: int
:param out_dim: Output dimension of the model (default: ``1``)
:type out_dim: int
:param num_filters: Number of filters used in convolutional layers
(default: ``128``)
:type num_filters: int
:param num_layers: Number of convolutional layers in the model
(default: ``6``)
:type num_layers: int
:param num_gaussians: Number of Gaussian functions used for radial
filters (default: ``50``)
:type num_gaussians: int
:param cutoff: Cutoff distance for interactions (default: ``10``)
:type cutoff: float
:param max_num_neighbors: Maximum number of neighboring atoms to
consider (default: ``32``)
:type max_num_neighbors: int
:param readout: Global pooling method to be used (default: ``"add"``)
:type readout: str
"""
super().__init__(
hidden_channels,
num_filters,
num_layers,
num_gaussians,
cutoff, # None, # Interaction graph is not used
max_num_neighbors,
readout,
dipole,
mean,
std,
atomref,
)
self.readout = readout
# Overwrite embbeding
self.embedding = torch.nn.LazyLinear(hidden_channels)
# Overwrite atom embedding and final predictor
self.lin2 = torch.nn.LazyLinear(out_dim)
@property
def required_batch_attributes(self) -> Set[str]:
"""
Required batch attributes for this encoder.
- ``x``: Node features (shape: :math:`(n, d)`)
- ``pos``: Node positions (shape: :math:`(n, 3)`)
- ``edge_index``: Edge indices (shape: :math:`(2, e)`)
- ``batch``: Batch indices (shape: :math:`(n,)`)
:return: Set of required batch attributes
:rtype: Set[str]
"""
return {"pos", "edge_index", "x", "batch"}
def forward(self, batch: Union[Batch, ProteinBatch]) -> EncoderOutput:
"""Implements the forward pass of the SchNet encoder.
Returns the node embedding and graph embedding in a dictionary.
:param batch: Batch of data to encode.
:type batch: Union[Batch, ProteinBatch]
:return: Dictionary of node and graph embeddings. Contains
``node_embedding`` and ``graph_embedding`` fields. The node
embedding is of shape :math:`(|V|, d)` and the graph embedding is
of shape :math:`(n, d)`, where :math:`|V|` is the number of nodes
and :math:`n` is the number of graphs in the batch and :math:`d` is
the dimension of the embeddings.
:rtype: EncoderOutput
"""
h = self.embedding(batch.x)
u, v = batch.edge_index
edge_weight = (batch.pos[u] - batch.pos[v]).norm(dim=-1)
edge_attr = self.distance_expansion(edge_weight)
for interaction in self.interactions:
h = h + interaction(h, batch.edge_index, edge_weight, edge_attr)
h = self.lin1(h)
h = self.act(h)
h = self.lin2(h)
return EncoderOutput(
{
"node_embedding": h,
"graph_embedding": torch_scatter.scatter(
h, batch.batch, dim=0, reduce=self.readout
),
}
)
if __name__ == "__main__":
import hydra
import omegaconf
import pyrootutils
from graphein.protein.tensor.data import get_random_protein
root = pyrootutils.setup_root(__file__, pythonpath=True)
cfg = omegaconf.OmegaConf.load(
root / "configs" / "encoder" / "schnet.yaml"
)
print(cfg)
encoder = hydra.utils.instantiate(cfg.schnet)
print(encoder)
batch = ProteinBatch().from_protein_list(
[get_random_protein() for _ in range(4)], follow_batch=["coords"]
)
batch.batch = batch.coords_batch
batch.edges("knn_8", cache="edge_index")
batch.pos = batch.coords[:, 1, :]
batch.x = batch.residue_type
print(batch)
out = encoder.forward(batch)
print(out)