In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import fused_add_relu_ext

class SimpleVGG(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        # example: call custom op between conv layers if you want
        x = self.features(x)
        x = self.classifier(x)
        return x


In [None]:

# usage
if __name__ == "__main__":
    model = SimpleVGG()
    inp = torch.randn(8, 3, 32, 32)  # batch 8
    out = model(inp)
    print(out.shape)

    # Own module test
    a = torch.randn(10)
    b = torch.randn(10)
    out = fused_add_relu_ext.fused_add_relu(a, b)
    print(out)

torch.Size([8, 10])
tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0966, 0.6735, 0.7647, 0.0000,
        3.3516])
