# Static analysis of GAB block

In [1]:
import sys
import torch
from torch.fx import symbolic_trace

sys.path.append('..')

import model_discovery.model.lab as lab
from model_discovery.model.library import *
from exec_utils import BuildTool
from model_discovery.configs.gam_config import ( 
    GAMConfig,GAMConfig_14M,GAMConfig_31M,GAMConfig_70M,GAMConfig_125M,GAMConfig_350M,GAMConfig_760M,
    GAMConfig_1300M,GAMConfig_2700M,GAMConfig_6700M,GAMConfig_13B,GAMConfig_175B,GAMConfig_1T,GAMConfig_debug
)


def load_gab(model_name: str, scale='14M'):
    gab_code = MODEL2CODE[model_name]
    checker = BuildTool(tool_type="checker")
    try:
        checkpass, gab_code = checker._check_format_and_reformat(gab_code)
        assert checkpass
    except AssertionError as e:
        print('Model does not pass the format checker')
        raise e
    
    # Wrap len inside the executed code
    gab_code = f"{gab_code}"
    
    module = {}
    exec(gab_code.replace("class GAB","class GABCustom"),module)
    assert "GABCustom" in module, "Class GAB not found in module. You should never ever change the class name of GAB and it should always inherit from GABBase."
    GAB = module["GABCustom"]

    cfg = eval(f"GAMConfig_{scale}()")
    gab_config = {} 
    assert "gab_config" in module, "Dictionary gab_config not found in module."
    gab_config = module["gab_config"]

    gab= GAB(cfg.d_model,block_loc=(0,cfg.n_block),device=None,dtype=None, **gab_config)

    return gab,cfg

  @custom_fwd
  @custom_bwd


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Failed to login to HuggingFace Hub, some datasets may not be available to download.


  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
2024-08-19:19:18:24,404 INFO     [__init__.py:29] Skipping import of cpp extensions


In [2]:
import torch
import torch.fx
from types import MethodType
from dataclasses import dataclass
import copy
from functools import partial

torch._dynamo.config.cache_size_limit = 64  # Increase the limit as needed


# Redefining the necessary classes with torch imports


class ModuleNode:
    def __init__(self, name, graph_module=None,kwarg=None):
        self.name = name
        self.graph_module = graph_module
        self.children = []
        self.kwargs = kwarg

    def print_tree(self, indent=""):
        print(indent + self.name)
        if self.graph_module:
            print(indent + "  (GraphModule captured)")#, self.graph_module)
        for child in self.children:
            child.print_tree(indent + "  ")

@dataclass
class BlockAnalysis:
    root: ModuleNode
    nodes: dict
    config: GAMConfig

class BlockAnalyzer:
    def __init__(self):
        self.module_tree_root = None
        self.current_inputs = {}  # To store inputs for each module during forward pass
        self.current_nodes = {}  # To store ModuleNode instances for each module by path

    def track_input_wrapper(self, original_forward, module_path):
        # Custom wrapper for the forward method to capture both positional and keyword arguments
        def wrapped_forward(module_self, *inputs, **kwargs):
            self.current_inputs[module_path] = (inputs, kwargs)
            # Call the original forward method without re-binding `self`
            return original_forward(*inputs, **kwargs)

        return wrapped_forward

    def wrap_forward_methods(self, model,wrapper):
        # Replace the forward method of each submodule with the custom wrapped version
        for module_path, module in self._get_full_module_paths(model):
            original_forward = module.forward
            module.forward = MethodType(wrapper(original_forward, module_path), module)

    def _get_full_module_paths(self, model):
        # Recursively generate the full path for each module in the model
        module_paths = []

        def recursive_collect_modules(parent, prefix):
            for name, module in parent.named_children():
                full_path = f"{prefix}.{name}" if prefix else name
                module_paths.append((full_path, module))
                recursive_collect_modules(module, full_path)

        recursive_collect_modules(model, "")
        return module_paths

    def analyze_submodule(self, module_path, module):
        # Retrieve the inputs and kwargs for this module captured during the forward pass
        inputs, kwargs = self.current_inputs.get(module_path, (None, None))

        module = copy.deepcopy(module)
        if inputs is None:
            return None
        if kwargs:
            for key in kwargs:
                try:
                    kwargs[key] = kwargs[key].detach()
                except:
                    pass
            module.forward=partial(module.forward,**kwargs)

        # Trace the current module with the captured inputs and kwargs
        if isinstance(inputs,tuple):
            new_inputs = []
            for inp in inputs:
                try:
                    new_inputs.append(inp.detach())
                except:
                    new_inputs.append(inp)
            inputs = tuple(new_inputs)
        else:
            try:
                inputs = inputs.detach()
            except:
                pass
        traced_module = torch.jit.trace(module, inputs)
        # Create a ModuleNode for the current module
        node = ModuleNode(module_path, traced_module, kwargs)
        self.current_nodes[module_path] = node

        # Recursively trace submodules
        for name, submodule in module.named_children():
            child_path = f"{module_path}.{name}"
            child_node = self.analyze_submodule(child_path.replace('root.',''), submodule)
            if child_node is not None:
                node.children.append(child_node)

        return node

    def analyze(self, model, cfg):
        # Wrap the forward methods to capture both positional and keyword arguments
        self.current_inputs = {}
        self.current_nodes = {}
        model_wrap = copy.deepcopy(model)   
        self.wrap_forward_methods(model_wrap,self.track_input_wrapper)

        # Run the model with an example input
        input_tensor = torch.randn(2, 100, cfg.d_model)
        model_wrap(input_tensor)  # This will trigger the wrapped forward methods and capture inputs
        del model_wrap
        
        # Start with the root module and analyze it recursively
        self.current_inputs['root'] = (input_tensor, None)
        self.module_tree_root = self.analyze_submodule('root', model)

        return BlockAnalysis(self.module_tree_root, self.current_nodes, cfg)

