In [1]:
import torch 
from torch import nn

In [103]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward=None, activation=nn.ReLU):
        super().__init__()
        if dim_feedforward is None:
            dim_feedforward = 4 * d_model
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
        # Implementation of feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.activation = activation()
        
    def forward(self, query, key, value):
        src = query
        src2 = self.self_attn(query=query, key=key, value=value)[0]
        src = src + src2
        src2 = self.linear2(self.activation(self.linear1(src2)))
        src = src + src2
        return src

In [127]:
class FusionNet(nn.Module):
    def __init__(self, d_model=8, num_heads=1):
        super().__init__()
        self.target_cross_attn_1 = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads)
        self.img_cross_attn_1 = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads)
        self.conv_1 = nn.Conv3d(d_model, 2 * d_model, kernel_size=3, stride=1)
        self.act_1 = nn.ELU()
        
        self.target_cross_attn_2 = TransformerEncoderLayer(d_model=2 * d_model, num_heads=num_heads)
        self.img_cross_attn_2 = TransformerEncoderLayer(d_model=2 * d_model, num_heads=num_heads)
        self.conv_2 = nn.Conv3d(2 * d_model, 4 * d_model, kernel_size=3, stride=1)
        self.act_2 = nn.ELU()
        
        self.target_cross_attn_3 = TransformerEncoderLayer(d_model=4 * d_model, num_heads=num_heads)
        self.img_cross_attn_3 = TransformerEncoderLayer(d_model=4 * d_model, num_heads=num_heads)
        self.conv_3 = nn.Conv3d(4 * d_model, 8 * d_model, kernel_size=3, stride=1)
        self.act_3 = nn.ELU()
        
        self.target_cross_attn_4 = TransformerEncoderLayer(d_model=8 * d_model, num_heads=num_heads)
        self.img_cross_attn_4 = TransformerEncoderLayer(d_model=8 * d_model, num_heads=num_heads)
        self.conv_4 = nn.Conv3d(8 * d_model, 16 * d_model, kernel_size=3, stride=1)
        self.act_4 = nn.ELU()
        
        self.fusion_head = nn.Sequential(
            nn.Linear(9 * 128 + 512, 1024),
            nn.ELU(),
            nn.Linear(1024, 512),
            nn.ELU(),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, 256),
            nn.ELU(),
        )
        
        
    def forward(self, target, img_features):
        batch_size = target.shape[0]
        
        # layer 1
        target = target.permute(0, 2, 3, 4, 1).reshape(batch_size, 9*11*11, 8)
        img = img_features.reshape(batch_size, 64, 8)
        target_1 = self.target_cross_attn_1(query=target, key=img, value=img)
        img_1 = self.img_cross_attn_1(query=img, key=target, value=target)
        target_1 = target_1.reshape(batch_size, 9, 11, 11, 8).permute(0, 4, 1, 2, 3)
        target_1 = self.act_1(self.conv_1(target_1))
        img_1 = img_1.reshape(batch_size, 512)
        
        # layer 2
        target_1 = target_1.permute(0, 2, 3, 4, 1).reshape(batch_size, 7*9*9, 16)
        img_1 = img_1.reshape(batch_size, 32, 16)
        target_2 = self.target_cross_attn_2(query=target_1, key=img_1, value=img_1)
        img_2 = self.img_cross_attn_2(query=img_1, key=target_1, value=target_1)
        target_2 = target_2.reshape(batch_size, 7, 9, 9, 16).permute(0, 4, 1, 2, 3)
        target_2 = self.act_2(self.conv_2(target_2))
        img_2 = img_2.reshape(batch_size, 512)
        
        # layer 3
        target_2 = target_2.permute(0, 2, 3, 4, 1).reshape(batch_size, 5*7*7, 32)
        img_2 = img_2.reshape(batch_size, 16, 32)
        target_3 = self.target_cross_attn_3(query=target_2, key=img_2, value=img_2)
        img_3 = self.img_cross_attn_3(query=img_2, key=target_2, value=target_2)
        target_3 = target_3.reshape(batch_size, 5, 7, 7, 32).permute(0, 4, 1, 2, 3)
        target_3 = self.act_3(self.conv_3(target_3))
        img_3 = img_3.reshape(batch_size, 512)
        
        # layer 4
        target_3 = target_3.permute(0, 2, 3, 4, 1).reshape(batch_size, 3*5*5, 64)
        img_3 = img_3.reshape(batch_size, 8, 64)
        target_4 = self.target_cross_attn_4(query=target_3, key=img_3, value=img_3)
        img_4 = self.img_cross_attn_4(query=img_3, key=target_3, value=target_3)
        target_4 = target_4.reshape(batch_size, 3, 5, 5, 64).permute(0, 4, 1, 2, 3)
        target_4 = self.act_4(self.conv_4(target_4))
        img_4 = img_4.reshape(batch_size, 512)
        
        features = torch.cat([target_4.reshape(batch_size, -1), img_4.reshape(batch_size, -1)], dim=1)
        features = self.fusion_head(features)
        
        return features

In [119]:
net = FusionNet()

In [120]:
target = torch.rand(1, 8, 9, 11, 11)
img = torch.rand(1, 512)
res = net(target, img)

In [121]:
res.shape

torch.Size([1, 256])

In [122]:
sum(p.numel() for p in net.parameters())

2853600

In [3]:
d_model = 7
num_heads = 1

target_cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
img_cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
conv = nn.Conv3d(7, 7, kernel_size=3, stride=1)

In [5]:
target = torch.rand(1, 9*11*11, d_model)
img = torch.rand(1, 50, d_model)

In [9]:
target_2 = target_cross_attn(query=target, 
                             key=img, 
                             value=img)[0]

img_2 = target_cross_attn(query=img, 
                          key=target, 
                          value=target)[0]

In [12]:
target_2.shape

torch.Size([1, 1089, 7])