Skip to content

Commit

Permalink
comments atomistic-machine-learning#1:SO3 features operations
Browse files Browse the repository at this point in the history
  • Loading branch information
1Bigsunflower committed Aug 15, 2023
1 parent 403dfe8 commit 29d5587
Showing 1 changed file with 36 additions and 33 deletions.
69 changes: 36 additions & 33 deletions src/schnetpack/nn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
]


class RealSphericalHarmonics(nn.Module):
class RealSphericalHarmonics(nn.Module): # 用于生成一批向量的实球谐函数
"""
Generates the real spherical harmonics for a batch of vectors.
Expand All @@ -40,31 +40,32 @@ def __init__(self, lmax: int, dtype_str: str = "float32"):
dtype_str: dtype for spherical harmonics coefficients
"""
super().__init__()
self.lmax = lmax
self.lmax = lmax # 最大角动量

(
powers,
zpow,
cAm,
cBm,
cPi,
) = self._generate_Ylm_coefficients(lmax)
) = self._generate_Ylm_coefficients(lmax) # 计算球谐函数的系数

# 注册缓冲区(buffer),在模型训练过程中被自动优化器所优化
dtype = as_dtype(dtype_str)
self.register_buffer("powers", powers.to(dtype=dtype), False)
self.register_buffer("zpow", zpow.to(dtype=dtype), False)
self.register_buffer("cAm", cAm.to(dtype=dtype), False)
self.register_buffer("cBm", cBm.to(dtype=dtype), False)
self.register_buffer("cPi", cPi.to(dtype=dtype), False)

ls = torch.arange(0, lmax + 1)
nls = 2 * ls + 1
self.lidx = torch.repeat_interleave(ls, nls)
self.midx = torch.cat([torch.arange(-l, l + 1) for l in ls])
ls = torch.arange(0, lmax + 1) # 生成一个从0到lmax的张量ls
nls = 2 * ls + 1 # 新的张量nls,其中每个元素是对应ls中的元素乘以2再加1
self.lidx = torch.repeat_interleave(ls, nls) # ls对应的m值列表self.lidx
self.midx = torch.cat([torch.arange(-l, l + 1) for l in ls]) # 角动量指标列表midx

self.register_buffer("flidx", self.lidx.to(dtype=dtype), False)
self.register_buffer("flidx", self.lidx.to(dtype=dtype), False) # 将ls作为缓冲区flidx进行注册

def _generate_Ylm_coefficients(self, lmax: int):
def _generate_Ylm_coefficients(self, lmax: int): # 计算球谐函数的系数
# see: https://en.wikipedia.org/wiki/Spherical_harmonics#Real_forms

# calculate Am/Bm coefficients
Expand Down Expand Up @@ -98,7 +99,7 @@ def _generate_Ylm_coefficients(self, lmax: int):

return powers, zpow, cAm, cBm, cPi

