In [None]:
class FusionNet(nn.Module):
    def __init__(self, d_model=6, num_heads=1):
        super().__init__()
        self.d_model = d_model
        self.pe = nn.Parameter(PositionalEncoder(d_model)())
        
        self.img_preproc = nn.Sequential(
            nn.Linear(512, 60),
            nn.ELU(),
        )
        
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        self.conv_net = nn.Sequential(
            nn.Conv3d(6, 8, kernel_size=3, padding=1),            # perceptive field = 3
            nn.ELU(),
            nn.Conv3d(8, 16, kernel_size=3, padding=1),           # perceptive field = 5
            nn.ELU(),
            nn.Conv3d(16, 32, kernel_size=3, padding=1),         # perceptive field = 7
            nn.ELU(),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),        # perceptive field = 9
            nn.ELU(),
            nn.Conv3d(64, 128, kernel_size=3, padding=1),        # perceptive field = 11
            nn.ELU(),
            nn.MaxPool3d(kernel_size=(9, 11, 11))
        )
        
        self.img_mlp = nn.Sequential(
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.ELU(),
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(128 + 128, 256),
            nn.ELU(),
            nn.Linear(256, 256),
            nn.ELU(),
        )
        
    def forward(self, target_features, img_features):
        batch_size = target_features.shape[0]
        
        img_features2 = self.img_preproc(img_features)
        target_features = target_features.permute(0, 2, 3, 4, 1).reshape(batch_size, 9 * 11 * 11, self.d_model)
        img_features2 = img_features2.reshape(batch_size, -1, self.d_model)
        target_features += self.cross_attn(key=img_features2, value=img_features2, query=target_features)[0]
        k = q = target_features + self.pe
        target_features += self.self_attn(key=k, value=target_features, query=q)[0]
        
        target_features = target_features.reshape(batch_size, 9, 11, 11, self.d_model).permute(0, 4, 1, 2, 3)
        target_features = self.conv_net(target_features).reshape(batch_size, -1)
        
        img_features = self.img_mlp(img_features)
        
        features = torch.cat([target_features, img_features], dim=1)
        features = self.mlp(features)
        
        return features

In [None]:
class FusionNet(nn.Module):
    def __init__(self, d_model=6, num_heads=1):
        super().__init__()
        self.d_model = d_model
        self.pe = nn.Parameter(PositionalEncoder(d_model)())
        
        self.img_preproc = nn.Sequential(
            nn.Linear(512, 120),
            nn.ELU(),
        )
        
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.self_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        self.conv_net = nn.Sequential(
            nn.Conv3d(6, 16, kernel_size=3, padding=1),             # perceptive field = 3
            nn.ELU(),
            nn.Conv3d(16, 32, kernel_size=3, padding=1),            # perceptive field = 5
            nn.ELU(),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),            # perceptive field = 7
            nn.ELU(),
            nn.Conv3d(64, 128, kernel_size=3, padding=1),           # perceptive field = 9
            nn.ELU(),
            nn.Conv3d(128, 256, kernel_size=(1, 3, 3), padding=1),  # perceptive field = (9, 11, 11)
            nn.ELU(),
            nn.MaxPool3d(kernel_size=(9, 11, 11))
        )
        
        self.img_mlp = nn.Sequential(
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, 256),
            nn.ELU(),
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(512, 512),
            nn.ELU(),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, 256),
            nn.ELU()
        )
        
    def forward(self, target_features, img_features):
        batch_size = target_features.shape[0]
        
        img_features2 = self.img_preproc(img_features)
        target_features = target_features.permute(0, 2, 3, 4, 1).reshape(batch_size, 9 * 11 * 11, self.d_model)
        img_features2 = img_features2.reshape(batch_size, -1, self.d_model)
        target_features += self.cross_attn(key=img_features2, value=img_features2, query=target_features)[0]
        k = q = target_features + self.pe
        target_features += self.self_attn(key=k, value=target_features, query=q)[0]
        
        target_features = target_features.reshape(batch_size, 9, 11, 11, self.d_model).permute(0, 4, 1, 2, 3)
        target_features = self.conv_net(target_features).reshape(batch_size, -1)
        
        img_features = self.img_mlp(img_features)
        
        features = torch.cat([target_features, img_features], dim=1)
        features = self.mlp(features)
        
        return features