# Place for experimenting the progressive design

In [28]:
import os,sys
import yaml
import inspect
import importlib

sys.path.append('..')

import model_discovery.utils as U
from model_discovery.configs.gam_config import GAMConfig, GAMConfig_14M
from model_discovery.model.composer import GAUTree,GAUBase,GAUNode
# from model_discovery.evolution import BuildEvolution
from model_discovery.agents.flow.gau_utils import check_and_reformat_gau_code

ckpt_dir = os.environ['CKPT_DIR']
lib_dir = U.pjoin(ckpt_dir, 'test_composer', 'lib')
test_tree = GAUTree('TestTree', None, None, None,None,lib_dir)

prompts_dir='../model_discovery/agents/prompts/'
gab_py = U.read_file(U.pjoin(prompts_dir,'gab_template.py'))
gam_py = U.read_file(U.pjoin(prompts_dir,'gam_prompt.py'))
GAU_TEMPLATE = U.read_file(U.pjoin(prompts_dir,'gau_template.py'))
GAU_BASE=inspect.getsource(GAUBase)


## Parsers

In [2]:
import json

code='''
# gau.py   # DO NOT CHANGE OR REMOVE THE MAKK HERE, KEEP IT ALWAYS THE FIRST LINE #

import torch
import torch.nn as nn

from model_discovery.model.utils.modules import GAUBase # DO NOT CHANGE THIS IMPORT STATEMENT #


# YOU CAN IMPORT MORE MODULES HERE #

# YOU CAN DEFINE MORE CLASSES OR FUNCTIONS HERE #


class GAU(GAUBase): # DO NOT CHANGE THE NAME OF THIS CLASS
    """Generalized Autoregressive Block Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable
    """
    def __init__(self, embed_dim: int, device=None, dtype=None,**kwargs): # YOU CAN ADD MORE ARGUMENTS WITH OPTIONAL DEFAULT VALUES, BUT YOU HAVE TO HAVE embed_dim, device, dtype AS THE ARGUTMENTS #
        # argv: list of hyperparameters
        factory_kwargs = {"device": device, "dtype": dtype} # remember to pass it to all nn layers
        super().__init__(embed_dim) # DO NOT CHANGE THIS LINE #

        # COMPLETING THE CODE HERE #
        self.token_scorer: GAUBase = TokenScoringGAU(embed_dim, **factory_kwargs)
        self.dual_path: GAUBase = DualPathGAU(embed_dim, **factory_kwargs)
        self.latent_attention: GAUBase = LatentAttentionGAU(embed_dim, **factory_kwargs)


    # YOU CAN ADD MORE FUNCTIONS HERE #


    def _forward(self, X, **Z): 

        # THE CODE HERE MUST BE COMPLETED #
        # Step 1: Score tokens
        X, Z = self.token_scorer(X, **Z)
        # Step 2: Route through dual paths
        # Step 3: Apply latent attention
        Y, Z = self.latent_attention(X, **Z)
        X, Z = self.dual_path(X, **Z)

        return Y, Z
'''


In [29]:
# Example usage
code2 = """
# gau.py

import torch
import torch.nn as nn

from model_discovery.model.utils.modules import GAUBase

# Placeholder classes for future implementation
class MemoryAccessUnit(nn.Module):
    def __init__(self, embed_dim, memory_size, device=None, dtype=None):
        super().__init__(embed_dim)

    def _forward(self, X, **Z):
        return X, {}

class DownsamplingUnit(nn.Module):
    def __init__(self, embed_dim, downsample_factor, device=None, dtype=None):
        super().__init__(embed_dim)

    def _forward(self, X, **Z):
        return X, {}

class XAEU(GAUBase):  # This class will be renamed to the unit_name
    def __init__(self, embed_dim: int, device=None, dtype=None):
        super().__init__(embed_dim)
        self.unit: GAUBase = MemoryAccessUnit(embed_dim=embed_dim, device=device)

    def _forward(self, X, **Z):
        return X, Z
"""

