In [16]:
import sys
sys.path.append('../../')
sys.path.append('../')
import time
import dfm

In [17]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.fx import subgraph_rewriter, symbolic_trace
import utils
from torch.fx import Proxy, Graph, GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

In [18]:
from torch.profiler import profile, record_function, ProfilerActivity
import time
import torch._dynamo as dynamo


In [19]:
def gen_pattern_replace_and_matcher_for_DFM(traced,
                                                  redency_part_slice,unredency_part_slice,
                                                  key_node_name,match_func = None
                                                ):
  from torch.fx.passes.utils.matcher_utils import SubgraphMatcher


  def _match(match,ori,pat):
    return True 
  env  = utils.get_env(traced)
  target_node = env[key_node_name]
  target_node_mod = utils.get_target_mod(traced,target_node.target)
  shape_info = target_node_mod.weight.data.shape
  class PatternClass(torch.nn.Module):
      def __init__(self):
          super().__init__()
          self.embed = torch.nn.Embedding(1, 1)
          self.embed_output_dim = shape_info[1]
          self.mlp = nn.Linear(shape_info[0],shape_info[1])


      def forward(self,x):
        x = self.embed(x)
        # x = x.view(-1,self.embed_output_dim)
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square    
        ix = torch.sum(ix, dim=1, keepdim=True)  
        return self.mlp(x.view(-1,self.embed_output_dim)), 0.5 * ix
  pattern = PatternClass()  
  pattern_trace = symbolic_trace(pattern)
  pattern_graph = pattern_trace.graph
  original_graph = traced.graph
  matcher =  SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
                              remove_overlapping_matches=True)
  _matches = matcher.match(original_graph)
  match_filters = [_match if match_func is None else match_func]
  _matches = [
      m for m in _matches
      if all(match_filter(m, original_graph, pattern_graph)
              for match_filter in match_filters)
  ]  
  # 因为在过滤器中做了限制应该只有一个符合要求的
  _matched = _matches[0]
  pattern_env = utils.get_env(pattern_trace)
  node_map = _matched.nodes_map
  
  embed_node = node_map[pattern_env['embed']]
  embed_node_module = utils.get_target_mod(traced,embed_node.target)
  
  linear_node = node_map[pattern_env['mlp']]
  linear_node_module = utils.get_target_mod(traced,linear_node.target)
  linear_node_weight = linear_node_module.weight.data
  linear_node_bias = linear_node_module.bias.data
  
  class ReplacementClass(torch.nn.Module):
    def __init__(self):
      super().__init__()
      self.embed = embed_node_module
      self.embed_dim = self.embed.weight.data.shape[1]
      self.redency_weight_len = self.embed_dim * redency_part_slice[1].stop
      redency_weight = linear_node_weight[:,:self.redency_weight_len]
      unredency_weight = linear_node_weight[:,self.redency_weight_len:]
      self.unredency_weight_len = unredency_weight.shape[1]
      self.redency_linear = nn.Linear(redency_weight.shape[1],redency_weight.shape[0])
      self.redency_linear.weight.data.copy_(redency_weight)
      self.redency_linear.bias.data.copy_(linear_node_bias)

      self.unredency_linear = nn.Linear(unredency_weight.shape[1],unredency_weight.shape[0],bias=False)
      self.unredency_linear.weight.data.copy_(unredency_weight)

      

    def forward(self,x):
      redency_part = x[redency_part_slice] 
      unredency_part = x[unredency_part_slice] 
      redency_embed = self.embed(redency_part)
      unredency_embed = self.embed(unredency_part)
      redency_embed_sum = torch.sum(redency_embed,dim=0)
      unredency_embed_sum = torch.sum(unredency_embed,dim=1)
      square_of_sum = (redency_embed_sum + unredency_embed_sum) ** 2
      redency_embed_square_sum = torch.sum(redency_embed ** 2,dim=0)
      unredency_embed_square_sum = torch.sum(unredency_embed ** 2,dim=1)
      sum_of_square = redency_embed_square_sum + unredency_embed_square_sum
      ix = square_of_sum - sum_of_square
      ix = torch.sum(ix,dim = 1,keepdim=True)
      return self.redency_linear(redency_embed.view(-1,self.redency_weight_len)) + self.unredency_linear(unredency_embed.view(-1,self.unredency_weight_len)), 0.5 * ix
      # return unredency_sum
    
  
  return pattern,ReplacementClass(),_match

