In [None]:
from __future__ import annotations
from inspect import isclass
from torch import nn
from transformer_lens.hook_points import HookPoint, HookedRootModule
from typing import List, Optional, TypeVar, Type, Union, cast, overload
from utils import iterate_module
from abc import ABC, abstractmethod
import torch
from functools import partial
from fastcore.basics import *
from fastcore.foundation import *
from torch import nn
from transformer_lens.hook_points import HookPoint
from typing import TypeVar, Generic, Union, Type, Any, Callable, get_type_hints, ParamSpec, Protocol
from inspect import isclass, signature
import functools
from fastapi import FastAPI



In [None]:
""" 
class AutoHookedRootModule(HookedRootModule):
    '''
    This class automatically builds hooks for all modules that are not hooks.
    NOTE this does not mean all edges in the graph are hooked only that the outputs of the modules are hooked.
    for instance torch.softmax(x) is not hooked but self.softmax(x) would be
    ''' """

In [None]:

class ModelTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.bla = nn.ModuleList([nn.Linear(10, 10)])
        self.lala = nn.Linear(10, 10)

    def forward(self, x):
        if isinstance(self, AutoHookedRootModule):
            print(f'{self.__class__.__name__}.mod_dict', self.mod_dict)
            print(self.bla[0], self.bla[0].hook_dict)
        x = self.bla[0].forward(x)
        x = self.lala.forward(x)
        return x

In [None]:

T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
_T = TypeVar("_T", bound=Callable)


def same_definition_as_in(t: _T) -> Callable[[Callable], _T]:
    def decorator(f: Callable) -> _T:
        return f  # type: ignore

    return decorator

class MyFastAPI(FastAPI):
    @same_definition_as_in(FastAPI.__init__)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) 

    @same_definition_as_in(FastAPI.get)
    def get(self, *args, **kwargs):
        print('get')
        return super().get(*args, **kwargs)

In [None]:

from typing import Optional, Set


from torch.nn.modules.module import Module


T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
R = TypeVar('R')


class WrappedClass(Generic[T]):
    def __init__(self, module_class: Type[T]) -> T: # type: ignore
        self.module_class = module_class

    def __call__(self, *args: Any, **kwargs: Any) -> HookedInstance[T]:
        instance = self.module_class(*args, **kwargs)
        return auto_wrap(instance)

    def __getattr__(self, name: str) -> Any:
        return getattr(self.module_class, name)

    def unwrap(self) -> Type[T]:
        '''recursively unwraps the module class'''
        for attr in self.module_class.__dict__:
            if isinstance(self.module_class.__dict__[attr], WrappedClass):
                self.module_class.__dict__[attr] = self.module_class.__dict__[attr].unwrap()
        return self.module_class

@overload
def auto_wrap(module_or_class: Type[T]) -> WrappedClass[T]: ...

@overload
def auto_wrap(module_or_class: T) -> HookedInstance[T]: ...

def auto_wrap(module_or_class: Union[T, Type[T]]) -> Union[HookedInstance[T], WrappedClass[T]]:
    '''
    This function wraps either a module instance or a module class and returns a type that
    preserves the original module's interface plus an additional unwrap method.
    '''
    if isclass(module_or_class):
        return WrappedClass(module_or_class)
    else:
        wrapped = HookedInstance(module_or_class)
        #NOTE we set the unwrap method to just return module_or_class
        wrapped.unwrap = lambda: module_or_class # type: ignore
        return cast(HookedInstance[T], wrapped)

