## BilinearInteraction

bilinear interaction in fibinet

In [None]:
import torch
from torch import nn
import itertools


class Bilinear(nn.Module):
    """A high performance implementation of bilinear interaction in fibinet
    ref: https://github.com/shenweichen/DeepCTR-Torch/blob/master/deepctr_torch/layers/interaction.py#L104C1-L156C35
    ref: https://arxiv.org/abs/2209.05016

    No LoRA
    """
    def __init__(self, field_num, field_emb_dim, method='each_ip'):
        """
        Args:
            method: share, each, each_ip, share_ip. P2P is costly, not suppoted currently
        """
        super().__init__()
        self.method = method
        if method == 'share' or method == 'share_ip':
            self.W = nn.Parameter(torch.rand(field_emb_dim, field_emb_dim))
        elif method == 'each' or method == 'each_ip':
            self.W = nn.Parameter(torch.rand(field_num, field_emb_dim, field_emb_dim))
        else:
            raise NotImplementedError

    def forward(self, x):
        """

        Args:
            x (torch.Tensor): (bs, f, d)

        Raises:
            NotImplementedError:

        Returns:
            (torch.Tensor): (bs, f(f-1)//2, d)
        """
        if self.method == 'each_ip':
            x_i = torch.unsqueeze(x, dim=2)  # (bs, f, 1, d)
            x_j = torch.unsqueeze(x, dim=1)  # (bs, 1, f, d)
            ip  = torch.matmul(x_i, self.W) * x_j

            idx = itertools.combinations(list(range(x.shape[1])), r=2)
            idx = torch.tensor(list(idx))
            return ip[:, idx[:, 0], idx[:, 1]] 
        
        elif self.method == 'share_ip':
            x_i = torch.unsqueeze(x, dim=2)  # (bs, f, 1, d)
            x_j = torch.unsqueeze(x, dim=1)  # (bs, 1, f, d)
            ip  = x_i @ self.W * x_j

            idx = itertools.combinations(list(range(x.shape[1])), r=2)
            idx = torch.tensor(list(idx))
            return ip[:, idx[:, 0], idx[:, 1]]  # (bs, f*f(-1)/2, d)
        
        else:
            raise NotImplementedError