Skip to content
This repository has been archived by the owner on Jul 11, 2023. It is now read-only.

Commit

Permalink
No commit message
Browse files Browse the repository at this point in the history
  • Loading branch information
Eun committed Dec 7, 2020
1 parent 0ea20aa commit 128ed9b
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 28 deletions.
90 changes: 69 additions & 21 deletions template.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ import (
// Template represents a template.
type Template struct {
options interp.Options
use []interp.Exports
templateReader io.Reader
use interp.Exports
imports importSymbols
templateReader io.Reader
StartTokens []rune
EndTokens []rune
interp *interp.Interpreter
Expand Down Expand Up @@ -107,21 +107,10 @@ func New(
use ...interp.Exports) (*Template, error) {
t := &Template{
options: options,
use: make([]interp.Exports, len(use)),
use: mergeExports(use...),
StartTokens: []rune("<$"),
EndTokens: []rune("$>"),
}

// copy use so we can be sure not to modify them
for i := range use {
t.use[i] = make(interp.Exports)
for packageName, funcMap := range use[i] {
t.use[i][packageName] = make(map[string]reflect.Value)
for funcName, funcReference := range funcMap {
t.use[i][packageName][funcName] = funcReference
}
}
}
return t, nil
}

Expand Down Expand Up @@ -178,8 +167,18 @@ func (t *Template) LazyParse(reader io.Reader) error {

t.interp = interp.New(t.options)

for i := 0; i < len(t.use); i++ {
t.interp.Use(t.use[i])
// if we already have some uses
// use them
if len(t.use) != 0 {
t.interp.Use(t.use)
}

// if we already have some imports
// import them
if len(t.imports) != 0 {
if _, err := t.safeEval(t.imports.ImportBlock()); err != nil {
return err
}
}

// import fmt
Expand Down Expand Up @@ -436,9 +435,6 @@ func (*Template) hasPackage(s string) (bool, error) {

// Import imports the specified imports to the interpreter.
func (t *Template) Import(imports ...Import) error {
if t.interp == nil {
return errors.New("template must be parsed before Import can be used")
}
var symbolsToImport importSymbols
for _, symbol := range imports {
if !t.imports.Contains(symbol) {
Expand All @@ -450,8 +446,10 @@ func (t *Template) Import(imports ...Import) error {
return nil
}

if _, err := t.safeEval(symbolsToImport.ImportBlock()); err != nil {
return err
if t.interp != nil { // if we have an interpreter, import right now
if _, err := t.safeEval(symbolsToImport.ImportBlock()); err != nil {
return err
}
}
t.imports = append(t.imports, symbolsToImport...)
return nil
Expand All @@ -464,3 +462,53 @@ func (t *Template) MustImport(imports ...Import) *Template {
}
return t
}

// Use loads binary runtime symbols in the interpreter context so
// they can be used in interpreted code.
func (t *Template) Use(values ...interp.Exports) error {
return t.useExports(mergeExports(values...))

}

func (t *Template) useExports(values interp.Exports) error {
if len(values) == 0 {
return nil
}

t.use = mergeExports(t.use, values)
// if we have an interpreter, use right now
if t.interp != nil {
t.interp.Use(t.use)
}
return nil
}

// MustUse is like Use, except it panics on failure.
func (t *Template) MustUse(values ...interp.Exports) *Template {
if err := t.Use(values...); err != nil {
panic(err)
}
return t
}

func mergeExports(values ...interp.Exports) interp.Exports {
result := make(map[string]*map[string]reflect.Value)
for i := range values {
for packageName, funcMap := range values[i] {
existingFuncMap, ok := result[packageName]
if !ok {
m := make(map[string]reflect.Value)
existingFuncMap = &m
result[packageName] = existingFuncMap
}
for funcName, funcReference := range funcMap {
(*existingFuncMap)[funcName] = funcReference
}
}
}
r := make(interp.Exports, len(result))
for s, m := range result {
r[s] = *m
}
return r
}
134 changes: 127 additions & 7 deletions template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package yaegi_template

import (
"io"
"reflect"
"testing"

"bytes"
Expand Down Expand Up @@ -507,7 +508,7 @@ func TestTemplate_ExecToNilWriter(t *testing.T) {
func TestTemplate_Import(t *testing.T) {
t.Run("single", func(t *testing.T) {
tm := MustNew(DefaultOptions(), DefaultSymbols()...).
MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(http.StatusOK) $>`))).
MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`).
MustImport(Import{
Path: "net/http",
})
Expand All @@ -517,7 +518,7 @@ func TestTemplate_Import(t *testing.T) {
})
t.Run("double import", func(t *testing.T) {
tm := MustNew(DefaultOptions(), DefaultSymbols()...).
MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(http.StatusOK) $>`))).
MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`).
MustImport(Import{
Path: "net/http",
}).
Expand All @@ -530,7 +531,7 @@ func TestTemplate_Import(t *testing.T) {
})
t.Run("alias import", func(t *testing.T) {
tm := MustNew(DefaultOptions(), DefaultSymbols()...).
MustLazyParse(bytes.NewReader([]byte(`Hello <$ fmt.Print(h.StatusOK) $>`))).
MustParseString(`Hello <$ fmt.Print(h.StatusOK) $>`).
MustImport(Import{
Name: "h",
Path: "net/http",
Expand All @@ -543,10 +544,129 @@ func TestTemplate_Import(t *testing.T) {
require.Equal(t, "Hello 200", buf.String())
})
t.Run("import before parse", func(t *testing.T) {
err := MustNew(DefaultOptions(), DefaultSymbols()...).
Import(Import{
tm := MustNew(DefaultOptions(), DefaultSymbols()...).
MustImport(Import{
Path: "net/http",
})
require.EqualError(t, err, "template must be parsed before Import can be used")
}).
MustParseString(`Hello <$ fmt.Print(http.StatusOK) $>`)
var buf bytes.Buffer
tm.MustExec(&buf, nil)
require.Equal(t, "Hello 200", buf.String())
})
}

func TestTemplateWithAdditionalSymbols(t *testing.T) {
t.Run("using New+Import", func(t *testing.T) {
t.Run("separate namespace", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
append(DefaultSymbols(), interp.Exports{
"ext": map[string]reflect.Value{
"Foo": reflect.ValueOf(func() string {
return "foo"
}),
},
})...).
MustImport(Import{
Path: "ext",
}).
MustParseString(`Hello <$ fmt.Print(ext.Foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
t.Run("in own namespace (dot import)", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
append(DefaultSymbols(), interp.Exports{
"ext": map[string]reflect.Value{
"Foo": reflect.ValueOf(func() string {
return "foo"
}),
},
})...).
MustImport(Import{
Name: ".",
Path: "ext",
}).
MustParseString(`Hello <$ fmt.Print(Foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
t.Run("in own namespace (private dot import)", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
append(DefaultSymbols(), interp.Exports{
"ext": map[string]reflect.Value{
"foo": reflect.ValueOf(func() string {
return "foo"
}),
},
})...).
MustImport(Import{
Name: ".",
Path: "ext",
}).
MustParseString(`Hello <$ fmt.Print(foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
})

t.Run("Use() func", func(t *testing.T) {
t.Run("separate namespace", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
DefaultSymbols()...).
MustUse(interp.Exports{
"ext": map[string]reflect.Value{
"Foo": reflect.ValueOf(func() string {
return "foo"
}),
},
}).
MustImport(Import{Path: "ext"}).
MustParseString(`Hello <$ fmt.Print(ext.Foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
t.Run("in own namespace (dot import)", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
DefaultSymbols()...).
MustUse(interp.Exports{
"ext": map[string]reflect.Value{
"Foo": reflect.ValueOf(func() string {
return "foo"
}),
},
}).
MustImport(Import{Name: ".", Path: "ext"}).
MustParseString(`Hello <$ fmt.Print(Foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
t.Run("in own namespace (private dot import)", func(t *testing.T) {
var buf bytes.Buffer
MustNew(
DefaultOptions(),
DefaultSymbols()...).
MustUse(interp.Exports{
"ext": map[string]reflect.Value{
"foo": reflect.ValueOf(func() string {
return "foo"
}),
},
}).
MustImport(Import{Name: ".", Path: "ext"}).
MustParseString(`Hello <$ fmt.Print(foo()) $>`).
MustExec(&buf, nil)
require.Equal(t, "Hello foo", buf.String())
})
})

}

0 comments on commit 128ed9b

Please sign in to comment.