In [100]:
from typing import Literal, overload
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from utils.size import Size2_t, Size2_p, to_size


def conv2d_b(
    inp: Tensor,
    weight: Tensor,
    dilation: Size2_p = 1,
    padding: Size2_p = 0,
    stride: Size2_p = 1,
) -> Tensor:
    """
    Shape:
        inp: (*, InC, H, W)
        weight: (OutC, InC, KH, KW)
    """
    assert weight.dim() == 4
    kernel_size = weight.shape[-2:]
    dilation = to_size(2, dilation)
    padding = to_size(2, padding)
    stride = to_size(2, stride)

    blocks = F.unfold(inp, kernel_size, dilation, padding, stride).transpose(-2, -1)
    weight = weight.view(weight.shape[0], -1).t()
    out = blocks @ weight

    h_out = (
        inp.shape[-2] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
    ) // stride[0] + 1
    w_out = (
        inp.shape[-1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
    ) // stride[1] + 1
    out = out.transpose(-2, -1).view(inp.shape[0], weight.shape[1], h_out, w_out)
    return out


def conv2d_d(
    inp: Tensor,
    weight: Tensor,
    dilation: Size2_p = 1,
    padding: Size2_p = 0,
    stride: Size2_p = 1,
) -> Tensor:
    """
    Shape:
        inp: (*, InC, H, W)
        weight: (OutC, InC, KH, KW)
    """
    assert weight.dim() == 4
    kernel_size = weight.shape[-2:]
    dilation = to_size(2, dilation)
    padding = to_size(2, padding)
    stride = to_size(2, stride)

    blocks = F.unfold(inp, kernel_size, dilation, padding, stride).transpose(-2, -1)
    weight = weight.view(weight.shape[0], -1).t()

    out_stack: list[Tensor] = []
    for i in range(blocks.shape[-2]):
        i_blocks = blocks[..., i : i + 1, :]
        i_out = i_blocks @ weight
        out_stack.append(i_out)
    out = torch.cat(out_stack, dim=-2)

    h_out = (
        inp.shape[-2] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
    ) // stride[0] + 1
    w_out = (
        inp.shape[-1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
    ) // stride[1] + 1
    out = out.transpose(-2, -1).view(inp.shape[0], weight.shape[1], h_out, w_out)
    return out


device = torch.device("cuda")
dtype = torch.float32
inp = torch.randn(1, 16, 16, 16, dtype=dtype, device=device)

w1 = torch.randn(32, 16, 3, 3, dtype=dtype, device=device)
w2 = torch.randn(64, 32, 3, 3, dtype=dtype, device=device)
w3 = torch.randn(64, 64, 3, 3, dtype=dtype, device=device)
w4 = torch.randn(64, 64, 3, 3, dtype=dtype, device=device)
w5 = torch.randn(64, 64, 3, 3, dtype=dtype, device=device)

nn.init.kaiming_normal_(w1)
nn.init.kaiming_normal_(w2)
nn.init.kaiming_normal_(w3)
nn.init.kaiming_normal_(w4)
nn.init.kaiming_normal_(w5)

b = inp
b = conv2d_b(b, w1, padding=1)
b = conv2d_b(b, w2, padding=1)
b = conv2d_b(b, w3, padding=1)
b = conv2d_b(b, w4, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)
b = conv2d_b(b, w5, padding=1)

d = inp
d = conv2d_d(d, w1, padding=1)
d = conv2d_d(d, w2, padding=1)
d = conv2d_d(d, w3, padding=1)
d = conv2d_d(d, w4, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)
d = conv2d_d(d, w5, padding=1)

equal = (b == d).count_nonzero().item() / b.numel() * 100
close = torch.isclose(b, d).count_nonzero().item() / b.numel() * 100
print(f"Equal {equal:.2f}%")
print(f"Close {close:.2f}%")
print("MaxDiff", (b - d).abs().max().item())


Equal 4.55%
Close 96.12%
MaxDiff 0.000640869140625


In [23]:
import torch

# torch.use_deterministic_algorithms(True)
dtype = torch.float16
device = torch.device("cuda")

# set random seed for reproducibility
torch.manual_seed(42)

# define the tensors for matrix multiplication
batch_size = 1
input_dim = 64
hidden_dim = 512

x = torch.randn(batch_size, input_dim, hidden_dim, dtype=dtype, device=device)
y = torch.randn(batch_size, hidden_dim, input_dim, dtype=dtype, device=device)

# perform batch matrix multiplication using torch.bmm()
output = torch.bmm(x, y)

# verify determinism
output2 = torch.bmm(x[..., :1, :], y)

b, d = output[..., 0, :], output2[..., 0, :]
equal = (b == d).count_nonzero().item() / b.numel() * 100
close = torch.isclose(b, d).count_nonzero().item() / b.numel() * 100
print(f"Equal {equal:.2f}%")
print(f"Close {close:.2f}%")
print("MaxDiff", (b - d).abs().max().item())


Equal 100.00%
Close 100.00%
MaxDiff 0.0
