In [5]:
import torch
import torch.nn as nn

class BinaryConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(BinaryConv2d, self).__init__()
        
        # Initialize a regular convolution layer
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    
    def forward(self, x):
        # Store the real-valued weights
        real_weights = self.conv.weight.data.clone()
        
        # Binarize the weights using the sign function
        binarized_weights = torch.sign(self.conv.weight.data)
        
        # Replace the regular weights with binarized weights for the forward pass
        self.conv.weight.data = binarized_weights
        
        # Convolution operation
        out = self.conv(x)
        
        # Restore the real-valued weights after forward pass
        self.conv.weight.data = real_weights
        
        return out

# Example usage:

# Create a random input tensor with:
# batch size = 8
# number of channels = 3 (RGB)
# height = 32
# width = 32
input_tensor = torch.randn(8, 3, 32, 32)

binary_conv = BinaryConv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
output = binary_conv(input_tensor)
output

tensor([[[[-6.0132e+00,  4.2320e+00,  4.0855e+00,  ..., -1.0014e+00,
           -2.3532e+00,  2.9424e-01],
          [ 6.4939e-01,  2.1306e+00, -8.6384e+00,  ..., -1.0533e+00,
           -5.3747e+00, -3.3081e+00],
          [ 7.0863e+00,  3.5534e+00, -1.1742e+01,  ...,  1.1391e+00,
           -5.6110e+00,  3.4204e+00],
          ...,
          [-1.1030e+00,  1.6746e+00, -2.2885e+00,  ...,  9.1610e+00,
            9.1672e-01,  2.6167e+00],
          [ 2.6493e+00,  3.2729e+00,  9.0303e-01,  ...,  1.7693e+00,
           -1.0004e+00,  4.2441e+00],
          [-2.4600e-01,  7.4067e-01, -3.7850e+00,  ...,  5.5284e+00,
            3.4767e+00,  4.6582e-01]],

         [[ 1.4240e+00, -6.0525e+00,  6.6013e+00,  ...,  4.3341e+00,
           -7.5690e+00, -4.6511e+00],
          [-4.7679e+00,  4.6895e+00,  8.4140e+00,  ..., -4.8949e+00,
           -8.0917e-01,  3.6194e-01],
          [-3.5534e+00,  9.1067e+00, -1.2436e+00,  ...,  4.0450e+00,
            6.3616e+00, -4.6767e+00],
          ...,
     