In [1]:
import torch

In [2]:
import numpy as np
import torch
import torch.nn.functional as F


In [3]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.fx import subgraph_rewriter, symbolic_trace
import utils
from torch.fx import Proxy, Graph, GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

replace success!


In [4]:
class FeaturesLinear(torch.nn.Module):

    def __init__(self, field_dims, output_dim=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        # self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # x = x + x.new_tensor(self.offsets).unsqueeze(0)
        x = x + self.offsets

        return torch.sum(self.fc(x), dim=1) + self.bias

In [5]:
class FieldAwareFactorizationMachine(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.num_fields = len(field_dims)
        self.embeddings = torch.nn.ModuleList([
            torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)
        ])
        # self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))

        for embedding in self.embeddings:
            torch.nn.init.xavier_uniform_(embedding.weight.data)
    def ffm_interaction(self,field_wise_emb_list):
        dot = 0
        num_fields = self.num_fields
        for i in range(num_fields - 1):
            for j in range(i + 1, num_fields):
                v_ij = field_wise_emb_list[j - 1][:, i, :]
                v_ji = field_wise_emb_list[i][:, j, :]
                dot += torch.sum(v_ij * v_ji, dim=1, keepdim=True)
        return dot
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # x = x + x.new_tensor(self.offsets).unsqueeze(0)
        x = x + self.offsets

        xs = [self.embeddings[i](x) for i in range(self.num_fields)]
        return self.ffm_interaction(xs)
        # ix = list()
        # for i in range(self.num_fields - 1):
        #     for j in range(i + 1, self.num_fields):
        #         ix.append(xs[j][:, i] * xs[i][:, j])
        # ix = torch.stack(ix, dim=1)
        # return ix

In [6]:
class FieldAwareFactorizationMachineModel(torch.nn.Module):
    """
    A pytorch implementation of Field-aware Factorization Machine.

    Reference:
        Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015.
    """

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True)
        ffm_term = self.ffm(x)
        x = self.linear(x) + ffm_term
        return torch.sigmoid(x.squeeze(1))

In [7]:
ffm = FieldAwareFactorizationMachineModel([100 for _ in range(100)],128)

In [8]:
ffm_model_traced = symbolic_trace(ffm)

In [9]:
interp = utils.ProfilingInterpreter(ffm_model_traced)
interp.run(torch.randint(low=0, high=88, size=(4096,100), dtype=torch.long))
print(interp.summary(True))

total true time 2175.6441593170166 ms
total time: 2269.789695739746 ms
Op type        Op                   Average runtime (ms)    Pct total runtime
-------------  -----------------  ----------------------  -------------------
call_module    ffm_embeddings_0               16.6414             0.733168
call_module    ffm_embeddings_48               8.66795            0.381883
call_module    ffm_embeddings_20               8.00157            0.352525
call_module    ffm_embeddings_2                7.96247            0.350802
call_module    ffm_embeddings_1                7.94125            0.349867
call_module    ffm_embeddings_29               7.92313            0.349069
call_module    ffm_embeddings_14               7.92193            0.349016
call_module    ffm_embeddings_17               7.91669            0.348785
call_module    ffm_embeddings_7                7.91621            0.348764
call_module    ffm_embeddings_18               7.91597            0.348754
call_module    ffm_embe

In [17]:
ffm = FieldAwareFactorizationMachineModel([100 for _ in range(22)],128)

In [18]:
ffm_model_traced = symbolic_trace(ffm)

In [19]:
interp = utils.ProfilingInterpreter(ffm_model_traced)
interp.run(torch.randint(low=0, high=88, size=(4096,22), dtype=torch.long))
print(interp.summary(True))

total true time 90.44122695922852 ms
total time: 95.56388854980469 ms
Op type        Op                   Average runtime (ms)    Pct total runtime
-------------  -----------------  ----------------------  -------------------
call_module    ffm_embeddings_10               1.27244              1.33151
call_module    ffm_embeddings_11               1.2064               1.2624
call_module    ffm_embeddings_2                1.158                1.21175
call_module    ffm_embeddings_1                1.14608              1.19928
call_module    ffm_embeddings_5                1.14608              1.19928
call_module    ffm_embeddings_4                1.14441              1.19753
call_module    ffm_embeddings_13               1.14083              1.19379
call_module    ffm_embeddings_9                1.1375               1.1903
call_module    ffm_embeddings_18               1.13416              1.18681
call_module    ffm_embeddings_7                1.11508              1.16685
call_module    f

