In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [2]:
# (N, C, D, H, W)
inputs = torch.rand(1, 1, 2, 6, 6)
in_channel = 1
key_filters = 2
value_filters = 2
output_filters = 3
num_heads = 1
dropout_prob = 0.5 
layer_type = 'SAME'

# QKV transfrom

In [3]:
if layer_type == 'SAME':
    q = nn.Conv3d(in_channel, key_filters,
                           kernel_size=1, stride=1,
                           padding=0, bias=True)(inputs)
elif layer_type == 'DOWN':
    q = nn.Conv3d(in_channel, key_filters,
                           kernel_size=3, stride=2,
                           padding=1, bias=True)(inputs)
elif layer_type == 'UP':
    q = nn.ConvTranspose3d(in_channel, key_filters, kernel_size=3, 
                           stride=2, padding=1, bias=True)(inputs, output_size=(inputs.shape[2] * 2,
                                                                                inputs.shape[3] * 2,
                                                                                inputs.shape[4] * 2))

k = nn.Conv3d(in_channel, key_filters, kernel_size=1, 
                         stride=1, padding=0, bias=True)(inputs)

v = nn.Conv3d(in_channel, key_filters, kernel_size=1, 
                           stride=1, padding=0, bias=True)(inputs)

Batch, Dq, Hq, Wq = q.shape[0], q.shape[2], q.shape[3], q.shape[4]
print("q.shape: ", q.shape, "\nk.shape: ", k.shape, "\nv.shape: ", v.shape)


q.shape:  torch.Size([1, 2, 2, 6, 6]) 
k.shape:  torch.Size([1, 2, 2, 6, 6]) 
v.shape:  torch.Size([1, 2, 2, 6, 6])


# Split to Multi Heads

In [4]:
q = q.permute(0, 2, 3, 4, 1)
k = k.permute(0, 2, 3, 4, 1)
v = v.permute(0, 2, 3, 4, 1)
print("Switch channel position.")
print("q.shape: ", q.shape, "\nk.shape: ", k.shape, "\nv.shape: ", v.shape)

q = q.view(q.shape[0], q.shape[1], q.shape[2], q.shape[3], num_heads,
           int(q.shape[-1] / num_heads))

k = k.view(k.shape[0], k.shape[1], k.shape[2], k.shape[3], num_heads,
           int(k.shape[-1] / num_heads))

v = v.view(v.shape[0], v.shape[1], v.shape[2], v.shape[3], num_heads,
           int(v.shape[-1] / num_heads))
print("\nSplit to Multi Heads.")
print("q.shape: ", q.shape, "\nk.shape: ", k.shape, "\nv.shape: ", v.shape)

q = torch.flatten(q, start_dim=0, end_dim=4)
k = torch.flatten(k, start_dim=0, end_dim=4)
v = torch.flatten(v, start_dim=0, end_dim=4)
print("\nFlatten to shape (N * D * H * W) x C")
print("q.shape: ", q.shape, "\nk.shape: ", k.shape, "\nv.shape", v.shape)

Switch channel position.
q.shape:  torch.Size([1, 2, 6, 6, 2]) 
k.shape:  torch.Size([1, 2, 6, 6, 2]) 
v.shape:  torch.Size([1, 2, 6, 6, 2])

Split to Multi Heads.
q.shape:  torch.Size([1, 2, 6, 6, 1, 2]) 
k.shape:  torch.Size([1, 2, 6, 6, 1, 2]) 
v.shape:  torch.Size([1, 2, 6, 6, 1, 2])

Flatten to shape (N * D * H * W) x C
q.shape:  torch.Size([72, 2]) 
k.shape:  torch.Size([72, 2]) 
v.shape torch.Size([72, 2])


# Attention

In [5]:
scale = (key_filters // num_heads) ** 0.5
# normalize
q = q / scale

# attention
# [(B, Dq, Hq, Wq, N), (B, D, H, W, N)]
A = torch.matmul(q, k.transpose(0, 1))
A = torch.softmax(A, dim=1)
A = nn.Dropout(dropout_prob)(A)
print("Compute softmax(Q * K.Trans) along dim 1.")
print("A.shape: ", A.shape)

# [(B, Dq, Hq, Wq, N), C]
out = torch.matmul(A, v)
print("\nCompute matmul(A, v)")
print("Out.shape: ", out.shape)

# [B, Dq, Hq, Wq, C]
out = out.view(Batch, Dq, Hq, Wq, v.shape[-1] * num_heads)
out = out.permute(0, 4, 1, 2, 3)
print("\nCombine multi-heads back to shape of Q.")
print("Out.shape: ", out.shape)


Compute softmax(Q * K.Trans) along dim 1.
A.shape:  torch.Size([72, 72])

Compute matmul(A, v)
Out.shape:  torch.Size([72, 2])

Combine multi-heads back to shape of Q.
Out.shape:  torch.Size([1, 2, 2, 6, 6])


# Output Convolution

In [6]:
out = nn.Conv3d(value_filters, output_filters, kernel_size=1, stride=1, padding=0, bias=True)(out)
print("O.shape: ", out.shape)

O.shape:  torch.Size([1, 3, 2, 6, 6])
