Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generator patches (#235) #237

Merged
merged 21 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ ts/
docusaurus/video/docusaurus/shared
node_modules/.yarn-integrity
yarn.lock
openapi-gen/test/models/*
openapi-gen/test/*.py
42 changes: 42 additions & 0 deletions openapi-gen/builtin_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,51 @@ import (
"github.com/iancoleman/strcase"
)

type set map[string]struct{}

func newSet() set {
return make(set)
}

func (s set) Add(value string) bool {
if _, ok := s[value]; ok {
return false
}
s[value] = struct{}{}
return true
}

func (s set) Contains(value string) bool {
_, ok := s[value]
return ok
}

func toConstant(s string) string {
return strings.ToUpper(strings.ReplaceAll(s, "-", "_"))
}

func PrepareBuiltinFunctions(config *Config) template.FuncMap {
return template.FuncMap{
"refToName": refToName,
"toSnake": strcase.ToSnake,
"toCamel": strcase.ToCamel,
"lower": strings.ToLower,
"newSet": newSet,
"append": func(value []string, elems ...string) []string {
return append(value, elems...)
},
"list": func(value ...string) []string {
return value
},
"contains": func(value []string, elem string) bool {
for _, el := range value {
if el == elem {
return true
}
}
return false
},
"join": strings.Join,
"has": func(sl []string, str string) bool {
for _, s := range sl {
if s == str {
Expand All @@ -21,6 +61,8 @@ func PrepareBuiltinFunctions(config *Config) template.FuncMap {
}
return false
},
"toUpper": strings.ToUpper,
"toConstant": toConstant,
"successfulResponse": func(responses openapi3.Responses) *openapi3.SchemaRef {
for code, response := range responses {
if code == "200" || code == "201" {
Expand Down
3 changes: 2 additions & 1 deletion openapi-gen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
type Config struct {
AdditionalParameters map[string]any `yaml:"additionalParameters" json:"additionalParameters"`
// FileNameModifier is function name to be used for generating the file name.
FileNameModifier string `yaml:"fileNameModifier" json:"fileNameModifier"`
FileNameModifier string `yaml:"fileNameModifier" json:"fileNameModifier"`
// prefix imports with this string
FileExtension string `yaml:"fileExtension" json:"fileExtension"`
CopyAdditionalFiles []string `yaml:"copyAdditionalFiles" json:"copyAdditionalFiles"`
GenerateRequestTypes bool `yaml:"generateRequestTypes" json:"generateRequestTypes"`
Expand Down
86 changes: 59 additions & 27 deletions openapi-gen/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ func (t *TemplateLoader) LoadTemplate(kind templateKind) *template.Template {
}

type TypeContext struct {
Schema *openapi3.Schema
Name string
Additional map[string]any
References []string
Schema *openapi3.Schema
Name string
Additional map[string]any
References []string
HasNonRequired bool
}

type RequestContext struct {
Expand All @@ -71,6 +72,15 @@ type RequestContext struct {
Body *openapi3.RequestBodyRef
}

func findInSlice(val string, slice []string) bool {
for _, item := range slice {
if item == val {
return true
}
}
return false
}

// go run . -i ../openapi/video-openapi.yaml -o ./go-generated -l go
func main() {
inputFile := flag.String("i", "", "yaml file to use for generating code")
Expand Down Expand Up @@ -107,6 +117,8 @@ func main() {
os.Exit(1)
}

// we have invalid spec according to this validation
// for example, we use extensions with reference
// err = doc.Validate(context.Background())
// if err != nil {
// fmt.Println("error validating doc", err)
Expand Down Expand Up @@ -159,30 +171,35 @@ func main() {
fmt.Println("error checking models subpackage", err)
os.Exit(1)
}
} else {
modelsDir = *outputDir
}

for _, filePath := range config.ModelsCopyFiles {
filename := filepath.Base(filePath)
dst, err := os.Create(path.Join(modelsDir, filename))
if err != nil {
fmt.Println("error creating file", err)
os.Exit(1)
}
src, err := os.Open(path.Join("templates", *targetLanguage, filePath))
if err != nil {
fmt.Println("error opening file", err)
os.Exit(1)
}

_, err = io.Copy(dst, src)
if err != nil {
fmt.Println("error copying file", err)
os.Exit(1)
}
for _, filePath := range config.ModelsCopyFiles {
filename := filepath.Base(filePath)
dst, err := os.Create(path.Join(modelsDir, filename))
if err != nil {
fmt.Println("error creating file", err)
os.Exit(1)
}
src, err := os.Open(path.Join("templates", *targetLanguage, filePath))
if err != nil {
fmt.Println("error opening file", err)
os.Exit(1)
}

_, err = io.Copy(dst, src)
if err != nil {
fmt.Println("error copying file", err)
os.Exit(1)
}
}

for name, schema := range doc.Components.Schemas {
if len(schema.Value.Properties) == 0 && len(schema.Value.Enum) == 0 && schema.Value.OneOf == nil {
fmt.Println("skipping", name, "because it has no properties or OneOf definitions")
continue
}
ext := config.FileExtension
f, err := os.Create(path.Join(modelsDir, config.getNameModifier()(name)+ext))
if err != nil {
Expand All @@ -191,11 +208,20 @@ func main() {
}
defer f.Close()

hasNonRequired := false
for propName := range schema.Value.Properties {
if !findInSlice(propName, schema.Value.Required) {
hasNonRequired = true
break
}
}

err = tmpl.Execute(f, TypeContext{
Name: name,
Schema: schema.Value,
Additional: config.AdditionalParameters,
References: getReferencesFromTypes(schema.Value),
Name: name,
Schema: schema.Value,
Additional: config.AdditionalParameters,
References: getReferencesFromTypes(schema.Value),
HasNonRequired: hasNonRequired,
})

if err != nil {
Expand Down Expand Up @@ -247,7 +273,6 @@ func main() {
}
}
}

f, err := os.Create(path.Join(*outputDir, "client"+config.FileExtension))
if err != nil {
fmt.Println("error creating file", err)
Expand Down Expand Up @@ -364,5 +389,12 @@ func getReferencesFromTypes(schema *openapi3.Schema) []string {
}
}

if schema.AdditionalProperties.Schema != nil {
if schema.AdditionalProperties.Schema.Ref != "" {
refs = append(refs, schema.AdditionalProperties.Schema.Ref)
} else if schema.AdditionalProperties.Schema.Value != nil {
refs = append(refs, getReferencesFromTypes(schema.AdditionalProperties.Schema.Value)...)
}
}
return refs
}
8 changes: 8 additions & 0 deletions openapi-gen/templates/go/type.tmpl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package {{index .Additional "package"}}

import (
"time"
)

{{/* this make imports valid even if time package is not needed in this file */}}
var _ = time.Time{}

