In [14]:
import time
import numpy as np  

import torch
import torch.nn as nn

from escnn import gspaces
import escnn.nn as enn

from molnet.escnn_models import InnerBatchNorm3D

device = "mps"

In [15]:
x = torch.randn(4, 8, 64, 64, 10, device=device)

In [16]:
torch_model = nn.Sequential(
    nn.Conv3d(8, 8, 3, padding=1),
    nn.BatchNorm3d(8)
)
torch_model.to(device)

Sequential(
  (0): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [20]:
r2 = gspaces.rot2dOnR3(4)
in_type = enn.FieldType(r2, 2*[r2.regular_repr])
out_type = enn.FieldType(r2, 8*[r2.regular_repr])

escnn_conv = enn.R3Conv(
    in_type,
    out_type=out_type,
    kernel_size=3
)
escnn_bnorm = InnerBatchNorm3D(out_type)

escnn_model = enn.SequentialModule(
    escnn_conv,
    escnn_bnorm
)

escnn_model.to(device)

SequentialModule(
  (0): R3Conv([C4_on_R3[(False, False, 4)]: {regular (x2)}(8)], [C4_on_R3[(False, False, 4)]: {regular (x8)}(32)], kernel_size=3, stride=1)
  (1): InnerBatchNorm3D([C4_on_R3[(False, False, 4)]: {regular (x8)}(32)], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [21]:
torch_times = []
for i in range(50):
    t0 = time.perf_counter()
    y = torch_model(x)
    t1 = time.perf_counter()
    if i < 5:
        continue
    torch_times.append(t1 - t0)


print(f"torch:")
print(f"  mean: {np.mean(torch_times):.6f}")
print(f"  std: {np.std(torch_times):.6f}")
print(f"  min: {np.min(torch_times):.6f}")
print(f"  max: {np.max(torch_times):.6f}")


torch:
  mean: 0.000427
  std: 0.000190
  min: 0.000288
  max: 0.001267


In [22]:
escnn_times = []
for i in range(50):
    t0 = time.perf_counter()
    
    y = escnn_model(in_type(x))
    
    t1 = time.perf_counter()
    if i < 5:
        continue
    escnn_times.append(t1 - t0)

print(f"escnn:")
print(f"  mean: {np.mean(escnn_times):.6f}")
print(f"  std: {np.std(escnn_times):.6f}")
print(f"  min: {np.min(escnn_times):.6f}")
print(f"  max: {np.max(escnn_times):.6f}")

escnn:
  mean: 0.007273
  std: 0.001097
  min: 0.006110
  max: 0.009143
