In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import benchmark
import os
import json

In [25]:
DEBUG = 0
device = torch.device("cuda")
logpath = "benchmark.json"
if os.path.exists(logpath):
    os.remove(logpath)

In [26]:


widths = [56, 28, 14, 7] if not DEBUG else [56]

def regnet_parameters(num):
    channels = None
    group = None

    if num == "002":
        channels = [24, 56, 152, 368]
        group = 8
    elif num == "004":
        channels = [32, 64, 160, 384]
        group = 16
    elif num == "006":
        channels = [48, 96, 240, 528]
        group = 24
    elif num == "008":
        channels = [64, 128, 288, 672]
        group = 16
    else:
        raise NotImplementedError

    return channels, group if not DEBUG else ([64], 16)

In [35]:
def run_conv(conv, x, repeat=100):
    timer = benchmark.Timer(
        stmt="conv(x)",
        globals={'conv': conv, 'x': x},
    )
    res = timer.timeit(repeat).raw_times[0]
    return res / repeat

def run_add(x, repeat=100):
    timer = benchmark.Timer(
        stmt = "x + x",
        globals={'x': x}
    )
    res = timer.timeit(repeat).raw_times[0]
    return res / repeat


class ResidualBlock(nn.Module):
    def __init__(self, channels, width, group):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 1, stride=1, padding=0, dilation=1, groups=1, bias=False)
        self.conv2 = nn.Conv2d(channels, channels, 3, stride=1, padding=1, dilation=1, groups=group, bias=False)
        self.conv3 = nn.Conv2d(channels, channels, 1, stride=1, padding=0, dilation=1, groups=1, bias=False)
        self.channels = channels
        self.width = width
        self.group = group
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out += x
        return out

    def speedtest(self):
        x = torch.randn((1, self.channels, self.width, self.width))
        t1 = run_conv(self.conv1, x)
        t2 = run_conv(self.conv2, x)
        t3 = run_conv(self.conv3, x)
        ta = run_add(x)
        res =  {
            "channel": self.channels,
            "width": self.width,
            "group": self.group,
            "conv1": t1,
            "conv2": t2,
            "conv3": t3,
            "add": ta,
        }
        with open(logpath, "a+") as f:
            f.write(json.dumps(res) + "\n")



In [36]:
for num in ["002", "004", "006", "008"]:
    channels, group = regnet_parameters(num)
    for i, width in enumerate(widths):
        channel = channels[i]
        rb = ResidualBlock(channel, width, group)
        rb.speedtest()