In [None]:
import torch
import torch.nn as nn

class TFN(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim, output_dim=1):
        super().__init__()

        # Add the bias dimension (+1)
        self.t_dim = t_dim + 1
        self.a_dim = a_dim + 1
        self.v_dim = v_dim + 1

        # Final classifier
        fused_dim = self.t_dim * self.a_dim * self.v_dim
        self.fc = nn.Linear(fused_dim, output_dim)

    def forward(self, t, a, v):
        batch_size = t.size(0)

        # Add bias (1)
        t_ = torch.cat([t, torch.ones(batch_size, 1)], dim=1)
        a_ = torch.cat([a, torch.ones(batch_size, 1)], dim=1)
        v_ = torch.cat([v, torch.ones(batch_size, 1)], dim=1)

        # Compute outer products
        # (B, t_dim, 1, 1)
        t_ = t_.unsqueeze(2).unsqueeze(3)
        # (B, 1, a_dim, 1)
        a_ = a_.unsqueeze(1).unsqueeze(3)
        # (B, 1, 1, v_dim)
        v_ = v_.unsqueeze(1).unsqueeze(2)

        # Tensor fusion (B, t_dim, a_dim, v_dim)
        z = t_ * a_ * v_

        # Flatten (B, t_dim * a_dim * v_dim)
        z = z.view(batch_size, -1)

        return self.fc(z)


# Example
t = torch.randn(16, 300)
a = torch.randn(16, 74)
v = torch.randn(16, 35)

model = TFN(300, 74, 35)
out = model(t, a, v)
print(out.shape)