def forward(self, directions: torch.Tensor) -> torch.Tensor:
def forward(self, directions: torch.Tensor) -> torch.Tensor: # 计算给定输入方向上的真实球谐函数值
"""
Args:
directions: batch of unit-length 3D vectors (Nx3)
Expand Down Expand Up @@ -149,7 +150,7 @@ def forward(self, directions: torch.Tensor) -> torch.Tensor:
return sphharm


def scalar2rsh(x: torch.Tensor, lmax: int) -> torch.Tensor:
def scalar2rsh(x: torch.Tensor, lmax: int) -> torch.Tensor: # 用于将形状为 [N, *] 的标量张量扩展成具有最大角动量 lmax 的球谐函数形状的张量
"""
Expand scalar tensor to spherical harmonics shape with angular momentum up to `lmax`
Expand All @@ -160,6 +161,7 @@ def scalar2rsh(x: torch.Tensor, lmax: int) -> torch.Tensor:
Returns:
zero-padded tensor to shape [N, (lmax+1)^2, *]
"""
# 创建零张量,这个零张量表示了在球谐函数形状中需要填充的部分。
y = torch.cat(
[
x,
Expand All @@ -169,12 +171,12 @@ def scalar2rsh(x: torch.Tensor, lmax: int) -> torch.Tensor:
dtype=x.dtype,
),
],
dim=1,
dim=1, # 在维度1上填充了零的张量 y
)
return y
return y # 经过零填充后的球谐函数形状的张量


class SO3TensorProduct(nn.Module):
class SO3TensorProduct(nn.Module): # 计算 SO3 等变的 Clebsch-Gordan 张量积
"""
SO3-equivariant Clebsch-Gordon tensor product.
Expand Down Expand Up @@ -218,7 +220,7 @@ def forward(
return y


class SO3Convolution(nn.Module):
class SO3Convolution(nn.Module): # 对 SO3 特征进行等变卷积的模型,通过在原子间进行旋转不变的卷积运算,能够保留原子特征的等变性质
"""
SO3-equivariant convolution using Clebsch-Gordon tensor product.
Expand All @@ -236,21 +238,22 @@ def __init__(self, lmax: int, n_atom_basis: int, n_radial: int):
self.n_atom_basis = n_atom_basis
self.n_radial = n_radial

cg = generate_clebsch_gordan_rsh(lmax).to(torch.float32)
cg, idx_in_1, idx_in_2, idx_out = sparsify_clebsch_gordon(cg)
cg = generate_clebsch_gordan_rsh(lmax).to(torch.float32) # 生成 Clebsch-Gordan 系数 cg
cg, idx_in_1, idx_in_2, idx_out = sparsify_clebsch_gordon(cg) # 进行稀疏化处理
# 将索引列表和稀疏化后的 Clebsch-Gordan 系数注册为模型的缓冲区,其中 persistent=False 表示这些缓冲区不会被序列化保存
self.register_buffer("idx_in_1", idx_in_1, persistent=False)
self.register_buffer("idx_in_2", idx_in_2, persistent=False)
self.register_buffer("idx_out", idx_out, persistent=False)
self.register_buffer("clebsch_gordan", cg, persistent=False)

self.filternet = snn.Dense(
self.filternet = snn.Dense( # 名为 "filternet" 的 Dense 层
n_radial, n_atom_basis * (self.lmax + 1), activation=None
)

lidx, _ = sh_indices(lmax)
self.register_buffer("Widx", lidx[self.idx_in_1])
lidx, _ = sh_indices(lmax) # 生成球谐函数的索引 lidx
self.register_buffer("Widx", lidx[self.idx_in_1]) # 按照索引idx_in_1保存到widx缓冲区

def _compute_radial_filter(
def _compute_radial_filter( # 计算径向(具有旋转不变性)滤波器
self, radial_ij: torch.Tensor, cutoff_ij: torch.Tensor
) -> torch.Tensor:
"""
Expand All @@ -266,9 +269,9 @@ def _compute_radial_filter(
Wij = self.filternet(radial_ij) * cutoff_ij
Wij = torch.reshape(Wij, (-1, self.lmax + 1, self.n_atom_basis))
Wij = Wij[:, self.Widx]
return Wij
return Wij # 径向滤波器张量 Wij

def forward(
def forward( # 实现SO3特征卷积操作,在卷积过程中,首先根据原子 j 的索引选择相应的 SO3 特征子集,然后根据径向滤波器、方向向量和 Clebsch-Gordan 系数对子集进行多项式的SO3卷积运算,并将结果累加到对应的原子 i 的特征上,最终得到输出的 SO3 特征
self,
x: torch.Tensor,
radial_ij: torch.Tensor,
Expand Down Expand Up @@ -305,7 +308,7 @@ def forward(
return y


class SO3ParametricGatedNonlinearity(nn.Module):
class SO3ParametricGatedNonlinearity(nn.Module): # 实现 SO3-equivariant 参数化门控非线性操作,可以提取和操作输入特征中不同角动量的信息,并通过门控机制对特征进行调节
"""
SO3-equivariant parametric gated nonlinearity.
Expand All @@ -319,8 +322,8 @@ class SO3ParametricGatedNonlinearity(nn.Module):

def __init__(self, n_in: int, lmax: int):
super().__init__()
self.lmax = lmax
self.n_in = n_in
self.lmax = lmax # 最大角动量
self.n_in = n_in # 输入特征维度
self.lidx, _ = sh_indices(lmax)
self.scaling = nn.Linear(n_in, n_in * (lmax + 1))

Expand All @@ -329,10 +332,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.scaling(s0).reshape(-1, self.lmax + 1, self.n_in)
h = h[:, self.lidx]
y = x * torch.sigmoid(h)
return y
return y # 输出特征张量


class SO3GatedNonlinearity(nn.Module):
class SO3GatedNonlinearity(nn.Module): # 用于实现 SO3-equivariant 的门控非线性操作,SO3-equivariant 操作是指在三维空间中进行旋转操作时能够保持特征表示不变的操作
"""
SO3-equivariant gated nonlinearity.
Expand All @@ -346,10 +349,10 @@ class SO3GatedNonlinearity(nn.Module):

def __init__(self, lmax: int):
super().__init__()
self.lmax = lmax
self.lidx, _ = sh_indices(lmax)
self.lmax = lmax # 最大的角动量量子数
self.lidx, _ = sh_indices(lmax) # 初始化为与角动量 l 相关的索引数组

def forward(self, x: torch.Tensor) -> torch.Tensor:
s0 = x[:, 0, :]
y = x * torch.sigmoid(s0[:, None, :])
def forward(self, x: torch.Tensor) -> torch.Tensor: # 实现了门控非线性操作,通过门控机制控制输入特征的流动
s0 = x[:, 0, :] # 表示 l=0、m=0 的角动量分量的特征
y = x * torch.sigmoid(s0[:, None, :]) # 对 s0 进行激活,将激活后的张量 s0 和输入特征张量 x 逐元素相乘,得到形状相同的张量 y
return y

0 comments on commit 29d5587

Please sign in to comment.