In [1]:
# e3nn
#  |
#  +-- o3        all about rotations and parity
#  +-- nn        modules to make neural networks
#  +-- io        utility classes for Spherical signals and Cartesian tensors

# more info at https://docs.e3nn.org/en/stable/api/e3nn.html

# The main class in e3nn.o3 is TensorProduct
# TensorProduct is a pytorch Module

# o3.TensorProduct                              general class to implement tensor products (optionally with parameters)
# |
# +-- o3.FullTensorProduct                      "usual" tensor product (with no parameters), give the two inputs and it deduce the outputs
# +-- o3.FullyConnectedTensorProduct            define 2 inputs and output and it connects everything together
# +-- o3.ElementwiseTensorProduct

# more info at https://docs.e3nn.org/en/stable/api/o3/o3_tp.html

In [2]:
from functools import lru_cache
from typing import Dict, List

import e3nn.o3 as o3
import torch
import torch.nn.functional as F
from torch import Tensor

#Get CG coefficients
@lru_cache(maxsize=None)
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
    """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
    return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)

device = torch.device('cuda')
J, d_in, d_out = 2, 1, 1
clebsch_gordon_tensor = get_clebsch_gordon(J, d_in, d_out, device)
print(clebsch_gordon_tensor.shape)



torch.Size([3, 3, 5])


In [3]:
@lru_cache(maxsize=None)
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
    all_cb = []
    for d_in in range(max_degree + 1):
        for d_out in range(max_degree + 1):
            K_Js = []
            for J in range(abs(d_in - d_out), d_in + d_out + 1):
                K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
            all_cb.append(K_Js)
    return all_cb

cg_all = get_all_clebsch_gordon(2, device)


print(len(cg_all))  # d_in=0, d_out=0, J=0
print(cg_all[0][0].shape)  # d_in=0, d_out=0, J=0
print(cg_all[1][0].shape)  # d_in=0, d_out=1, J=1
#print(cg_all[0][2].shape)  # d_in=0, d_out=1, J=2
print(cg_all[2][0].shape)  # d_in=0, d_out=2, J=2
print(cg_all[3][0].shape)  # d_in=1, d_out=0, J=1
print(cg_all[5][0].shape)  # d_in=1, d_out=2, J=1
print(cg_all[5][1].shape)  # d_in=1, d_out=2, J=2
print(cg_all[5][2].shape)  # d_in=1, d_out=2, J=3


9
torch.Size([1, 1, 1])
torch.Size([3, 1, 3])
torch.Size([5, 1, 5])
torch.Size([1, 3, 3])
torch.Size([5, 3, 3])
torch.Size([5, 3, 5])
torch.Size([5, 3, 7])


In [4]:
from torch.cuda.nvtx import range as nvtx_range

def degree_to_dim(d):
    if isinstance(d, int):
        d = torch.tensor(d)
    # Number of spherical harmonics coefficients for a given degree
    return 2 * d + 1

def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
    all_degrees = list(range(2 * max_degree + 1))
    with nvtx_range('spherical harmonics'):
        sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
        return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)

In [5]:
# Define input: relative position tensor for one point in 3D
relative_pos = torch.tensor([[0.0, 1.0, 2.0]])  # Shape: (1, 3)

# Compute spherical harmonics up to max_degree = 2
max_degree = 2

# Get spherical harmonics
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
#print(spherical_harmonics)
# Print the shape of the spherical harmonics
for i, sh in enumerate(spherical_harmonics):
    print(f"Spherical harmonics of degree {i}: {sh.shape}")
    print(sh)


print("Spherical harmonics of degree 0: ", spherical_harmonics[:])

Spherical harmonics of degree 0: torch.Size([1, 1])
tensor([[0.2821]])
Spherical harmonics of degree 1: torch.Size([1, 3])
tensor([[0.0000, 0.2185, 0.4370]])
Spherical harmonics of degree 2: torch.Size([1, 5])
tensor([[ 0.0000,  0.0000, -0.1262,  0.4370,  0.4370]])
Spherical harmonics of degree 3: torch.Size([1, 7])
tensor([[ 0.0000,  0.0000,  0.0000, -0.3338,  0.0000,  0.5171,  0.4222]])
Spherical harmonics of degree 4: torch.Size([1, 9])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000, -0.1693, -0.4282,  0.1514,  0.5664,
          0.4005]])
Spherical harmonics of degree 0:  (tensor([[0.2821]]), tensor([[0.0000, 0.2185, 0.4370]]), tensor([[ 0.0000,  0.0000, -0.1262,  0.4370,  0.4370]]), tensor([[ 0.0000,  0.0000,  0.0000, -0.3338,  0.0000,  0.5171,  0.4222]]), tensor([[ 0.0000,  0.0000,  0.0000,  0.0000, -0.1693, -0.4282,  0.1514,  0.5664,
          0.4005]]))


In [6]:
@torch.jit.script
def get_basis_script(max_degree: int,
                     use_pad_trick: bool,
                     spherical_harmonics: List[Tensor],
                     clebsch_gordon: List[List[Tensor]],
                     amp: bool) -> Dict[str, Tensor]:
    """
    Compute pairwise bases matrices for degrees up to max_degree
    :param max_degree:            Maximum input or output degree
    :param use_pad_trick:         Pad some of the odd dimensions for a better use of Tensor Cores
    :param spherical_harmonics:   List of computed spherical harmonics
    :param clebsch_gordon:        List of computed CB-coefficients
    :param amp:                   When true, return bases in FP16 precision
    """
    basis = {}
    idx = 0
    # Double for loop instead of product() because of JIT script
    for d_in in range(max_degree + 1):
        for d_out in range(max_degree + 1):
            key = f'{d_in},{d_out}'
            K_Js = []
            for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
                Q_J = clebsch_gordon[idx][freq_idx]
                K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))

            basis[key] = torch.stack(K_Js, 2)  # Stack on second dim so order is n l f k
            if amp:
                basis[key] = basis[key].half()
            if use_pad_trick:
                basis[key] = F.pad(basis[key], (0, 1))  # Pad the k dimension, that can be sliced later

            idx += 1

    return basis

In [7]:
def get_basis(relative_pos: Tensor,
              max_degree: int = 4,
              compute_gradients: bool = False,
              use_pad_trick: bool = False,
              amp: bool = False) -> Dict[str, Tensor]:
    with nvtx_range('spherical harmonics'):
        spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
    with nvtx_range('CB coefficients'):
        clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)

    with torch.autograd.set_grad_enabled(compute_gradients):
        with nvtx_range('bases'):
            basis = get_basis_script(max_degree=max_degree,
                                     use_pad_trick=use_pad_trick,
                                     spherical_harmonics=spherical_harmonics,
                                     clebsch_gordon=clebsch_gordon,
                                     amp=amp)
            return basis
        
        


In [8]:

basis = get_basis(relative_pos, max_degree=2, compute_gradients=True, use_pad_trick=False, amp=False)

print(basis.keys())
print(basis['0,1'].shape)  # d_in=0, d_out=1, J=1
# explain, [batch, d_in,  num_of_J, d_out]
print(basis['1,1'].shape)  # d_in=1, d_out=1 J=[0, 1, 2]
print(basis['1,2'].shape)  # d_in=1, d_out=2  j=[1, 2, 3]


dict_keys(['0,0', '0,1', '0,2', '1,0', '1,1', '1,2', '2,0', '2,1', '2,2'])
torch.Size([1, 1, 1, 3])
torch.Size([1, 3, 3, 3])
torch.Size([1, 3, 3, 5])


In [9]:
# Next step : how these bases are fused

def degree_to_dim(degree: int) -> int:
    return 2 * degree + 1

@torch.jit.script
def update_basis_with_fused(basis: Dict[str, Tensor],
                            max_degree: int,
                            use_pad_trick: bool,
                            fully_fused: bool) -> Dict[str, Tensor]:
    """ Update the basis dict with partially and optionally fully fused bases """
    num_edges = basis['0,0'].shape[0]
    device = basis['0,0'].device
    dtype = basis['0,0'].dtype
    sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])

    # Fused per output degree
    for d_out in range(max_degree + 1):
        sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
                                  device=device, dtype=dtype)
        acc_d, acc_f = 0, 0
        for d_in in range(max_degree + 1):
            basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
            :degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]

            acc_d += degree_to_dim(d_in)
            acc_f += degree_to_dim(min(d_out, d_in))

        basis[f'out{d_out}_fused'] = basis_fused

    # Fused per input degree
    for d_in in range(max_degree + 1):
        sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
        basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
                                  device=device, dtype=dtype)
        acc_d, acc_f = 0, 0
        for d_out in range(max_degree + 1):
            basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
                = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]

            acc_d += degree_to_dim(d_out)
            acc_f += degree_to_dim(min(d_out, d_in))

        basis[f'in{d_in}_fused'] = basis_fused

    if fully_fused:
        # Fully fused
        # Double sum this way because of JIT script
        sum_freq = sum([
            sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
        ])
        basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)

        acc_d, acc_f = 0, 0
        for d_out in range(max_degree + 1):
            b = basis[f'out{d_out}_fused']
            basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
                                                                                              :degree_to_dim(d_out)]
            acc_f += b.shape[2]
            acc_d += degree_to_dim(d_out)

        basis['fully_fused'] = basis_fused

    del basis['0,0']  # We know that the basis for l = k = 0 is filled with a constant
    return basis

In [10]:
updated_basis = update_basis_with_fused(basis, max_degree=2, use_pad_trick=False, fully_fused=True)
print(updated_basis.keys())
print(updated_basis['out0_fused'].shape)  # d_in=0,1,2 d_out=0, J=0
#[batch, sum(d_to_dim(d_in)), num_of_J, d_out]
print(updated_basis['out1_fused'].shape)  # d_in=0,1,2 d_out=1, J=[1],[0, 1, 2],[1,2, 3]
print(updated_basis['in1_fused'].shape)  # d_in=1 d_out=0,1,2 J=[1],[0, 1, 2],[1,2, 3]
#[batch, sum(d_to_dim(d_in)), num_of_J, sum(d_to_dim(d_out))]
print(updated_basis['fully_fused'].shape)  # d_in=0,1,2 d_out=0,1,2 J=[0],[1],[2],[1],[0,1,2],[1,2,3],[2],[1,2,3],[0,1,2,3,4]
# [batch, sum(d_to_dim(d_in)), num_of_J, sum(d_to_dim(d_out))]

dict_keys(['0,1', '0,2', '1,0', '1,1', '1,2', '2,0', '2,1', '2,2', 'out0_fused', 'out1_fused', 'out2_fused', 'in0_fused', 'in1_fused', 'in2_fused', 'fully_fused'])
torch.Size([1, 9, 3, 1])
torch.Size([1, 9, 7, 3])
torch.Size([1, 3, 7, 9])
torch.Size([1, 9, 19, 9])


# using Basis function to build message passing  
source i (l order, 2l+1 dim), neighbour j (k order)  
we need a W  in [2k+1, 2l+1] to brige k oder to l order  
Che-Gordon * Y spherical harmonic can serve this purpose  
Che-Gordon(l,k,J): [(2l+1)(2k+1),2J+1]-->can be veiwed as [2l+1, 2k+1, 2J+1]  
Y for certain J with fixed l, k: [2J+1]  
We can have J= |l-k|...l+k; for each J, Che-Gordon(l,k,J)*Y(J) ->[2l+1, 2k+1]  (basis shown above)  
We further used invariant info in edge, such as ||xi-xj|| as input into MLP to learn a scale for J  
Final message: learned sacler* Che-Gordon(l,k,J)*Y(J)  

Consider channel_in, channel_out and num of J: the learned scaler forms a [channel_in, channel_out, num of J] matrix   


In [11]:
# This is how the scalers are learned:
import torch.nn as nn
class RadialProfile(nn.Module):
    """
    Radial profile function.
    Outputs weights used to weigh basis matrices in order to get convolution kernels.
    In TFN notation: $R^{l,k}$
    In SE(3)-Transformer notation: $\phi^{l,k}$

    Note:
        In the original papers, this function only depends on relative node distances ||x||.
        Here, we allow this function to also take as input additional invariant edge features.
        This does not break equivariance and adds expressive power to the model.

    Diagram:
        invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
    """

    def __init__(
            self,
            num_freq: int,
            channels_in: int,
            channels_out: int,
            edge_dim: int = 1,
            mid_dim: int = 32,
            use_layer_norm: bool = False
    ):
        """
        :param num_freq:         Number of frequencies (number of J)
        :param channels_in:      Number of input channels
        :param channels_out:     Number of output channels
        :param edge_dim:         Number of invariant edge features (input to the radial function)
        :param mid_dim:          Size of the hidden MLP layers
        :param use_layer_norm:   Apply layer normalization between MLP layers
        """
        super().__init__()
        modules = [
            nn.Linear(edge_dim, mid_dim),
            nn.LayerNorm(mid_dim) if use_layer_norm else None,
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.LayerNorm(mid_dim) if use_layer_norm else None,
            nn.ReLU(),
            nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
        ]

        self.net = nn.Sequential(*[m for m in modules if m is not None])

    def forward(self, features: Tensor) -> Tensor:
        return self.net(features)

In [12]:
# combine all above to have the build block of convolution (message passaging)
from enum import Enum
class ConvSE3FuseLevel(Enum):
    NONE = 0
    OUT = 1
    IN = 2
    FULL = 3

class VersatileConvSE3(nn.Module):
    """
    Spherical convolution block.
    In TFN notation: $C^{l,k}$
    In SE(3)-Transformer notation: $\phi^{l,k}$

    Diagram:
        input features ───> radial profile ───> basis matrices ───> fused basis ───> output features
    """

    def __init__(
            self,
            max_degree: int,
            channels_in: int,
            channels_out: int,
            edge_dim: int = 1,
            mid_dim: int = 32,
            use_pad_trick: bool = False,
            fully_fused: bool = True,
            amp: bool = False
    ):
        """
        :param max_degree:       Maximum input or output degree
        :param channels_in:      Number of input channels
        :param channels_out:     Number of output channels
        :param edge_dim:         Number of invariant edge features (input to the radial function)
        :param mid_dim:          Size of the hidden MLP layers
        :param use_pad_trick:    Pad some of the odd dimensions for a better use of Tensor Cores
        :param fully_fused:      Use fully fused bases
        :param amp:              When true, return bases in FP16 precision
        """
        super().__init__()
        self.max_degree = max_degree
        self.use_pad_trick = use_pad_trick
        self.fully_fused = fully_fused

        # Radial profile function
        self.radial_profile = RadialProfile(
            num_freq=max_degree + 1,
            channels_in=channels_in,
            channels_out=channels_out,
            edge_dim=edge_dim,
            mid_dim=mid_dim,
            use_layer_norm=True
        )

    def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor ) -> Tensor:
        """
        :param features:         Input features (num_edges, channels_in, in_dim(2k+1))
        :param invariant_edge_feats: Invariant edge features (batch_size, num_edges, edge_dim)
        :param basis:          computed Basis matrices in differnt format (infused, outfused,fully_fused)
        [batch_size, sum(d_to_dim(d_in)), num_of_J, sum(d_to_dim(d_out))]
        :return:
            Output features (batch_size, num_edges, channels_out)
        """
        with nvtx_range(f'VersatileConvSE3'):
            num_edges = features.shape[0]
            in_dim = features.shape[2]
            with nvtx_range(f'RadialProfile'):
                radial_weights = self.radial_func(invariant_edge_feats) \
                    .view(-1, self.channels_out, self.channels_in * self.freq_sum)

            if basis is not None:
                # This block performs the einsum n i l, n o i f, n l f k -> n o k
                # feature: n i l: num_edges, in_channels, in_dim (2l+1) # l is used for in order
                # radial_wights: n o i f: num_edges, out_channels, in_channels, num_of_J
                # basis: n l f k: num_edges, in_dims, num_of_J, out_dim (2k+1) # k is used for out order
                # output: n o k: num_edges, out_channels, out_dim (2k+1) # k is used for out order
                out_dim = basis.shape[-1]
                if self.fuse_level != ConvSE3FuseLevel.FULL:
                    out_dim += out_dim % 2 - 1  # Account for padded basis
                basis_view = basis.view(num_edges, in_dim, -1) 
                tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
                return (radial_weights @ tmp)[:, :, :out_dim]
            else:
                # k = l = 0 non-fused case
                return radial_weights @ features
       

In [13]:
# SE3_Transformer attention caluclation
# omitted due to the complexity of the code, but the general idea is to use TFN to generate K, V AND use
# a linear layer to generate Q.


In [14]:
# example usage of SE3_transformer
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber

Fiber_in = Fiber({0:32})
Fiber_hid = Fiber.create(num_degrees=2, num_channels=32)
Fiber_out = Fiber.create(num_degrees=2, num_channels=32)
test_se3_transformer = SE3Transformer(
    num_layers=2,
    fiber_in=Fiber_in,
    fiber_hidden=Fiber_hid,
    fiber_out =Fiber_out,
    num_heads= 4,
    channels_div =2,
    fiber_edge= Fiber({}),
    norm= True,
    use_layer_norm=True,
    tensor_cores=False,
    low_memory= False
)



In [15]:
# example usage of SE3_transformer for QM9 DATASET
from se3_transformer.data_loading.qm9 import *

qm9_dataset = QM9DataModule(data_dir='qm9',task='homo', batch_size=32, num_workers=4,num_degrees=4, precompute_bases=True)

# exploare qm9 dataset: full data is randomaly divided into train test val
print(len(qm9_dataset.ds_test), len(qm9_dataset.ds_train), len(qm9_dataset.ds_val))
# data info: graph (containing node and edge features, target, basis for Se3_transformer)
graph,y,basis=qm9_dataset.ds_train.dataset[0]
# node information : 'pos' and 'attr'
print(graph.ndata)
#edge information: 'edge_attr'
print(graph.edges()) # edge index
print(graph.edata) # edge feature
print(basis.keys()) # unfused basis
print(basis['0,0'].shape) # num_edge, 2l+1, num_of_J, 2k+1+1 
print(basis['1,2'].shape)
print(basis['0,0'][0,0,0,:])
print(basis['1,2'][0,0,0,:])


# Load QM9 datase


Done loading data from cached files.


Precomputing QM9 bases: 100%|██████████| 4089/4089 [00:43<00:00, 93.28it/s] 

13083 100000 17748
{'pos': tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]]), 'attr': tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
(tensor([0, 0, 0, 0, 1, 2, 3, 4]), tensor([1, 2, 3, 4, 0, 0, 0, 0]))
{'edge_attr': tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])}
dict_keys(['0,0', '0,1', '0,2', '0,3', '1,0', '1,1', '1,2', '1,3', '2,0', '2,1', '2,2', '2,3', '3,0', '3,1', '3,2', '3,3'])
torch.Size([8, 1, 1, 2])
torc




In [16]:
# transform QM9 data so that it's suitable for SE3_Transformer input
# it's done by using  _collate function in the QM9DataModule 
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
    x = qm9_graph.ndata['pos']
    src, dst = qm9_graph.edges()
    rel_pos = x[dst] - x[src] # [edge, 3]
    return rel_pos

# used as transform function in dataloader
def _collate(samples):
        graphs, y, *bases = map(list, zip(*samples))
        batched_graph = dgl.batch(graphs)
        edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]} # edge feature is used as l0
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
        node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]} # node featere is used as l0 
        #targets = (torch.cat(y) - self.targets_mean) / self.targets_std  # normalized targets value
        targets = torch.cat(y)  # targets value
        if bases:
            # collate bases
            all_bases = {
                key: torch.cat([b[key] for b in bases[0]], dim=0)
                for key in bases[0][0].keys()
            }  # concatenate basis in the batch dimension

            return batched_graph, node_feats, edge_feats, all_bases, targets
        else:
            return batched_graph, node_feats, edge_feats, targets
        

train_dataloader = _collate([qm9_dataset.ds_train.dataset[0], qm9_dataset.ds_train.dataset[1]])
print(train_dataloader[0]) # batched graph
print(train_dataloader[1]['0'].shape) # node features
#print(test_collate[1]['1'].shape) # node features no l1 order feature
print(train_dataloader[2]['0'].shape) # edge features also only l0 order
print(train_dataloader[3]['0,0'].shape) # basis
print(train_dataloader[4].shape) # basis

Graph(num_nodes=9, num_edges=14,
      ndata_schemes={'pos': Scheme(shape=(3,), dtype=torch.float32), 'attr': Scheme(shape=(11,), dtype=torch.float32)}
      edata_schemes={'edge_attr': Scheme(shape=(4,), dtype=torch.float32), 'rel_pos': Scheme(shape=(3,), dtype=torch.float32)})
torch.Size([9, 6, 1])
torch.Size([14, 4, 1])
torch.Size([14, 1, 1, 2])
torch.Size([2])


In [17]:
train_dataloader= DataLoader(
    qm9_dataset.ds_train,
    batch_size=16,
    collate_fn=qm9_dataset._collate,
    shuffle=True,
    num_workers=4,
)

for batched_data in train_dataloader:

    print(batched_data[0]) # batched graph
    print(batched_data[1]['0'].shape) # node features
    #print(test_collate[1]['1'].shape) # node features no l1 order feature
    print(batched_data[2]['0'].shape) # edge features also only l0 order
    print(batched_data[3]['0,0'].shape) # basis
    print(batched_data[4].shape) # basis

    break

Graph(num_nodes=290, num_edges=602,
      ndata_schemes={'pos': Scheme(shape=(3,), dtype=torch.float32), 'attr': Scheme(shape=(11,), dtype=torch.float32)}
      edata_schemes={'edge_attr': Scheme(shape=(4,), dtype=torch.float32), 'rel_pos': Scheme(shape=(3,), dtype=torch.float32)})
torch.Size([290, 6, 1])
torch.Size([602, 4, 1])
torch.Size([602, 1, 1, 2])
torch.Size([16])


In [18]:
# We will use qm9 data to train the test_se3_transformer model (we will use the pooled version)
# write one iteration of training procedure
import torch
from torch.utils.data import DataLoader
from se3_transformer.model import SE3TransformerPooled
from se3_transformer.model.fiber import Fiber

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Fiber_in = Fiber({0:6}) # the in_fiber needs matching with the input node features
Fiber_out = Fiber({0:32})#num_degrees=0, num_channels=32) # the out_fiber needs matching with the output node features
num_edge_features = 4 # the edge features are used as l0 order
test_se3_transformer = SE3TransformerPooled(
    num_layers=2,
    fiber_in=Fiber_in,
    fiber_out =Fiber_out,
    num_degrees=2,
    num_channels=32,
    num_heads= 4,
    channels_div =2,
    fiber_edge = Fiber({0: num_edge_features}),
    output_dim= 1,
    norm= True,
    use_layer_norm=True,
    tensor_cores=True,
    low_memory= False,  
    pooling='max'    
)

model = test_se3_transformer.to(device)

def train(model:torch.nn.Module,
          dataloader: DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module
        ):
    train_loss, train_acc = 0.0, 0.0
    model.train()
    for data in dataloader:
        # Unpack the tuple
        graph, node_feats, edge_feats, basis,target = data

        # Move each element to the device
        graph = graph.to(device)  # DGLGraph or similar object
        node_feats = {k: v.to(device) for k, v in node_feats.items()}  # Dictionary of tensors
        edge_feats = {k: v.to(device) for k, v in edge_feats.items()}  # Dictionary of tensors
        basis = {k: v.to(device) for k, v in basis.items()}  # Optional dictionary of tensors
        target = target.to(device)

        # Forward pass
        output = model(graph, node_feats, edge_feats, basis)
        #forwad pass syntaxt(graph: DGLGraph, node_feats: Dict[str, Tensor],
        #           edge_feats: Optional[Dict[str, Tensor]] = None,
        #            basis: Optional[Dict[str, Tensor]] = None)
        loss=loss_fn(output, target)
        train_loss +=loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #train_acc = accuracy_fn(output, data.y)

    train_loss = train_loss/len(dataloader)
    #train_acc = train_acc/len(dataloader)

    return train_loss
    

def accuracy_fn(output, target):
    # pred = output.argmax(dim=1, keepdim=True)
    # correct = pred.eq(target.view_as(pred)).sum().item()
    pass 
    #return correct / len(target)
    

In [19]:
# training 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

for epoch in range(10):
    train_loss = train(model, train_dataloader, optimizer, loss_fn)
    print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}")
    #print(f"Epoch {epoch+1}, Accuracy: {train_acc:.4f}")

  assert input.numel() == input.storage().size(), "Cannot convert view " \


Epoch 1, Loss: 0.1906
Epoch 2, Loss: 0.1135
Epoch 3, Loss: 0.0909
Epoch 4, Loss: 0.0764
Epoch 5, Loss: 0.0681
Epoch 6, Loss: 0.0615
Epoch 7, Loss: 0.0569
Epoch 8, Loss: 0.0530
Epoch 9, Loss: 0.0498
Epoch 10, Loss: 0.0472


In [20]:
# print the number of parameters in the model
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)  
print(f"Number of parameters in the model: {num_params}")


Number of parameters in the model: 284641


In [21]:
# save the model
torch.save(model.state_dict(), 'se3_transformer_model.pth')

In [23]:
# test the model
def test(model:torch.nn.Module,
          dataloader: DataLoader,
          loss_fn: torch.nn.Module
        ):
    test_loss, test_acc = 0.0, 0.0
    model.eval()
    with torch.no_grad():
        for data in dataloader:
            graph, node_feats, edge_feats, basis,target = data

            # Move each element to the device
            graph = graph.to(device)  # DGLGraph or similar object
            node_feats = {k: v.to(device) for k, v in node_feats.items()}  # Dictionary of tensors
            edge_feats = {k: v.to(device) for k, v in edge_feats.items()}  # Dictionary of tensors
            basis = {k: v.to(device) for k, v in basis.items()}  # Optional dictionary of tensors
            target = target.to(device)

            # Forward pass
            output = model(graph, node_feats, edge_feats, basis)
            #forwad pass syntaxt(graph: DGLGraph, node_feats: Dict[str, Tensor],
            #           edge_feats: Optional[Dict[str, Tensor]] = None,
            #            basis: Optional[Dict[str, Tensor]] = None)
            loss=loss_fn(output, target)
            test_loss +=loss
        
        test_loss = test_loss/len(dataloader)

    return train_loss

# test the model

test_dataloader = DataLoader(
    qm9_dataset.ds_test,
    batch_size=16,
    collate_fn=qm9_dataset._collate,
    shuffle=True,
    num_workers=2,
)

test_loss = test(model, test_dataloader, loss_fn)
print(f"Test Loss: {test_loss:.4f}")
    

  assert input.numel() == input.storage().size(), "Cannot convert view " \


Test Loss: 0.0472