# This update removes the re-binding of `self` and directly calls the original forward method with the provided inputs and kwargs.

# Example usage:
gab,cfg = load_gab('ttt')
analyzer = BlockAnalyzer()
analysis = analyzer.analyze(gab, cfg)

Loaded tokenized dataset wikitext-2 from /home/junyanc/model_discovery/data/wikitext-2/tokenized/meta-llama/Llama-2-7b-hf/2048
Checking code format...
Code after reformatted:

import torch
import torch.nn as nn
from model_discovery.model.utils.modules import GABBase
from typing import Any, Dict, Optional, Tuple, Union
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils._pytree import tree_map
from transformers.utils import logging
from transformers.activations import ACT2FN
try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except:
    causal_conv1d_update, causal_conv1d_fn = None, None


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., :x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)


def permute_qk(q, k):
    bsz, num_head, seq_len, head_dim = q.shape
    q = q.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4
        ).reshape(bsz, nu

  assert X.shape[-1] == self.embed_dim
  assert Y.shape == X.shape, f"GAB Output shape must be the same as input shape, got {Y.shape} instead"


In [3]:
# analysis.nodes
# # TODO: 1. modularize everything; 2. track the flow of tensor
print(analysis.nodes['seq_modeling_block'].graph_module.graph)

graph(%self.1 : __torch__.builtins.___torch_mangle_32.TTTLinear,
      %hidden_states : Float(2, 100, 128, strides=[12800, 128, 1], requires_grad=0, device=cpu),
      %position_ids.1 : Long(1, 100, strides=[100, 1], requires_grad=0, device=cpu)):
  %o_proj : __torch__.torch.nn.modules.linear.___torch_mangle_29.Linear = prim::GetAttr[name="o_proj"](%self.1)
  %post_norm : __torch__.torch.nn.modules.normalization.___torch_mangle_31.LayerNorm = prim::GetAttr[name="post_norm"](%self.1)
  %ttt_norm_bias : Tensor = prim::GetAttr[name="ttt_norm_bias"](%self.1)
  %ttt_norm_weight : Tensor = prim::GetAttr[name="ttt_norm_weight"](%self.1)
  %b1 : Tensor = prim::GetAttr[name="b1"](%self.1)
  %W1 : Tensor = prim::GetAttr[name="W1"](%self.1)
  %learnable_token_idx : Tensor = prim::GetAttr[name="learnable_token_idx"](%self.1)
  %token_idx : Tensor = prim::GetAttr[name="token_idx"](%self.1)
  %learnable_ttt_lr_bias : Tensor = prim::GetAttr[name="learnable_ttt_lr_bias"](%self.1)
  %learnable_ttt_lr_w

## Analysis of Flow