In [227]:
import time
import numpy as np  

import torch
import torch.nn as nn

from escnn import gspaces
import escnn.nn as enn

from molnet import escnn_models

device = "cpu"

In [233]:
x = torch.randn(4, 8, 128, 128, 10, device=device)

In [234]:
torch_model = nn.Sequential(
    nn.Conv3d(8, 8, 3, padding=1),
    nn.BatchNorm3d(8),
    nn.MaxPool3d(2),
)
torch_model.to(device)

Sequential(
  (0): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [235]:
r2 = gspaces.rot2dOnR3(4)
in_type = enn.FieldType(r2, 2*[r2.regular_repr])
out_type = enn.FieldType(r2, 2*[r2.regular_repr])

escnn_conv = enn.R3Conv(
    in_type,
    out_type=out_type,
    kernel_size=3
)
escnn_bnorm = escnn_models.InnerBatchNorm3D(out_type)

escnn_model = enn.SequentialModule(
    escnn_conv,
    #enn.GNormBatchNorm(out_type),
    #enn.IIDBatchNorm3d(out_type),
    #enn.InnerBatchNorm(out_type),
    escnn_bnorm,
    escnn_models.NormMaxPool3D(out_type, kernel_size=2),
)

escnn_model.to(device)

SequentialModule(
  (0): R3Conv([C4_on_R3[(False, False, 4)]: {regular (x2)}(8)], [C4_on_R3[(False, False, 4)]: {regular (x2)}(8)], kernel_size=3, stride=1)
  (1): InnerBatchNorm3D([C4_on_R3[(False, False, 4)]: {regular (x2)}(8)], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): NormMaxPool3D()
)

In [236]:
N = 50
t0 = time.perf_counter()
torch_model.train()
for i in range(N):
    y = torch_model(x)
    if i < 5:
        continue
t1 = time.perf_counter()

total_time = t1-t0
avg_time = total_time / (N-5)

print(f"torch:")
print(f"  average time: {avg_time*1000:.3f} ms")

torch:
  average time: 135.407 ms


In [237]:
t0 = time.perf_counter()
in_x = in_type(x)
for i in range(N):
    y = escnn_model(in_x)
    
    if i < 5:
        continue

t1 = time.perf_counter()
total_time = t1-t0
avg_time = total_time / (N-5)

print(f"escnn:")
print(f"  average time: {avg_time*1000:.3f} ms")


escnn:
  average time: 70.955 ms


In [240]:
times = []
in_x = in_type(x)
for i in range(N):
    t0 = time.perf_counter()
    y = escnn_model(in_x)
    t1 = time.perf_counter()
    
    if i < 5:
        continue

    times.append(t1-t0)

print(f"escnn:")
print(f"  average time: {np.mean(times)*1000:.3f} ms")
print(f"  std: {np.std(times)*1000:.3f} ms")
print(f"  min: {np.min(times)*1000:.3f} ms")
print(f"  max: {np.max(times)*1000:.3f} ms")


escnn:
  average time: 62.276 ms
  std: 5.117 ms
  min: 56.660 ms
  max: 83.949 ms


In [241]:
torch_model.train()
with torch.autograd.profiler.profile(use_cpu=True) as prof:
    y = torch_model(x)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::conv3d         0.00%       4.666us        95.79%     138.735ms     138.735ms             1  
                aten::convolution         0.02%      32.417us        95.79%     138.730ms     138.730ms             1  
               aten::_convolution         0.01%      10.999us        95.76%     138.698ms     138.698ms             1  
                aten::slow_conv3d         0.00%       5.751us        95.76%     138.687ms     138.687ms             1  
        aten::slow_conv3d_forward        95.29%     138.011ms        95.75%     138.681ms     138.681ms             1  
                 aten::batch_norm       

In [243]:
escnn_model.train()
with torch.autograd.profiler.profile(use_cpu=True) as prof:
    y = escnn_model(in_type(x))
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::conv3d         0.00%       3.916us        91.77%      91.941ms      91.941ms             1  
                aten::convolution         0.01%      11.417us        91.76%      91.937ms      91.937ms             1  
               aten::_convolution         0.01%       8.416us        91.75%      91.925ms      91.925ms             1  
                aten::slow_conv3d         0.01%       8.083us        91.74%      91.917ms      91.917ms             1  
        aten::slow_conv3d_forward        91.46%      91.631ms        91.74%      91.909ms      91.909ms             1  
                 aten::batch_norm       