class HookedInstance(HookedRootModule, Generic[T]):
    def __init__(self, module: T):
        super().__init__()
        # NOTE we need to name it in this way to not 
        # to avoid infinite regress and override 
        self._module = module
        self.hook_point = HookPoint()
        self._create_forward()
        self._wrap_submodules()
        self.setup()

    def new_attr_fn(self, name: str) -> Any:
        return getattr(self._module, name)

    #NOTE we override the nn.Module implementation to use _module only
    def named_modules(self, memo: Set[Module] | None = None, prefix: str = '', remove_duplicate: bool = True):
        #NOTE BE VERY CAREFUL HERE
        
        if memo is None:
            memo = set()

        if self not in memo:
            memo.add(self)
            yield prefix, self
            for name, module in self._module.named_children():
                if module not in memo:
                    submodule_prefix = prefix + ('.' if prefix else '') + name
                    if isinstance(module, HookedInstance):
                        yield from module.named_modules(memo, submodule_prefix)
                    else:
                        yield submodule_prefix, module
                        if hasattr(module, 'named_modules'):
                            yield from module.named_modules(memo, submodule_prefix)

            if hasattr(self, 'hook_point'):
                hook_point_prefix = prefix + ('.' if prefix else '') + 'hook_point'
                yield hook_point_prefix, self.hook_point

    #def unwrap(self) -> T: ...

    def _wrap_submodules(self):
        for name, submodule in self._module.named_children():
            if isinstance(submodule, (nn.ModuleList, nn.Sequential)):
                wrapped_container = type(submodule)() #initialize the container
                for i, m in enumerate(submodule):
                    wrapped_container.append(auto_wrap(m))
                setattr(self._module, name, wrapped_container)
            elif isinstance(submodule, nn.ModuleDict):
                wrapped_container = type(submodule)()
                for key, m in submodule.items():
                    wrapped_container[key] = auto_wrap(m)
                setattr(self._module, name, wrapped_container)
            else:
                setattr(self._module, name, auto_wrap(submodule))

    def _create_forward(self):
        original_forward = self._module.forward
        original_type_hints = get_type_hints(original_forward)

        @functools.wraps(original_forward)
        def new_forward(*args: Any, **kwargs: Any) -> Any:
            return self.hook_point(original_forward(*args, **kwargs))

        new_forward.__annotations__ = original_type_hints
        self.forward = new_forward  # Assign to instance, not class

    def get_hooks(self):
        return [(hook, hook_point) for hook, hook_point in self.hook_dict.items()] 

args = (1,1)

class Linear(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.wtf = nn.Linear(*args)
    def forward(self, x):
        x = self.wtf(x)
        return x

class Meta2Linear(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.modules_lst = nn.ModuleDict({'linear1': Linear(*args), 'linear2': Linear(*args)})
    def forward(self, x):
        x = self.modules_lst['linear1'](x)
        x = self.modules_lst['linear2'](x)
        return x
        #for module in self.modules_lst:
        #    x = module(x)
        #return x

Wrapped = auto_wrap(Meta2Linear(*args))
print(Wrapped.unwrap())




In [None]:
#BASIC TESTS

class NestedLinear(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.linear1 = nn.Linear(*args)
        self.linear2 = nn.Linear(*args)
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


def Generic_test_types(Model : Type[T], args):
    pre_init_Model = auto_wrap(Model(*args))
    post_init_Model = auto_wrap(Model)(*args)

    assert type(pre_init_Model) == type(post_init_Model), f"{type(pre_init_Model)} != {type(post_init_Model)}"
    assert type(pre_init_Model.unwrap()) == type(post_init_Model.unwrap()) , f"{type(pre_init_Model.unwrap())} != {type(post_init_Model.unwrap())}"
    assert WrappedLinear1.unwrap() == type(WrappedLinear2.unwrap()), f"{WrappedLinear1.unwrap()} != {type(WrappedLinear2.unwrap())}"

def Generic_test_type_hints(Model : Type[T], args):
    pre_init_Model = auto_wrap(Model(*args))
    post_init_Model = auto_wrap(Model)(*args)

    orig_type_hints = get_type_hints(Model(*args).forward)
    wrapped_pre_init_type_hints = get_type_hints(pre_init_Model.forward)
    wrapped_post_init_type_hints = get_type_hints(post_init_Model.forward)
    assert orig_type_hints == wrapped_pre_init_type_hints, f"{orig_type_hints} != {wrapped_pre_init_type_hints}"
    assert orig_type_hints == wrapped_post_init_type_hints, f"{orig_type_hints} != {wrapped_post_init_type_hints}"

def Generic_test_hook(Model : Type[T], args):
    WrappedLinear1 = auto_wrap(Model)(*args)
    WrappedLinear2 = auto_wrap(Model(*args))

    counter = {'data': 0}

    def hook_fn(x, hook=None, hook_name=None):
        counter['data'] += 1

    WrappedLinear1.run_with_hooks(
        torch.rand(1,1),
        fwd_hooks=[('hook_point', partial(hook_fn, hook_name='hook_point'))],
    )

    assert counter['data'] == 1
    WrappedLinear2.run_with_hooks(
        torch.rand(1,1),
        fwd_hooks=[('hook_point', partial(hook_fn, hook_name='hook_point'))],
    )
    assert counter['data'] == 2

TEST_CLASSES = [(nn.Linear, (1,1)), (NestedLinear, (1,1))]
for Model, args in TEST_CLASSES:
    print(f'Testing {Model.__name__} with args {args}')
    Generic_test_types(Model, args)
    Generic_test_type_hints(Model, args)
    Generic_test_hook(Model, args)
    print('SUCCESS\n')



In [17]:
from torch import nn

class Bruh(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1,1)

isinstance(Bruh(), nn.Module)

True

In [1]:
a = {'a' : 1}
a.a

AttributeError: 'dict' object has no attribute 'a'