In [None]:
from __future__ import annotations
from inspect import isclass
from torch import nn
from transformer_lens.hook_points import HookPoint, HookedRootModule
from typing import List, Optional, TypeVar, Type, Union, cast, overload
from utils import iterate_module
from abc import ABC, abstractmethod
import torch
from functools import partial
from fastcore.basics import *
from fastcore.foundation import *
from torch import nn
from transformer_lens.hook_points import HookPoint
from typing import TypeVar, Generic, Union, Type, Any, Callable, get_type_hints, ParamSpec, Protocol
from inspect import isclass, signature
import functools
from fastapi import FastAPI



In [None]:
""" 
class AutoHookedRootModule(HookedRootModule):
    '''
    This class automatically builds hooks for all modules that are not hooks.
    NOTE this does not mean all edges in the graph are hooked only that the outputs of the modules are hooked.
    for instance torch.softmax(x) is not hooked but self.softmax(x) would be
    ''' """

In [None]:

class ModelTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.bla = nn.ModuleList([nn.Linear(10, 10)])
        self.lala = nn.Linear(10, 10)

    def forward(self, x):
        if isinstance(self, AutoHookedRootModule):
            print(f'{self.__class__.__name__}.mod_dict', self.mod_dict)
            print(self.bla[0], self.bla[0].hook_dict)
        x = self.bla[0].forward(x)
        x = self.lala.forward(x)
        return x

In [None]:

T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
_T = TypeVar("_T", bound=Callable)


def same_definition_as_in(t: _T) -> Callable[[Callable], _T]:
    def decorator(f: Callable) -> _T:
        return f  # type: ignore

    return decorator

class MyFastAPI(FastAPI):
    @same_definition_as_in(FastAPI.__init__)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) 

    @same_definition_as_in(FastAPI.get)
    def get(self, *args, **kwargs):
        print('get')
        return super().get(*args, **kwargs)

In [118]:

T = TypeVar('T', bound=nn.Module)
P = ParamSpec('P')
R = TypeVar('R')

class WrappedInstance(HookedRootModule, Generic[T]):
    def __init__(self, module: T):
        super().__init__()
        self._module = module
        self.hook_point = HookPoint()
        self._create_forward()
        self.setup()

    def unwrap(self) -> T: ...

    def _create_forward(self):
        original_forward = self._module.forward
        original_type_hints = get_type_hints(original_forward)

        @functools.wraps(original_forward)
        def new_forward(*args: Any, **kwargs: Any) -> Any:
            # Remove 'self' from args if it's present
            if args and isinstance(args[0], WrappedInstance):
                args = args[1:]
            return self.hook_point(original_forward(*args, **kwargs))

        new_forward.__annotations__ = original_type_hints
        setattr(self.__class__, 'forward', new_forward)

class WrappedClass(Generic[T]):
    def __init__(self, module_class: Type[T]) -> T: # type: ignore
        self.module_class = module_class

    def __call__(self, *args: Any, **kwargs: Any) -> WrappedInstance[T]:
        instance = self.module_class(*args, **kwargs)
        return auto_wrap(instance)

    def __getattr__(self, name: str) -> Any:
        return getattr(self.module_class, name)

    def unwrap(self) -> Type[T]:
        return self.module_class

@overload
def auto_wrap(module_or_class: Type[T]) -> WrappedClass[T]: ...

@overload
def auto_wrap(module_or_class: T) -> WrappedInstance[T]: ...

def auto_wrap(module_or_class: Union[T, Type[T]]) -> Union[WrappedInstance[T], WrappedClass[T]]:
    '''
    This function wraps either a module instance or a module class and returns a type that
    preserves the original module's interface plus an additional unwrap method.
    '''
    if isclass(module_or_class):
        return WrappedClass(module_or_class)
    else:
        wrapped = WrappedInstance(module_or_class)
        #NOTE we set the unwrap method to just return module_or_class
        wrapped.unwrap = lambda: module_or_class # type: ignore
        return cast(WrappedInstance[T], wrapped)

WrappedLinear1 = auto_wrap(nn.Linear)(1,1)
WrappedLinear2 = auto_wrap(nn.Linear(10,1))

#WrappedLinear1.mod_dict



In [124]:
#BASIC TESTS
WrappedLinear1 = auto_wrap(nn.Linear)
WrappedLinear2 = auto_wrap(nn.Linear(10,1))

def test_types():
    WrappedLinear1Instance = WrappedLinear1(1, 1)
    assert type(WrappedLinear1Instance) == type(WrappedLinear2), f"{type(WrappedLinear1Instance)} != {type(WrappedLinear2)}"
    assert type(WrappedLinear1Instance.unwrap()) == type(WrappedLinear2.unwrap()) , f"{type(WrappedLinear1Instance.unwrap())} != {type(WrappedLinear2.unwrap())}"
    assert WrappedLinear1.unwrap() == type(WrappedLinear2.unwrap()), f"{WrappedLinear1.unwrap()} != {type(WrappedLinear2.unwrap())}"

def test_type_hints():
    linear_type_hints = get_type_hints(nn.Linear(10, 20).forward)
    wrapped_linear_instcane_type_hints = get_type_hints(WrappedLinear2.forward)
    wrapped_linear_class_type_hints = get_type_hints(WrappedLinear1.forward)
    assert linear_type_hints == wrapped_linear_instcane_type_hints, f"{linear_type_hints} != {wrapped_linear_instcane_type_hints}"
    assert linear_type_hints == wrapped_linear_class_type_hints, f"{linear_type_hints} != {wrapped_linear_class_type_hints}"

test_types()
test_type_hints()
