Skip to content

Commit

Permalink
Add better pointer checks, to handle interface situations (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulSonOfLars committed Apr 23, 2023
1 parent ff805ad commit 38c2331
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 63 deletions.
48 changes: 20 additions & 28 deletions gen_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ func (bot *Bot) DeleteMessage(chatId int64, messageId int64, opts *DeleteMessage
// DeleteMyCommandsOpts is the set of optional fields for Bot.DeleteMyCommands.
type DeleteMyCommandsOpts struct {
// A JSON-serialized object, describing scope of users for which the commands are relevant. Defaults to BotCommandScopeDefault.
Scope *BotCommandScope
Scope BotCommandScope
// A two-letter ISO 639-1 language code. If empty, commands will be applied to all users from the given scope, for whose language there are no dedicated commands
LanguageCode string
// RequestOpts are an additional optional field to configure timeouts for individual requests
Expand All @@ -969,13 +969,11 @@ type DeleteMyCommandsOpts struct {
func (bot *Bot) DeleteMyCommands(opts *DeleteMyCommandsOpts) (bool, error) {
v := map[string]string{}
if opts != nil {
if opts.Scope != nil {
bs, err := json.Marshal(opts.Scope)
if err != nil {
return false, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
bs, err := json.Marshal(opts.Scope)
if err != nil {
return false, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
v["language_code"] = opts.LanguageCode
}

Expand Down Expand Up @@ -1937,7 +1935,7 @@ func (bot *Bot) GetMe(opts *GetMeOpts) (*User, error) {
// GetMyCommandsOpts is the set of optional fields for Bot.GetMyCommands.
type GetMyCommandsOpts struct {
// A JSON-serialized object, describing scope of users. Defaults to BotCommandScopeDefault.
Scope *BotCommandScope
Scope BotCommandScope
// A two-letter ISO 639-1 language code or an empty string
LanguageCode string
// RequestOpts are an additional optional field to configure timeouts for individual requests
Expand All @@ -1951,13 +1949,11 @@ type GetMyCommandsOpts struct {
func (bot *Bot) GetMyCommands(opts *GetMyCommandsOpts) ([]BotCommand, error) {
v := map[string]string{}
if opts != nil {
if opts.Scope != nil {
bs, err := json.Marshal(opts.Scope)
if err != nil {
return nil, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
bs, err := json.Marshal(opts.Scope)
if err != nil {
return nil, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
v["language_code"] = opts.LanguageCode
}

Expand Down Expand Up @@ -4369,7 +4365,7 @@ type SetChatMenuButtonOpts struct {
// Unique identifier for the target private chat. If not specified, default bot's menu button will be changed
ChatId *int64
// A JSON-serialized object for the bot's new menu button. Defaults to MenuButtonDefault
MenuButton *MenuButton
MenuButton MenuButton
// RequestOpts are an additional optional field to configure timeouts for individual requests
RequestOpts *RequestOpts
}
Expand All @@ -4384,13 +4380,11 @@ func (bot *Bot) SetChatMenuButton(opts *SetChatMenuButtonOpts) (bool, error) {
if opts.ChatId != nil {
v["chat_id"] = strconv.FormatInt(*opts.ChatId, 10)
}
if opts.MenuButton != nil {
bs, err := json.Marshal(opts.MenuButton)
if err != nil {
return false, fmt.Errorf("failed to marshal field menu_button: %w", err)
}
v["menu_button"] = string(bs)
bs, err := json.Marshal(opts.MenuButton)
if err != nil {
return false, fmt.Errorf("failed to marshal field menu_button: %w", err)
}
v["menu_button"] = string(bs)
}

var reqOpts *RequestOpts
Expand Down Expand Up @@ -4655,7 +4649,7 @@ func (bot *Bot) SetGameScore(userId int64, score int64, opts *SetGameScoreOpts)
// SetMyCommandsOpts is the set of optional fields for Bot.SetMyCommands.
type SetMyCommandsOpts struct {
// A JSON-serialized object, describing scope of users for which the commands are relevant. Defaults to BotCommandScopeDefault.
Scope *BotCommandScope
Scope BotCommandScope
// A two-letter ISO 639-1 language code. If empty, commands will be applied to all users from the given scope, for whose language there are no dedicated commands
LanguageCode string
// RequestOpts are an additional optional field to configure timeouts for individual requests
Expand All @@ -4677,13 +4671,11 @@ func (bot *Bot) SetMyCommands(commands []BotCommand, opts *SetMyCommandsOpts) (b
v["commands"] = string(bs)
}
if opts != nil {
if opts.Scope != nil {
bs, err := json.Marshal(opts.Scope)
if err != nil {
return false, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
bs, err := json.Marshal(opts.Scope)
if err != nil {
return false, fmt.Errorf("failed to marshal field scope: %w", err)
}
v["scope"] = string(bs)
v["language_code"] = opts.LanguageCode
}

Expand Down
10 changes: 0 additions & 10 deletions scripts/generate/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,6 @@ func toGoType(s string) string {
return pref + s
}

func isBaseGoType(s string) bool {
s = stripPointersAndArrays(s)
for _, v := range tgToGoTypeMap {
if s == v {
return true
}
}
return false
}

func stripPointersAndArrays(retType string) string {
for isPointer(retType) {
retType = strings.TrimPrefix(retType, "*")
Expand Down
15 changes: 11 additions & 4 deletions scripts/generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func isTgType(d APIDescription, goType string) bool {
return ok
}

func (f Field) getPreferredType() (string, error) {
func (f Field) getPreferredType(d APIDescription) (string, error) {
if f.Name == "media" {
if len(f.Types) == 1 && f.Types[0] == "String" {
return tgTypeInputFile, nil
Expand Down Expand Up @@ -299,9 +299,16 @@ func (f Field) getPreferredType() (string, error) {
if len(f.Types) == 1 {
goType := toGoType(f.Types[0])

// Optional JSON fields should always be pointers.
if !f.Required && strings.Contains(f.Description, "JSON-serialized object") && !isBaseGoType(goType) {
return "*" + goType, nil
// Optional TG types should be pointers, unless they're already an interface type.
if !f.Required && isTgType(d, f.Types[0]) && !isArray(goType) && goType != tgTypeInputFile {
rawType, err := getTypeByName(d, f.Types[0])
if err != nil {
return "", fmt.Errorf("failed to get parent for %s: %w", f.Types[0], err)
}

if len(rawType.Subtypes) == 0 {
return "*" + goType, nil
}
}

// Some fields are marked as "May be empty", in which case the empty values are still meaningful.
Expand Down
4 changes: 2 additions & 2 deletions scripts/generate/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func generateHelperArguments(d APIDescription, tgMethod MethodDescription, recei
for _, mf := range tgMethod.Fields {
hasOpts = hasOpts || !mf.Required

prefType, err := mf.getPreferredType()
prefType, err := mf.getPreferredType(d)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", mf.Name, tgMethod.Name, err)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func getMethodFieldsSubtypeMatches(d APIDescription, tgMethod MethodDescription,
}

for _, mf := range tgMethod.Fields {
prefType, err := f.getPreferredType()
prefType, err := f.getPreferredType(d)
if err != nil {
return "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", mf.Name, tgMethod.Name, err)
}
Expand Down
14 changes: 7 additions & 7 deletions scripts/generate/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func generateMethodDef(d APIDescription, tgMethod MethodDescription) (string, er
}

// Generate method description
desc, err := tgMethod.description()
desc, err := tgMethod.description(d)
if err != nil {
return "", fmt.Errorf("failed to generate method description for %s: %w", tgMethod.Name, err)
}
Expand Down Expand Up @@ -113,7 +113,7 @@ func generateMethodSignature(d APIDescription, tgMethod MethodDescription) (stri
return "", nil, "", fmt.Errorf("failed to get return for %s: %w", tgMethod.Name, err)
}

args, optionalsStruct, err := tgMethod.getArgs()
args, optionalsStruct, err := tgMethod.getArgs(d)
if err != nil {
return "", nil, "", fmt.Errorf("failed to get args for method %s: %w", tgMethod.Name, err)
}
Expand Down Expand Up @@ -169,7 +169,7 @@ return %s, true, nil
return returnString.String(), nil
}

func (m MethodDescription) description() (string, error) {
func (m MethodDescription) description(d APIDescription) (string, error) {
description := strings.Builder{}

description.WriteString(m.docs())
Expand All @@ -179,7 +179,7 @@ func (m MethodDescription) description() (string, error) {
continue
}

prefType, err := f.getPreferredType()
prefType, err := f.getPreferredType(d)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -236,7 +236,7 @@ func (m MethodDescription) argsToValues(d APIDescription, methodName string, def
}

func generateValue(d APIDescription, methodName string, f Field, goParam string, defaultRetVal string) (string, bool, error) {
fieldType, err := f.getPreferredType()
fieldType, err := f.getPreferredType(d)
if err != nil {
return "", false, fmt.Errorf("failed to get preferred type: %w", err)
}
Expand Down Expand Up @@ -367,12 +367,12 @@ func getRetVarName(retType string) string {
return strings.ToLower(retType[:1])
}

func (m MethodDescription) getArgs() (string, string, error) {
func (m MethodDescription) getArgs(d APIDescription) (string, string, error) {
var requiredArgs []string
optionals := strings.Builder{}

for _, f := range m.Fields {
fieldType, err := f.getPreferredType()
fieldType, err := f.getPreferredType(d)
if err != nil {
return "", "", fmt.Errorf("failed to get preferred type: %w", err)
}
Expand Down
24 changes: 12 additions & 12 deletions scripts/generate/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func generateTypeDef(d APIDescription, tgType TypeDescription) (string, error) {

ok, fieldName, err := containsInputFile(d, tgType, map[string]bool{})
if err != nil {
return "", fmt.Errorf("failed to check if type requires special handling")
return "", fmt.Errorf("failed to check if type requires special handling: %w", err)
}
if ok {
err = inputParamsTmpl.Execute(&typeDef, inputParamsMethodData{
Expand All @@ -105,7 +105,7 @@ func generateTypeDef(d APIDescription, tgType TypeDescription) (string, error) {

// fieldContainsInputFile checks whether the field's type contains any inputfiles, and thus might be used to send data.
func fieldContainsInputFile(d APIDescription, field Field) (bool, error) {
goType, err := field.getPreferredType()
goType, err := field.getPreferredType(d)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -133,7 +133,7 @@ func containsInputFile(d APIDescription, tgType TypeDescription, checked map[str
}

for _, f := range tgType.Fields {
goType, err := f.getPreferredType()
goType, err := f.getPreferredType(d)
if err != nil {
return false, "", err
}
Expand All @@ -145,7 +145,7 @@ func containsInputFile(d APIDescription, tgType TypeDescription, checked map[str
if isTgType(d, goType) {
ok, _, err := containsInputFile(d, d.Types[goType], checked)
if err != nil {
return false, "", fmt.Errorf("failed to check if %s contains inputfiles", goType)
return false, "", fmt.Errorf("failed to check if %s contains inputfiles: %w", goType, err)
}
if ok {
// We return an error, because we can't actually handle this case yet.
Expand Down Expand Up @@ -192,7 +192,7 @@ func setupCustomUnmarshal(d APIDescription, tgType TypeDescription) (string, err
var fields []customUnmarshalFieldData
generateCustomMarshal := false
for idx, f := range tgType.Fields {
prefType, err := f.getPreferredType()
prefType, err := f.getPreferredType(d)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func commonFieldGenerator(d APIDescription, tgType TypeDescription, parentType T

bd := strings.Builder{}
if len(commonFields) > 0 {
commonGetMethods, err := generateAllCommonGetMethods(tgType.Name, commonFields, constantField, shortName)
commonGetMethods, err := generateAllCommonGetMethods(d, tgType.Name, commonFields, constantField, shortName)
if err != nil {
return "", err
}
Expand All @@ -357,15 +357,15 @@ func commonFieldGenerator(d APIDescription, tgType TypeDescription, parentType T
return bd.String(), nil
}

func generateAllCommonGetMethods(typeName string, commonFields []Field, constantField string, shortName string) (string, error) {
func generateAllCommonGetMethods(d APIDescription, typeName string, commonFields []Field, constantField string, shortName string) (string, error) {
bd := strings.Builder{}
for _, commonField := range commonFields {
commonValueName := "v." + snakeToTitle(commonField.Name)
if commonField.Name == constantField {
commonValueName = strconv.Quote(shortName)
}

prefType, err := commonField.getPreferredType()
prefType, err := commonField.getPreferredType(d)
if err != nil {
return "", fmt.Errorf("failed to get preferred type for field %s of %s: %w", commonField.Name, typeName, err)
}
Expand Down Expand Up @@ -399,7 +399,7 @@ func generateTypeFields(d APIDescription, tgType TypeDescription) (string, error
func generateStructFields(d APIDescription, fields []Field, constantFields []string) (string, error) {
typeFields := strings.Builder{}
for _, f := range fields {
fieldType, err := f.getPreferredType()
fieldType, err := f.getPreferredType(d)
if err != nil {
return "", fmt.Errorf("failed to get preferred type: %w", err)
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func generateGenericInterfaceType(d APIDescription, name string, subtypes []Type
bd := strings.Builder{}
bd.WriteString(fmt.Sprintf("\ntype %s interface{", name))
for _, f := range commonFields {
prefType, err := f.getPreferredType()
prefType, err := f.getPreferredType(d)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -478,7 +478,7 @@ func generateGenericInterfaceType(d APIDescription, name string, subtypes []Type

bd.WriteString("\n" + mergedStruct)

commonGetMethods, err := generateAllCommonGetMethods("Merged"+name, commonFields, "", "")
commonGetMethods, err := generateAllCommonGetMethods(d, "Merged"+name, commonFields, "", "")
if err != nil {
return "", err
}
Expand Down Expand Up @@ -540,7 +540,7 @@ func generateMergeFunc(d APIDescription, typeName string, shortname string, fiel
deref := false
for _, parentField := range allParentFields {
if parentField.Name == f.Name {
fieldType, err := f.getPreferredType()
fieldType, err := f.getPreferredType(d)
if err != nil {
return "", fmt.Errorf("failed to get preferred type: %w", err)
}
Expand Down

0 comments on commit 38c2331

Please sign in to comment.