In [1]:
from typing import Literal, overload

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from modules.df_conv2d import conv2d_b, conv2d_d
from utils.debug import print_diff
from utils.functions import stable_softmax
from utils.size import Size2_p, Size2_t, to_size


In [2]:
def get_mask(attn: Tensor) -> Tensor:
    mask = torch.ones(attn.shape[-2:], device=attn.device, dtype=attn.dtype).tril(0)
    return mask


def attention_b(
    inp: Tensor,
    v_size: int,
    k_size: int,
    weight: Tensor,
) -> Tensor:
    *b, c, h, w = inp.shape
    assert weight.shape == (2 * k_size + v_size, c, 1, 1)
    qkv = conv2d_b(inp, weight)
    q, k, v = (
        t.view(*b, -1, h * w)
        for t in torch.split_with_sizes(qkv, [k_size, k_size, v_size], -3)
    )

    attn = q.transpose(-2, -1) @ k
    mask = get_mask(attn)
    attn = attn + mask.where(mask == 0, -torch.inf)
    attn = stable_softmax(attn, -2).nan_to_num()
    attn = attn.triu(1)

    out = v @ attn
    return out.view(*b, v_size, h, w)


def attention_d(
    inp: Tensor,
    v_size: int,
    k_size: int,
    weight: Tensor,
) -> Tensor:
    *b, c, h, w = inp.shape
    assert weight.shape == (2 * k_size + v_size, c, 1, 1)
    out = torch.zeros(*b, v_size, h, w)
    qkv = torch.zeros(*b, 2 * k_size + v_size, h, w)
    for y in range(h):
        for x in range(w):
            qkv[..., y, x] = conv2d_d(inp, weight, pos=(y, x)).squeeze()
            q, k, v = (
                t.view(*b, -1, h * w)
                for t in torch.split_with_sizes(qkv, [k_size, k_size, v_size], -3)
            )
            i = y * w + x
            q = q[..., :i]
            k = k[..., i : i + 1]
            v = v[..., :i]
            attn = q.transpose(-2, -1) @ k
            attn = stable_softmax(attn, -3).nan_to_num()
            out[..., y, x] = (v @ attn).squeeze()
    return out


inp = torch.randn(64, 64, 28, 28)
weight = torch.randn(8 + 8 + 64, 64, 1, 1)
ab = attention_b(inp, 64, 8, weight)
ad = attention_d(inp, 64, 8, weight)
print_diff(ab, ad)


Equal 80.24%
Close 97.50%
MaxDiff 1.68e-03
A:
    Range [-4.29e+01, 4.26e+01]
    Mean -9.45e-04
    Std 8.50e+00
B:
    Range [-4.29e+01, 4.26e+01]
    Mean -9.45e-04
    Std 8.50e+00    


In [2]:
a = torch.randn(256, 32, 256, device="cuda")
b = torch.randn(256, 256, 64, device="cuda")

c1 = a @ b
c2 = torch.stack([a[:, i:i+1, :] @ b for i in range(a.shape[1])], dim=1).squeeze()

print_diff(c1, c2)
# torch.allclose(c1, c2)
# torch.equal(c1, c2)
# equal = ((c1 == c2).count_nonzero() / c1.numel()).item() * 100
# close = (torch.isclose(c1, c2).count_nonzero() / c1.numel()).item() * 100
# maxdiff = (c1 - c2).abs().max().item()
# print(f"Equal {equal:.2f}%")
# print(f"Close {close:.2f}%")
# print(f"MaxDiff {maxdiff:.2e}")

Equal 10.00%
Close 98.95%
MaxDiff 5.7220458984375e-05
A:
    Range: [-7.59e+01, 7.46e+01]
    Mean: -2.49e-02
    Std: 1.60e+01
B:
    Range: [-7.59e+01, 7.46e+01]
    Mean: -2.49e-02
    Std: 1.60e+01
