diff --git a/gen_methods.go b/gen_methods.go index cf552dcf..02c42af2 100755 --- a/gen_methods.go +++ b/gen_methods.go @@ -955,7 +955,7 @@ func (bot *Bot) DeleteMessage(chatId int64, messageId int64, opts *DeleteMessage // DeleteMyCommandsOpts is the set of optional fields for Bot.DeleteMyCommands. type DeleteMyCommandsOpts struct { // A JSON-serialized object, describing scope of users for which the commands are relevant. Defaults to BotCommandScopeDefault. - Scope *BotCommandScope + Scope BotCommandScope // A two-letter ISO 639-1 language code. If empty, commands will be applied to all users from the given scope, for whose language there are no dedicated commands LanguageCode string // RequestOpts are an additional optional field to configure timeouts for individual requests @@ -969,13 +969,11 @@ type DeleteMyCommandsOpts struct { func (bot *Bot) DeleteMyCommands(opts *DeleteMyCommandsOpts) (bool, error) { v := map[string]string{} if opts != nil { - if opts.Scope != nil { - bs, err := json.Marshal(opts.Scope) - if err != nil { - return false, fmt.Errorf("failed to marshal field scope: %w", err) - } - v["scope"] = string(bs) + bs, err := json.Marshal(opts.Scope) + if err != nil { + return false, fmt.Errorf("failed to marshal field scope: %w", err) } + v["scope"] = string(bs) v["language_code"] = opts.LanguageCode } @@ -1937,7 +1935,7 @@ func (bot *Bot) GetMe(opts *GetMeOpts) (*User, error) { // GetMyCommandsOpts is the set of optional fields for Bot.GetMyCommands. type GetMyCommandsOpts struct { // A JSON-serialized object, describing scope of users. Defaults to BotCommandScopeDefault. - Scope *BotCommandScope + Scope BotCommandScope // A two-letter ISO 639-1 language code or an empty string LanguageCode string // RequestOpts are an additional optional field to configure timeouts for individual requests @@ -1951,13 +1949,11 @@ type GetMyCommandsOpts struct { func (bot *Bot) GetMyCommands(opts *GetMyCommandsOpts) ([]BotCommand, error) { v := map[string]string{} if opts != nil { - if opts.Scope != nil { - bs, err := json.Marshal(opts.Scope) - if err != nil { - return nil, fmt.Errorf("failed to marshal field scope: %w", err) - } - v["scope"] = string(bs) + bs, err := json.Marshal(opts.Scope) + if err != nil { + return nil, fmt.Errorf("failed to marshal field scope: %w", err) } + v["scope"] = string(bs) v["language_code"] = opts.LanguageCode } @@ -4369,7 +4365,7 @@ type SetChatMenuButtonOpts struct { // Unique identifier for the target private chat. If not specified, default bot's menu button will be changed ChatId *int64 // A JSON-serialized object for the bot's new menu button. Defaults to MenuButtonDefault - MenuButton *MenuButton + MenuButton MenuButton // RequestOpts are an additional optional field to configure timeouts for individual requests RequestOpts *RequestOpts } @@ -4384,13 +4380,11 @@ func (bot *Bot) SetChatMenuButton(opts *SetChatMenuButtonOpts) (bool, error) { if opts.ChatId != nil { v["chat_id"] = strconv.FormatInt(*opts.ChatId, 10) } - if opts.MenuButton != nil { - bs, err := json.Marshal(opts.MenuButton) - if err != nil { - return false, fmt.Errorf("failed to marshal field menu_button: %w", err) - } - v["menu_button"] = string(bs) + bs, err := json.Marshal(opts.MenuButton) + if err != nil { + return false, fmt.Errorf("failed to marshal field menu_button: %w", err) } + v["menu_button"] = string(bs) } var reqOpts *RequestOpts @@ -4655,7 +4649,7 @@ func (bot *Bot) SetGameScore(userId int64, score int64, opts *SetGameScoreOpts) // SetMyCommandsOpts is the set of optional fields for Bot.SetMyCommands. type SetMyCommandsOpts struct { // A JSON-serialized object, describing scope of users for which the commands are relevant. Defaults to BotCommandScopeDefault. - Scope *BotCommandScope + Scope BotCommandScope // A two-letter ISO 639-1 language code. If empty, commands will be applied to all users from the given scope, for whose language there are no dedicated commands LanguageCode string // RequestOpts are an additional optional field to configure timeouts for individual requests @@ -4677,13 +4671,11 @@ func (bot *Bot) SetMyCommands(commands []BotCommand, opts *SetMyCommandsOpts) (b v["commands"] = string(bs) } if opts != nil { - if opts.Scope != nil { - bs, err := json.Marshal(opts.Scope) - if err != nil { - return false, fmt.Errorf("failed to marshal field scope: %w", err) - } - v["scope"] = string(bs) + bs, err := json.Marshal(opts.Scope) + if err != nil { + return false, fmt.Errorf("failed to marshal field scope: %w", err) } + v["scope"] = string(bs) v["language_code"] = opts.LanguageCode } diff --git a/scripts/generate/common.go b/scripts/generate/common.go index fc0e4287..edfee24f 100644 --- a/scripts/generate/common.go +++ b/scripts/generate/common.go @@ -69,16 +69,6 @@ func toGoType(s string) string { return pref + s } -func isBaseGoType(s string) bool { - s = stripPointersAndArrays(s) - for _, v := range tgToGoTypeMap { - if s == v { - return true - } - } - return false -} - func stripPointersAndArrays(retType string) string { for isPointer(retType) { retType = strings.TrimPrefix(retType, "*") diff --git a/scripts/generate/gen.go b/scripts/generate/gen.go index 999c9aaa..6acc22b6 100644 --- a/scripts/generate/gen.go +++ b/scripts/generate/gen.go @@ -256,7 +256,7 @@ func isTgType(d APIDescription, goType string) bool { return ok } -func (f Field) getPreferredType() (string, error) { +func (f Field) getPreferredType(d APIDescription) (string, error) { if f.Name == "media" { if len(f.Types) == 1 && f.Types[0] == "String" { return tgTypeInputFile, nil @@ -299,9 +299,16 @@ func (f Field) getPreferredType() (string, error) { if len(f.Types) == 1 { goType := toGoType(f.Types[0]) - // Optional JSON fields should always be pointers. - if !f.Required && strings.Contains(f.Description, "JSON-serialized object") && !isBaseGoType(goType) { - return "*" + goType, nil + // Optional TG types should be pointers, unless they're already an interface type. + if !f.Required && isTgType(d, f.Types[0]) && !isArray(goType) && goType != tgTypeInputFile { + rawType, err := getTypeByName(d, f.Types[0]) + if err != nil { + return "", fmt.Errorf("failed to get parent for %s: %w", f.Types[0], err) + } + + if len(rawType.Subtypes) == 0 { + return "*" + goType, nil + } } // Some fields are marked as "May be empty", in which case the empty values are still meaningful. diff --git a/scripts/generate/helpers.go b/scripts/generate/helpers.go index 6e1019e5..a38c3c77 100644 --- a/scripts/generate/helpers.go +++ b/scripts/generate/helpers.go @@ -114,7 +114,7 @@ func generateHelperArguments(d APIDescription, tgMethod MethodDescription, recei for _, mf := range tgMethod.Fields { hasOpts = hasOpts || !mf.Required - prefType, err := mf.getPreferredType() + prefType, err := mf.getPreferredType(d) if err != nil { return nil, nil, "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", mf.Name, tgMethod.Name, err) } @@ -158,7 +158,7 @@ func getMethodFieldsSubtypeMatches(d APIDescription, tgMethod MethodDescription, } for _, mf := range tgMethod.Fields { - prefType, err := f.getPreferredType() + prefType, err := f.getPreferredType(d) if err != nil { return "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", mf.Name, tgMethod.Name, err) } diff --git a/scripts/generate/methods.go b/scripts/generate/methods.go index 7fc0bd9e..d8c7d9e5 100644 --- a/scripts/generate/methods.go +++ b/scripts/generate/methods.go @@ -57,7 +57,7 @@ func generateMethodDef(d APIDescription, tgMethod MethodDescription) (string, er } // Generate method description - desc, err := tgMethod.description() + desc, err := tgMethod.description(d) if err != nil { return "", fmt.Errorf("failed to generate method description for %s: %w", tgMethod.Name, err) } @@ -113,7 +113,7 @@ func generateMethodSignature(d APIDescription, tgMethod MethodDescription) (stri return "", nil, "", fmt.Errorf("failed to get return for %s: %w", tgMethod.Name, err) } - args, optionalsStruct, err := tgMethod.getArgs() + args, optionalsStruct, err := tgMethod.getArgs(d) if err != nil { return "", nil, "", fmt.Errorf("failed to get args for method %s: %w", tgMethod.Name, err) } @@ -169,7 +169,7 @@ return %s, true, nil return returnString.String(), nil } -func (m MethodDescription) description() (string, error) { +func (m MethodDescription) description(d APIDescription) (string, error) { description := strings.Builder{} description.WriteString(m.docs()) @@ -179,7 +179,7 @@ func (m MethodDescription) description() (string, error) { continue } - prefType, err := f.getPreferredType() + prefType, err := f.getPreferredType(d) if err != nil { return "", err } @@ -236,7 +236,7 @@ func (m MethodDescription) argsToValues(d APIDescription, methodName string, def } func generateValue(d APIDescription, methodName string, f Field, goParam string, defaultRetVal string) (string, bool, error) { - fieldType, err := f.getPreferredType() + fieldType, err := f.getPreferredType(d) if err != nil { return "", false, fmt.Errorf("failed to get preferred type: %w", err) } @@ -367,12 +367,12 @@ func getRetVarName(retType string) string { return strings.ToLower(retType[:1]) } -func (m MethodDescription) getArgs() (string, string, error) { +func (m MethodDescription) getArgs(d APIDescription) (string, string, error) { var requiredArgs []string optionals := strings.Builder{} for _, f := range m.Fields { - fieldType, err := f.getPreferredType() + fieldType, err := f.getPreferredType(d) if err != nil { return "", "", fmt.Errorf("failed to get preferred type: %w", err) } diff --git a/scripts/generate/types.go b/scripts/generate/types.go index b4783213..2f84db9a 100644 --- a/scripts/generate/types.go +++ b/scripts/generate/types.go @@ -88,7 +88,7 @@ func generateTypeDef(d APIDescription, tgType TypeDescription) (string, error) { ok, fieldName, err := containsInputFile(d, tgType, map[string]bool{}) if err != nil { - return "", fmt.Errorf("failed to check if type requires special handling") + return "", fmt.Errorf("failed to check if type requires special handling: %w", err) } if ok { err = inputParamsTmpl.Execute(&typeDef, inputParamsMethodData{ @@ -105,7 +105,7 @@ func generateTypeDef(d APIDescription, tgType TypeDescription) (string, error) { // fieldContainsInputFile checks whether the field's type contains any inputfiles, and thus might be used to send data. func fieldContainsInputFile(d APIDescription, field Field) (bool, error) { - goType, err := field.getPreferredType() + goType, err := field.getPreferredType(d) if err != nil { return false, err } @@ -133,7 +133,7 @@ func containsInputFile(d APIDescription, tgType TypeDescription, checked map[str } for _, f := range tgType.Fields { - goType, err := f.getPreferredType() + goType, err := f.getPreferredType(d) if err != nil { return false, "", err } @@ -145,7 +145,7 @@ func containsInputFile(d APIDescription, tgType TypeDescription, checked map[str if isTgType(d, goType) { ok, _, err := containsInputFile(d, d.Types[goType], checked) if err != nil { - return false, "", fmt.Errorf("failed to check if %s contains inputfiles", goType) + return false, "", fmt.Errorf("failed to check if %s contains inputfiles: %w", goType, err) } if ok { // We return an error, because we can't actually handle this case yet. @@ -192,7 +192,7 @@ func setupCustomUnmarshal(d APIDescription, tgType TypeDescription) (string, err var fields []customUnmarshalFieldData generateCustomMarshal := false for idx, f := range tgType.Fields { - prefType, err := f.getPreferredType() + prefType, err := f.getPreferredType(d) if err != nil { return "", err } @@ -331,7 +331,7 @@ func commonFieldGenerator(d APIDescription, tgType TypeDescription, parentType T bd := strings.Builder{} if len(commonFields) > 0 { - commonGetMethods, err := generateAllCommonGetMethods(tgType.Name, commonFields, constantField, shortName) + commonGetMethods, err := generateAllCommonGetMethods(d, tgType.Name, commonFields, constantField, shortName) if err != nil { return "", err } @@ -357,7 +357,7 @@ func commonFieldGenerator(d APIDescription, tgType TypeDescription, parentType T return bd.String(), nil } -func generateAllCommonGetMethods(typeName string, commonFields []Field, constantField string, shortName string) (string, error) { +func generateAllCommonGetMethods(d APIDescription, typeName string, commonFields []Field, constantField string, shortName string) (string, error) { bd := strings.Builder{} for _, commonField := range commonFields { commonValueName := "v." + snakeToTitle(commonField.Name) @@ -365,7 +365,7 @@ func generateAllCommonGetMethods(typeName string, commonFields []Field, constant commonValueName = strconv.Quote(shortName) } - prefType, err := commonField.getPreferredType() + prefType, err := commonField.getPreferredType(d) if err != nil { return "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", commonField.Name, typeName, err) } @@ -399,7 +399,7 @@ func generateTypeFields(d APIDescription, tgType TypeDescription) (string, error func generateStructFields(d APIDescription, fields []Field, constantFields []string) (string, error) { typeFields := strings.Builder{} for _, f := range fields { - fieldType, err := f.getPreferredType() + fieldType, err := f.getPreferredType(d) if err != nil { return "", fmt.Errorf("failed to get preferred type: %w", err) } @@ -448,7 +448,7 @@ func generateGenericInterfaceType(d APIDescription, name string, subtypes []Type bd := strings.Builder{} bd.WriteString(fmt.Sprintf("\ntype %s interface{", name)) for _, f := range commonFields { - prefType, err := f.getPreferredType() + prefType, err := f.getPreferredType(d) if err != nil { return "", err } @@ -478,7 +478,7 @@ func generateGenericInterfaceType(d APIDescription, name string, subtypes []Type bd.WriteString("\n" + mergedStruct) - commonGetMethods, err := generateAllCommonGetMethods("Merged"+name, commonFields, "", "") + commonGetMethods, err := generateAllCommonGetMethods(d, "Merged"+name, commonFields, "", "") if err != nil { return "", err } @@ -540,7 +540,7 @@ func generateMergeFunc(d APIDescription, typeName string, shortname string, fiel deref := false for _, parentField := range allParentFields { if parentField.Name == f.Name { - fieldType, err := f.getPreferredType() + fieldType, err := f.getPreferredType(d) if err != nil { return "", fmt.Errorf("failed to get preferred type: %w", err) }