In [None]:
from __future__ import annotations
from inspect import isclass
from torch import nn
from transformer_lens.hook_points import HookPoint, HookedRootModule
from typing import List, Optional, TypeVar, Type, Union, cast, overload
from utils import iterate_module
from abc import ABC, abstractmethod
import torch
from functools import partial
from fastcore.basics import *
from fastcore.foundation import *
from torch import nn
from transformer_lens.hook_points import HookPoint
from typing import TypeVar, Generic, Union, Type, Any, Callable, get_type_hints, ParamSpec, Protocol
from inspect import isclass, signature
import functools
from fastapi import FastAPI



In [None]:
""" 
class AutoHookedRootModule(HookedRootModule):
    '''
    This class automatically builds hooks for all modules that are not hooks.
    NOTE this does not mean all edges in the graph are hooked only that the outputs of the modules are hooked.
    for instance torch.softmax(x) is not hooked but self.softmax(x) would be
    ''' """

In [None]:

class ModelTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.bla = nn.ModuleList([nn.Linear(10, 10)])
        self.lala = nn.Linear(10, 10)

    def forward(self, x):
        if isinstance(self, AutoHookedRootModule):
            print(f'{self.__class__.__name__}.mod_dict', self.mod_dict)
            print(self.bla[0], self.bla[0].hook_dict)
        x = self.bla[0].forward(x)
        x = self.lala.forward(x)
        return x

In [None]:

T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
_T = TypeVar("_T", bound=Callable)


def same_definition_as_in(t: _T) -> Callable[[Callable], _T]:
    def decorator(f: Callable) -> _T:
        return f  # type: ignore

    return decorator

class MyFastAPI(FastAPI):
    @same_definition_as_in(FastAPI.__init__)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) 

    @same_definition_as_in(FastAPI.get)
    def get(self, *args, **kwargs):
        print('get')
        return super().get(*args, **kwargs)

In [81]:
from typing import ParamSpec

T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
R = TypeVar('R')



class WrappedModule(nn.Module, Generic[T]):
    def __init__(self, module: T):
        super().__init__()
        self._module = module
        self.hook_point = HookPoint()
        self._create_forward()

    def _create_forward(self):
        original_forward = self._module.forward
        original_type_hints = get_type_hints(original_forward)

        @functools.wraps(original_forward)
        def new_forward(*args: Any, **kwargs: Any) -> Any:
            # Remove 'self' from args if it's present
            if args and isinstance(args[0], WrappedModule):
                args = args[1:]
            return self.hook_point(original_forward(*args, **kwargs))

        new_forward.__annotations__ = original_type_hints
        setattr(self.__class__, 'forward', new_forward)

class Wrapper(Generic[T]):
    def __init__(self, module: Union[T, Type[T]]):
        self.is_class = isclass(module)
        self.module = module
        if not self.is_class:
            self.wrapped = cast(T, WrappedModule(module))

    def __call__(self, *args: Any, **kwargs: Any) -> WrappedModule[T]:
        if self.is_class:
            return WrappedModule(self.module(*args, **kwargs))
        else:
            return self.wrapped

    def unwrap(self) -> Union[T, Type[T]]:
        return self.module

    def __getattr__(self, name: str) -> Any:
        return getattr(self.module, name)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.module})"

class Unwrappable(Protocol[T]):
    def unwrap(self) -> T: ...

class UnwrappableModule(Protocol[T]):
    def __getattr__(self, name: str) -> Any: ...
    def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
    def unwrap(self) -> T: ...


class WrappedModuleClass(Generic[T]):
    def __init__(self, module_class: Type[T]) -> T: # type: ignore
        self.module_class = module_class

    def __call__(self, *args: Any, **kwargs: Any) -> UnwrappableModule[T]:
        instance = self.module_class(*args, **kwargs)
        return auto_wrap(instance)

    def __getattr__(self, name: str) -> Any:
        return getattr(self.module_class, name)

    def unwrap(self) -> Type[T]:
        return self.module_class

@overload
def auto_wrap(module_or_class: Type[T]) -> WrappedModuleClass[T]: ...

@overload
def auto_wrap(module_or_class: T) -> UnwrappableModule[T]: ...

def auto_wrap(module_or_class: Union[T, Type[T]]) -> Union[UnwrappableModule[T], WrappedModuleClass[T]]:
    '''
    This function wraps either a module instance or a module class and returns a type that
    preserves the original module's interface plus an additional unwrap method.
    '''
    if isclass(module_or_class):
        return WrappedModuleClass(module_or_class)
    else:
        wrapped = WrappedModule(module_or_class)
        wrapped.unwrap = lambda: module_or_class  # type: ignore
        return cast(UnwrappableModule[T], wrapped)


# Usage example:
WrappedLinear = auto_wrap(nn.Linear)
print("type of WrappedLinear", WrappedLinear.unwrap())
print('get_type_hints(nn.Linear(10, 20).forward)', get_type_hints(nn.Linear(10, 20).forward))
print('get_type_hints(WrappedLinear.forward)', get_type_hints(WrappedLinear.forward))
print('WrappedLinear.forward(torch.randn(10))', WrappedLinear(10, 10).forward(torch.randn(10)))

type of WrappedLinear <class 'torch.nn.modules.linear.Linear'>
get_type_hints(nn.Linear(10, 20).forward) {'input': <class 'torch.Tensor'>, 'return': <class 'torch.Tensor'>}
get_type_hints(WrappedLinear.forward) {'input': <class 'torch.Tensor'>, 'return': <class 'torch.Tensor'>}
WrappedLinear.forward(torch.randn(10)) tensor([ 0.0644,  0.6165,  0.2268,  0.8688, -0.2364,  0.3074, -0.6121, -0.1512,
         0.5288, -1.2037], grad_fn=<ViewBackward0>)