In [10]:
class ReFieldAwareFactorizationMachine(torch.nn.Module):

    def __init__(self, field_dims, embed_dim,redundancy_len):
        super().__init__()
        self.num_fields = len(field_dims)
        self.redundancy_len = redundancy_len
        self.embeddings = torch.nn.ModuleList([
            torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)
        ])
        # self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))
        self.redundancy_offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims[:self.redundancy_len])[:-1]), dtype=np.int64))
        self.non_redundancy_offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims[self.redundancy_len:])[:-1]), dtype=np.int64))

    def ffm_interaction(self,redundancy_field_wise_emb_list,non_redundancy_field_wise_emb_list):
        dot = 0
        # num_fields = self.num_fields
        redundancy_len = self.redundancy_len
        non_redundancy_len = self.num_fields - self.redundancy_len
        for i in range(non_redundancy_len - 1):
            for j in range(i + 1, non_redundancy_len):
                v_ij = non_redundancy_field_wise_emb_list[j + self.redundancy_len][:, i, :]
                v_ji = non_redundancy_field_wise_emb_list[i + self.redundancy_len][:, j, :]
                dot += torch.sum(v_ij * v_ji, dim=1, keepdim=True)
        for i in range(redundancy_len - 1):
          for j in range(i + 1,redundancy_len):
                v_ij = redundancy_field_wise_emb_list[j][:,i, :]
                v_ji = redundancy_field_wise_emb_list[i][:,j, :]
                dot += torch.sum(v_ij * v_ji, dim=1)            

        for i in range(redundancy_len):
          for j in range(non_redundancy_len):
            v_ij = redundancy_field_wise_emb_list[j + self.redundancy_len][:,i, :]
            v_ji = non_redundancy_field_wise_emb_list[i][:, j, :]
            dot += torch.sum(v_ij * v_ji, dim=1, keepdim=True)
        return dot
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # x = x + x.new_tensor(self.offsets).unsqueeze(0)
        # x = x + self.offsets
        redundancy_x, non_redundancy_x = x[0,:self.redundancy_len], x[:,self.redundancy_len:]
        redundancy_xs = [self.embeddings[i](redundancy_x).unsqueeze(0) for i in range(self.num_fields)]
        non_redundancy_xs = [self.embeddings[i](non_redundancy_x) for i in range(self.num_fields)]
        
        # xs = [self.embeddings[i](x) for i in range(self.num_fields)]
        return self.ffm_interaction(redundancy_xs,non_redundancy_xs)
        # ix = list()
        # for i in range(self.num_fields - 1):
        #     for j in range(i + 1, self.num_fields):
        #         ix.append(xs[j][:, i] * xs[i][:, j])
        # ix = torch.stack(ix, dim=1)
        # return ix

In [11]:
class ReFieldAwareFactorizationMachineModel(torch.nn.Module):
    """
    A pytorch implementation of Field-aware Factorization Machine.

    Reference:
        Y Juan, et al. Field-aware Factorization Machines for CTR Prediction, 2015.
    """

    def __init__(self, field_dims, embed_dim,redundancy_len):
        super().__init__()
        self.linear = FeaturesLinear(field_dims)
        self.ffm = ReFieldAwareFactorizationMachine(field_dims, embed_dim,redundancy_len)

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True)
        ffm_term = self.ffm(x)
        x = self.linear(x) + ffm_term
        return torch.sigmoid(x.squeeze(1))

In [12]:
re_ffm = ReFieldAwareFactorizationMachineModel([100 for _ in range(100)],128,50)

In [13]:
re_ffm(torch.randint(low=0, high=88, size=(4096,100), dtype=torch.long))

tensor([0., 1., 0.,  ..., 1., 1., 1.], grad_fn=<SigmoidBackward0>)

In [14]:
re_ffm_model_traced = symbolic_trace(re_ffm)

In [16]:
interp = utils.ProfilingInterpreter(re_ffm_model_traced)
interp.run(torch.randint(low=0, high=88, size=(4096,100), dtype=torch.long))
print(interp.summary(True))

total true time 1773.083209991455 ms
total time: 1854.3193340301514 ms
Op type        Op                    Average runtime (ms)    Pct total runtime
-------------  ------------------  ----------------------  -------------------
call_function  getitem_8443                   348.3               18.7832
call_module    ffm_embeddings_138               8.00896            0.431908
call_module    ffm_embeddings_139               4.52757            0.244163
call_module    ffm_embeddings_143               4.44198            0.239548
call_module    ffm_embeddings_140               4.44078            0.239483
call_module    ffm_embeddings_162               4.44007            0.239445
call_module    ffm_embeddings_133               4.43959            0.239419
call_module    ffm_embeddings_152               4.43506            0.239175
call_module    ffm_embeddings_171               4.43292            0.239059
call_module    ffm_embeddings_169               4.4291             0.238853
call_module  

In [20]:
re_ffm = ReFieldAwareFactorizationMachineModel([100 for _ in range(22)],128,10)

In [21]:
re_ffm_model_traced = symbolic_trace(re_ffm)

In [22]:
interp = utils.ProfilingInterpreter(re_ffm_model_traced)
interp.run(torch.randint(low=0, high=88, size=(4096,22), dtype=torch.long))
print(interp.summary(True))

total true time 68.71962547302246 ms
total time: 72.9973316192627 ms
Op type        Op                   Average runtime (ms)    Pct total runtime
-------------  -----------------  ----------------------  -------------------
call_module    ffm_embeddings_24               2.12622              2.91273
call_module    ffm_embeddings_22               0.833273             1.14151
call_module    ffm_embeddings_23               0.795841             1.09023
call_module    ffm_embeddings_37               0.75078              1.0285
call_module    ffm_embeddings_25               0.69499              0.952076
call_module    ffm_embeddings_27               0.693321             0.94979
call_module    ffm_embeddings_40               0.679493             0.930846
call_module    ffm_embeddings_36               0.679016             0.930193
call_module    ffm_embeddings_41               0.677824             0.92856
call_module    ffm_embeddings_32               0.67544              0.925294
call_module 