Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions deep_inject_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package inject

import "testing"

// 测试深度注入的各种场景
func TestDeepInjectAdvanced(t *testing.T) {
// 场景1: 多层嵌套的手动注入实例
type Level3 struct {
Value string
}

type Level2 struct {
L3 *Level3 `inject:""`
}

type Level1 struct {
L2 *Level2 `inject:""`
}

type Root struct {
L1 *Level1 `inject:""`
}

var g Graph

// 手动创建嵌套结构
root := &Root{
L1: &Level1{
L2: &Level2{
L3: &Level3{Value: "manually created"},
},
},
}

// 提供根对象
if err := g.Provide(&Object{Value: root}); err != nil {
t.Fatal("failed to provide root:", err)
}

// 执行深度注入
if err := g.Populate(); err != nil {
t.Fatal("failed to populate:", err)
}

// 验证深度注入是否成功
if root.L1 == nil {
t.Fatal("root.L1 should not be nil")
}
if root.L1.L2 == nil {
t.Fatal("root.L1.L2 should not be nil")
}
if root.L1.L2.L3 == nil {
t.Fatal("root.L1.L2.L3 should not be nil")
}
if root.L1.L2.L3.Value != "manually created" {
t.Fatal("deep injected value should be preserved")
}
}

func TestDeepInjectWithMixedProvision(t *testing.T) {
// 场景2: 混合手动创建和依赖注入提供的实例
type ServiceA struct {
Name string
}

type ServiceB struct {
A *ServiceA `inject:""`
}

type ServiceC struct {
B *ServiceB `inject:""`
}

var g Graph

// 手动提供ServiceA
serviceA := &ServiceA{Name: "ServiceA"}
if err := g.Provide(&Object{Value: serviceA}); err != nil {
t.Fatal("failed to provide serviceA:", err)
}

// 手动创建包含部分依赖的ServiceC
serviceC := &ServiceC{
B: &ServiceB{}, // B没有A的依赖
}

if err := g.Provide(&Object{Value: serviceC}); err != nil {
t.Fatal("failed to provide serviceC:", err)
}

// 执行注入
if err := g.Populate(); err != nil {
t.Fatal("failed to populate:", err)
}

// 验证混合注入结果
if serviceC.B == nil {
t.Fatal("serviceC.B should not be nil")
}
if serviceC.B.A == nil {
t.Fatal("serviceC.B.A should be injected")
}
if serviceC.B.A != serviceA {
t.Fatal("serviceC.B.A should be the same instance as serviceA")
}
if serviceC.B.A.Name != "ServiceA" {
t.Fatal("injected service should preserve its properties")
}
}

func TestDeepInjectCircularDependency(t *testing.T) {
// 场景3: 测试循环依赖的处理(这应该能正常工作)
type CircularB struct {
A interface{} `inject:""`
}

type CircularA struct {
B *CircularB `inject:""`
}

var g Graph

// 手动创建一个带有循环依赖的结构
a := &CircularA{}
b := &CircularB{A: a}
a.B = b

// 提供到依赖图
if err := g.Provide(&Object{Value: a}); err != nil {
t.Fatal("failed to provide a:", err)
}

if err := g.Provide(&Object{Value: b}); err != nil {
t.Fatal("failed to provide b:", err)
}

// 执行注入
if err := g.Populate(); err != nil {
t.Fatal("failed to populate:", err)
}

// 验证循环依赖保持完整
if a.B != b {
t.Fatal("circular dependency should be preserved")
}
if b.A.(*CircularA) != a {
t.Fatal("circular dependency should be preserved")
}
}
68 changes: 60 additions & 8 deletions inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,56 @@ func (g *Graph) Provide(objects ...*Object) error {
return nil
}

// provideForDeepInject 专门用于深度注入时提供对象,允许同类型的多个实例
func (g *Graph) provideForDeepInject(o *Object) error {
o.reflectType = reflect.TypeOf(o.Value)
o.reflectValue = reflect.ValueOf(o.Value)

if o.Fields != nil {
return fmt.Errorf("fields were specified on object %v when it was provided", o)
}

if o.Name == "" {
if !isStructPtr(o.reflectType) {
return fmt.Errorf(
"expected unnamed object value to be a pointer to a struct but got type %s with value %v",
o.reflectType,
o.Value)
}

// 对于深度注入,我们不检查类型重复,直接添加到unnamed列表
// 但我们需要检查是否已经存在相同的实例(相同的指针)
for _, existing := range g.unnamed {
if existing.Value == o.Value {
// 相同的实例已存在,不需要重复添加
return nil
}
}

g.unnamed = append(g.unnamed, o)
} else {
if g.named == nil {
g.named = make(map[string]*Object)
}

if g.named[o.Name] != nil {
return fmt.Errorf("provided two instances named %s", o.Name)
}
g.named[o.Name] = o
}

if g.Logger != nil {
if o.created {
g.Logger.Info("created %v", o)
} else if o.embedded {
g.Logger.Info("provided embedded %v", o)
} else {
g.Logger.Info("provided %v for deep injection", o)
}
}
return nil
}

// Populate 填充不完整的对象
func (g *Graph) Populate() error {
for _, o := range g.named {
Expand Down Expand Up @@ -235,14 +285,16 @@ StructLoop:
private: false,
created: false,
}
if err := g.Provide(existingObject); err == nil {
// 递归填充现有对象的依赖(深度注入)
if err := g.populateExplicit(existingObject); err != nil {
return err
}
if g.Logger != nil {
g.Logger.Info("deep injected existing %v in field %s of %v", existingObject, o.reflectType.Elem().Field(i).Name, o)
}
// 对于深度注入,我们需要特殊处理类型重复的情况
if err := g.provideForDeepInject(existingObject); err != nil {
return fmt.Errorf("failed to provide existing object for deep injection: %v", err)
}
// 递归填充现有对象的依赖(深度注入)
if err := g.populateExplicit(existingObject); err != nil {
return err
}
if g.Logger != nil {
g.Logger.Info("deep injected existing %v in field %s of %v", existingObject, o.reflectType.Elem().Field(i).Name, o)
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ type TypeAForTestDeepInject struct {
}

type TypeBForTestDeepInject struct {
A *TypeAForTestDeepInject `inject:""`
}

type TypeCForTestDeepInject struct {
Expand Down Expand Up @@ -1046,4 +1047,7 @@ func TestForDeepInject(t *testing.T) {
if d.A.C.B == nil {
t.Fatal("d.A.C.B is nil")
}
if b.A == nil {
t.Fatal("b.A is nil")
}
}
Loading