In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.fx import subgraph_rewriter, symbolic_trace
from torch.fx import Proxy, Graph, GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

In [3]:
import time

In [4]:
class ReFeaturesLinear(torch.nn.Module):

    def __init__(self, field_dims, prefix,batch,output_dim=1):
        super().__init__()
        self.fc = torch.nn.Embedding(sum(field_dims), output_dim)
        self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
        self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))
        self.prefix = prefix
        self.batch = batch
        self.prefix_offsets = self.offsets[:self.prefix]
        self.rest_offsets = self.offsets[self.prefix:]
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: tuple of tensor ``(prefix_index, rest_index)``
        prefix_index ``(prefix_field)``
        rest_index  ``(batch_szie,rest_field)``
        """
        prefix_index, rest_index = x[self.prefix_slice], x[self.rest_slice]
        prefix_index = prefix_index + self.prefix_offsets
        rest_index = rest_index + self.rest_offsets
        prefix_index = prefix_index.repeat(self.batch,1)
        index = torch.concat([prefix_index,rest_index],dim = 1)
        prefix_index, rest_index = index[self.prefix_slice], index[self.rest_slice]
        
        prefix_fc = self.fc(prefix_index) + self.bias
        prefix_fc = prefix_fc.repeat(self.batch,1,1)
        rest_fc = self.fc(rest_index)
        fc_result = torch.concat([prefix_fc,rest_fc],dim = 1)
        prefix_fc = fc_result[self.prefix_slice]
        rest_fc = fc_result[self.rest_slice]
        prefix_sum = torch.sum(prefix_fc,dim = 0)
        rest_sum = torch.sum(rest_fc,dim=1)
        return prefix_sum + rest_sum

In [5]:
import sys
sys.path.append('/home/yssun/pytorch-fm/torchfm/')
sys.path.append('/home/yssun/pytorch-fm/torchfm/model/')
sys.path.append('/home/yssun/pytorch-fm/torchfm/model/test_fx')
import utils
from layer import FeaturesLinear

replace success!


In [6]:
class ReFeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim,prefix,batch):
        super().__init__()
        self.prefix = prefix
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))
        self.prefix_offsets = self.offsets[:self.prefix]
        self.rest_offsets = self.offsets[self.prefix:]
        self.batch = batch
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: tuple of tensor ``(prefix_index, rest_index)``
        prefix_index ``(prefix_field)``
        rest_index  ``(batch_szie,rest_field)``
        """
        prefix_index, rest_index = x[self.prefix_slice],x[self.rest_slice]
        prefix_index = prefix_index + self.prefix_offsets
        rest_index = rest_index + self.rest_offsets
        prefix_index = prefix_index.repeat(self.batch,1)
        index = torch.concat([prefix_index,rest_index],dim = 1)
        prefix_index, rest_index = index[self.prefix_slice], index[self.rest_slice]
        return torch.concat([self.embedding(prefix_index).repeat(self.batch,1,1),self.embedding(rest_index)],dim = 1)

