In [1]:
import torch
import torch.nn as nn
# from fusion import fuse_conv_and_bn
from torch.nn.utils import fuse_conv_bn_eval

In [2]:
torch.set_grad_enabled(False)

class ConvBNModel(nn.Module):
    def __init__(self, in_channels: int = 3, out_channels: int = 16, kernel_size: int = 3, groups: int = 1):
        super(ConvBNModel, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=1,
            bias=True,
            groups=groups
        )
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

torch.manual_seed(42)  # For reproducibility

layer_params = {"in_channels": 4096,
                "out_channels": 8192,
                "kernel_size": 3,
                "groups": 1}
layer = ConvBNModel(**layer_params)
# layer.conv.to(dtype=torch.double)
# layer.bn.to(dtype=torch.double)
layer.eval()

# fused_layer = fuse_conv_and_bn(layer.conv, layer.bn)
# fused_layer.eval()

fused_layer = fuse_conv_bn_eval(layer.conv, layer.bn)
fused_layer.eval()

# dummy = torch.randn((1, layer.conv.in_channels, 16, 16)).double()
dummy = torch.randn((1, layer.conv.in_channels, 16, 16))
out1 = layer(dummy)
out2 = fused_layer(dummy)

print(f"Output diff: {torch.linalg.norm(out1 - out2)}")

Output diff: 0.0030515091493725777


In [14]:
layer.float()
fused_layer.float()
dummy = dummy.float()

res1 = layer(dummy)
res2 = fused_layer(dummy)
print(f"Output diff: {torch.linalg.norm(res1 - res2)}")
# Slightly smaller than the full-flow float32 conversion

Output diff: 0.003050478408113122


In [3]:
a = torch.tensor([257])

In [4]:
a.to(dtype=torch.uint8)

tensor([1], dtype=torch.uint8)

In [5]:
a

tensor([257])