16
16
from contextlib import ContextDecorator , contextmanager , suppress
17
17
from dataclasses import dataclass , field
18
18
from enum import Enum , auto
19
- from functools import partialmethod , singledispatchmethod , wraps
19
+ from functools import (
20
+ partialmethod ,
21
+ singledispatchmethod ,
22
+ update_wrapper ,
23
+ wraps ,
24
+ )
20
25
from inspect import Signature , isclass
21
- from types import MappingProxyType , UnionType
26
+ from types import UnionType
22
27
from typing import (
23
28
Any ,
29
+ ClassVar ,
24
30
ContextManager ,
25
31
NamedTuple ,
26
32
NoReturn ,
@@ -169,7 +175,7 @@ def get_instance(self) -> _T:
169
175
class SingletonInjectable (BaseInjectable [_T ]):
170
176
__slots__ = ("__dict__" ,)
171
177
172
- __INSTANCE_KEY = "$instance"
178
+ __INSTANCE_KEY : ClassVar [ str ] = "$instance"
173
179
174
180
@property
175
181
def cache (self ) -> MutableMapping [str , Any ]:
@@ -405,21 +411,15 @@ def inject(
405
411
wrapped : Callable [..., Any ] = None ,
406
412
/ ,
407
413
* ,
408
- force : bool = False ,
409
414
return_factory : bool = False ,
410
415
):
411
416
def decorator (wp ):
412
417
if not return_factory and isclass (wp ):
413
- wp .__init__ = self .inject (wp .__init__ , force = force )
418
+ wp .__init__ = self .inject (wp .__init__ )
414
419
return wp
415
420
416
- lazy_binder = Lazy [Binder ](lambda : self .__new_binder (wp ))
417
-
418
- @wraps (wp )
419
- def wrapper (* args , ** kwargs ):
420
- arguments = (~ lazy_binder ).bind (args , kwargs , force )
421
- return wp (* arguments .args , ** arguments .kwargs )
422
-
421
+ wrapper = InjectedFunction (wp ).update (self )
422
+ self .add_listener (wrapper )
423
423
return wrapper
424
424
425
425
return decorator (wrapped ) if wrapped else decorator
@@ -535,21 +535,15 @@ def __move_module(self, module: Module, priority: ModulePriority):
535
535
f"`{ module } ` can't be found in the modules used by `{ self } `."
536
536
) from exc
537
537
538
- def __new_binder (self , target : Callable [..., Any ]) -> Binder :
539
- signature = inspect .signature (target , eval_str = True )
540
- binder = Binder (signature ).update (self )
541
- self .add_listener (binder )
542
- return binder
543
-
544
538
545
539
"""
546
- Binder
540
+ InjectedFunction
547
541
"""
548
542
549
543
550
544
@dataclass (repr = False , frozen = True , slots = True )
551
545
class Dependencies :
552
- mapping : MappingProxyType [str , Injectable ]
546
+ mapping : Mapping [str , Injectable ]
553
547
554
548
def __bool__ (self ) -> bool :
555
549
return bool (self .mapping )
@@ -558,75 +552,137 @@ def __iter__(self) -> Iterator[tuple[str, Any]]:
558
552
for name , injectable in self .mapping .items ():
559
553
yield name , injectable .get_instance ()
560
554
555
+ @property
556
+ def are_resolved (self ) -> bool :
557
+ if isinstance (self .mapping , LazyMapping ) and not self .mapping .is_set :
558
+ return False
559
+
560
+ return bool (self )
561
+
561
562
@property
562
563
def arguments (self ) -> OrderedDict [str , Any ]:
563
564
return OrderedDict (self )
564
565
565
566
@classmethod
566
567
def from_mapping (cls , mapping : Mapping [str , Injectable ]):
567
- return cls (mapping = MappingProxyType ( mapping ) )
568
+ return cls (mapping = mapping )
568
569
569
570
@classmethod
570
571
def empty (cls ):
571
572
return cls .from_mapping ({})
572
573
573
574
@classmethod
574
- def resolve (cls , signature : Signature , module : Module ):
575
- dependencies = LazyMapping (cls .__resolver (signature , module ))
575
+ def resolve (cls , signature : Signature , module : Module , owner : type = None ):
576
+ dependencies = LazyMapping (cls .__resolver (signature , module , owner ))
576
577
return cls .from_mapping (dependencies )
577
578
578
579
@classmethod
579
580
def __resolver (
580
581
cls ,
581
582
signature : Signature ,
582
583
module : Module ,
584
+ owner : type = None ,
583
585
) -> Iterator [tuple [str , Injectable ]]:
584
- for name , parameter in signature . parameters . items ( ):
586
+ for name , annotation in cls . __get_annotations ( signature , owner ):
585
587
try :
586
- injectable = module [parameter . annotation ]
588
+ injectable = module [annotation ]
587
589
except KeyError :
588
590
continue
589
591
590
592
yield name , injectable
591
593
594
+ @staticmethod
595
+ def __get_annotations (
596
+ signature : Signature ,
597
+ owner : type = None ,
598
+ ) -> Iterator [tuple [str , type | Any ]]:
599
+ parameters = iter (signature .parameters .items ())
600
+
601
+ if owner :
602
+ name , _ = next (parameters )
603
+ yield name , owner
604
+
605
+ for name , parameter in parameters :
606
+ yield name , parameter .annotation
607
+
592
608
593
609
class Arguments (NamedTuple ):
594
610
args : Iterable [Any ]
595
611
kwargs : Mapping [str , Any ]
596
612
597
613
598
- class Binder (EventListener ):
599
- __slots__ = ("__signature" , "__dependencies" )
614
+ class InjectedFunction (EventListener ):
615
+ __slots__ = ("__dict__" , "__wrapper" , "__dependencies" , "__owner" )
616
+
617
+ def __init__ (self , wrapped : Callable [..., Any ], / ):
618
+ update_wrapper (self , wrapped )
619
+ self .__signature__ = Lazy [Signature ](
620
+ lambda : inspect .signature (wrapped , eval_str = True )
621
+ )
622
+
623
+ @wraps (wrapped )
624
+ def wrapper (* args , ** kwargs ):
625
+ args , kwargs = self .bind (args , kwargs )
626
+ return wrapped (* args , ** kwargs )
600
627
601
- def __init__ (self , signature : Signature ):
602
- self .__signature = signature
628
+ self .__wrapper = wrapper
603
629
self .__dependencies = Dependencies .empty ()
630
+ self .__owner = None
631
+
632
+ def __repr__ (self ) -> str :
633
+ return repr (self .__wrapper )
634
+
635
+ def __str__ (self ) -> str :
636
+ return str (self .__wrapper )
637
+
638
+ def __call__ (self , / , * args , ** kwargs ) -> Any :
639
+ return self .__wrapper (* args , ** kwargs )
640
+
641
+ def __get__ (self , instance : object | None , owner : type ):
642
+ if instance is None :
643
+ return self
644
+
645
+ return self .__wrapper .__get__ (instance , owner )
646
+
647
+ def __set_name__ (self , owner : type , name : str ):
648
+ if self .__dependencies .are_resolved :
649
+ raise TypeError (
650
+ "`__set_name__` is called after dependencies have been resolved."
651
+ )
652
+
653
+ if self .__owner :
654
+ raise TypeError ("Function owner is already defined." )
655
+
656
+ self .__owner = owner
657
+
658
+ @property
659
+ def signature (self ) -> Signature :
660
+ return self .__signature__ ()
604
661
605
662
def bind (
606
663
self ,
607
664
args : Iterable [Any ] = (),
608
665
kwargs : Mapping [str , Any ] = None ,
609
- force : bool = False ,
610
666
) -> Arguments :
611
667
if kwargs is None :
612
668
kwargs = {}
613
669
614
670
if not self .__dependencies :
615
671
return Arguments (args , kwargs )
616
672
617
- bound = self .__signature .bind_partial (* args , ** kwargs )
673
+ bound = self .signature .bind_partial (* args , ** kwargs )
618
674
dependencies = self .__dependencies .arguments
619
-
620
- if force :
621
- bound .arguments |= dependencies
622
- else :
623
- bound .arguments = dependencies | bound .arguments
675
+ bound .arguments = dependencies | bound .arguments
624
676
625
677
return Arguments (bound .args , bound .kwargs )
626
678
627
679
def update (self , module : Module ):
628
680
with thread_lock :
629
- self .__dependencies = Dependencies .resolve (self .__signature , module )
681
+ self .__dependencies = Dependencies .resolve (
682
+ self .signature ,
683
+ module ,
684
+ self .__owner ,
685
+ )
630
686
631
687
return self
632
688
0 commit comments