diff --git a/deep_inject_test.go b/deep_inject_test.go new file mode 100644 index 0000000..7b38060 --- /dev/null +++ b/deep_inject_test.go @@ -0,0 +1,149 @@ +package inject + +import "testing" + +// 测试深度注入的各种场景 +func TestDeepInjectAdvanced(t *testing.T) { + // 场景1: 多层嵌套的手动注入实例 + type Level3 struct { + Value string + } + + type Level2 struct { + L3 *Level3 `inject:""` + } + + type Level1 struct { + L2 *Level2 `inject:""` + } + + type Root struct { + L1 *Level1 `inject:""` + } + + var g Graph + + // 手动创建嵌套结构 + root := &Root{ + L1: &Level1{ + L2: &Level2{ + L3: &Level3{Value: "manually created"}, + }, + }, + } + + // 提供根对象 + if err := g.Provide(&Object{Value: root}); err != nil { + t.Fatal("failed to provide root:", err) + } + + // 执行深度注入 + if err := g.Populate(); err != nil { + t.Fatal("failed to populate:", err) + } + + // 验证深度注入是否成功 + if root.L1 == nil { + t.Fatal("root.L1 should not be nil") + } + if root.L1.L2 == nil { + t.Fatal("root.L1.L2 should not be nil") + } + if root.L1.L2.L3 == nil { + t.Fatal("root.L1.L2.L3 should not be nil") + } + if root.L1.L2.L3.Value != "manually created" { + t.Fatal("deep injected value should be preserved") + } +} + +func TestDeepInjectWithMixedProvision(t *testing.T) { + // 场景2: 混合手动创建和依赖注入提供的实例 + type ServiceA struct { + Name string + } + + type ServiceB struct { + A *ServiceA `inject:""` + } + + type ServiceC struct { + B *ServiceB `inject:""` + } + + var g Graph + + // 手动提供ServiceA + serviceA := &ServiceA{Name: "ServiceA"} + if err := g.Provide(&Object{Value: serviceA}); err != nil { + t.Fatal("failed to provide serviceA:", err) + } + + // 手动创建包含部分依赖的ServiceC + serviceC := &ServiceC{ + B: &ServiceB{}, // B没有A的依赖 + } + + if err := g.Provide(&Object{Value: serviceC}); err != nil { + t.Fatal("failed to provide serviceC:", err) + } + + // 执行注入 + if err := g.Populate(); err != nil { + t.Fatal("failed to populate:", err) + } + + // 验证混合注入结果 + if serviceC.B == nil { + t.Fatal("serviceC.B should not be nil") + } + if serviceC.B.A == nil { + t.Fatal("serviceC.B.A should be injected") + } + if serviceC.B.A != serviceA { + t.Fatal("serviceC.B.A should be the same instance as serviceA") + } + if serviceC.B.A.Name != "ServiceA" { + t.Fatal("injected service should preserve its properties") + } +} + +func TestDeepInjectCircularDependency(t *testing.T) { + // 场景3: 测试循环依赖的处理(这应该能正常工作) + type CircularB struct { + A interface{} `inject:""` + } + + type CircularA struct { + B *CircularB `inject:""` + } + + var g Graph + + // 手动创建一个带有循环依赖的结构 + a := &CircularA{} + b := &CircularB{A: a} + a.B = b + + // 提供到依赖图 + if err := g.Provide(&Object{Value: a}); err != nil { + t.Fatal("failed to provide a:", err) + } + + if err := g.Provide(&Object{Value: b}); err != nil { + t.Fatal("failed to provide b:", err) + } + + // 执行注入 + if err := g.Populate(); err != nil { + t.Fatal("failed to populate:", err) + } + + // 验证循环依赖保持完整 + if a.B != b { + t.Fatal("circular dependency should be preserved") + } + if b.A.(*CircularA) != a { + t.Fatal("circular dependency should be preserved") + } +} diff --git a/inject.go b/inject.go index 215f8b6..e46db4f 100644 --- a/inject.go +++ b/inject.go @@ -114,6 +114,56 @@ func (g *Graph) Provide(objects ...*Object) error { return nil } +// provideForDeepInject 专门用于深度注入时提供对象,允许同类型的多个实例 +func (g *Graph) provideForDeepInject(o *Object) error { + o.reflectType = reflect.TypeOf(o.Value) + o.reflectValue = reflect.ValueOf(o.Value) + + if o.Fields != nil { + return fmt.Errorf("fields were specified on object %v when it was provided", o) + } + + if o.Name == "" { + if !isStructPtr(o.reflectType) { + return fmt.Errorf( + "expected unnamed object value to be a pointer to a struct but got type %s with value %v", + o.reflectType, + o.Value) + } + + // 对于深度注入,我们不检查类型重复,直接添加到unnamed列表 + // 但我们需要检查是否已经存在相同的实例(相同的指针) + for _, existing := range g.unnamed { + if existing.Value == o.Value { + // 相同的实例已存在,不需要重复添加 + return nil + } + } + + g.unnamed = append(g.unnamed, o) + } else { + if g.named == nil { + g.named = make(map[string]*Object) + } + + if g.named[o.Name] != nil { + return fmt.Errorf("provided two instances named %s", o.Name) + } + g.named[o.Name] = o + } + + if g.Logger != nil { + if o.created { + g.Logger.Info("created %v", o) + } else if o.embedded { + g.Logger.Info("provided embedded %v", o) + } else { + g.Logger.Info("provided %v for deep injection", o) + } + } + return nil +} + // Populate 填充不完整的对象 func (g *Graph) Populate() error { for _, o := range g.named { @@ -235,14 +285,16 @@ StructLoop: private: false, created: false, } - if err := g.Provide(existingObject); err == nil { - // 递归填充现有对象的依赖(深度注入) - if err := g.populateExplicit(existingObject); err != nil { - return err - } - if g.Logger != nil { - g.Logger.Info("deep injected existing %v in field %s of %v", existingObject, o.reflectType.Elem().Field(i).Name, o) - } + // 对于深度注入,我们需要特殊处理类型重复的情况 + if err := g.provideForDeepInject(existingObject); err != nil { + return fmt.Errorf("failed to provide existing object for deep injection: %v", err) + } + // 递归填充现有对象的依赖(深度注入) + if err := g.populateExplicit(existingObject); err != nil { + return err + } + if g.Logger != nil { + g.Logger.Info("deep injected existing %v in field %s of %v", existingObject, o.reflectType.Elem().Field(i).Name, o) } } } diff --git a/inject_test.go b/inject_test.go index 0976fbc..d2e422c 100644 --- a/inject_test.go +++ b/inject_test.go @@ -998,6 +998,7 @@ type TypeAForTestDeepInject struct { } type TypeBForTestDeepInject struct { + A *TypeAForTestDeepInject `inject:""` } type TypeCForTestDeepInject struct { @@ -1046,4 +1047,7 @@ func TestForDeepInject(t *testing.T) { if d.A.C.B == nil { t.Fatal("d.A.C.B is nil") } + if b.A == nil { + t.Fatal("b.A is nil") + } }