In [9]:
import sys
sys.path.append('../../')
sys.path.append('../')
import time
import dfm

In [10]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.fx import subgraph_rewriter, symbolic_trace
import utils
from torch.fx import Proxy, Graph, GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

In [11]:
from torch.profiler import profile, record_function, ProfilerActivity
import time
import torch._dynamo as dynamo


In [12]:
def gen_pattern_replace_and_matcher_for_SingleMLP(traced,
                                                  redency_part_slice,unredency_part_slice,
                                                  key_node_name,match_func = None
                                                ):
  from torch.fx.passes.utils.matcher_utils import SubgraphMatcher


  def _match(match,ori,pat):
    return True 
  env  = utils.get_env(traced)
  target_node = env[key_node_name]
  target_node_mod = utils.get_target_mod(traced,target_node.target)
  shape_info = target_node_mod.weight.data.shape
  class PatternClass(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.embed = torch.nn.Embedding(1, 1)
          self.embed_output_dim = shape_info[1]
          self.mlp = nn.Linear(shape_info[0],shape_info[1])


      def forward(self,x):
          x = self.embed(x)
          square_of_sum = torch.sum(x, dim=1) ** 2
          sum_of_square = torch.sum(x ** 2, dim=1)
          ix = square_of_sum - sum_of_square    
          ix = torch.sum(ix, dim=1, keepdim=True)  
          return self.mlp(x.view(-1,self.embed_output_dim)), 0.5 * ix
  pattern = PatternClass()  
  pattern_trace = symbolic_trace(pattern)
  pattern_graph = pattern_trace.graph
  original_graph = traced.graph
  matcher =  SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
                              remove_overlapping_matches=True)
  _matches = matcher.match(original_graph)
  match_filters = [_match if match_func is None else match_func]
  _matches = [
      m for m in _matches
      if all(match_filter(m, original_graph, pattern_graph)
              for match_filter in match_filters)
  ]  
  # 因为在过滤器中做了限制应该只有一个符合要求的
  _matched = _matches[0]
  pattern_env = utils.get_env(pattern_trace)
  node_map = _matched.nodes_map
  
  embed_node = node_map[pattern_env['embed']]
  embed_node_module = utils.get_target_mod(traced,embed_node.target)
  
  linear_node = node_map[pattern_env['mlp']]
  linear_node_module = utils.get_target_mod(traced,linear_node.target)
  linear_node_weight = linear_node_module.weight.data
  linear_node_bias = linear_node_module.bias.data
  
  class ReplacementClass(torch.nn.Module):
    def __init__(self):
      super().__init__()
      self.embed = embed_node_module
      self.embed_dim = self.embed.weight.data.shape[1]
      self.redency_weight_len = self.embed_dim * redency_part_slice[1].stop
      redency_weight = linear_node_weight[:,:self.redency_weight_len]
      unredency_weight = linear_node_weight[:,self.redency_weight_len:]
      self.unredency_weight_len = unredency_weight.shape[1]
      self.redency_linear = nn.Linear(redency_weight.shape[1],redency_weight.shape[0])
      self.redency_linear.weight.data.copy_(redency_weight)
      self.redency_linear.bias.data.copy_(linear_node_bias)

      self.unredency_linear = nn.Linear(unredency_weight.shape[1],unredency_weight.shape[0],bias=False)
      self.unredency_linear.weight.data.copy_(unredency_weight)

      

    def forward(self,x):
      emb = self.embed(x)
      square_of_sum = torch.sum(emb, dim=1) ** 2
      sum_of_square = torch.sum(emb ** 2, dim=1)
      ix = square_of_sum - sum_of_square    
      ix = torch.sum(ix, dim=1, keepdim=True)  
      redency_part = emb[redency_part_slice] 
      unredency_part = emb[unredency_part_slice] 
      return self.redency_linear(redency_part.view(-1,self.redency_weight_len)) + self.unredency_linear(unredency_part.view(-1,self.unredency_weight_len)), 0.5 * ix
      # return unredency_sum
    
  
  return pattern,ReplacementClass(),_match

In [13]:
def workload_dfm(num_field, prefix,dim = 64,l = [1024,512,256]):
  print(f"now gen workload of DFM with config: dim: {dim}, num_field: {num_field}, prefix: {prefix}")
  DFM_model_ori = dfm.DeepFactorizationMachineModel([100 for i in range(num_field)],dim,l,0.1)
  ori_traced = symbolic_trace(DFM_model_ori)
  
  DFM_model_modify = dfm.DeepFactorizationMachineModel([100 for i in range(num_field)],dim,l,0.1)
  modify_traced = symbolic_trace(DFM_model_modify)
  pattern,replace,match = gen_pattern_replace_and_matcher_for_SingleMLP(modify_traced,
                                                                      (0,slice(None,prefix,None)),(slice(None,None,None),slice(prefix,None,None)),
                                                                      "mlp_mlp_0")
  matches = subgraph_rewriter.replace_pattern_with_filters(modify_traced, pattern, replace,[match])
  return ori_traced,modify_traced

In [14]:
dims= [32]
batches = [1024,2048,4096]
num_field_and_prefixs = [(34 * 5,29*5),(22 * 5,10 * 5)]

In [15]:
def genWorkload(num_field = 34 * 5,prefix = 29 * 5, batch = 4096, dim = 64):
  ori_model_name = f'/home/yssun/pytorch-fm/torchfm/model/test_fx/exp/model_repo/dfm_linear/DFM_{batch}_{num_field}_{prefix}_{dim}_ori.onnx'
  modify_model_name = f'/home/yssun/pytorch-fm/torchfm/model/test_fx/exp/model_repo/dfm_linear/DFM_{batch}_{num_field}_{prefix}_{dim}_modify.onnx'
  ori, modify = workload_dfm(num_field = num_field,prefix = prefix,  dim = dim)
  # torch.onnx.export(ori,               # 模型 being run
  #                 torch.randint(low=0, high=20, size=(batch,num_field), dtype=torch.long),                  # 模型输入 (or a tuple for multiple inputs)
  #                 ori_model_name,        # 导出文件的文件名
  #                 export_params=True, # 如果设置为True，则参数也会被导出。注意某些情况下参数可能无法被导出。
  #                 opset_version=10,   # ONNX版本
  #                 do_constant_folding=True,  # 是否执行常量折叠以优化模型
  #                 input_names = ['input'],   # 输入的名称
  #                 output_names = ['output'], # 输出的名称
  #                 )
  torch.onnx.export(modify,               # 模型 being run
                  torch.randint(low=0, high=20, size=(batch,num_field), dtype=torch.long),                  # 模型输入 (or a tuple for multiple inputs)
                  modify_model_name,        # 导出文件的文件名
                  export_params=True, # 如果设置为True，则参数也会被导出。注意某些情况下参数可能无法被导出。
                  opset_version=10,   # ONNX版本
                  do_constant_folding=True,  # 是否执行常量折叠以优化模型
                  input_names = ['input'],   # 输入的名称
                  output_names = ['output'], # 输出的名称
                  )

In [16]:
for dim in dims:
  for batch in batches:
    for num_field,prefix in num_field_and_prefixs:
      genWorkload(num_field=num_field,prefix=prefix,batch=batch,dim=dim)

now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145




now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50




now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145




now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50




now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145




now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50


