In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FeatureModulationBlock(nn.Module):
    def __init__(self):
        super(FeatureModulationBlock, self).__init__()
        
        # For simplicity, we'll use 1x1 convolutions to match the diagram.
        self.conv_g = nn.Conv2d(1, 1, kernel_size=1)  # For the g input
        self.conv_x = nn.Conv2d(1, 1, kernel_size=1)  # For the x input
        
        # Resampling is equivalent to applying a mask; no need for convolution here
        self.alpha_conv = nn.Conv2d(1, 1, kernel_size=1)

    def forward(self, g, x, mask):
        # Convert mask to the right format and apply it in the forward pass
        mask_tensor = torch.tensor(mask).unsqueeze(0).unsqueeze(0)  # Reshape to 1x1x2x2
        
        # Process g input
        g_out = self.conv_g(g)
        
        # Process x input
        x_out = self.conv_x(x)
        
        # Element-wise addition (represented by ⊕ in the diagram)
        combined = g_out + x_out
        
        # Apply ReLU (σ₁) after combining
        relu_out = F.relu(combined)
        
        # Sigmoid to modulate features
        sigmoid_out = torch.sigmoid(relu_out)
        
        # Resample or apply mask
        # In this case, directly multiply the result by the mask
        resampled = F.interpolate(mask_tensor, size=sigmoid_out.shape[2:], mode='bilinear', align_corners=False)
        
        # Apply mask (or scaling factor α) to sigmoid output
        alpha_out = resampled * sigmoid_out
        
        # Element-wise multiplication with input x
        output = alpha_out * x
        
        return output

In [3]:
initialMatrix = torch.tensor([[1, 1, 1], [1, 10, 10], [1, 10, 10]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape 1x1x3x3
mask = [[0.1, 0.1], [0.1, 0.1]]  # 2x2 mask


In [4]:
g_input = torch.ones_like(initialMatrix)  # 1x1x3x3

# Create the model and forward pass
model = FeatureModulationBlock()
output = model(g_input, initialMatrix, mask)

In [5]:
print("Output:")
print(output.squeeze().detach().numpy())  # Print the result in readable form

Output:
[[0.05       0.05       0.05      ]
 [0.05       0.5        0.5       ]
 [0.05       0.5        0.49999997]]