In [20]:
def workload_dfm(num_field, prefix,dim = 64,l = [1024,512,256]):
  print(f"now gen workload of DFM with config: dim: {dim}, num_field: {num_field}, prefix: {prefix}")
  DFM_model_ori = dfm.DeepFactorizationMachineModel([100 for i in range(num_field)],dim,l,0.1)
  ori_traced = symbolic_trace(DFM_model_ori)
  
  DFM_model_modify = dfm.DeepFactorizationMachineModel([100 for i in range(num_field)],dim,l,0.1)
  modify_traced = symbolic_trace(DFM_model_modify)
  pattern,replace,match = gen_pattern_replace_and_matcher_for_DFM(modify_traced,
                                                                      (0,slice(None,prefix,None)),(slice(None,None,None),slice(prefix,None,None)),
                                                                      "mlp_mlp_0")
  matches = subgraph_rewriter.replace_pattern_with_filters(modify_traced, pattern, replace,[match])
  return ori_traced,modify_traced

In [21]:
def calculate_mean_and_variance_manual(data):
    n = len(data)
    mean = sum(data) / n
    variance = sum((x - mean) ** 2 for x in data) / n
    return mean, variance

In [22]:
def gen_and_test(num_field = 22,prefix = 10, batch = 4096, dim = 32, workload_func = workload_dfm,l = [1024,512,256]):
  def run(model):
    t = torch.randint(low=0, high=88, size=(batch,num_field), dtype=torch.long).cuda()
    traced_model = torch.jit.trace(model.cuda(), t)
    compiled_model = torch.compile(traced_model, backend="inductor")
    compiled_model.eval()
    total_time = []
    for i in range(100):
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)  
        start_event.record()

        with torch.no_grad():
          soutput = compiled_model(t)
        end_event.record()
        torch.cuda.synchronize()
        # 计算并打印函数执行所需的时间
        elapsed_time = start_event.elapsed_time(end_event)
        total_time.append(elapsed_time)
    print(calculate_mean_and_variance_manual(total_time[2:]))
  print(f"now gen workload of DFM with config: dim: {dim}, num_field: {num_field}, prefix: {prefix}, batch :{batch}")
  ori, modify = workload_func(num_field,prefix,dim,l)
  run(ori)
  run(modify)
  torch.cuda.empty_cache()