{{- define "generateObject"}} struct {
{{- $struct := . -}}
{{- range $key, $value := $struct.Properties}}
Expand All @@ -10,6 +17,7 @@ package {{index .Additional "package"}}
{{- end}}

{{- end}}
}
{{- end}}

{{- define "generateType"}}
Expand Down
6 changes: 4 additions & 2 deletions openapi-gen/templates/python/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from urllib.parse import quote

def build_path(path: str, path_params: dict) -> str:
for k, v in path_params:
path_params[k] = quote(v, safe='') # in case of special characters in the path. Known cases: chat message ids.
for k, v in path_params.items():
path_params[k] = quote(
v, safe=""
) # in case of special characters in the path. Known cases: chat message ids.

return path.format(**path_params)

Expand Down
34 changes: 32 additions & 2 deletions openapi-gen/templates/python/client.tmpl
Original file line number Diff line number Diff line change
@@ -1,11 +1,41 @@
from base_client import BaseClient, StreamResponse
from dataclasses import asdict
from typing import Optional
from datetime import datetime
{{- range $i, $ref := (clientReferences .Paths)}}
from models.{{refToName $ref | toSnake}} import {{refToName $ref | toCamel}}
from {{with index additionalParameters "modelImportPrefix"}}{{.}}{{else}}models.{{end}}models.{{refToName $ref | toSnake}} import {{refToName $ref | toCamel}}
{{- end}}

{{/* TODO: make class name configurable */ -}}
class {{with index additionalParameters "clientClassName"}}{{.}}{{else}}Client{{end}}(BaseClient):
def __init__(self, api_key: str, base_url, token, timeout, user_agent):
"""
Initializes VideoClient with BaseClient instance
:param api_key: A string representing the client's API key
:param base_url: A string representing the base uniform resource locator
:param token: A string instance representing the client's token
:param timeout: A number representing the time limit for a request
:param user_agent: A string representing the user agent
"""
super().__init__(
api_key=api_key,
base_url=base_url,
token=token,
timeout=timeout,
user_agent=user_agent,
)

def call(self, call_type: str, call_id: str):
"""
Returns instance of Call class
param call_type: A string representing the call type
:param call_id: A string representing a unique call identifier
:return: Instance of Call class
"""
return Call(self, call_type, call_id)

{{- range $path, $item := .Paths -}}{{- range $method, $operation := $item.Operations}}
{{template "requestFunction" (operationContext $operation $method $path)}}
{{- end -}}{{- end -}}
{{- end -}}

{{- end -}}
3 changes: 2 additions & 1 deletion openapi-gen/templates/python/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ modelsSubpackage: models
modelsCopyFiles:
- models/__init__.py
additionalParameters:
clientClassName: VideoClient
clientClassName: VideoClient
modelImportPrefix: getstream.
5 changes: 3 additions & 2 deletions openapi-gen/templates/python/request.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def {{toSnake .OperationID}}(self{{range $index, $param := (requiredParameters .
{{- end}}
{{- end}}

self.request("{{.Method}}", "{{.Path}}",
return self.{{.Method | lower}}("{{.Path}}",
{{with (successfulResponse .Responses)}} {{template "generateSchemaRef" .}},{{end}}
query_params=query_params,
path_params=path_params{{- if (requestSchema .Operation)}},
body=data{{end -}}
json=asdict(data){{end -}}
)

{{ end -}}
Loading