In [7]:
class ReFactorizationMachine(torch.nn.Module):

    def __init__(self,prefix,batch, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum
        self.batch = batch
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        prefix_embed, rest_embed = x[self.prefix_slice],x[self.rest_slice]
        # square_of_sum = torch.sum(x, dim=1) ** 2
        square_of_sum = (torch.sum(rest_embed,dim = 1) + torch.sum(prefix_embed,dim = 0)) ** 2
        prefix_embed, rest_embed = prefix_embed ** 2, rest_embed ** 2
        pow_embed = torch.concat([prefix_embed.repeat(self.batch,1,1),rest_embed],dim = 1)
        pow_prefix_embed, pow_rest_embed = pow_embed[self.prefix_slice],pow_embed[self.rest_slice]
        sum_of_square = torch.sum(pow_rest_embed, dim=1) + torch.sum(pow_prefix_embed, dim = 0)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [9]:
class ReFactorizationMachineModel(torch.nn.Module):
  def __init__(self, field_dims, embed_dim,prefix,batch):
    super().__init__()
    self.embedding = ReFeaturesEmbedding(field_dims, embed_dim,prefix,batch)
    self.linear = ReFeaturesLinear(field_dims,prefix,batch)
    self.fm = ReFactorizationMachine(prefix,batch,reduce_sum=True)

  def forward(self, x):
      """
      :param x: Long tensor of size ``(batch_size, num_fields)``
      """
      x = self.linear(x) + self.fm(self.embedding(x))
      return torch.sigmoid(x.squeeze(1))


In [10]:
fm_model = ReFactorizationMachineModel([100 for _ in range(100)],32,30,4096)

In [11]:
import fm

In [12]:
fm_ori_model = fm.FactorizationMachineModel([100 for _ in range(100)],64)

In [13]:
fm_ori_model_traced = symbolic_trace(fm_ori_model)

In [14]:
fm_model = ReFactorizationMachineModel([100 for _ in range(100)],64,50,4096)

In [15]:
fm_modify_model_traced = symbolic_trace(fm_model)

In [18]:
interp = utils.ProfilingInterpreter(fm_ori_model_traced)
input_data = torch.zeros((4096,100),dtype=torch.long)
interp.run(input_data)
print(interp.summary(True))

total true time 23.090600967407227 ms
total time: 33.370018005371094 ms
Op type        Op                     Average runtime (ms)    Pct total runtime
-------------  -------------------  ----------------------  -------------------
call_module    embedding_embedding               8.65006             25.9217
call_function  pow_2                             6.76298             20.2666
call_function  sum_2                             2.67673              8.02135
call_function  sum_3                             2.20108              6.59598
call_module    linear_fc                         0.791311             2.37132
call_function  add_2                             0.470877             1.41108
call_function  sub                               0.423908             1.27033
call_function  pow_1                             0.338316             1.01383
call_function  add                               0.255108             0.764482
call_function  sum_1                             0.104427          

In [67]:
fm_modify_model_traced.print_readable()

class ReFactorizationMachineModel(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        getitem = x[[0, slice(None, 50, None)]]
        getitem_1 = x[[slice(None, None, None), slice(50, None, None)]]
        linear_prefix_offsets = self.linear.prefix_offsets
        add = getitem + linear_prefix_offsets;  getitem = linear_prefix_offsets = None
        linear_rest_offsets = self.linear.rest_offsets
        add_1 = getitem_1 + linear_rest_offsets;  getitem_1 = linear_rest_offsets = None
        repeat = add.repeat(4096, 1);  add = None
        concat = torch.concat([repeat, add_1], dim = 1);  repeat = add_1 = None
        getitem_2 = concat[[0, slice(None, 50, None)]]
        getitem_3 = concat[[slice(None, None, None), slice(50, None, None)]];  concat = None
        linear_fc = self.linear.fc(getitem_2);  getitem_2 = None
        linear_bias = self.linear.bias
        add_2 = linear_fc + linear_bias;  linear_fc = linear_bias = None
       

'class ReFactorizationMachineModel(torch.nn.Module):\n    def forward(self, x):\n        # No stacktrace found for following nodes\n        getitem = x[[0, slice(None, 50, None)]]\n        getitem_1 = x[[slice(None, None, None), slice(50, None, None)]]\n        linear_prefix_offsets = self.linear.prefix_offsets\n        add = getitem + linear_prefix_offsets;  getitem = linear_prefix_offsets = None\n        linear_rest_offsets = self.linear.rest_offsets\n        add_1 = getitem_1 + linear_rest_offsets;  getitem_1 = linear_rest_offsets = None\n        repeat = add.repeat(4096, 1);  add = None\n        concat = torch.concat([repeat, add_1], dim = 1);  repeat = add_1 = None\n        getitem_2 = concat[[0, slice(None, 50, None)]]\n        getitem_3 = concat[[slice(None, None, None), slice(50, None, None)]];  concat = None\n        linear_fc = self.linear.fc(getitem_2);  getitem_2 = None\n        linear_bias = self.linear.bias\n        add_2 = linear_fc + linear_bias;  linear_fc = linear_bia

In [19]:
interp = utils.ProfilingInterpreter(fm_modify_model_traced)
input_data = torch.zeros((4096,100),dtype=torch.long)
interp.run(input_data)
print(interp.summary(True))

total true time 35.35819053649902 ms
total time: 63.648223876953125 ms
Op type        Op                          Average runtime (ms)    Pct total runtime
-------------  ------------------------  ----------------------  -------------------
call_function  concat_3                               6.55317             10.2959
call_function  concat_4                               5.97405              9.38605
call_method    repeat_3                               5.29456              8.31847
call_module    embedding_embedding_1                  4.0741               6.40096
call_function  pow_3                                  2.9192               4.58645
call_method    repeat_4                               2.7144               4.26468
call_function  sum_3                                  1.13702              1.78641
call_function  sum_5                                  1.13702              1.78641
call_function  concat                                 0.916958             1.44067
call_function

In [20]:
class ReFactorizationMachine_pow_sum(torch.nn.Module):
    # 对 sum(x ** 2)进行改写后的模块
    def __init__(self,prefix,batch, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum
        self.batch = batch
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        prefix_embed, rest_embed = x[self.prefix_slice],x[self.rest_slice]
        # square_of_sum = torch.sum(x, dim=1) ** 2
        square_of_sum = (torch.sum(rest_embed,dim = 1) + torch.sum(prefix_embed,dim = 0)) ** 2
        prefix_embed, rest_embed = prefix_embed ** 2, rest_embed ** 2
        # pow_embed = torch.concat([prefix_embed.repeat(self.batch,1,1),rest_embed],dim = 1)
        # pow_prefix_embed, pow_rest_embed = pow_embed[self.prefix_slice],pow_embed[self.rest_slice]
        sum_of_square = torch.sum(rest_embed, dim=1) + torch.sum(prefix_embed, dim = 0)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [37]:
class ReFactorizationMachineModel_pow_sum(torch.nn.Module):
  # 对 sum(x ** 2)进行改写后的FM
  def __init__(self, field_dims, embed_dim,prefix,batch):
    super().__init__()
    self.embedding = ReFeaturesEmbedding(field_dims, embed_dim,prefix,batch)
    self.linear = ReFeaturesLinear(field_dims,prefix,batch)
    self.fm = ReFactorizationMachine_pow_sum(prefix,batch,reduce_sum=True)

  def forward(self, x):
      """
      :param x: Long tensor of size ``(batch_size, num_fields)``
      """
      x = self.linear(x) + self.fm(self.embedding(x))
      return torch.sigmoid(x.squeeze(1))

In [38]:
fm_model_pow_sum = ReFactorizationMachineModel_pow_sum([100 for _ in range(100)],64,50,4096)

In [39]:
fm_modify_model_pow_sum_traced = symbolic_trace(fm_model_pow_sum)

In [45]:
interp = utils.ProfilingInterpreter(fm_modify_model_pow_sum_traced)
input_data = torch.zeros((4096,100),dtype=torch.long)
interp.run(input_data)
print(interp.summary(True))

total true time 23.701190948486328 ms
total time: 37.69659996032715 ms
Op type        Op                          Average runtime (ms)    Pct total runtime
-------------  ------------------------  ----------------------  -------------------
call_function  concat_3                               6.93965             18.4092
call_method    repeat_3                               4.57954             12.1484
call_module    embedding_embedding_1                  4.15516             11.0226
call_function  pow_3                                  2.63143              6.98054
call_function  sum_3                                  0.971794             2.57794
call_function  sum_5                                  0.807047             2.1409
call_module    linear_fc_1                            0.622034             1.65011
call_function  concat                                 0.262976             0.697611
call_function  concat_2                               0.213146             0.565426
call_function 

In [26]:
from layer import FeaturesEmbedding

In [None]:
class ReFactorizationMachineModel_pow_sum_embed(torch.nn.Module):
  # 对 sum(x ** 2)和Embedding没有改写进行改写后的FM
  def __init__(self, field_dims, embed_dim,prefix,batch):
    super().__init__()
    self.embedding = FeaturesEmbedding(field_dims, embed_dim)
    self.linear = ReFeaturesLinear(field_dims,prefix,batch)
    self.fm = ReFactorizationMachine_pow_sum(prefix,batch,reduce_sum=True)

  def forward(self, x):
      """
      :param x: Long tensor of size ``(batch_size, num_fields)``
      """
      x = self.linear(x) + self.fm(self.embedding(x))
      return torch.sigmoid(x.squeeze(1))

In [47]:
fm_model_pow_sum_embed = ReFactorizationMachineModel_pow_sum_embed([100 for _ in range(100)],64,50,4096)

In [48]:
fm_modify_model_pow_sum_embed_traced = symbolic_trace(fm_model_pow_sum_embed)

In [49]:
interp = utils.ProfilingInterpreter(fm_modify_model_pow_sum_embed_traced)
input_data = torch.zeros((4096,100),dtype=torch.long)
interp.run(input_data)
print(interp.summary(True))

total true time 19.382238388061523 ms
total time: 25.19083023071289 ms
Op type        Op                       Average runtime (ms)    Pct total runtime
-------------  ---------------------  ----------------------  -------------------
call_module    embedding_embedding                 9.34649             37.1027
call_function  pow_3                               3.6881              14.6406
call_function  sum_3                               1.40762              5.58784
call_function  sum_5                               1.07837              4.28079
call_module    linear_fc_1                         0.644207             2.55731
call_function  add_4                               0.507593             2.01499
call_function  concat                              0.336409             1.33544
call_function  add_6                               0.252247             1.00134
call_function  add_1                               0.203609             0.808268
call_method    repeat                         

In [57]:
class ReFeaturesEmbedding_embed(torch.nn.Module):

    def __init__(self, field_dims, embed_dim,prefix,batch):
        super().__init__()
        self.prefix = prefix
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = torch.as_tensor(np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.int64))
        self.prefix_offsets = self.offsets[:self.prefix]
        self.rest_offsets = self.offsets[self.prefix:]
        self.batch = batch
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: tuple of tensor ``(prefix_index, rest_index)``
        prefix_index ``(prefix_field)``
        rest_index  ``(batch_szie,rest_field)``
        """
        prefix_index, rest_index = x[self.prefix_slice],x[self.rest_slice]
        prefix_index = prefix_index + self.prefix_offsets
        rest_index = rest_index + self.rest_offsets
        prefix_index = prefix_index.repeat(self.batch,1)
        index = torch.concat([prefix_index,rest_index],dim = 1)
        prefix_index, rest_index = index[self.prefix_slice], index[self.rest_slice]
        return self.embedding(prefix_index), self.embedding(rest_index)

In [58]:
class ReFactorizationMachine_pow_sum_pow(torch.nn.Module):
    # 对 sum(x ** 2)进行改写后的模块
    def __init__(self,prefix,batch, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum
        self.batch = batch
        self.prefix_slice = [0,slice(None,prefix)]
        self.rest_slice = [slice(None,None),slice(prefix,None)]
    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        prefix_embed, rest_embed = x
        # square_of_sum = torch.sum(x, dim=1) ** 2
        square_of_sum = (torch.sum(rest_embed,dim = 1) + torch.sum(prefix_embed,dim = 0)) ** 2
        prefix_embed, rest_embed = prefix_embed ** 2, rest_embed ** 2
        # pow_embed = torch.concat([prefix_embed.repeat(self.batch,1,1),rest_embed],dim = 1)
        # pow_prefix_embed, pow_rest_embed = pow_embed[self.prefix_slice],pow_embed[self.rest_slice]
        sum_of_square = torch.sum(rest_embed, dim=1) + torch.sum(prefix_embed, dim = 0)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

In [59]:
class ReFactorizationMachineModel_pow_sum_pow_embed(torch.nn.Module):
  # 对 sum(x ** 2)和Embedding改写进行改写后的FM
  def __init__(self, field_dims, embed_dim,prefix,batch):
    super().__init__()
    self.embedding = ReFeaturesEmbedding_embed(field_dims, embed_dim,prefix,batch)
    self.linear = ReFeaturesLinear(field_dims,prefix,batch)
    self.fm = ReFactorizationMachine_pow_sum_pow(prefix,batch,reduce_sum=True)

  def forward(self, x):
      """
      :param x: Long tensor of size ``(batch_size, num_fields)``
      """
      x = self.linear(x) + self.fm(self.embedding(x))
      return torch.sigmoid(x.squeeze(1))

In [60]:
fm_model_pow_sum_pow_embed = ReFactorizationMachineModel_pow_sum_pow_embed([100 for _ in range(100)],64,50,4096)

In [66]:
fm_model_pow_sum_pow_embed_trace = symbolic_trace(fm_model_pow_sum_pow_embed)
fm_model_pow_sum_pow_embed_trace.print_readable()

class ReFactorizationMachineModel_pow_sum_pow_embed(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        getitem = x[[0, slice(None, 50, None)]]
        getitem_1 = x[[slice(None, None, None), slice(50, None, None)]]
        linear_prefix_offsets = self.linear.prefix_offsets
        add = getitem + linear_prefix_offsets;  getitem = linear_prefix_offsets = None
        linear_rest_offsets = self.linear.rest_offsets
        add_1 = getitem_1 + linear_rest_offsets;  getitem_1 = linear_rest_offsets = None
        repeat = add.repeat(4096, 1);  add = None
        concat = torch.concat([repeat, add_1], dim = 1);  repeat = add_1 = None
        getitem_2 = concat[[0, slice(None, 50, None)]]
        getitem_3 = concat[[slice(None, None, None), slice(50, None, None)]];  concat = None
        linear_fc = self.linear.fc(getitem_2);  getitem_2 = None
        linear_bias = self.linear.bias
        add_2 = linear_fc + linear_bias;  linear_fc = linear_b

'class ReFactorizationMachineModel_pow_sum_pow_embed(torch.nn.Module):\n    def forward(self, x):\n        # No stacktrace found for following nodes\n        getitem = x[[0, slice(None, 50, None)]]\n        getitem_1 = x[[slice(None, None, None), slice(50, None, None)]]\n        linear_prefix_offsets = self.linear.prefix_offsets\n        add = getitem + linear_prefix_offsets;  getitem = linear_prefix_offsets = None\n        linear_rest_offsets = self.linear.rest_offsets\n        add_1 = getitem_1 + linear_rest_offsets;  getitem_1 = linear_rest_offsets = None\n        repeat = add.repeat(4096, 1);  add = None\n        concat = torch.concat([repeat, add_1], dim = 1);  repeat = add_1 = None\n        getitem_2 = concat[[0, slice(None, 50, None)]]\n        getitem_3 = concat[[slice(None, None, None), slice(50, None, None)]];  concat = None\n        linear_fc = self.linear.fc(getitem_2);  getitem_2 = None\n        linear_bias = self.linear.bias\n        add_2 = linear_fc + linear_bias;  line

In [65]:
interp = utils.ProfilingInterpreter(fm_model_pow_sum_pow_embed)
input_data = torch.zeros((4096,100),dtype=torch.long)
interp.run(input_data)
print(interp.summary(True))

total true time 15.653610229492188 ms
total time: 20.832061767578125 ms
Op type        Op                          Average runtime (ms)    Pct total runtime
-------------  ------------------------  ----------------------  -------------------
call_module    embedding_embedding_1                  4.71473             22.6321
call_function  pow_3                                  3.35765             16.1177
call_function  sum_3                                  1.31798              6.32668
call_function  sum_5                                  1.22976              5.90322
call_function  concat                                 0.803232             3.85575
call_function  concat_2                               0.786066             3.77335
call_module    linear_fc_1                            0.569582             2.73416
call_function  add_5                                  0.457525             2.19626
call_method    repeat_2                               0.393867             1.89068
call_function