In [23]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 1024, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50, batch :1024
now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 375 / 1024 (36.6%)
Greatest absolute difference: 0.0031673312187194824 at index (114,) (up to 1e-05 allowed)
Greatest relative difference: 0.015811055797185038 at index (247,) (up to 1e-05 allowed)
  _check_trace(


(0.7701939575526179, 0.00013752196257895465)


Tensor-likes are not close!

Mismatched elements: 376 / 1024 (36.7%)
Greatest absolute difference: 0.0031235218048095703 at index (298,) (up to 1e-05 allowed)
Greatest relative difference: 0.013729514335591013 at index (252,) (up to 1e-05 allowed)
  _check_trace(


(0.6043637528711435, 9.298631220342191e-05)


In [24]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 2048, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50, batch :2048
now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 752 / 2048 (36.7%)
Greatest absolute difference: 0.0034317970275878906 at index (1256,) (up to 1e-05 allowed)
Greatest relative difference: 0.013067086886115416 at index (24,) (up to 1e-05 allowed)
  _check_trace(


(1.453819755388766, 0.0006762274911022042)


Tensor-likes are not close!

Mismatched elements: 798 / 2048 (39.0%)
Greatest absolute difference: 0.003990471363067627 at index (673,) (up to 1e-05 allowed)
Greatest relative difference: 0.015548503604045472 at index (55,) (up to 1e-05 allowed)
  _check_trace(


(1.0519719184661398, 0.00034827077409320677)


In [25]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 4096, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50, batch :4096
now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 1552 / 4096 (37.9%)
Greatest absolute difference: 0.0033912062644958496 at index (3823,) (up to 1e-05 allowed)
Greatest relative difference: 0.01790787183319697 at index (3136,) (up to 1e-05 allowed)
  _check_trace(


(2.6780365632504832, 0.0038125604902510177)


Tensor-likes are not close!

Mismatched elements: 1578 / 4096 (38.5%)
Greatest absolute difference: 0.004133433103561401 at index (2742,) (up to 1e-05 allowed)
Greatest relative difference: 0.01691574097619152 at index (1778,) (up to 1e-05 allowed)
  _check_trace(


(1.843475587513982, 0.001162220164734185)


In [26]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 8192, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50, batch :8192
now gen workload of DFM with config: dim: 32, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 3196 / 8192 (39.0%)
Greatest absolute difference: 0.004728376865386963 at index (1303,) (up to 1e-05 allowed)
Greatest relative difference: 0.0190084456736153 at index (6667,) (up to 1e-05 allowed)
  _check_trace(


(5.187429870877947, 0.009569111782654165)


Tensor-likes are not close!

Mismatched elements: 3075 / 8192 (37.5%)
Greatest absolute difference: 0.004712343215942383 at index (4939,) (up to 1e-05 allowed)
Greatest relative difference: 0.020412850317887588 at index (88,) (up to 1e-05 allowed)
  _check_trace(


(3.428110991205488, 0.0028803468356389367)


In [27]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 1024, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50, batch :1024
now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 389 / 1024 (38.0%)
Greatest absolute difference: 0.0037508010864257812 at index (116,) (up to 1e-05 allowed)
Greatest relative difference: 0.012746036225854532 at index (53,) (up to 1e-05 allowed)
  _check_trace(


(1.2363608199722913, 0.001333068081767655)


Tensor-likes are not close!

Mismatched elements: 378 / 1024 (36.9%)
Greatest absolute difference: 0.004143655300140381 at index (1000,) (up to 1e-05 allowed)
Greatest relative difference: 0.01355061765429789 at index (834,) (up to 1e-05 allowed)
  _check_trace(


(0.8817067784922463, 0.00033840813114588533)


In [28]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 2048, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50, batch :2048
now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 767 / 2048 (37.5%)
Greatest absolute difference: 0.003098726272583008 at index (1393,) (up to 1e-05 allowed)
Greatest relative difference: 0.014466151208444194 at index (1653,) (up to 1e-05 allowed)
  _check_trace(


(2.4609116729424922, 0.002649942458709102)


Tensor-likes are not close!

Mismatched elements: 768 / 2048 (37.5%)
Greatest absolute difference: 0.003360748291015625 at index (1929,) (up to 1e-05 allowed)
Greatest relative difference: 0.020016769820366518 at index (1640,) (up to 1e-05 allowed)
  _check_trace(


(1.6354530611816718, 0.0006587272460439262)


In [29]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 4096, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50, batch :4096
now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 1606 / 4096 (39.2%)
Greatest absolute difference: 0.003762573003768921 at index (4026,) (up to 1e-05 allowed)
Greatest relative difference: 0.017128239289190023 at index (2803,) (up to 1e-05 allowed)
  _check_trace(


(4.696010117628137, 0.007481103968815433)


Tensor-likes are not close!

Mismatched elements: 1572 / 4096 (38.4%)
Greatest absolute difference: 0.003855586051940918 at index (3529,) (up to 1e-05 allowed)
Greatest relative difference: 0.020650739246520014 at index (205,) (up to 1e-05 allowed)
  _check_trace(


(2.9025074219217104, 0.0026891782272568197)


In [30]:
gen_and_test(num_field = 22 * 5,prefix = 10 * 5, batch = 8192, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50, batch :8192
now gen workload of DFM with config: dim: 64, num_field: 110, prefix: 50


Tensor-likes are not close!

Mismatched elements: 3103 / 8192 (37.9%)
Greatest absolute difference: 0.004856526851654053 at index (361,) (up to 1e-05 allowed)
Greatest relative difference: 0.018484574826173696 at index (6604,) (up to 1e-05 allowed)
  _check_trace(


(9.362633130988296, 0.01449921515971199)


Tensor-likes are not close!

Mismatched elements: 3142 / 8192 (38.4%)
Greatest absolute difference: 0.004033505916595459 at index (7974,) (up to 1e-05 allowed)
Greatest relative difference: 0.016660282813649233 at index (668,) (up to 1e-05 allowed)
  _check_trace(


(5.5894938780336965, 0.0077903637390282455)


In [31]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 1024, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145, batch :1024
now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 336 / 1024 (32.8%)
Greatest absolute difference: 0.003484964370727539 at index (975,) (up to 1e-05 allowed)
Greatest relative difference: 0.01513034643365099 at index (313,) (up to 1e-05 allowed)
  _check_trace(


(1.0086746161081352, 0.0009598554242039722)


Tensor-likes are not close!

Mismatched elements: 325 / 1024 (31.7%)
Greatest absolute difference: 0.003077775239944458 at index (830,) (up to 1e-05 allowed)
Greatest relative difference: 0.012863294354531575 at index (208,) (up to 1e-05 allowed)
  _check_trace(


(0.49135738702452914, 5.141858758416337e-05)


In [32]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 2048, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145, batch :2048
now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 606 / 2048 (29.6%)
Greatest absolute difference: 0.0027855634689331055 at index (616,) (up to 1e-05 allowed)
Greatest relative difference: 0.014614265990509264 at index (656,) (up to 1e-05 allowed)
  _check_trace(


(1.9563098817455524, 0.0022269685679972536)


Tensor-likes are not close!

Mismatched elements: 613 / 2048 (29.9%)
Greatest absolute difference: 0.0034360289573669434 at index (189,) (up to 1e-05 allowed)
Greatest relative difference: 0.014406042004683019 at index (1431,) (up to 1e-05 allowed)
  _check_trace(


(0.7535464483864454, 6.674785440649764e-05)


In [33]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 4096, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145, batch :4096
now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 1292 / 4096 (31.5%)
Greatest absolute difference: 0.0047768354415893555 at index (1061,) (up to 1e-05 allowed)
Greatest relative difference: 0.01971065956639185 at index (953,) (up to 1e-05 allowed)
  _check_trace(


(3.7771464537601083, 0.003646208813686822)


Tensor-likes are not close!

Mismatched elements: 1252 / 4096 (30.6%)
Greatest absolute difference: 0.0035182833671569824 at index (3517,) (up to 1e-05 allowed)
Greatest relative difference: 0.017989128280492258 at index (2227,) (up to 1e-05 allowed)
  _check_trace(


(1.1911033464937795, 0.0006123158414863787)


In [34]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 8192, dim = 32)

now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145, batch :8192
now gen workload of DFM with config: dim: 32, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 2504 / 8192 (30.6%)
Greatest absolute difference: 0.004373878240585327 at index (6621,) (up to 1e-05 allowed)
Greatest relative difference: 0.01872289456927342 at index (7636,) (up to 1e-05 allowed)
  _check_trace(


(7.5373727642759984, 0.006104925622159843)


Tensor-likes are not close!

Mismatched elements: 2614 / 8192 (31.9%)
Greatest absolute difference: 0.00418362021446228 at index (4361,) (up to 1e-05 allowed)
Greatest relative difference: 0.01660952900252025 at index (7000,) (up to 1e-05 allowed)
  _check_trace(


(2.2104542231073183, 0.0013156600716968058)


In [35]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 1024, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145, batch :1024
now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 320 / 1024 (31.2%)
Greatest absolute difference: 0.0037676990032196045 at index (283,) (up to 1e-05 allowed)
Greatest relative difference: 0.013588792539999978 at index (535,) (up to 1e-05 allowed)
  _check_trace(


(1.7701841629281336, 0.00262452377834797)


Tensor-likes are not close!

Mismatched elements: 310 / 1024 (30.3%)
Greatest absolute difference: 0.003512144088745117 at index (780,) (up to 1e-05 allowed)
Greatest relative difference: 0.016305951526588244 at index (26,) (up to 1e-05 allowed)
  _check_trace(


(0.6086390194844227, 0.00016000300854885065)


In [36]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 2048, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145, batch :2048
now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 652 / 2048 (31.8%)
Greatest absolute difference: 0.0031566321849823 at index (1435,) (up to 1e-05 allowed)
Greatest relative difference: 0.014952242614919719 at index (1728,) (up to 1e-05 allowed)
  _check_trace(


(3.5266109291388066, 0.003940729863629173)


Tensor-likes are not close!

Mismatched elements: 646 / 2048 (31.5%)
Greatest absolute difference: 0.004623115062713623 at index (345,) (up to 1e-05 allowed)
Greatest relative difference: 0.015400739416719933 at index (379,) (up to 1e-05 allowed)
  _check_trace(


(0.984047673794688, 0.00038198126294393603)


In [37]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 4096, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145, batch :4096
now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 1182 / 4096 (28.9%)
Greatest absolute difference: 0.003252476453781128 at index (3010,) (up to 1e-05 allowed)
Greatest relative difference: 0.01765136951497521 at index (3190,) (up to 1e-05 allowed)
  _check_trace(


(6.890155115906073, 0.008961491695780938)


Tensor-likes are not close!

Mismatched elements: 1279 / 4096 (31.2%)
Greatest absolute difference: 0.0034065842628479004 at index (726,) (up to 1e-05 allowed)
Greatest relative difference: 0.01765894135335728 at index (449,) (up to 1e-05 allowed)
  _check_trace(


(1.689119343854943, 0.0010405757937635178)


In [38]:
gen_and_test(num_field = 34 * 5,prefix = 29 * 5, batch = 8192, dim = 64)

now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145, batch :8192
now gen workload of DFM with config: dim: 64, num_field: 170, prefix: 145


Tensor-likes are not close!

Mismatched elements: 2616 / 8192 (31.9%)
Greatest absolute difference: 0.0044345855712890625 at index (8038,) (up to 1e-05 allowed)
Greatest relative difference: 0.020443083238905504 at index (1588,) (up to 1e-05 allowed)
  _check_trace(


(13.50084148134504, 0.02551305685008684)


Tensor-likes are not close!

Mismatched elements: 2502 / 8192 (30.5%)
Greatest absolute difference: 0.0032520592212677 at index (1040,) (up to 1e-05 allowed)
Greatest relative difference: 0.014105015624222577 at index (8082,) (up to 1e-05 allowed)
  _check_trace(


(3.154724576035324, 0.0020048053606321575)
