In [None]:
class HGNN(nn.Module):


In [4]:
from torch_geometric.nn.conv import MessagePassing
from typing import Union, Tuple
import torch

class CGConv(MessagePassing):
    def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0,
                 aggr: str = 'add', normalization: str = None,
                 bias: bool = True, if_exp: bool = False, **kwargs):
        super(CGConv, self).__init__(aggr=aggr, flow="source_to_target", **kwargs)
        self.channels = channels
        self.dim = dim
        self.normalization = normalization
        self.if_exp = if_exp

        if isinstance(channels, int):
            channels = (channels, channels)

        self.lin_f = nn.Linear(sum(channels) + dim, channels[1], bias=bias)
        self.lin_s = nn.Linear(sum(channels) + dim, channels[1], bias=bias)
        if self.normalization == 'BatchNorm':
            self.bn = nn.BatchNorm1d(channels[1], track_running_stats=True)
        elif self.normalization == 'LayerNorm':
            self.ln = LayerNorm(channels[1])
        elif self.normalization == 'PairNorm':
            self.pn = PairNorm(channels[1])
        elif self.normalization == 'InstanceNorm':
            self.instance_norm = InstanceNorm(channels[1])
        elif self.normalization == 'GraphNorm':
            self.gn = GraphNorm(channels[1])
        elif self.normalization == 'DiffGroupNorm':
            self.group_norm = DiffGroupNorm(channels[1], 128)
        elif self.normalization is None:
            pass
        else:
            raise ValueError('Unknown normalization function: {}'.format(normalization))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_f.reset_parameters()
        self.lin_s.reset_parameters()
        if self.normalization == 'BatchNorm':
            self.bn.reset_parameters()

    def forward(self, x: Union[torch.Tensor, PairTensor], edge_index: Adj,
                edge_attr: OptTensor, batch, distance, size: Size = None) -> torch.Tensor:
        """"""
        if isinstance(x, torch.Tensor):
            x: PairTensor = (x, x)

        # propagate_type: (x: PairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, distance=distance, size=size)
        if self.normalization == 'BatchNorm':
            out = self.bn(out)
        elif self.normalization == 'LayerNorm':
            out = self.ln(out, batch)
        elif self.normalization == 'PairNorm':
            out = self.pn(out, batch)
        elif self.normalization == 'InstanceNorm':
            out = self.instance_norm(out, batch)
        elif self.normalization == 'GraphNorm':
            out = self.gn(out, batch)
        elif self.normalization == 'DiffGroupNorm':
            out = self.group_norm(out)
        out += x[1]
        return out

    def message(self, x_i, x_j, edge_attr: OptTensor, distance) -> torch.Tensor:
        z = torch.cat([x_i, x_j, edge_attr], dim=-1)
        out = self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))
        if self.if_exp:
            sigma = 3
            n = 2
            out = out * torch.exp(-distance ** n / sigma ** n / 2).view(-1, 1)
        return out

    def __repr__(self):
        return '{}({}, dim={})'.format(self.__class__.__name__, self.channels, self.dim)


NameError: name 'PairTensor' is not defined

In [3]:
class MPLayer(nn.Module):
    def __init__(self, in_atom_fea_len, in_edge_fea_len, out_edge_fea_len, if_exp, if_edge_update, normalization,
                 atom_update_net, gauss_stop, output_layer=False):
        super(MPLayer, self).__init__()
        if atom_update_net == 'CGConv':
            self.cgconv = CGConv(channels=in_atom_fea_len,
                                 dim=in_edge_fea_len,
                                 aggr='add',
                                 normalization=normalization,
                                 if_exp=if_exp)

        self.if_edge_update = if_edge_update
        self.atom_update_net = atom_update_net
        if if_edge_update:
            if output_layer:
                self.e_lin = nn.Sequential(nn.Linear(in_edge_fea_len + in_atom_fea_len * 2, 128),
                                           nn.SiLU(),
                                           nn.Linear(128, out_edge_fea_len),
                                           )
            else:
                self.e_lin = nn.Sequential(nn.Linear(in_edge_fea_len + in_atom_fea_len * 2, 128),
                                           nn.SiLU(),
                                           nn.Linear(128, out_edge_fea_len),
                                           nn.SiLU(),
                                           )

    def forward(self, atom_fea, edge_idx, edge_fea, batch, distance, edge_vec):
        if self.atom_update_net == 'PAINN':
            atom_fea = self.cgconv(atom_fea, edge_idx, edge_fea, batch, edge_vec)
            atom_fea_s = atom_fea.node_fea_s
        else:
            atom_fea = self.cgconv(atom_fea, edge_idx, edge_fea, batch, distance)
            atom_fea_s = atom_fea
        if self.if_edge_update:
            row, col = edge_idx
            edge_fea = self.e_lin(torch.cat([atom_fea_s[row], atom_fea_s[col], edge_fea], dim=-1))
            return atom_fea, edge_fea
        else:
            return atom_fea

NameError: name 'nn' is not defined