In [239]:
import torch
from torch import nn, Tensor

In [240]:
batch = 1
in_channels = 3
out_channels = 512
patch_size = 16
num_patches = 2

# init image

In [241]:
image: Tensor = torch.ones(
    batch, in_channels, num_patches * patch_size, num_patches * patch_size
)
image_flatten: Tensor = (
    image.reshape(batch, in_channels, num_patches, patch_size, num_patches, patch_size)
    .permute(0, 2, 4, 1, 3, 5)
    .reshape(batch, num_patches * num_patches, in_channels * patch_size * patch_size)
)
image_flatten.shape

torch.Size([1, 4, 768])

# init parameters

In [242]:
weight: Tensor = torch.arange(
    out_channels * in_channels * patch_size * patch_size, dtype=torch.float32
).reshape(out_channels, in_channels, patch_size, patch_size)
weight_flatten: Tensor = weight.flatten(1)
print(weight.shape)
print(weight_flatten.shape)

torch.Size([512, 3, 16, 16])
torch.Size([512, 768])


In [243]:
bias: Tensor = torch.arange(512, dtype=torch.float32)
bias.shape

torch.Size([512])

# init conv

In [244]:
conv: nn.Conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=16, stride=16).eval()
conv

Conv2d(3, 512, kernel_size=(16, 16), stride=(16, 16))

In [245]:
# out_c, in_c, k_h, k_w
print(conv.weight.data.shape)
print(conv.bias.data.shape)

torch.Size([512, 3, 16, 16])
torch.Size([512])


## replace weight and bias

In [246]:
conv.weight.data = weight
conv.bias.data = bias
print(conv.weight.data.shape)
print(conv.bias.data.shape)

torch.Size([512, 3, 16, 16])
torch.Size([512])


# init linear

In [247]:
linear: nn.Linear = nn.Linear(
    in_channels * patch_size * patch_size, out_channels
).eval()
linear

Linear(in_features=768, out_features=512, bias=True)

In [248]:
print(linear.weight.data.shape)
print(linear.bias.data.shape)

torch.Size([512, 768])
torch.Size([512])


## replace weight and bias

In [249]:
linear.weight.data = weight_flatten
linear.bias.data = bias
print(linear.weight.data.shape)
print(linear.bias.data.shape)

torch.Size([512, 768])
torch.Size([512])


# run

In [257]:
with torch.inference_mode():
    conv_ret: Tensor = conv(image)
print(conv_ret.shape)
conv_ret = conv_ret.reshape(batch, out_channels, -1).transpose(1, 2)
print(conv_ret.shape)

torch.Size([1, 512, 2, 2])
torch.Size([1, 4, 512])


In [258]:
with torch.inference_mode():
    linear_ret: Tensor = linear(image_flatten)
linear_ret.shape

torch.Size([1, 4, 512])

In [269]:
conv_ret[0, 0, :32]

tensor([  294528.,   884353.,  1474178.,  2064003.,  2653828.,  3243653.,
         3833478.,  4423303.,  5013128.,  5602953.,  6192778.,  6782603.,
         7372428.,  7962253.,  8552078.,  9141903.,  9731728., 10321553.,
        10911378., 11501203., 12091028., 12680853., 13270678., 13860503.,
        14450328., 15040153., 15629978., 16219803., 16809628., 17399452.,
        17989272., 18579092.])

In [270]:
linear_ret[0, 0, :32]

tensor([  294528.,   884353.,  1474178.,  2064003.,  2653828.,  3243653.,
         3833478.,  4423303.,  5013128.,  5602953.,  6192778.,  6782603.,
         7372428.,  7962253.,  8552078.,  9141903.,  9731728., 10321553.,
        10911378., 11501203., 12091028., 12680853., 13270678., 13860503.,
        14450328., 15040153., 15629978., 16219803., 16809628., 17399452.,
        17989278., 18579104.])

In [267]:
torch.all(conv_ret == linear_ret)

tensor(False)

In [268]:
torch.allclose(conv_ret, linear_ret)

True