In [48]:
import time
import numpy as np  

import torch
import torch.nn as nn

from escnn import gspaces
import escnn.nn as enn

from molnet.escnn_models import InnerBatchNorm3D

In [57]:
x = torch.randn(4, 8, 64, 64, 10, device='cuda')

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [53]:
torch_conv = nn.Conv3d(8, 8, 3, padding=1)
torch_bnorm = nn.BatchNorm3d(8)

In [54]:

r2 = gspaces.rot2dOnR3(4)
in_type = enn.FieldType(r2, 2*[r2.regular_repr])
out_type = enn.FieldType(r2, 2*[r2.regular_repr])

escnn_conv = enn.R3Conv(
    in_type,
    out_type=out_type,
    kernel_size=3
)
escnn_bnorm = InnerBatchNorm3D(out_type)

In [55]:
torch_times = []
for i in range(50):
    t0 = time.perf_counter()
    y = torch_conv(x)
    y = torch_bnorm(y)
    t1 = time.perf_counter()
    if i < 5:
        continue
    torch_times.append(t1 - t0)


print(f"torch:")
print(f"  mean: {np.mean(torch_times):.6f}")
print(f"  std: {np.std(torch_times):.6f}")
print(f"  min: {np.min(torch_times):.6f}")
print(f"  max: {np.max(torch_times):.6f}")


torch:
  mean: 0.003171
  std: 0.000883
  min: 0.002626
  max: 0.008321


In [56]:
escnn_times = []
for i in range(50):
    t0 = time.perf_counter()
    
    y = escnn_conv(in_type(x))
    y = escnn_bnorm(y)
    
    t1 = time.perf_counter()
    if i < 5:
        continue
    escnn_times.append(t1 - t0)

print(f"escnn:")
print(f"  mean: {np.mean(escnn_times):.6f}")
print(f"  std: {np.std(escnn_times):.6f}")
print(f"  min: {np.min(escnn_times):.6f}")
print(f"  max: {np.max(escnn_times):.6f}")

escnn:
  mean: 0.004503
  std: 0.000767
  min: 0.003741
  max: 0.007124
