In [50]:
import math
from typing import Dict, Optional

import torch

from e3nn import o3
from e3nn.math import soft_one_hot_linspace
from e3nn.nn import ExtractIr, FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct
from e3nn.util.jit import compile_mode
from e3nn.math import soft_unit_step
import torch
from e3nn import nn
from torch_scatter import scatter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch_cluster import radius_graph


In [57]:
irreps_input = o3.Irreps("15x0e")
irreps_query = o3.Irreps("15x0e")
irreps_key = o3.Irreps("15x0e")
irreps_output = o3.Irreps("128x0e")

In [87]:
import torch
from torch.utils.data import Dataset, DataLoader

class RandomGraphDataset(Dataset):
    def __init__(self, num_graphs, num_nodes_per_graph=20):
        self.num_graphs = num_graphs
        self.num_nodes_per_graph = num_nodes_per_graph

    def __len__(self):
        return self.num_graphs

    def __getitem__(self, idx):
        num_nodes = self.num_nodes_per_graph
        pos = torch.randn(num_nodes, 3)
        x = irreps_input.randn(num_nodes, -1)  # Replace with your actual irreps_input initialization
        
        # You can return any additional data as needed
        return {'pos': pos, 'x': x}

# Parameters
num_graphs = 100
num_nodes_per_graph = 20
batch_size = 10  # Adjust batch size as needed

# Create dataset and dataloader
dataset = RandomGraphDataset(num_graphs, num_nodes_per_graph)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Example of iterating through the dataloader
for batch in dataloader:
    # Access batch data
    batch_pos = batch['pos']  # Tensor of shape (batch_size, num_nodes_per_graph, 3)
    batch_x = batch['x']      # Tensor of shape (batch_size, num_nodes_per_graph, -1)
    
    # Perform operations with batch data as needed
    print("Batch position tensor shape:", batch_pos.shape)
    print("Batch feature tensor shape:", batch_x.shape)


Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 20, 3])
Batch feature tensor shape: torch.Size([10, 20, 15])
Batch position tensor shape: torch.Size([10, 2

In [58]:
num_nodes = 20

pos = torch.randn(num_nodes, 3)
f = irreps_input.randn(num_nodes, -1)

# create graph
max_radius = 1.3
edge_src, edge_dst = radius_graph(pos, max_radius)
edge_vec = pos[edge_src] - pos[edge_dst]
edge_length = edge_vec.norm(dim=1)

In [59]:
pos.shape

torch.Size([20, 3])

In [60]:
h_q = o3.Linear(irreps_input, irreps_query)



In [61]:
number_of_basis = 10
edge_length_embedded = soft_one_hot_linspace(
    edge_length,
    start=0.0,
    end=max_radius,
    number=number_of_basis,
    basis='smooth_finite',
    cutoff=True  # goes (smoothly) to zero at `start` and `end`
)
edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5)

In [62]:
edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius))

In [63]:
irreps_sh = o3.Irreps.spherical_harmonics(3)
edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component')

In [64]:
tp_k = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_key, shared_weights=False)
fc_k = nn.FullyConnectedNet([number_of_basis, 16, tp_k.weight_numel], act=torch.nn.functional.silu)

tp_v = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)
fc_v = nn.FullyConnectedNet([number_of_basis, 16, tp_v.weight_numel], act=torch.nn.functional.silu)

dot = o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "0e")



In [65]:
# compute the queries (per node), keys (per edge) and values (per edge)
q = h_q(f)
k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded))
v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded))

# compute the softmax (per edge)
exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp()  # compute the numerator
z = scatter(exp, edge_dst, dim=0, dim_size=len(f))  # compute the denominator (per nodes)
z[z == 0] = 1  # to avoid 0/0 when all the neighbors are exactly at the cutoff
alpha = exp / z[edge_dst]

# compute the outputs (per node)
f_out = scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f))

In [79]:
dir(tp_k.weight)

