diff --git a/template.go b/template.go index 29c785c..371619f 100644 --- a/template.go +++ b/template.go @@ -74,9 +74,9 @@ import ( // Template represents a template. type Template struct { options interp.Options - use []interp.Exports - templateReader io.Reader + use interp.Exports imports importSymbols + templateReader io.Reader StartTokens []rune EndTokens []rune interp *interp.Interpreter @@ -107,21 +107,10 @@ func New( use ...interp.Exports) (*Template, error) { t := &Template{ options: options, - use: make([]interp.Exports, len(use)), + use: mergeExports(use...), StartTokens: []rune("<$"), EndTokens: []rune("$>"), } - - // copy use so we can be sure not to modify them - for i := range use { - t.use[i] = make(interp.Exports) - for packageName, funcMap := range use[i] { - t.use[i][packageName] = make(map[string]reflect.Value) - for funcName, funcReference := range funcMap { - t.use[i][packageName][funcName] = funcReference - } - } - } return t, nil } @@ -178,8 +167,18 @@ func (t *Template) LazyParse(reader io.Reader) error { t.interp = interp.New(t.options) - for i := 0; i < len(t.use); i++ { - t.interp.Use(t.use[i]) + // if we already have some uses + // use them + if len(t.use) != 0 { + t.interp.Use(t.use) + } + + // if we already have some imports + // import them + if len(t.imports) != 0 { + if _, err := t.safeEval(t.imports.ImportBlock()); err != nil { + return err + } } // import fmt @@ -436,9 +435,6 @@ func (*Template) hasPackage(s string) (bool, error) { // Import imports the specified imports to the interpreter. func (t *Template) Import(imports ...Import) error { - if t.interp == nil { - return errors.New("template must be parsed before Import can be used") - } var symbolsToImport importSymbols for _, symbol := range imports { if !t.imports.Contains(symbol) { @@ -450,8 +446,10 @@ func (t *Template) Import(imports ...Import) error { return nil } - if _, err := t.safeEval(symbolsToImport.ImportBlock()); err != nil { - return err + if t.interp != nil { // if we have an interpreter, import right now + if _, err := t.safeEval(symbolsToImport.ImportBlock()); err != nil { + return err + } } t.imports = append(t.imports, symbolsToImport...) return nil @@ -464,3 +462,53 @@ func (t *Template) MustImport(imports ...Import) *Template { } return t } + +// Use loads binary runtime symbols in the interpreter context so +// they can be used in interpreted code. +func (t *Template) Use(values ...interp.Exports) error { + return t.useExports(mergeExports(values...)) + +} + +func (t *Template) useExports(values interp.Exports) error { + if len(values) == 0 { + return nil + } + + t.use = mergeExports(t.use, values) + // if we have an interpreter, use right now + if t.interp != nil { + t.interp.Use(t.use) + } + return nil +} + +// MustUse is like Use, except it panics on failure. +func (t *Template) MustUse(values ...interp.Exports) *Template { + if err := t.Use(values...); err != nil { + panic(err) + } + return t +} + +func mergeExports(values ...interp.Exports) interp.Exports { + result := make(map[string]*map[string]reflect.Value) + for i := range values { + for packageName, funcMap := range values[i] { + existingFuncMap, ok := result[packageName] + if !ok { + m := make(map[string]reflect.Value) + existingFuncMap = &m + result[packageName] = existingFuncMap + } + for funcName, funcReference := range funcMap { + (*existingFuncMap)[funcName] = funcReference + } + } + } + r := make(interp.Exports, len(result)) + for s, m := range result { + r[s] = *m + } + return r +} diff --git a/template_test.go b/template_test.go index 7d28164..6502f7a 100644 --- a/template_test.go +++ b/template_test.go @@ -2,6 +2,7 @@ package yaegi_template import ( "io" + "reflect" "testing" "bytes" @@ -507,7 +508,7 @@ func TestTemplate_ExecToNilWriter(t *testing.T) { func TestTemplate_Import(t *testing.T) { t.Run("single", func(t *testing.T) { tm := MustNew(DefaultOptions(), DefaultSymbols()...). - MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(http.StatusOK) $>`))). + MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`). MustImport(Import{ Path: "net/http", }) @@ -517,7 +518,7 @@ func TestTemplate_Import(t *testing.T) { }) t.Run("double import", func(t *testing.T) { tm := MustNew(DefaultOptions(), DefaultSymbols()...). - MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(http.StatusOK) $>`))). + MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`). MustImport(Import{ Path: "net/http", }). @@ -530,7 +531,7 @@ func TestTemplate_Import(t *testing.T) { }) t.Run("alias import", func(t *testing.T) { tm := MustNew(DefaultOptions(), DefaultSymbols()...). - MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(h.StatusOK) $>`))). + MustParseString(`Hello <$ fmt.Print(h.StatusOK) $>`). MustImport(Import{ Name: "h", Path: "net/http", @@ -543,10 +544,129 @@ func TestTemplate_Import(t *testing.T) { require.Equal(t, "Hello 200", buf.String()) }) t.Run("import before parse", func(t *testing.T) { - err := MustNew(DefaultOptions(), DefaultSymbols()...). - Import(Import{ + tm := MustNew(DefaultOptions(), DefaultSymbols()...). + MustImport(Import{ Path: "net/http", - }) - require.EqualError(t, err, "template must be parsed before Import can be used") + }). + MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`) + var buf bytes.Buffer + tm.MustExec(&buf, nil) + require.Equal(t, "Hello 200", buf.String()) }) } + +func TestTemplateWithAdditionalSymbols(t *testing.T) { + t.Run("using New+Import", func(t *testing.T) { + t.Run("separate namespace", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + append(DefaultSymbols(), interp.Exports{ + "ext": map[string]reflect.Value{ + "Foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + })...). + MustImport(Import{ + Path: "ext", + }). + MustParseString(`Hello <$ fmt.Print(ext.Foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + t.Run("in own namespace (dot import)", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + append(DefaultSymbols(), interp.Exports{ + "ext": map[string]reflect.Value{ + "Foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + })...). + MustImport(Import{ + Name: ".", + Path: "ext", + }). + MustParseString(`Hello <$ fmt.Print(Foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + t.Run("in own namespace (private dot import)", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + append(DefaultSymbols(), interp.Exports{ + "ext": map[string]reflect.Value{ + "foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + })...). + MustImport(Import{ + Name: ".", + Path: "ext", + }). + MustParseString(`Hello <$ fmt.Print(foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + }) + + t.Run("Use() func", func(t *testing.T) { + t.Run("separate namespace", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + DefaultSymbols()...). + MustUse(interp.Exports{ + "ext": map[string]reflect.Value{ + "Foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + }). + MustImport(Import{Path: "ext"}). + MustParseString(`Hello <$ fmt.Print(ext.Foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + t.Run("in own namespace (dot import)", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + DefaultSymbols()...). + MustUse(interp.Exports{ + "ext": map[string]reflect.Value{ + "Foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + }). + MustImport(Import{Name: ".", Path: "ext"}). + MustParseString(`Hello <$ fmt.Print(Foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + t.Run("in own namespace (private dot import)", func(t *testing.T) { + var buf bytes.Buffer + MustNew( + DefaultOptions(), + DefaultSymbols()...). + MustUse(interp.Exports{ + "ext": map[string]reflect.Value{ + "foo": reflect.ValueOf(func() string { + return "foo" + }), + }, + }). + MustImport(Import{Name: ".", Path: "ext"}). + MustParseString(`Hello <$ fmt.Print(foo()) $>`). + MustExec(&buf, nil) + require.Equal(t, "Hello foo", buf.String()) + }) + }) + +}