diff --git a/src/go/constant.ts b/src/go/constant.ts index 7ac39d8..32f8553 100644 --- a/src/go/constant.ts +++ b/src/go/constant.ts @@ -34,20 +34,24 @@ export const translations: Map = new Map([ export const IMPORTS = { context: "context", + embed: "embed", errors: "errors", json: "encoding/json", net: "net", os: "os", + http: "net/http", fiber: "github.com/gofiber/fiber/v2", - tfiber: "github.com/apexlang/api-go/transport/tfiber", - httpresponse: "github.com/apexlang/api-go/transport/httpresponse", emptypb: "google.golang.org/protobuf/types/known/emptypb", - errorz: "github.com/apexlang/api-go/errorz", - convert: "github.com/apexlang/api-go/convert", timestamppb: "google.golang.org/protobuf/types/known/timestamppb", wrapperspb: "google.golang.org/protobuf/types/known/wrapperspb", grpc: "google.golang.org/grpc", + convert: "github.com/apexlang/api-go/convert", + errorz: "github.com/apexlang/api-go/errorz", + httpresponse: "github.com/apexlang/api-go/transport/httpresponse", + authorization: "github.com/apexlang/api-go/transport/authorization", + tfiber: "github.com/apexlang/api-go/transport/tfiber", tgrpc: "github.com/apexlang/api-go/transport/tgrpc", + thttp: "github.com/apexlang/api-go/transport/thttp", msgpack: "github.com/wapc/tinygo-msgpack", msgpackconvert: ["convert", "github.com/wapc/tinygo-msgpack/convert"], zap: "go.uber.org/zap", diff --git a/src/go/embed_visitor.ts b/src/go/embed_visitor.ts new file mode 100644 index 0000000..11542cc --- /dev/null +++ b/src/go/embed_visitor.ts @@ -0,0 +1,24 @@ +import { Context } from "../../deps/@apexlang/core/model/mod.ts"; +import { getImports, GoVisitor } from "./go_visitor.ts"; + +interface Embed { + path: string; + var: string; + type: string; +} + +export class EmbedVisitor extends GoVisitor { + public override visitInterfacesBefore(context: Context): void { + const importer = getImports(context); + const config = context.config.embed as Embed[]; + + if (config && config.length && config.length > 0) { + importer.stdlib("embed", "_"); + config.forEach((value) => { + this.write(`//go:embed ${value.path} +var ${value.var} ${value.type}\n`); + }); + this.write(`\n`); + } + } +} diff --git a/src/go/mod.ts b/src/go/mod.ts index 39f6a6c..9ab2f8f 100644 --- a/src/go/mod.ts +++ b/src/go/mod.ts @@ -17,6 +17,7 @@ limitations under the License. export * from "./go_visitor.ts"; export * from "./alias_visitor.ts"; export * from "./constant.ts"; +export * from "./embed_visitor.ts"; export * from "./enum_visitor.ts"; export * from "./fiber_visitor.ts"; export * from "./grpc_visitor.ts"; @@ -26,6 +27,7 @@ export * from "./interfaces_visitor.ts"; export * from "./main_visitor.ts"; export * from "./scaffold_visitor.ts"; export * from "./struct_visitor.ts"; +export * from "./servemux_visitor.ts"; export * from "./union_visitor.ts"; export * from "./msgpack_visitor.ts"; export * from "./msgpack_constants.ts"; diff --git a/src/go/servemux_visitor.ts b/src/go/servemux_visitor.ts new file mode 100644 index 0000000..942a545 --- /dev/null +++ b/src/go/servemux_visitor.ts @@ -0,0 +1,186 @@ +import { + Annotation, + AnyType, + Context, + Kind, + Type, +} from "../../deps/@apexlang/core/model/mod.ts"; +import { + capitalize, + convertOperationToType, + isKinds, + isObject, + isService, + unwrapKinds, +} from "../utils/mod.ts"; +import { getMethods, getPath, hasBody, ScopesDirective } from "../rest/mod.ts"; +import { StructVisitor } from "./struct_visitor.ts"; +import { expandType, fieldName, methodName } from "./helpers.ts"; +import { translateAlias } from "./alias_visitor.ts"; +import { getImporter, GoVisitor } from "./go_visitor.ts"; +import { IMPORTS } from "./constant.ts"; + +export class ServeMuxVisitor extends GoVisitor { + public override visitInterfaceBefore(context: Context): void { + if (!isService(context)) { + return; + } + + const { interface: iface } = context; + const visitor = new ServeMuxServiceVisitor(this.writer); + iface.accept(context, visitor); + } +} + +class ServeMuxServiceVisitor extends GoVisitor { + public override visitInterfaceBefore(context: Context): void { + const { interface: iface } = context; + const $ = getImporter(context, IMPORTS); + this + .write( + `func ${iface.name}ServeMux(service ${iface.name}) func(*${$.http}.ServeMux) { + return func(mux *${$.http}.ServeMux) {\n`, + ); + } + + public override visitOperation(context: Context): void { + const { interface: iface, operation } = context; + const $ = getImporter(context, IMPORTS); + const path = getPath(context); + if (path == "") { + return; + } + const methods = getMethods(operation).map((m) => + capitalize(m.toLowerCase()) + ); + const translate = translateAlias(context); + + let scopes: string[] = []; + iface.annotation("scopes", (a) => { + scopes = getScopes(a); + }); + // Operation scopes override interface scopes + operation.annotation("scopes", (a) => { + scopes = getScopes(a); + }); + + methods.forEach((method) => { + let paramType: AnyType | undefined; + this.write( + `mux.HandleFunc("${method.toUpperCase()} ${path}", func(w ${$.http}.ResponseWriter, r *${$.http}.Request) {\n`, + ); + + if (scopes.length > 0) { + this.write( + `if err := ${$.authorization}.CheckScopes(r.Context(), "write:clusters"); err != nil { + ${$.thttp}.Error(w, nil, err, ${$.errorz}.PermissionDenied) + return +}\n`, + ); + } + + this.write(`resp := ${$.httpresponse}.New() + ctx := ${$.httpresponse}.NewContext(r.Context(), resp)\n`); + if (operation.isUnary()) { + // TODO: check type + paramType = operation.parameters[0].type; + } else if (operation.parameters.length > 0) { + const argsType = convertOperationToType( + context.getType.bind(context), + iface, + operation, + ); + paramType = argsType; + const structVisitor = new StructVisitor(this.writer); + argsType.accept(context.clone({ type: argsType }), structVisitor); + } + + const operMethod = methodName(operation, operation.name); + + if (paramType) { + // TODO + this.write( + `var args ${expandType(paramType, undefined, false, translate)}\n`, + ); + if (hasBody(method)) { + this.write( + `if err := ${$.json}.NewDecoder(r.Body).Decode(&args); err != nil { + ${$.thttp}.Error(w, resp, err, ${$.errorz}.Internal) + return + }\n`, + ); + } + + switch (paramType.kind) { + case Kind.Type: { + let foundQuery = false; + const t = paramType as Type; + t.fields.forEach((f) => { + if (path.indexOf(`{${f.name}}`) != -1) { + // Set path argument + this.write( + `args.${fieldName(f, f.name)} = r.PathValue("${f.name}")\n`, + ); + } else if (f.annotation("query") != undefined) { + if (!foundQuery) { + this.write(`query := r.URL.Query()\n`); + foundQuery = true; + } + this.write( + `args.${fieldName(f, f.name)} = query.Get("${f.name}")\n`, + ); + } + }); + + break; + } + } + + if (operation.type.kind != Kind.Void) { + this.write(`result, `); + } + if (operation.isUnary()) { + const pt = unwrapKinds(paramType, Kind.Alias); + const share = isKinds(pt, Kind.Primitive, Kind.Enum) ? "" : "&"; + this.write(`err := service.${operMethod}(ctx, ${share}args)\n`); + } else { + const args = (paramType as Type).fields + .map( + (f) => + `, ${isObject(f.type, false) ? "&" : ""}args.${ + fieldName( + f, + f.name, + ) + }`, + ) + .join(""); + this.write(`err := service.${operMethod}(ctx${args})\n`); + } + } else { + this.write(`err := service.${operMethod}(ctx)\n`); + } + + if (operation.type.kind != Kind.Void) { + this.write(`${$.thttp}.Response(w, resp, result, err)\n`); + } else { + this.write(`${$.thttp}.NoContent(w, resp, err)\n`); + } + this.write(`})\n`); + }); + } + + public override visitInterfaceAfter(_context: Context): void { + this.write(` } +}\n`); + } +} + +function getScopes(a: Annotation): string[] { + let scopes = a.convert().value; + // Convert single value to array + if (typeof scopes === "string") { + scopes = [scopes as string]; + } + return scopes || []; +} diff --git a/src/openapiv3/openapiv3.ts b/src/openapiv3/openapiv3.ts index 91923b4..ec0c962 100644 --- a/src/openapiv3/openapiv3.ts +++ b/src/openapiv3/openapiv3.ts @@ -17,20 +17,24 @@ limitations under the License. import { Alias, + Annotation, AnyType, BaseVisitor, Context, Field, + Interface, Kind, List as ListType, Map as MapType, Named, + Operation, Optional, Primitive, Type, Writer, } from "../../deps/@apexlang/core/model/mod.ts"; import { + ComponentsObject, Document, ExternalDocumentationObject, InfoObject, @@ -51,7 +55,7 @@ import { ExposedTypesVisitor, isService, } from "../utils/mod.ts"; -import { getPath, ResponseDirective } from "../rest/mod.ts"; +import { getPath, ResponseDirective, ScopesDirective } from "../rest/mod.ts"; type Method = "get" | "post" | "options" | "put" | "delete" | "patch"; @@ -104,6 +108,25 @@ const removeEmpty = (obj: any): any => { return newObj; }; +interface Config { + securitySchemes: SecurityScheme; +} + +interface SecurityScheme { + oauth2: OAuth2Scheme; +} + +interface OAuth2Scheme { + flows: OAuth2Flow[]; + hasRefreshURL: boolean; +} + +type OAuth2Flow = + | "implicit" + | "password" + | "clientCredentials" + | "authorizationCode"; + export class OpenAPIV3Visitor extends BaseVisitor { private root: Mutable = { openapi: "3.0.3", @@ -137,7 +160,81 @@ export class OpenAPIV3Visitor extends BaseVisitor { public override visitNamespaceAfter(context: Context): void { const filename = context.config["$filename"]; this.root.paths = this.paths; - this.root.components = { schemas: this.schemas }; + + const components: Mutable = { + schemas: this.schemas, + }; + + const ns = context.namespace; + const config = context.config as Config; + if (config.securitySchemes && config.securitySchemes.oauth2) { + const allScopes: Set = new Set(); + Object.values(ns.interfaces).forEach((i: Interface) => { + i.annotation("scopes", (a: Annotation) => { + getScopes(a).forEach((v) => allScopes.add(v)); + }); + i.operations.forEach((o: Operation) => { + o.annotation("scopes", (a: Annotation) => { + getScopes(a).forEach((v) => allScopes.add(v)); + }); + }); + }); + const scopes: { [name: string]: string } = {}; + Array.from(allScopes).sort().forEach((v) => { + scopes[v] = v; + }); + const oauth2 = config.securitySchemes.oauth2; + const flows: { [name: string]: any } = {}; + oauth2.flows.forEach((flow) => { + switch (flow) { + case "implicit": + flows.implicit = { + authorizationUrl: "{{OAUTH_AUTHORIZATION_URL}}", + scopes: scopes, + }; + if (oauth2.hasRefreshURL) { + flows.implicit.refreshUrl = "{{OAUTH_REFRESH_URL}}"; + } + break; + case "password": + flows.password = { + tokenUrl: "{{OAUTH_ACCESS_TOKEN_URL}}", + scopes: scopes, + }; + if (oauth2.hasRefreshURL) { + flows.password.refreshUrl = "{{OAUTH_REFRESH_URL}}"; + } + break; + case "clientCredentials": + flows.clientCredentials = { + tokenUrl: "{{OAUTH_ACCESS_TOKEN_URL}}", + scopes: scopes, + }; + if (oauth2.hasRefreshURL) { + flows.clientCredentials.refreshUrl = "{{OAUTH_REFRESH_URL}}"; + } + break; + case "authorizationCode": + flows.authorizationCode = { + authorizationUrl: "{{OAUTH_AUTHORIZATION_URL}}", + tokenUrl: "{{OAUTH_ACCESS_TOKEN_URL}}", + scopes: scopes, + }; + if (oauth2.hasRefreshURL) { + flows.authorizationCode.refreshUrl = "{{OAUTH_REFRESH_URL}}"; + } + break; + } + }); + components.securitySchemes = { + oauth2: { + type: "oauth2", + flows: flows, + }, + }; + } + + this.root.components = components; const contents = removeEmpty(this.root); if (filename.toLowerCase().endsWith(".json")) { this.write(JSON.stringify(contents, null, 2)); @@ -206,6 +303,15 @@ export class OpenAPIV3Visitor extends BaseVisitor { summary = a.convert().value; }); + let scopes: string[] = []; + iface.annotation("scopes", (a) => { + scopes = getScopes(a); + }); + // Operation scopes override interface scopes + operation.annotation("scopes", (a) => { + scopes = getScopes(a); + }); + this.path = path; this.method = method; this.operation = { @@ -218,6 +324,9 @@ export class OpenAPIV3Visitor extends BaseVisitor { operation.annotation("deprecated", (_a) => { this.operation!.deprecated = true; }); + if (scopes.length > 0) { + this.operation!.security = [{ oauth2: scopes }]; + } pathItem[method] = this.operation; } @@ -709,3 +818,12 @@ const primitiveTypeMap = new Map([ ["any", {}], ["value", {}], ]); + +function getScopes(a: Annotation): string[] { + let scopes = a.convert().value; + // Convert single value to array + if (typeof scopes === "string") { + scopes = [scopes as string]; + } + return scopes || []; +} diff --git a/src/rest/mod.ts b/src/rest/mod.ts index 0958217..356b329 100644 --- a/src/rest/mod.ts +++ b/src/rest/mod.ts @@ -27,20 +27,18 @@ export interface ResponseDirective { examples?: { [k: string]: string }; } +export interface ScopesDirective { + value: string[]; +} + export function getPath(context: Context): string { const ns = context.namespace; - const inter = context.interface; const { interface: iface, operation } = context; let path = ""; ns.annotation("path", (a) => { path += a.convert().value; }); - if (inter) { - inter.annotation("path", (a) => { - path += a.convert().value; - }); - } if (iface) { iface.annotation("path", (a) => { path += a.convert().value;