In [18]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [19]:
class MultiheadAttention(nn.Module):
    def __init__(self) -> None: #  *args, **kwargs
        super(MultiheadAttention, self).__init__()
        embed_dim = embed_dim
        num_heads = num_heads
        dropout = dropout
        head_dim = embed_dim // num_heads
        assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

In [21]:
h_ct = torch.rand((1, 64, 128, 128, 128))
h_pet = torch.rand((1, 64, 128, 128, 128))
print(f'128x128x128={128*128*128}')

attn = nn.MultiheadAttention(embed_dim=64, num_heads=4)
transformer = nn.TransformerDecoderLayer(d_model=64, nhead=4, dim_feedforward=64*4, activation=F.selu)

### Attention test
bz, d, *size = h_pet.shape
h_pet = h_pet.view(bz, d, -1).contiguous().transpose(1,2)
h_ct = h_ct.view(bz, d, -1).contiguous().transpose(1,2)

ct_inter, w_ = attn(h_pet, h_ct, h_ct)
pet_inter, w2 = attn(h_ct, h_pet, h_pet)

print(ct_inter.shape, pet_inter.shape)

### Transformer

ct_inter = transformer(tgt=h_pet, memory=h_pet)

128x128x128=2097152
torch.Size([1, 2097152, 64]) torch.Size([1, 2097152, 64])


In [None]:
class IntraInter_Attention(nn.Module):
    def __init__(self, d_model, n_head=4) -> None:
        super(IntraInter_Attention, self).__init__()
        # transformer 사용 vs Multihead Attn 사용 비교 필요 
        assert d_model % n_head == 0 

        self.PET_inter = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=d_model*4, activation=F.selu)
        self.CT_inter = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=d_model*4, activation=F.selu)

        self.PET_intra = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=d_model*4, activation=F.selu)
        self.CT_intra = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=d_model*4, activation=F.selu)

    def forward(self, h_ct, h_pet):
        bz, d, *size = h_pet.shape

        h_pet = h_pet.view(bz, d, -1).contiguous().transpose(1,2)
        h_ct = h_ct.view(bz, d, -1).contiguous().transpose(1,2)

        pet_inter = self.PET_inter(h_pet, h_ct)
        ct_inter  = self.CT_inter(h_ct, h_pet)

        pet_intra = self.PET_intra(h_pet, h_pet)
        ct_intra  = self.CT_intra(h_ct, h_ct)

        out = torch.concat([pet_inter, ct_inter, pet_intra, ct_intra], dim=2)
        out = out.transpose(1, 2).view(bz, d, size[0], size[1], size[2])

        return out # return shape: bz, d*4, *size_