['H',
 'T',
 '__abs__',
 '__add__',
 '__and__',
 '__array__',
 '__array_priority__',
 '__array_wrap__',
 '__bool__',
 '__class__',
 '__complex__',
 '__contains__',
 '__deepcopy__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__div__',
 '__dlpack__',
 '__dlpack_device__',
 '__doc__',
 '__eq__',
 '__float__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__iadd__',
 '__iand__',
 '__idiv__',
 '__ifloordiv__',
 '__ilshift__',
 '__imod__',
 '__imul__',
 '__index__',
 '__init__',
 '__init_subclass__',
 '__int__',
 '__invert__',
 '__ior__',
 '__ipow__',
 '__irshift__',
 '__isub__',
 '__iter__',
 '__itruediv__',
 '__ixor__',
 '__le__',
 '__len__',
 '__long__',
 '__lshift__',
 '__lt__',
 '__matmul__',
 '__mod__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__nonzero__',
 '__or__',
 '__pos__',
 '__pow__',
 '__radd__',
 '__rand__',
 '__rdiv__',
 '__reduce__',
 '__reduce_ex__',
 '__repr_

In [85]:
type(tp_k.weight)

torch.Tensor

param 不是 nn.Parameter 对象


In [66]:
f_out.shape

torch.Size([20, 128])

In [67]:
f

tensor([[ 1.0108, -0.4120, -1.2691,  0.5434, -0.3108, -0.0259,  1.2164, -0.0178,
          0.7902, -1.5214, -0.4848,  0.4506, -0.0061, -1.0693, -1.6660],
        [ 1.9457, -0.9373, -0.6052, -0.0843,  0.5191,  1.3375, -0.7168, -1.1653,
          0.7648, -1.1443,  0.5113, -0.7004,  0.9156, -0.0039, -1.2587],
        [ 0.7415,  0.8114, -0.4896,  0.7137, -0.5670,  0.6979, -0.3470,  0.9645,
         -2.4666,  0.9865,  1.0848,  0.4396, -0.1511,  0.4272,  0.8656],
        [ 1.0039, -0.8731,  0.1023,  0.8909,  2.0612, -0.3998,  0.5176,  0.5354,
         -1.7104,  1.0056, -0.6347,  0.0039,  1.2202,  0.3048,  1.2008],
        [ 1.7200, -1.4296,  0.7554,  1.1488,  0.2106,  2.5101,  0.1621,  0.0049,
          1.0198,  0.5128,  0.0367,  1.4464,  0.5385,  0.6567,  1.2070],
        [-0.9541,  1.0299,  0.6233,  1.7747, -0.5644,  0.3849, -1.1908,  0.7297,
          0.3244,  0.7336,  1.0145, -0.7193, -0.7312, -1.5739,  0.2230],
        [-1.8391,  0.6928,  0.0222,  1.5460,  1.8051, -1.2119, -0.9513,  0.5

In [68]:
f_out

tensor([[ 0.8893,  0.4559, -0.2268,  ..., -0.7073,  1.0102, -0.7480],
        [ 1.0796,  0.9648, -0.0740,  ...,  0.0138, -0.4639,  0.8033],
        [ 0.1390,  1.8978,  1.2474,  ...,  0.4028, -2.0364, -0.6878],
        ...,
        [-0.1836, -1.1423, -0.2519,  ...,  1.7280, -0.4552,  0.2658],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2556, -0.8711, -0.5214,  ...,  0.1657,  0.2269,  0.2245]],
       grad_fn=<ScatterAddBackward0>)

In [32]:
from torch_geometric.data import Data

In [33]:
dataset=Data(pos=pos,x=f)

In [35]:
class ETG(nn.Module):
    def __init__(self,
                 irreps_in: o3.Irreps,
                 irreps_query: o3.Irreps,
                 irreps_key: o3.Irreps,
                 irreps_output: o3.Irreps,
                 ):
        super(ETG, self).__init__()
        self.irreps_in = irreps_in
        self.irreps_query = irreps_query
        self.irreps_key = irreps_key
        self.irreps_output = irreps_output
        
    

In [36]:
def transformer(
        irreps_input: o3.Irreps,
        irreps_query: o3.Irreps,
        irreps_key: o3.Irreps,
        irreps_output: o3.Irreps,
        dataset,
        
):
    edge_src, edge_dst = radius_graph(dataset.pos, max_radius)
    edge_vec = dataset.pos[edge_src] - dataset.pos[edge_dst]
    edge_length = edge_vec.norm(dim=1)
    number_of_basis = 10
    edge_length_embedded = soft_one_hot_linspace(
        edge_length,
        start=0.0,
        end=max_radius,
        number=number_of_basis,
        basis='smooth_finite',
        cutoff=True
    )
    edge_length_embedded = edge_length_embedded.mul(number_of_basis**0.5)
    edge_weight_cutoff = soft_unit_step(10 * (1 - edge_length / max_radius))
    irreps_sh = o3.Irreps.spherical_harmonics(3)
    dot = o3.FullyConnectedTensorProduct(irreps_query, irreps_key, "0e")
    edge_sh = o3.spherical_harmonics(irreps_sh, edge_vec, True, normalization='component')

    h_q = o3.Linear(irreps_input, irreps_query)
    tp_k = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_key, shared_weights=False)
    fc_k = nn.FullyConnectedNet([number_of_basis, 16, tp_k.weight_numel], act=torch.nn.functional.silu)

    tp_v = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)
    fc_v = nn.FullyConnectedNet([number_of_basis, 16, tp_v.weight_numel], act=torch.nn.functional.silu)
    q = h_q(f)
    k = tp_k(f[edge_src], edge_sh, fc_k(edge_length_embedded))
    v = tp_v(f[edge_src], edge_sh, fc_v(edge_length_embedded))

    exp = edge_weight_cutoff[:, None] * dot(q[edge_dst], k).exp()
    z = scatter(exp, edge_dst, dim=0, dim_size=len(f))
    z[z == 0] = 1
    alpha = exp / z[edge_dst]

    return scatter(alpha.relu().sqrt() * v, edge_dst, dim=0, dim_size=len(f))


In [69]:
transformer(
    irreps_input,
    irreps_query,
    irreps_key,
    irreps_output,
    dataset
)



tensor([[-0.9321,  0.7098,  1.2883,  ...,  0.9811,  0.1229,  0.6901],
        [ 1.6012, -0.7377,  0.2171,  ..., -0.5989, -1.6058,  2.6011],
        [-0.7274, -0.2307, -1.1741,  ..., -0.0054,  0.5369, -1.0118],
        ...,
        [ 0.9583,  0.1718,  0.3004,  ...,  0.5392, -0.6205, -0.7205],
        [-0.2768, -1.1128,  0.5728,  ...,  1.7265,  0.0795,  0.7104],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       grad_fn=<ScatterAddBackward0>)

In [100]:
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader
import torch_geometric.transforms as T
import torch_geometric.utils as utils
from e3nn import o3


def random_graph_dataset(num_graphs, num_nodes_per_graph=20):
    irreps_input = o3.Irreps("15x0e")
    dataset=[]
    for i in range(num_graphs):
        pos = torch.randn(num_nodes, 3)
        x = irreps_input.randn(num_nodes, -1)
        data = Data(pos=pos, x=x)
        dataset.append(data)
    return dataset



In [105]:
random_graph_dataset(100)

[Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=[20, 15], pos=[20, 3]),
 Data(x=

torch_geometric.data.data.Data

In [106]:
import torch
from torch_geometric.data import Data, Dataset, DataLoader
from torch.utils.data import SubsetRandomSampler
from e3nn import o3

def random_graph_dataset(num_graphs, num_nodes_per_graph=20):
    irreps_input = o3.Irreps("15x0e")
    dataset = []
    for i in range(num_graphs):
        pos = torch.randn(num_nodes_per_graph, 3)
        x = irreps_input.randn(num_nodes_per_graph, -1)
        data = Data(pos=pos, x=x)
        dataset.append(data)
    return dataset

# Parameters
num_graphs = 100
num_nodes_per_graph = 20
batch_size = 10
test_split = 0.2  # 20% of the dataset will be used for testing

# Create dataset
dataset = random_graph_dataset(num_graphs, num_nodes_per_graph)

# Split dataset into train and test sets using SubsetRandomSampler
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(test_split * dataset_size)
train_indices, test_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

# Create DataLoader for train and test sets
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

# Example of accessing data from DataLoader
for batch in train_loader:
    # Batch is a list of Data objects
    for data in batch:
        print(data)
    break  # Print only the first batch for demonstration

# Similarly, you can iterate over `test_loader` for testing data


('x', tensor([[-0.0970, -0.3019, -0.3401,  ...,  1.5613,  0.2192, -0.6163],
        [ 0.9388,  0.1690,  1.6719,  ...,  0.0951,  1.0620,  1.2787],
        [ 1.5363, -0.9804,  1.1722,  ...,  1.0141,  0.5436, -0.8653],
        ...,
        [-1.0270,  1.0216, -1.9001,  ...,  0.0991, -0.0368,  0.7191],
        [-0.5482,  0.6090, -0.8456,  ...,  1.0077,  0.7487, -0.5869],
        [-0.1358, -1.2702,  1.1213,  ...,  1.4287,  1.8845, -0.6317]]))
('pos', tensor([[-1.6496,  0.6762,  1.4739],
        [ 0.3195,  2.4148, -1.0086],
        [ 0.0045, -2.8196,  1.5646],
        [-1.5862, -0.8392,  0.9361],
        [-2.7861,  1.9581, -0.5676],
        [ 1.0528, -1.0708, -0.4172],
        [ 0.0249, -1.0879, -1.6348],
        [ 1.6456,  0.5276,  0.6098],
        [-0.4563,  0.4636, -0.1804],
        [-0.2991,  0.2490,  0.5873],
        [ 1.9016,  0.3927, -0.7262],
        [-0.5603,  1.2478, -0.3520],
        [-0.8684,  0.6173, -1.4482],
        [ 0.7661, -0.3866,  0.5676],
        [ 1.1414,  0.3377, -2.288