unit_name = "XAU"  # Provide the unit_name to rename GAU class
reformatted_code, children_units, new_args, called, errors, warnings = check_and_reformat_gau_code(code, unit_name)
print("Reformatted Code:\n" + reformatted_code)
print("Errors:\n", errors)
print("Warnings:\n", warnings)
print("Children Units:\n", children_units)
print("New Arguments:\n", new_args)
print("Called Children:\n", called)

test_tree.add_unit(
    unit_name,reformatted_code,new_args,None,called,None,None,children_units,None
)
test_tree.root=test_tree.units['XAU']

Reformatted Code:
import torch
import torch.nn as nn
from model_discovery.model.utils.modules import GAUBase


class XAU(GAUBase):
    """Generalized Autoregressive Block Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable
    """

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)
        self.token_scorer: GAUBase = TokenScoringGAU(embed_dim=embed_dim,
            device=device, dtype=dtype, **kwargs)
        self.dual_path: GAUBase = DualPathGAU(embed_dim=embed_dim, device=
            device, dtype=dtype, **kwargs)
        self.latent_attention: GAUBase = LatentAttentionGAU(embed_dim=
         

In [33]:
children={
    'unit1': 'MockUnit1',
    'unit2': 'MockUnit2',
}
called=['unit1', 'unit2', 'unit1']
test_tree.units['DualPathGAU'] = GAUNode('DualPathGAU', None, None, None, called, None, None, children, None)
children={
    'unit1': 'MockUnit3',
    'unit2': 'MockUnit4',
}
called=['unit2', 'unit2']
test_tree.units['MockUnit1'] = GAUNode('MockUnit1', None, None, None, called, None, None, children, None)

In [6]:
test_tree.units

{'XAU': GAUNode(name='XAU', code='import torch\nimport torch.nn as nn\nfrom model_discovery.model.utils.modules import GAUBase\n\n\nclass XAU(GAUBase):\n    """Generalized Autoregressive Block Unit\n        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}\n        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}\n        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable\n    """\n\n    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):\n        factory_kwargs = {\'device\': device, \'dtype\': dtype}\n        super().__init__(embed_dim)\n        self.token_scorer: GAUBase = TokenScoringGAU(embed_dim=embed_dim,\n            device=device, dtype=dtype, **kwargs)\n        self.dual_path: GAUBase = DualPathGAU(embed_dim=embed_dim, device=\n            device, dtype=dtype, **kwargs)\n        self.latent_attention: GAUBase = 

In [34]:
def _view(self,_name,obj='root',path='',node=None,pstr='',unimplemented=set()):
    # create a string representation of the tree
    name=obj+': '+_name
    if node is None: 
        name += ' (Unimplemented)'
        unimplemented.add(_name)
    else:
        name += f' (Exec path: {'->'.join(node.path)})'
    if path!='':
        level=len(path.split('.'))
        name='    '*level+' |- '+name
        path+='.'+_name
    else:
        pstr+=f'GAU Tree Map of {self.name}:\n'
        path=_name
    pstr+='  '+name+'\n'
    if node is not None:
        for child, child_unit in node.children.items():
            child_node = self.units.get(child_unit,None)
            pstr,unimplemented=_view(self,child_unit,child,path,child_node,pstr,unimplemented)
    return pstr,unimplemented

def view(self):
    pstr,unimplemented=_view(self,self.root.name,node=self.root)
    implemented = set(self.units.keys())
    pstr+='\nImplemented Units: '+', '.join(implemented)
    if len(unimplemented)>0:
        pstr+='\nUnimplemented Units: '+', '.join(unimplemented)
    else:
        pstr+='\nAll units are implemented.'
    return pstr

print(view(test_tree))

GAU Tree Map of TestTree:
  root: XAU (Exec path: token_scorer->latent_attention->dual_path)
       |- token_scorer: TokenScoringGAU (Unimplemented)
       |- dual_path: DualPathGAU (Exec path: unit1->unit2->unit1)
           |- unit1: MockUnit1 (Exec path: unit2->unit2)
               |- unit1: MockUnit3 (Unimplemented)
               |- unit2: MockUnit4 (Unimplemented)
           |- unit2: MockUnit2 (Unimplemented)
       |- latent_attention: LatentAttentionGAU (Unimplemented)

Implemented Units: XAU, MockUnit1, DualPathGAU
Unimplemented Units: LatentAttentionGAU, MockUnit4, MockUnit3, MockUnit2, TokenScoringGAU


In [5]:
print(reformatted_code)

import torch
import torch.nn as nn
from model_discovery.model.utils.modules import GAUBase


class XAU(GAUBase):
    """Generalized Autoregressive Block Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable
    """

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)
        self.token_scorer: GAUBase = TokenScoringGAU(embed_dim=embed_dim,
            device=device, dtype=dtype, **kwargs)
        self.dual_path: GAUBase = DualPathGAU(embed_dim=embed_dim, device=
            device, dtype=dtype, **kwargs)
        self.latent_attention: GAUBase = LatentAttentionGAU(embed_dim=
            embed_dim, devi

In [6]:
gab_template='''
# gab.py    # DO NOT CHANGE OR REMOVE THE MAKK HERE, KEEP IT ALWAYS THE FIRST LINE #

import torch
import torch.nn as nn

from model_discovery.model.utils.modules import GABBase # DO NOT CHANGE THIS IMPORT STATEMENT #


class GAB(GABBase):
    def __init__(self,embed_dim: int, block_loc: tuple, device=None,dtype=None,**kwargs): # YOU CAN ADD MORE ARGUMENTS, BUT YOU HAVE TO HAVE embed_dim, device, dtype AS THE ARGUTMENTS #
        factory_kwargs = {{"device": device, "dtype": dtype}} # remember to pass it to nn layers
        super().__init__(embed_dim, block_loc) # DO NOT CHANGE THIS LINE #
        self.root = {ROOT_UNIT_NAME}(embed_dim, embed_dim=embed_dim, device=device, dtype=dtype, **kwargs)

    def _forward(self, X, *Z): 
        X, Z = self.root(X, **Z)
        return X, Z
'''

In [7]:
import ast
import astor
from typing import List

def replace_from_second(text, old, new):
    first_part, remaining = text.split(old, 1)
    remaining = remaining.replace(old, new)
    return first_part + old + remaining

class GABComposer:
    
    def generate_gab_code(self,tree):
        root_node = tree.root
        generated_code = []
        
        # Recursively generate code for the root and its children
        self.generate_node_code(root_node.name, generated_code, tree.units)
        
        # Combine all generated code into a single Python file content
        gau_code = "\n".join(generated_code)

        gathered_args={}
        for unit in tree.units.values():
            gathered_args.update(unit.args)
        gab_code=gab_template.format(ROOT_UNIT_NAME=root_node.name)

        cfg_code=f'gab_config = {json.dumps(gathered_args)}'

        compoesed_code = f'{gab_code}\n\n{gau_code}\n\n{cfg_code}'

        compoesed_code=replace_from_second(compoesed_code,'import torch\n','')
        compoesed_code=replace_from_second(compoesed_code,'import torch.nn as nn\n','')
        compoesed_code=replace_from_second(compoesed_code,'from model_discovery.model.utils.modules import GAUBase\n','')

        return compoesed_code


    # Recursive function to generate code for a node and its children
    def generate_node_code(self, unit_name, generated_code: List[str], units):
        # Check if the node exists in units
        if unit_name not in units:
            # If the node does not exist in units, create a placeholder
            generated_code.append(self.create_placeholder_class(unit_name))
        else:
            node = units[unit_name]
            generated_code.append(node.code)
            
            # Recursively generate code for children
            children_units=set()
            for child_name, child_unit_name in node.children.items():
                children_units.add(child_unit_name)
            for child_unit in children_units:
                self.generate_node_code(child_unit, generated_code, units)

    # Function to create a placeholder class for a GAUNode
    def create_placeholder_class(self, unit_name) -> str:
        class_template = f"""
class {unit_name}(GAUBase): 
    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs): 
        factory_kwargs = {{"device": device, "dtype": dtype}} 
        super().__init__(embed_dim) 
        
    def _forward(self, X, **Z): 
        return X
"""
        return class_template

    # Function to convert the generated code to AST using ast and astor
    def convert_code_to_ast(self, code: str):
        try:
            return ast.parse(code)
        except SyntaxError as e:
            print(f"Syntax error in code: {code}")
            raise e

    # Function to convert AST back to Python code using astor
    def convert_ast_to_code(self, ast_tree: ast.AST) -> str:
        return astor.to_source(ast_tree)
    

# Example usage
generated_code = GABComposer().generate_gab_code(test_tree)
print(generated_code)  # This will print the final Python code for the entire GAUTree



# gab.py    # DO NOT CHANGE OR REMOVE THE MAKK HERE, KEEP IT ALWAYS THE FIRST LINE #

import torch
import torch.nn as nn

from model_discovery.model.utils.modules import GABBase # DO NOT CHANGE THIS IMPORT STATEMENT #


class GAB(GABBase):
    def __init__(self,embed_dim: int, block_loc: tuple, device=None,dtype=None,**kwargs): # YOU CAN ADD MORE ARGUMENTS, BUT YOU HAVE TO HAVE embed_dim, device, dtype AS THE ARGUTMENTS #
        factory_kwargs = {"device": device, "dtype": dtype} # remember to pass it to nn layers
        super().__init__(embed_dim, block_loc) # DO NOT CHANGE THIS LINE #
        self.root = XAU(embed_dim, embed_dim=embed_dim, device=device, dtype=dtype, **kwargs)

    def _forward(self, X, *Z): 
        X, Z = self.root(X, **Z)
        return X, Z


from model_discovery.model.utils.modules import GAUBase


class XAU(GAUBase):
    """Generalized Autoregressive Block Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variable

In [56]:
print(code)

# gau.py

import torch
import torch.nn as nn

from model_discovery.model.utils.modules import GAUBase

# Placeholder imports for future GAUs
# from gau import RandomizedAttentionUnit, HierarchicalCompositionUnit

class HRAB(GAUBase):
 """Hierarchical Randomized Attention Block Unit
 Input: X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
 Output: Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
 Constraints: Causal, differentiable, parameter number, complexity, parallelizable
 """
 def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
 factory_kwargs = {"device": device, "dtype": dtype}
 super().__init__(embed_dim)
 
 # Initialize the Randomized Attention Unit
 self.randomized_attention: GAUBase = RandomizedAttentionUnit(embed_dim, **factory_kwargs, **kwargs)
 
 # Initialize the Hierarchical Composition Unit
 self.hierarchical_composition: GAUBase = HierarchicalCompositionUnit(embe

In [48]:
check_report=r'''
Checking the designed model...
Checking code format...
Code format is correct and reformatted.


Warnings:

The super().__init__(embed_dim, block_loc) call in GAB is force overwritten by the reformatter. It may cause error if you modified this line.

Error: Model initialization failed with error: Expected size for first two dimensions of batch2 tensor to be: [2, 128] but got: [2, 2048].
Full Traceback: 
Traceback (most recent call last):
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\agents\roles\checker.py", line 835, in check
    glm(mock_input) # super slow as well, why??? but its only for the first time initialization
    ^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\gam.py", line 399, in forward
    hidden_states = self.backbone(input_ids, **gab_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\gam.py", line 285, in forward
    hidden_states, residual, intermediate_vars = block(
                                                 ^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\gam.py", line 106, in forward
    hidden_states,intermediate_vars = self.gab(hidden_states, **intermediate_vars)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\utils\modules.py", line 29, in forward
    Y = self._forward(X, **Z)
        ^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 16, in _forward
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\utils\modules.py", line 58, in forward
    Y = self._forward(X, **_Z)
        ^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 41, in _forward
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\model\utils\modules.py", line 58, in forward
    Y = self._forward(X, **_Z)
        ^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 77, in _forward
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 128] but got: [2, 2048].

Hint: 1. if it is a dtype or device error, check whether the factory kwargs are passed to the layers, and whether you manually designate a type instead of apply the type from factory kwargs or the input's type during conversion or creating of an variable. 2. If it is a shape error, check whether the output shape is equal to the input shape. The output shape of GAB should be the same as the input. 3. Always remember to follow the template and do not implement redundant part like embedding layer.
'''

In [47]:
gabcode=r'''import torch
import torch.nn as nn
from model_discovery.model.utils.modules import GABBase


class GAB(GABBase):

    def __init__(self, embed_dim: int, block_loc: tuple, device=None, dtype
        =None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim, block_loc)
        self.root = RootGAU(embed_dim=embed_dim, device=device, dtype=dtype,
            **kwargs)

    def _forward(self, X, **Z):
        X, Z = self.root(X, **Z)
        return X, Z


from model_discovery.model.utils.modules import GAUBase


class RootGAU(GAUBase):
    """Generalized Autoregressive Block Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable
    """

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)
        self.meta_sorting: GAUBase = MetaSortingGAU(embed_dim=embed_dim,
            device=device, dtype=dtype, **kwargs)
        self.chunked_attention: GAUBase = ChunkedAttentionGAU(embed_dim=
            embed_dim, device=device, dtype=dtype, **kwargs)
        self.hybrid_attention: GAUBase = HybridAttentionGAU(embed_dim=
            embed_dim, device=device, dtype=dtype, **kwargs)

    def _forward(self, X, **Z):
        X, Z = self.meta_sorting(X, **Z)
        X, Z = self.chunked_attention(X, **Z)
        Y, Z = self.hybrid_attention(X, **Z)
        return Y, Z


import torch.nn.functional as F


def sinkhorn_balancing(log_alpha, n_iters=5):
    """Applies Sinkhorn balancing to ensure doubly stochastic matrix."""
    for _ in range(n_iters):
        log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-1, keepdim=True
            )
        log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-2, keepdim=True
            )
    return log_alpha


class MetaSortingGAU(GAUBase):
    """Meta Sorting Generalized Autoregressive Unit
        Input:        X: (batch, seqlen, embed_dim), Z: {dict of all current intermediate variables}
        Output:       Y: (batch, seqlen, embed_dim), Z_: Optional, {dict of *new* intermediate variables to update the current Z}
        Constraints:  Causal, differentiable, parameter number, complexity, parallelizable
    """

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)
        self.sorting_network = nn.Linear(embed_dim, embed_dim, **factory_kwargs
            )

    def _forward(self, X, **Z):
        latent_permutations = self.sorting_network(X)
        balanced_permutations = sinkhorn_balancing(latent_permutations)
        balanced_permutations = F.softmax(balanced_permutations, dim=-1)
        sorted_sequence = torch.matmul(balanced_permutations, X)
        truncated_sequence = sorted_sequence
        assert truncated_sequence.shape == X.shape, 'Output shape must match input shape'
        return truncated_sequence, Z


class HybridAttentionGAU(GAUBase):

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)

    def _forward(self, X, **Z):
        return X


class ChunkedAttentionGAU(GAUBase):

    def __init__(self, embed_dim: int, device=None, dtype=None, **kwargs):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(embed_dim)

    def _forward(self, X, **Z):
        return X


gab_config = {}
'''

In [54]:
gabcode_lines=gabcode.split('\n')
new_check_report=[]
for line in check_report.split('\n'):
    if 'File "<string>", line' in line:
        line=line.replace('File "<string>", line','File "gab.py", line')
        line_num=int(line.split('File "gab.py", line ')[-1].split(',')[0].strip())
        line=line.replace(f'line {line_num}',f'line {line_num}: {gabcode_lines[line_num-1]}')
    new_check_report.append(line)
new_check_report='\n'.join(new_check_report)
print(new_check_report)



Checking the designed model...
Checking code format...
Code format is correct and reformatted.



The super().__init__(embed_dim, block_loc) call in GAB is force overwritten by the reformatter. It may cause error if you modified this line.

Error: Model initialization failed with error: Expected size for first two dimensions of batch2 tensor to be: [2, 128] but got: [2, 2048].
Full Traceback: 
Traceback (most recent call last):
  File "C:\ChengJunyan1\Research\model_discovery\model_discovery\agents\roles\checker.py", line 835, in check
    glm(mock_input) # super slow as well, why??? but its only for the first time initialization
    ^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ChengJunyan1\anaconda3\envs\modis\Lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl