From 796d66d413ee9a63fbeb88d7d458891da60b1e84 Mon Sep 17 00:00:00 2001 From: Jakab Zeller Date: Tue, 14 Mar 2023 20:42:30 +0000 Subject: [PATCH] Improving Paginators by making them aware of BindingParams #2 - Added the reddit.RateLimit type for tracking rate limits in the reddit API, as well as the RateLimitsConfig interface for configuring rates (13/03/2023 - 15:16:26) - A lot more reddit.Types taken mostly from https://github.com/vartanbeno/go-reddit/reddit/things.go (13/03/2023 - 15:23:14) - Paginator's are now aware of multiple different types of pagination parameters thanks to the Binding.Params method (14/03/2023 - 14:14:10) - The untyped paginator is now the type Paginator[any, any] which makes more sense and allows us to actually use the Afterable interafce (14/03/2023 - 16:22:05) --- api/api.go | 80 ++++++- api/api_test.go | 6 +- api/binding.go | 34 ++- api/paginator.go | 373 +++++++++++++++++++++++++++---- config.go | 34 ++- go.mod | 2 +- go.sum | 4 + main.go | 44 ++-- monday/client.go | 2 +- reddit/bindings.go | 37 ++-- reddit/client.go | 85 ++++++-- reddit/config.go | 13 ++ reddit/types.go | 534 ++++++++++++++++++++++++++++++++++----------- 13 files changed, 1000 insertions(+), 248 deletions(-) diff --git a/api/api.go b/api/api.go index bbfd930..dee75cd 100644 --- a/api/api.go +++ b/api/api.go @@ -7,6 +7,7 @@ import ( "github.com/machinebox/graphql" "net/http" "reflect" + "sync" "time" ) @@ -37,8 +38,48 @@ type Request interface { // Client is the API client that will execute the Binding. type Client interface { - // Run should execute the given Request and unmarshal the response into the given response interface. - Run(ctx context.Context, attrs map[string]any, req Request, res any) error + // Run should execute the given Request and unmarshal the response into the given response interface. It is usually + // called from Binding.Execute to execute a Binding, hence why we also pass in the name of the Binding (from + // Binding.Name). + Run(ctx context.Context, bindingName string, attrs map[string]any, req Request, res any) error +} + +type RateLimitType int + +const ( + // RequestRateLimit means that the RateLimit is limited by the number of HTTP requests that can be made in a certain + // timespan. + RequestRateLimit RateLimitType = iota + // ResourceRateLimit means that the RateLimit is limited by the number of resources that can be fetched in a certain + // timespan. + ResourceRateLimit +) + +// RateLimit represents a RateLimit for a binding. +type RateLimit interface { + // Reset returns the time at which the RateLimit resets. + Reset() time.Time + // Remaining returns the number of requests remaining/resources that can be fetched for this RateLimit. + Remaining() int + // Used returns the number of requests used/resources fetched so far for this RateLimit. + Used() int + // Type is the type of the RateLimit. See RateLimitType for documentation. + Type() RateLimitType +} + +// RateLimitedClient is an API Client that has a RateLimit for each Binding it has authority over. +type RateLimitedClient interface { + // Client should implement a Client.Run method that sets an internal sync.Map of RateLimit(s). + Client + // RateLimits returns the sync.Map of Binding names to RateLimit instances. + RateLimits() *sync.Map + // AddRateLimit should add a RateLimit to the internal sync.Map within the Client. It should check if the Binding of + // the given name already has a RateLimit, and whether the RateLimit.Reset lies after the currently set RateLimit + // for that Binding. + AddRateLimit(bindingName string, rateLimit RateLimit) + // LatestRateLimit should return the latest RateLimit for the Binding of the given name. If multiple Binding(s) + // share the same RateLimit(s) then this can also be encoded into this method. + LatestRateLimit(bindingName string) RateLimit } // BindingWrapper wraps a Binding value with its name. This is used within the Schema map so that we don't have to use @@ -54,11 +95,23 @@ func (bw BindingWrapper) String() string { return fmt.Sprintf("%s/%v", bw.name, bw.binding.Type()) } +// Name returns the name of the underlying Binding. +func (bw BindingWrapper) Name() string { return bw.name } + +func (bw BindingWrapper) bindingName() string { + return bw.binding.MethodByName("Name").Call([]reflect.Value{})[0].Interface().(string) +} + // Paginated calls the Binding.Paginated method for the underlying Binding in the BindingWrapper. func (bw BindingWrapper) Paginated() bool { return bw.binding.MethodByName("Paginated").Call([]reflect.Value{})[0].Bool() } +// Paginator returns an un-typed Paginator for the underlying Binding of the BindingWrapper. +func (bw BindingWrapper) Paginator(client Client, waitTime time.Duration, args ...any) (paginator Paginator[any, any], err error) { + return NewPaginator(client, waitTime, bw, args...) +} + // ArgsFromStrings calls the Binding.ArgsFromStrings method for the underlying Binding in the BindingWrapper. func (bw BindingWrapper) ArgsFromStrings(args ...string) (parsedArgs []any, err error) { values := bw.binding.MethodByName("ArgsFromStrings").Call(slices.Comprehension(args, func(idx int, value string, arr []string) reflect.Value { @@ -72,6 +125,11 @@ func (bw BindingWrapper) ArgsFromStrings(args ...string) (parsedArgs []any, err return } +// Params calls the Binding.Params method for the underlying Binding in the BindingWrapper. +func (bw BindingWrapper) Params() []BindingParam { + return bw.binding.MethodByName("Params").Call([]reflect.Value{})[0].Interface().([]BindingParam) +} + // Execute calls the Binding.Execute method for the underlying Binding in the BindingWrapper. func (bw BindingWrapper) Execute(client Client, args ...any) (val any, err error) { arguments := []any{client} @@ -87,14 +145,20 @@ func (bw BindingWrapper) Execute(client Client, args ...any) (val any, err error return } -// WrapBinding will return the BindingWrapper for the given Binding of the given name. -func WrapBinding[ResT any, RetT any](name string, binding Binding[ResT, RetT]) BindingWrapper { +func (bw BindingWrapper) setName(name string) { + fmt.Println("setName", bw.binding.Type()) + bw.binding.MethodByName("SetName").Call([]reflect.Value{reflect.ValueOf(name)}) +} + +// WrapBinding will return the BindingWrapper for the given Binding. The name of the BindingWrapper will be fetched from +// Binding.Name, so make sure to override this before using the Binding. +func WrapBinding[ResT any, RetT any](binding Binding[ResT, RetT]) BindingWrapper { var ( resT ResT retT RetT ) return BindingWrapper{ - name: name, + name: binding.Name(), responseType: reflect.TypeOf(resT), returnType: reflect.TypeOf(retT), binding: reflect.ValueOf(&binding).Elem(), @@ -112,6 +176,10 @@ type API struct { // NewAPI constructs a new API instance for the given Client and Schema combination. func NewAPI(client Client, schema Schema) *API { + for bindingName, bindingWrapper := range schema { + bindingWrapper.name = bindingName + } + return &API{ Client: client, schema: schema, @@ -152,7 +220,7 @@ func (api *API) Execute(name string, args ...any) (val any, err error) { } // Paginator returns a Paginator for the Binding of the given name within the API. -func (api *API) Paginator(name string, waitTime time.Duration, args ...any) (paginator Paginator[[]any, []any], err error) { +func (api *API) Paginator(name string, waitTime time.Duration, args ...any) (paginator Paginator[any, any], err error) { var binding BindingWrapper if binding, err = api.checkBindingExists(name); err != nil { return diff --git a/api/api_test.go b/api/api_test.go index cec55f8..a6a390c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -20,7 +20,7 @@ import ( type httpClient struct { } -func (h httpClient) Run(ctx context.Context, attrs map[string]any, req Request, res any) (err error) { +func (h httpClient) Run(ctx context.Context, bindingName string, attrs map[string]any, req Request, res any) (err error) { request := req.(HTTPRequest).Request var response *http.Response @@ -496,12 +496,12 @@ func ExampleNewAPI() { // when creating Bindings. This will execute a similar HTTP request to the "products" Binding but // Binding.Execute will instead return a single Product instance. // Note: how the RetT type param is set to just "Product". - "first_product": WrapBinding("first_product", NewBindingChain(func(binding Binding[[]Product, Product], args ...any) (request Request) { + "first_product": WrapBinding(NewBindingChain(func(binding Binding[[]Product, Product], args ...any) (request Request) { req, _ := http.NewRequest(http.MethodGet, "https://fakestoreapi.com/products?limit=1", nil) return HTTPRequest{req} }).SetResponseMethod(func(binding Binding[[]Product, Product], response []Product, args ...any) Product { return response[0] - })), + }).SetName("first_product")), }) // Then we can execute our "users" binding with a limit of 3... diff --git a/api/binding.go b/api/binding.go index 690a8e4..dd0c517 100644 --- a/api/binding.go +++ b/api/binding.go @@ -84,6 +84,14 @@ type Binding[ResT any, RetT any] interface { // chained with others when creating a new Binding through NewBindingChain. SetPaginated(paginated bool) Binding[ResT, RetT] + // Name returns the name of the Binding. When using NewBinding, NewBindingChain, or NewWrappedBinding, this will be + // set to whatever is returned by the following line of code: + // fmt.Sprintf("%T", binding) + // Where "binding" is the referred to Binding. + Name() string + // SetName sets the name of the Binding. This returns the Binding so it can be chained. + SetName(name string) Binding[ResT, RetT] + // Attrs returns the attributes for the Binding. These can be passed in when creating a Binding through the // NewBinding function. Attrs can be used in any of the implemented functions, and they are also passed to // Client.Run when Execute-ing the Binding. @@ -115,6 +123,8 @@ type bindingProto[ResT any, RetT any] struct { checkedParams bool paramsMethod BindingParamsMethod[ResT, RetT] paginated bool + name string + nameSet bool attrs map[string]any attrFuncs []Attr } @@ -244,8 +254,8 @@ func checkParams(params []BindingParam) (err error) { // checkParams will see if the given BindingParam(s) make sense. This means that: // - BindingParam(s) should have a unique BindingParam.name. -// - Non Required BindingParam(s) should trail after all Required BindingParam(s). -// - Variadic BindingParam(s) should trail after all non-required BindingParam(s) +// - Non Required BindingParam(s) should trail afterParamSet all Required BindingParam(s). +// - Variadic BindingParam(s) should trail afterParamSet all non-required BindingParam(s) // - Variadic BindingParam(s) should not be Required. // - Variadic BindingParam(s) should have DefaultValue that is an empty reflect.Slice/reflect.Array type. // @@ -387,7 +397,7 @@ func (b bindingProto[ResT, RetT]) Execute(client Client, args ...any) (response responseWrapperInt := responseWrapper.Interface() ctx := context.Background() - if err = client.Run(ctx, b.attrs, req, &responseWrapperInt); err != nil { + if err = client.Run(ctx, b.Name(), b.attrs, req, &responseWrapperInt); err != nil { err = errors.Wrapf(err, "could not Execute Binding %T", b) return } @@ -406,6 +416,20 @@ func (b bindingProto[ResT, RetT]) SetPaginated(paginated bool) Binding[ResT, Ret b.paginated = paginated return &b } + +func (b bindingProto[ResT, RetT]) Name() string { + if !b.nameSet { + return fmt.Sprintf("%T", b) + } + return b.name +} + +func (b bindingProto[ResT, RetT]) SetName(name string) Binding[ResT, RetT] { + b.name = name + b.nameSet = true + return &b +} + func (b bindingProto[ResT, RetT]) Attrs() map[string]any { return b.attrs } func (b bindingProto[ResT, RetT]) AddAttrs(attrs ...Attr) Binding[ResT, RetT] { @@ -516,7 +540,9 @@ func NewWrappedBinding[ResT any, RetT any]( paginated bool, attrs ...Attr, ) BindingWrapper { - return WrapBinding(name, NewBinding(request, wrap, unwrap, response, params, paginated, attrs...)) + b := NewBinding(request, wrap, unwrap, response, params, paginated, attrs...) + b.SetName(name) + return WrapBinding(b) } // NewBindingChain creates a new Binding for an API via a prototype that implements the Binding interface. Unlike the diff --git a/api/paginator.go b/api/paginator.go index 1d68c17..5045852 100644 --- a/api/paginator.go +++ b/api/paginator.go @@ -2,11 +2,101 @@ package api import ( "fmt" + "github.com/andygello555/gotils/v2/slices" + mapset "github.com/deckarep/golang-set/v2" "github.com/pkg/errors" "reflect" + "strings" "time" ) +// Afterable denotes whether a response type can be used in a Paginator for a Binding that takes an "after" parameter. +type Afterable interface { + // After returns the value of the "after" parameter that should be used for the next page of pagination. If this + // returns nil, then it is assumed that pagination has finished. + After() any +} + +type paginatorParamSet int + +const ( + unknownParamSet paginatorParamSet = iota + pageParamSet + afterParamSet +) + +func (pps paginatorParamSet) String() string { + return strings.TrimPrefix(pps.Set().String(), "Set") +} + +func (pps paginatorParamSet) GetPaginatorParamValue(params []BindingParam, resource any, page int) ([]any, error) { + switch pps { + case pageParamSet: + return []any{page}, nil + case afterParamSet: + if resource == nil { + for _, param := range params { + if param.name == "after" { + return []any{reflect.Zero(param.Type()).Interface()}, nil + } + } + return nil, fmt.Errorf("cannot find \"after\" parameter in parameters to use zero value for nil resource") + } + + if afterable, ok := resource.(Afterable); ok { + return []any{afterable.After()}, nil + } else { + return nil, fmt.Errorf("cannot find next \"after\" parameter as return type %T is not Afterable", resource) + } + default: + return nil, fmt.Errorf("%v is not a valid paginatorParamSet", pps) + } +} + +func (pps paginatorParamSet) InsertPaginatorParamValues(params []BindingParam, args []any, paginatorValues []any) []any { + ppsSet := pps.Set() + paramIdxs := make([]int, len(paginatorValues)) + + i := 0 + for paramNo, param := range params { + if ppsSet.Contains(param.name) { + paramIdxs[i] = paramNo + i++ + } + } + args = slices.AddElems(args, paginatorValues, paramIdxs...) + return args +} + +func (pps paginatorParamSet) Set() mapset.Set[string] { + switch pps { + case pageParamSet: + return mapset.NewSet("page") + case afterParamSet: + return mapset.NewSet("after") + default: + return mapset.NewSet[string]() + } +} + +func (pps paginatorParamSet) Sets() []paginatorParamSet { + return []paginatorParamSet{pageParamSet, afterParamSet} +} + +func checkPaginatorParams(params []BindingParam) paginatorParamSet { + paramNameSet := mapset.NewSet(slices.Comprehension(params, func(idx int, value BindingParam, arr []BindingParam) string { + return value.name + })...) + for _, pps := range unknownParamSet.Sets() { + if pps.Set().Difference(paramNameSet).Cardinality() == 0 { + return pps + } + } + return unknownParamSet +} + +var limitParamNames = mapset.NewSet[string]("limit", "count") + // Paginator can fetch resources from a Binding that is paginated. Use NewPaginator or NewTypedPaginator to create a new // one for a given Binding. type Paginator[ResT any, RetT any] interface { @@ -19,16 +109,23 @@ type Paginator[ResT any, RetT any] interface { Next() error // All returns all the return values for the Binding at once. All() (RetT, error) + // Pages fetches the given number of pages from the Binding whilst appending each response slice together. + Pages(pages int) (RetT, error) } type typedPaginator[ResT any, RetT any] struct { - client Client - binding Binding[ResT, RetT] - waitTime time.Duration - args []any - returnType reflect.Type - page int - currentPage RetT + client Client + rateLimitedClient RateLimitedClient + usingRateLimitedClient bool + binding Binding[ResT, RetT] + params []BindingParam + paramSet paginatorParamSet + limitArg *float64 + waitTime time.Duration + args []any + returnType reflect.Type + page int + currentPage RetT } func (p *typedPaginator[ResT, RetT]) Continue() bool { @@ -37,13 +134,117 @@ func (p *typedPaginator[ResT, RetT]) Continue() bool { func (p *typedPaginator[ResT, RetT]) Page() RetT { return p.currentPage } +func paginatorCheckRateLimit( + client Client, + bindingName string, + limitArg **float64, + page int, + currentPage any, + params []BindingParam, + args []any, +) (ignoreFirstRequest bool, ok bool, err error) { + cont := func() bool { + return page == 1 || reflect.ValueOf(currentPage).Len() > 0 + } + + var rateLimitedClient RateLimitedClient + if rateLimitedClient, ok = client.(RateLimitedClient); ok { + rl := rateLimitedClient.LatestRateLimit(bindingName) + if rl != nil && rl.Reset().After(time.Now()) { + sleepTime := rl.Reset().Sub(time.Now()) + switch rl.Type() { + case RequestRateLimit: + if rl.Remaining() == 0 { + time.Sleep(sleepTime) + } + case ResourceRateLimit: + if reflect.ValueOf(currentPage).Len() > rl.Remaining() { + time.Sleep(sleepTime) + } else if cont() { + if limitArg == nil { + for i, param := range params { + if !limitParamNames.Contains(param.name) { + continue + } + + var argVal reflect.Value + if i < len(args) { + argVal = reflect.ValueOf(args[i]) + } else if !param.required && !param.variadic { + argVal = reflect.ValueOf(param.defaultValue) + } + + var val float64 + switch { + case argVal.CanInt(): + val = float64(argVal.Int()) + case argVal.CanUint(): + val = float64(argVal.Uint()) + case argVal.CanFloat(): + val = argVal.Float() + default: + continue + } + **limitArg = val + // Break out of the loop if we have found a limit argument + break + } + } + + if **limitArg > float64(rl.Remaining()) { + time.Sleep(sleepTime) + } + } + } + } else if page == 1 { + ignoreFirstRequest = true + } else { + err = fmt.Errorf( + "could not get the latest RateLimit/RateLimit has expired but we are on page %d, check Client.Run", + page, + ) + return + } + } + return +} + func (p *typedPaginator[ResT, RetT]) Next() (err error) { - args := []any{p.page} - args = append(args, p.args...) - if p.currentPage, err = p.binding.Execute(p.client, args...); err != nil { - err = errors.Wrapf(err, "error occurred on page no. %d", p.page) + var paginatorValues []any + if paginatorValues, err = p.paramSet.GetPaginatorParamValue(p.params, p.currentPage, p.page); err != nil { + err = errors.Wrapf( + err, "cannot get paginator param values from %T value on page %d", + p.currentPage, p.page, + ) return } + args := p.paramSet.InsertPaginatorParamValues(p.params, p.args, paginatorValues) + + var ignoreFirstRequest bool + execute := func() (ret RetT, err error) { + if ignoreFirstRequest, p.usingRateLimitedClient, err = paginatorCheckRateLimit( + p.client, p.binding.Name(), &p.limitArg, p.page, p.currentPage, p.params, p.args, + ); err != nil { + return + } + return p.binding.Execute(p.client, args...) + } + + if p.currentPage, err = execute(); err != nil { + if !ignoreFirstRequest { + err = errors.Wrapf(err, "error occurred on page no. %d", p.page) + return + } + + if p.currentPage, err = execute(); err != nil { + err = errors.Wrapf( + err, "error occurred on page no. %d, after ignoring the first request due to no rate limit", + p.page, + ) + return + } + } + p.page++ if p.waitTime != 0 { time.Sleep(p.waitTime) @@ -62,9 +263,40 @@ func (p *typedPaginator[ResT, RetT]) All() (RetT, error) { return pages.Interface().(RetT), nil } +func (p *typedPaginator[ResT, RetT]) Pages(pageNo int) (RetT, error) { + pages := reflect.New(p.returnType).Elem() + for p.Continue() && p.page <= pageNo { + if err := p.Next(); err != nil { + return pages.Interface().(RetT), err + } + pages = reflect.AppendSlice(pages, reflect.ValueOf(p.Page())) + } + return pages.Interface().(RetT), nil +} + // NewTypedPaginator creates a new type aware Paginator using the given Client, wait time.Duration, and arguments for -// the given Binding. The first given argument should not be the page parameter. Args should contain everything after the -// page parameter that would usually be passed to a standard Binding.Execute. +// the given Binding. The given Binding's Binding.Paginated method must return true, and the return type (RetT) of the +// Binding must be a slice-type, otherwise an appropriate error will be returned. +// +// The Paginator requires one of the following sets of BindingParam(s) taken by the given Binding: +// 1. ("page",): a singular page argument where each time Paginator.Next is called the page will be incremented +// 2. ("after",): a singular after argument where each time Paginator.Next is called the Afterable.After method will be +// called on the returned response and the returned value will be set as the "after" parameter for the next +// Binding.Execute. This requires the RetT to implement the Afterable interface. +// +// The sets of BindingParam(s) shown above are given in priority order. This means that a Binding that defines multiple +// BindingParam(s) that exist within these sets, only the first complete set will be taken. +// +// The args given to NewTypedPaginator should not include the set of BindingParam(s) (listed above), that are going to +// be used to paginate the binding. +// +// If the given Client also implements RateLimitedClient then the given waitTime argument will be ignored in favour of +// waiting (or not) until the RateLimit for the given Binding resets. If the RateLimit that is returned by +// RateLimitedClient.LatestRateLimit is of type ResourceRateLimit, and the Paginator is on the first page. The following +// parameter arguments will be checked for a limit/count value to see whether there is enough RateLimit.Remaining (in +// priority order): +// 1. "limit" +// 2. "count" func NewTypedPaginator[ResT any, RetT any](client Client, waitTime time.Duration, binding Binding[ResT, RetT], args ...any) (paginator Paginator[ResT, RetT], err error) { if !binding.Paginated() { err = fmt.Errorf("cannot create typed Paginator as Binding is not pagenatable") @@ -74,11 +306,21 @@ func NewTypedPaginator[ResT any, RetT any](client Client, waitTime time.Duration p := &typedPaginator[ResT, RetT]{ client: client, binding: binding, + params: binding.Params(), waitTime: waitTime, args: args, page: 1, } + p.rateLimitedClient, p.usingRateLimitedClient = client.(RateLimitedClient) + if p.paramSet = checkPaginatorParams(p.params); p.paramSet == unknownParamSet { + err = fmt.Errorf( + "cannot create typed Paginator as we couldn't find any paginateable params, need one of the following sets of params %v", + unknownParamSet.Sets(), + ) + return + } + returnType := reflect.ValueOf(new(RetT)).Elem().Type() switch returnType.Kind() { case reflect.Slice, reflect.Array: @@ -94,35 +336,62 @@ func NewTypedPaginator[ResT any, RetT any](client Client, waitTime time.Duration } type paginator struct { - client Client - binding *BindingWrapper - waitTime time.Duration - args []any - returnType reflect.Type - page int - currentPage []any + client Client + rateLimitedClient RateLimitedClient + usingRateLimitedClient bool + binding *BindingWrapper + params []BindingParam + paramSet paginatorParamSet + limitArg *float64 + waitTime time.Duration + args []any + returnType reflect.Type + page int + currentPage any } -func (p *paginator) Continue() bool { - return p.page == 1 || len(p.currentPage) > 0 -} - -func (p *paginator) Page() []any { return p.currentPage } +func (p *paginator) Continue() bool { return p.page == 1 || reflect.ValueOf(p.currentPage).Len() > 0 } +func (p *paginator) Page() any { return p.currentPage } func (p *paginator) Next() (err error) { - args := []any{p.page} - args = append(args, p.args...) + var paginatorValues []any + if paginatorValues, err = p.paramSet.GetPaginatorParamValue(p.params, p.currentPage, p.page); err != nil { + err = errors.Wrapf( + err, "cannot get paginator param values from %T value on page %d", + p.currentPage, p.page, + ) + return + } + args := p.paramSet.InsertPaginatorParamValues(p.params, p.args, paginatorValues) + fmt.Println("paginatorValues", paginatorValues, len(paginatorValues)) + fmt.Println("args", args, len(args)) - var page any - if page, err = p.binding.Execute(p.client, args...); err != nil { - err = errors.Wrapf(err, "error occurred on page no. %d", p.page) + var ignoreFirstRequest bool + execute := func() (err error) { + if ignoreFirstRequest, p.usingRateLimitedClient, err = paginatorCheckRateLimit( + p.client, p.binding.Name(), &p.limitArg, p.page, p.currentPage, p.params, p.args, + ); err != nil { + return + } + + if p.currentPage, err = p.binding.Execute(p.client, args...); err != nil { + err = errors.Wrapf(err, "error occurred on page no. %d", p.page) + } return } - s := reflect.ValueOf(page) - p.currentPage = make([]any, 0) - for i := 0; i < s.Len(); i++ { - p.currentPage = append(p.currentPage, s.Index(i).Interface()) + if err = execute(); err != nil { + if !ignoreFirstRequest { + return + } + + if err = execute(); err != nil { + err = errors.Wrapf( + err, "error occurred on page no. %d, after ignoring the first request due to no rate limit", + p.page, + ) + return + } } p.page++ @@ -132,18 +401,32 @@ func (p *paginator) Next() (err error) { return } -func (p *paginator) All() ([]any, error) { - pages := make([]any, 0) +func (p *paginator) All() (any, error) { + pages := reflect.New(p.returnType).Elem() for p.Continue() { if err := p.Next(); err != nil { return pages, err } - pages = append(pages, p.Page()...) + pages = reflect.AppendSlice(pages, reflect.ValueOf(p.Page())) + } + return pages.Interface(), nil +} + +func (p *paginator) Pages(pageNo int) (any, error) { + pages := reflect.New(p.returnType).Elem() + for p.Continue() && p.page <= pageNo { + if err := p.Next(); err != nil { + return pages, err + } + pages = reflect.AppendSlice(pages, reflect.ValueOf(p.Page())) } - return pages, nil + return pages.Interface(), nil } -func NewPaginator(client Client, waitTime time.Duration, binding BindingWrapper, args ...any) (pag Paginator[[]any, []any], err error) { +// NewPaginator creates an un-typed Paginator for the given BindingWrapper. It creates a Paginator in a similar way as +// NewTypedPaginator, except the return type of the Paginator is []any. See NewTypedPaginator for more information on +// Paginator construction. +func NewPaginator(client Client, waitTime time.Duration, binding BindingWrapper, args ...any) (pag Paginator[any, any], err error) { if !binding.Paginated() { err = fmt.Errorf("cannot create a Paginator as Binding is not pagenatable") return @@ -152,18 +435,28 @@ func NewPaginator(client Client, waitTime time.Duration, binding BindingWrapper, p := &paginator{ client: client, binding: &binding, + params: binding.Params(), waitTime: waitTime, args: args, page: 1, } + p.rateLimitedClient, p.usingRateLimitedClient = client.(RateLimitedClient) + if p.paramSet = checkPaginatorParams(p.params); p.paramSet == unknownParamSet { + err = fmt.Errorf( + "cannot create a Paginator as we couldn't find any paginateable params, need one of the following sets of params %v", + unknownParamSet.Sets(), + ) + return + } + switch binding.returnType.Kind() { case reflect.Slice, reflect.Array: p.returnType = binding.returnType pag = p default: err = fmt.Errorf( - "cannot create typed Paginator for Binding[%v, %v] that has a non-slice/array return type", + "cannot create a Paginator for Binding[%v, %v] that has a non-slice/array return type", binding.responseType, binding.returnType, ) } diff --git a/config.go b/config.go index d440fa4..300ffeb 100644 --- a/config.go +++ b/config.go @@ -9,6 +9,7 @@ import ( "github.com/andygello555/game-scout/db/models" "github.com/andygello555/game-scout/email" "github.com/andygello555/game-scout/monday" + "github.com/andygello555/game-scout/reddit" task "github.com/andygello555/game-scout/tasks" myTwitter "github.com/andygello555/game-scout/twitter" "github.com/antonmedv/expr" @@ -345,6 +346,24 @@ func (c *TwitterConfig) TwitterQuery() string { return fmt.Sprintf("(%s) %s", query, strings.Join(hashtags, " ")) } +type RedditRateLimits struct { + RequestsPerMonth uint64 `json:"requests_per_month"` + RequestsPerWeek uint64 `json:"requests_per_week"` + RequestsPerDay uint64 `json:"requests_per_day"` + RequestsPerHour uint64 `json:"requests_per_hour"` + RequestsPerMinute uint64 `json:"requests_per_minute"` + RequestsPerSecond uint64 `json:"requests_per_second"` + TimePerRequest Duration `json:"time_per_request"` +} + +func (rl *RedditRateLimits) LimitPerMonth() uint64 { return rl.RequestsPerMonth } +func (rl *RedditRateLimits) LimitPerWeek() uint64 { return rl.RequestsPerWeek } +func (rl *RedditRateLimits) LimitPerDay() uint64 { return rl.RequestsPerDay } +func (rl *RedditRateLimits) LimitPerHour() uint64 { return rl.RequestsPerHour } +func (rl *RedditRateLimits) LimitPerMinute() uint64 { return rl.RequestsPerMinute } +func (rl *RedditRateLimits) LimitPerSecond() uint64 { return rl.RequestsPerSecond } +func (rl *RedditRateLimits) LimitPerRequest() time.Duration { return rl.TimePerRequest.Duration } + type RedditConfig struct { // PersonalUseScript is the ID of the personal use script that was set up for game-scout scraping. PersonalUseScript string `json:"personal_use_script"` @@ -358,14 +377,17 @@ type RedditConfig struct { Password string `json:"password"` // Subreddits is the list of subreddits to scrape in the form: "GameDevelopment" (sans "r/" prefix). Subreddits []string `json:"subreddits"` + // RateLimits contains the rate per-unit of time. + RateLimits *RedditRateLimits `json:"rate_limits"` } -func (rc *RedditConfig) RedditPersonalUseScript() string { return rc.PersonalUseScript } -func (rc *RedditConfig) RedditSecret() string { return rc.Secret } -func (rc *RedditConfig) RedditUserAgent() string { return rc.UserAgent } -func (rc *RedditConfig) RedditUsername() string { return rc.Username } -func (rc *RedditConfig) RedditPassword() string { return rc.Password } -func (rc *RedditConfig) RedditSubreddits() []string { return rc.Subreddits } +func (rc *RedditConfig) RedditPersonalUseScript() string { return rc.PersonalUseScript } +func (rc *RedditConfig) RedditSecret() string { return rc.Secret } +func (rc *RedditConfig) RedditUserAgent() string { return rc.UserAgent } +func (rc *RedditConfig) RedditUsername() string { return rc.Username } +func (rc *RedditConfig) RedditPassword() string { return rc.Password } +func (rc *RedditConfig) RedditSubreddits() []string { return rc.Subreddits } +func (rc *RedditConfig) RedditRateLimits() reddit.RateLimitConfig { return rc.RateLimits } type MondayMappingConfig struct { // ModelName is the name of the model that this MondayMappingConfig is for. This should either be "models.SteamApp" diff --git a/go.mod b/go.mod index 1d85dd0..0c12fc7 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae github.com/RichardKnop/machinery v1.10.6 github.com/anaskhan96/soup v1.2.5 - github.com/andygello555/gotils/v2 v2.1.2 + github.com/andygello555/gotils/v2 v2.1.5 github.com/deckarep/golang-set/v2 v2.1.0 github.com/g8rswimmer/go-twitter/v2 v2.1.2 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index a69e5c4..bb5083d 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,10 @@ github.com/andygello555/gotils/v2 v2.1.0 h1:/nL14tDKiVyVNY/iBU8nQZHVYmQF4A5WoUbW github.com/andygello555/gotils/v2 v2.1.0/go.mod h1:w+9GpBTgvBubwl3lrzmE3dUDBagzQQ+R35W9p92dAnY= github.com/andygello555/gotils/v2 v2.1.2 h1:dpZfws5X3y0jxwaoZcSMBDxr0T0KeeenNNjHW029Kf0= github.com/andygello555/gotils/v2 v2.1.2/go.mod h1:w+9GpBTgvBubwl3lrzmE3dUDBagzQQ+R35W9p92dAnY= +github.com/andygello555/gotils/v2 v2.1.4 h1:VuHTJRJKzIeFydvykBKwPPPF2mB8viskXyZj1dHBQ2U= +github.com/andygello555/gotils/v2 v2.1.4/go.mod h1:w+9GpBTgvBubwl3lrzmE3dUDBagzQQ+R35W9p92dAnY= +github.com/andygello555/gotils/v2 v2.1.5 h1:W+CVLajyQDCaZ7u5jLTBpBGQrNYA1YqRISuD8+DZEyM= +github.com/andygello555/gotils/v2 v2.1.5/go.mod h1:w+9GpBTgvBubwl3lrzmE3dUDBagzQQ+R35W9p92dAnY= github.com/antonmedv/expr v1.12.0 h1:hIOn7jjY86E09PXvn9zgdt2FbWVru0ud9Rm5DbNoYNw= github.com/antonmedv/expr v1.12.0/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= diff --git a/main.go b/main.go index aeae4f0..70b2de3 100644 --- a/main.go +++ b/main.go @@ -1171,12 +1171,13 @@ func main() { return all, err } } - if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding("getboards", monday.GetBoards), args...); err != nil { + + if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding(monday.GetBoards), args...); err != nil { return nil, err } else { all, err := paginator.All() if err == nil { - fmt.Println("paginator items:", len(all)) + fmt.Println("paginator items:", reflect.ValueOf(all).Len()) } return all, err } @@ -1238,12 +1239,12 @@ func main() { return all, err } } - if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding("getitems", monday.GetItems), args...); err != nil { + if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding(monday.GetItems), args...); err != nil { return nil, err } else { all, err := paginator.All() if err == nil { - fmt.Println("paginator items:", len(all)) + fmt.Println("paginator items:", reflect.ValueOf(all).Len()) } return all, err } @@ -1306,12 +1307,12 @@ func main() { return all, err } } - if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding("getgames", models.GetGamesFromMonday), globalConfig.Monday, db.DB); err != nil { + if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding(models.GetGamesFromMonday), globalConfig.Monday, db.DB); err != nil { return nil, err } else { all, err := paginator.All() if err == nil { - fmt.Println("paginator items:", len(all)) + fmt.Println("paginator items:", reflect.ValueOf(all).Len()) } return all, err } @@ -1334,12 +1335,12 @@ func main() { return all, err } } - if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding("getgames", models.GetSteamAppsFromMonday), globalConfig.Monday, db.DB); err != nil { + if paginator, err := api.NewPaginator(monday.DefaultClient, time.Millisecond*100, api.WrapBinding(models.GetSteamAppsFromMonday), globalConfig.Monday, db.DB); err != nil { return nil, err } else { all, err := paginator.All() if err == nil { - fmt.Println("paginator items:", len(all)) + fmt.Println("paginator items:", reflect.ValueOf(all).Len()) } return all, err } @@ -1415,9 +1416,10 @@ func main() { Usage: "an arg to execute the binding with", Value: &cli.StringSlice{}, }, - cli.BoolFlag{ - Name: "all", - Usage: "fetch all the resources from the binding using a Paginator", + cli.IntFlag{ + Name: "pages", + Required: false, + Usage: "fetch the given number of pages from the binding using a Paginator", }, }, Action: func(c *cli.Context) (err error) { @@ -1428,13 +1430,25 @@ func main() { response any args []any ) - if args, err = reddit.API.ArgsFromStrings(bindingName, c.StringSlice("arg")...); err != nil { + bindingWrapper, _ := reddit.API.Binding(bindingName) + if args, err = bindingWrapper.ArgsFromStrings(c.StringSlice("arg")...); err != nil { return cli.NewExitError(err.Error(), 1) } - if response, err = reddit.API.Execute(bindingName, args...); err != nil { - return cli.NewExitError(err.Error(), 1) + + if c.Int("pages") > 0 { + var paginator api.Paginator[any, any] + if paginator, err = bindingWrapper.Paginator(reddit.API.Client, time.Millisecond*500, args...); err != nil { + return cli.NewExitError(err.Error(), 1) + } + if response, err = paginator.Pages(c.Int("pages")); err != nil { + return cli.NewExitError(err.Error(), 1) + } + fmt.Println("paginator items:", reflect.ValueOf(response).Len()) + } else { + if response, err = bindingWrapper.Execute(reddit.API.Client, args...); err != nil { + return cli.NewExitError(err.Error(), 1) + } } - bindingWrapper, _ := reddit.API.Binding(bindingName) fmt.Printf("%v(%v): %+v\n", bindingWrapper, args, response) } return diff --git a/monday/client.go b/monday/client.go index 6d0b2c4..b0b89b4 100644 --- a/monday/client.go +++ b/monday/client.go @@ -361,7 +361,7 @@ type Client struct { *graphql.Client } -func (c *Client) Run(ctx context.Context, attrs map[string]any, req api.Request, res any) error { +func (c *Client) Run(ctx context.Context, bindingName string, attrs map[string]any, req api.Request, res any) error { config := attrs["config"].(Config) req.Header().Set("Authorization", config.MondayToken()) req.Header().Set("Content-Type", "application/json") diff --git a/reddit/bindings.go b/reddit/bindings.go index 7c4759b..1c45086 100644 --- a/reddit/bindings.go +++ b/reddit/bindings.go @@ -10,7 +10,7 @@ import ( ) var API = api.NewAPI(nil, api.Schema{ - "access_token": api.WrapBinding("access_token", api.NewBindingChain(func(binding api.Binding[accessTokenResponse, AccessToken], args ...any) (request api.Request) { + "access_token": api.WrapBinding(api.NewBindingChain(func(binding api.Binding[accessTokenResponse, AccessToken], args ...any) (request api.Request) { client := binding.Attrs()["client"].(*Client) data := url.Values{ "grant_type": {"password"}, @@ -31,34 +31,35 @@ var API = api.NewAPI(nil, api.Schema{ return *client.AccessToken }).AddAttrs( func(client api.Client) (string, any) { return "client", client }, - func(client api.Client) (string, any) { return "access_token", true }, - func(client api.Client) (string, any) { return "binding", "access_token" }, - )), + ).SetName("access_token")), - "me": api.WrapBinding("me", api.NewBindingChain(func(binding api.Binding[Me, Me], args ...any) (request api.Request) { + "me": api.WrapBinding(api.NewBindingChain(func(binding api.Binding[Me, Me], args ...any) (request api.Request) { req, _ := http.NewRequest(http.MethodGet, "https://oauth.reddit.com/api/v1/me", http.NoBody) return api.HTTPRequest{Request: req} - }).AddAttrs( - func(client api.Client) (string, any) { return "binding", "me" }, - )), + }).SetName("me")), - "top": api.WrapBinding("top", api.NewBindingChain(func(binding api.Binding[listingWrapper, Listing], args ...any) (request api.Request) { + "top": api.WrapBinding(api.NewBindingChain(func(binding api.Binding[listingWrapper, Listings], args ...any) (request api.Request) { subreddit := args[0].(string) timePeriod := args[1].(TimePeriod) + after := args[2].(string) + limit := args[3].(int) req, _ := http.NewRequest( http.MethodGet, fmt.Sprintf( - "https://oauth.reddit.com/r/%s/top.json?t=%s", - subreddit, timePeriod, + "https://oauth.reddit.com/r/%s/top.json?t=%s&limit=%d&after=%s", + subreddit, timePeriod, limit, after, ), http.NoBody, ) return api.HTTPRequest{Request: req} - }).SetResponseMethod(func(binding api.Binding[listingWrapper, Listing], response listingWrapper, args ...any) Listing { - return response.Data - }).SetParamsMethod(func(binding api.Binding[listingWrapper, Listing]) []api.BindingParam { - return api.Params("subreddit", "", true, "timePeriod", Day) - }).AddAttrs( - func(client api.Client) (string, any) { return "binding", "top" }, - )), + }).SetResponseMethod(func(binding api.Binding[listingWrapper, Listings], response listingWrapper, args ...any) Listings { + return []Listing{response.Data} + }).SetParamsMethod(func(binding api.Binding[listingWrapper, Listings]) []api.BindingParam { + return api.Params( + "subreddit", "", true, + "timePeriod", Day, + "after", "", + "limit", 25, + ) + }).SetPaginated(true).SetName("top")), }) diff --git a/reddit/client.go b/reddit/client.go index 6d9a734..616daa5 100644 --- a/reddit/client.go +++ b/reddit/client.go @@ -40,9 +40,9 @@ func (at AccessToken) Headers() http.Header { } type RateLimit struct { - Used int64 - Remaining int64 - Reset time.Time + used int64 + remaining int64 + reset time.Time } // RateLimitFromHeader returns a new RateLimit instance from the given http.Header by fetching and parsing the values @@ -52,12 +52,12 @@ type RateLimit struct { // - X-Ratelimit-Used func RateLimitFromHeader(header http.Header) (rl *RateLimit, err error) { rl = &RateLimit{} - if rl.Remaining, err = strconv.ParseInt(header.Get("X-Ratelimit-Remaining"), 10, 64); err != nil { + if rl.remaining, err = strconv.ParseInt(header.Get("X-Ratelimit-Remaining"), 10, 64); err != nil { err = errors.Wrap(err, "cannot parse \"X-Ratelimit-Remaing\" header to int") return } - if rl.Used, err = strconv.ParseInt(header.Get("X-Ratelimit-Used"), 10, 64); err != nil { + if rl.used, err = strconv.ParseInt(header.Get("X-Ratelimit-Used"), 10, 64); err != nil { err = errors.Wrap(err, "cannot parse \"X-Ratelimit-Used\" header to int") return } @@ -67,14 +67,58 @@ func RateLimitFromHeader(header http.Header) (rl *RateLimit, err error) { err = errors.Wrap(err, "cannot \"X-Ratelimit-Reset\" header to int") return } - rl.Reset = time.Now().Add(time.Second * time.Duration(resetSeconds)) + rl.reset = time.Now().Add(time.Second * time.Duration(resetSeconds)) return } +func (r *RateLimit) Reset() time.Time { return r.reset } +func (r *RateLimit) Remaining() int { return int(r.remaining) } +func (r *RateLimit) Used() int { return int(r.used) } +func (r *RateLimit) Type() api.RateLimitType { return api.RequestRateLimit } + +func (r *RateLimit) String() string { + return fmt.Sprintf( + "%d/%d used (%.2f%%) resets %s", + r.used, r.used+r.remaining, (float64(r.used)/(float64(r.used)+float64(r.remaining)))*100.0, r.reset.String(), + ) +} + type Client struct { Config Config AccessToken *AccessToken - RateLimits sync.Map + // rateLimits is a sync.Map of binding names to references to api.RateLimit(s). I.e. + // map[string]api.RateLimit + rateLimits sync.Map +} + +func (c *Client) RateLimits() *sync.Map { return &c.rateLimits } + +func (c *Client) AddRateLimit(bindingName string, rateLimit api.RateLimit) { + if rlAny, ok := c.rateLimits.Load(bindingName); ok { + // If there is already a RateLimit for this binding, check if the rate-limit returned by the current request is + // newer. + if rateLimit.Reset().After(rlAny.(api.RateLimit).Reset()) { + c.rateLimits.Store(bindingName, rateLimit) + } + } else { + c.rateLimits.Store(bindingName, rateLimit) + } +} + +func (c *Client) LatestRateLimit(bindingName string) api.RateLimit { + var latestRateLimit api.RateLimit + c.rateLimits.Range(func(key, value any) bool { + rl := value.(api.RateLimit) + if latestRateLimit == nil { + latestRateLimit = rl + } else { + if latestRateLimit.Reset().Before(rl.Reset()) { + latestRateLimit = rl + } + } + return true + }) + return latestRateLimit } func CreateClient(config Config) { @@ -91,13 +135,13 @@ func (c *Client) RefreshToken() (err error) { return } -func (c *Client) Run(ctx context.Context, attrs map[string]any, req api.Request, res any) (err error) { +func (c *Client) Run(ctx context.Context, bindingName string, attrs map[string]any, req api.Request, res any) (err error) { request := req.(api.HTTPRequest).Request request.Header.Set("User-Agent", c.Config.RedditUserAgent()) - binding := attrs["binding"].(string) + fmt.Printf("latest rate limit for %s = %v\n", bindingName, c.LatestRateLimit(bindingName)) // If we are not currently fetching the AccessToken, then we will check if we need to refresh the access token. - if binding != "access_token" { + if bindingName != "access_token" { if err = c.RefreshToken(); err != nil { return } @@ -117,7 +161,7 @@ func (c *Client) Run(ctx context.Context, attrs map[string]any, req api.Request, return } - var rl *RateLimit + var rl api.RateLimit if rl, err = RateLimitFromHeader(response.Header); err != nil { err = errors.Wrapf( err, "could not parse RateLimit from response headers to %s %s", @@ -125,16 +169,15 @@ func (c *Client) Run(ctx context.Context, attrs map[string]any, req api.Request, ) return } - - if rlAny, ok := c.RateLimits.Load(binding); ok { - // If there is already a RateLimit for this binding, check if the rate-limit returned by the current request is - // newer. - if rl.Reset.After(rlAny.(*RateLimit).Reset) { - c.RateLimits.Store(binding, rl) - } - } else { - c.RateLimits.Store(binding, rl) - } + c.AddRateLimit(bindingName, rl) + + i := 0 + c.rateLimits.Range(func(key, value any) bool { + fmt.Printf("%d: %s - %q", i+1, key, value) + i++ + return true + }) + fmt.Println() if response.Body != nil { defer func(body io.ReadCloser) { diff --git a/reddit/config.go b/reddit/config.go index 4a9791c..28b144b 100644 --- a/reddit/config.go +++ b/reddit/config.go @@ -1,5 +1,17 @@ package reddit +import "time" + +type RateLimitConfig interface { + LimitPerMonth() uint64 + LimitPerWeek() uint64 + LimitPerDay() uint64 + LimitPerHour() uint64 + LimitPerMinute() uint64 + LimitPerSecond() uint64 + LimitPerRequest() time.Duration +} + type Config interface { RedditPersonalUseScript() string RedditSecret() string @@ -7,4 +19,5 @@ type Config interface { RedditUsername() string RedditPassword() string RedditSubreddits() []string + RedditRateLimits() RateLimitConfig } diff --git a/reddit/types.go b/reddit/types.go index e5ced02..b6b0ee8 100644 --- a/reddit/types.go +++ b/reddit/types.go @@ -1,5 +1,38 @@ package reddit +import ( + "encoding/json" + "fmt" + "github.com/anaskhan96/soup" + "html" + "strconv" + "time" +) + +const ( + kindComment = "t1" + kindUser = "t2" + kindPost = "t3" + kindMessage = "t4" + kindSubreddit = "t5" + kindTrophy = "t6" + kindListing = "Listing" + kindSubredditSettings = "subreddit_settings" + kindKarmaList = "KarmaList" + kindTrophyList = "TrophyList" + kindUserList = "UserList" + kindMore = "more" + kindLiveThread = "LiveUpdateEvent" + kindLiveThreadUpdate = "LiveUpdate" + kindModAction = "modaction" + kindMulti = "LabeledMulti" + kindMultiDescription = "LabeledMultiDescription" + kindWikiPage = "wikipage" + kindWikiPageListing = "wikipagelisting" + kindWikiPageSettings = "wikipagesettings" + kindStyleSheet = "stylesheet" +) + type TimePeriod string const ( @@ -34,6 +67,47 @@ func (tp TimePeriod) String() string { return string(tp) } +// Timestamp represents a time that can be unmarshalled from a JSON string +// formatted as either an RFC3339 or Unix timestamp. +type Timestamp struct { + time.Time +} + +// MarshalJSON implements the json.Marshaler interface. +func (t *Timestamp) MarshalJSON() ([]byte, error) { + if t == nil || t.Time.IsZero() { + return []byte(`false`), nil + } + + parsed := t.Time.Format(time.RFC3339) + return []byte(`"` + parsed + `"`), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +// Time is expected in RFC3339 or Unix format. +func (t *Timestamp) UnmarshalJSON(data []byte) (err error) { + str := string(data) + + // "edited" for posts and comments is either false, or a timestamp. + if str == "false" { + return + } + + f, err := strconv.ParseFloat(str, 64) + if err == nil { + t.Time = time.Unix(int64(f), 0).UTC() + } else { + t.Time, err = time.Parse(`"`+time.RFC3339+`"`, str) + } + + return +} + +// Equal reports whether t and u are equal based on time.Equal +func (t Timestamp) Equal(u Timestamp) bool { + return t.Time.Equal(u.Time) +} + type Me struct { IsEmployee bool `json:"is_employee"` SeenLayoutSwitch bool `json:"seen_layout_switch"` @@ -190,154 +264,348 @@ type listingWrapper struct { Kind string `json:"kind"` } +type Thing struct { + Kind string `json:"kind"` + Data any `json:"data"` +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *Thing) UnmarshalJSON(b []byte) (err error) { + root := new(struct { + Kind string `json:"kind"` + Data json.RawMessage `json:"data"` + }) + + if err = json.Unmarshal(b, root); err != nil { + return err + } + + t.Kind = root.Kind + var v any + switch t.Kind { + case kindListing: + v = new(Listing) + case kindPost: + v = new(Post) + case kindComment: + v = new(Comment) + case kindMore: + v = new(More) + default: + err = fmt.Errorf("unrecognised kind %s", t.Kind) + return + } + + if err = json.Unmarshal(root.Data, v); err != nil { + return err + } + t.Data = v + return +} + +type Things struct { + Comments []*Comment + Posts []*Post + Mores []*More +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (t *Things) UnmarshalJSON(b []byte) error { + var things []Thing + if err := json.Unmarshal(b, &things); err != nil { + return err + } + + t.add(things...) + return nil +} + +func (t *Things) add(things ...Thing) { + for _, thing := range things { + switch v := thing.Data.(type) { + case *Post: + t.Posts = append(t.Posts, v) + case *Comment: + t.Comments = append(t.Comments, v) + case *More: + t.Mores = append(t.Mores, v) + } + } +} + +type Listings []Listing + +func (l Listings) After() any { + if len(l) > 0 { + return l[len(l)-1].After() + } + return "" +} + type Listing struct { - After string `json:"after"` - Before string `json:"before"` - Children []ListingChild `json:"children"` - Dist int `json:"dist"` - GeoFilter string `json:"geo_filter"` - Modhash string `json:"modhash"` + after string `json:"after"` + Before string `json:"before"` + Children Things `json:"children"` + Dist int `json:"dist"` + GeoFilter string `json:"geo_filter"` + Modhash string `json:"modhash"` } -type ListingChild struct { - Data ListingData `json:"data"` - Kind string `json:"kind"` +func (l *Listing) After() any { return l.after } + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (l *Listing) UnmarshalJSON(b []byte) error { + root := new(struct { + After string `json:"after"` + Before string `json:"before"` + Children Things `json:"children"` + Dist int `json:"dist"` + GeoFilter string `json:"geo_filter"` + Modhash string `json:"modhash"` + }) + + err := json.Unmarshal(b, root) + if err != nil { + return err + } + + l.after = root.After + l.Before = root.Before + l.Children = root.Children + l.Dist = root.Dist + l.GeoFilter = root.GeoFilter + l.Modhash = root.Modhash + + return nil } -type ListingGildings struct { +// Post represents a post on a subreddit. +type Post struct { + ID string `json:"id,omitempty"` + FullID string `json:"name,omitempty"` + Created *Timestamp `json:"created_utc,omitempty"` + Edited *Timestamp `json:"edited,omitempty"` + Permalink string `json:"permalink,omitempty"` + URL string `json:"url,omitempty"` + Title string `json:"title,omitempty"` + Body string `json:"selftext,omitempty"` + BodyHTML string `json:"selftext_html,omitempty"` + // Indicates if you've upvoted/downvoted (true/false). + // If neither, it will be nil. + Likes *bool `json:"likes"` + Ups int `json:"ups"` + Downs int `json:"downs"` + Score int `json:"score"` + UpvoteRatio float32 `json:"upvote_ratio"` + NumberOfComments int `json:"num_comments"` + SubredditName string `json:"subreddit,omitempty"` + SubredditNamePrefixed string `json:"subreddit_name_prefixed,omitempty"` + SubredditID string `json:"subreddit_id,omitempty"` + SubredditSubscribers int `json:"subreddit_subscribers"` + Author string `json:"author,omitempty"` + AuthorID string `json:"author_fullname,omitempty"` + Spoiler bool `json:"spoiler"` + Locked bool `json:"locked"` + NSFW bool `json:"over_18"` + IsSelfPost bool `json:"is_self"` + Saved bool `json:"saved"` + Stickied bool `json:"stickied"` } -type RedditVideo struct { - BitrateKbps int `json:"bitrate_kbps"` - DashURL string `json:"dash_url"` - Duration int `json:"duration"` - FallbackURL string `json:"fallback_url"` - Height int `json:"height"` - HlsURL string `json:"hls_url"` - IsGif bool `json:"is_gif"` - ScrubberMediaURL string `json:"scrubber_media_url"` - TranscodingStatus string `json:"transcoding_status"` - Width int `json:"width"` +func (p *Post) String() string { + return fmt.Sprintf( + `{ID: %v, FullID: %v, Created: %v, Edited: %v, Permalink: %v, URL: %v, Title: %v, Body: %v, Likes: %v, Ups: %v, Downs: %v, Score: %v, UpvoteRatio: %v, NumberOfComments: %v, SubredditName: %v, SubredditNamePrefixed: %v, SubredditID: %v, SubredditSubscribers: %v, Author: %v, AuthorID: %v, Spoiler: %v, Locked: %v, NSFW: %v, IsSelfPost: %v, Saved: %v, Stickied: %v}`, + p.ID, + p.FullID, + p.Created, + p.Edited, + p.Permalink, + p.URL, + p.Title, + p.Body, + p.Likes, + p.Ups, + p.Downs, + p.Score, + p.UpvoteRatio, + p.NumberOfComments, + p.SubredditName, + p.SubredditNamePrefixed, + p.SubredditID, + p.SubredditSubscribers, + p.Author, + p.AuthorID, + p.Spoiler, + p.Locked, + p.NSFW, + p.IsSelfPost, + p.Saved, + p.Stickied, + ) } -type ListingMedia struct { - RedditVideo RedditVideo `json:"reddit_video"` +func (p *Post) Soup() soup.Root { + return soup.HTMLParse(html.UnescapeString(p.BodyHTML)) } -type ListingMediaEmbed struct { +// Comment is a comment on a post. +type Comment struct { + ID string `json:"id,omitempty"` + FullID string `json:"name,omitempty"` + Created *Timestamp `json:"created_utc,omitempty"` + Edited *Timestamp `json:"edited,omitempty"` + ParentID string `json:"parent_id,omitempty"` + Permalink string `json:"permalink,omitempty"` + Body string `json:"body,omitempty"` + BodyHTML string `json:"body_html,omitempty"` + Author string `json:"author,omitempty"` + AuthorID string `json:"author_fullname,omitempty"` + AuthorFlairText string `json:"author_flair_text,omitempty"` + AuthorFlairID string `json:"author_flair_template_id,omitempty"` + SubredditName string `json:"subreddit,omitempty"` + SubredditNamePrefixed string `json:"subreddit_name_prefixed,omitempty"` + SubredditID string `json:"subreddit_id,omitempty"` + // Indicates if you've upvote/downvoted (true/false). + // If neither, it will be nil. + Likes *bool `json:"likes"` + Score int `json:"score"` + Controversiality int `json:"controversiality"` + PostID string `json:"link_id,omitempty"` + // This doesn't appear consistently. + PostTitle string `json:"link_title,omitempty"` + // This doesn't appear consistently. + PostPermalink string `json:"link_permalink,omitempty"` + // This doesn't appear consistently. + PostAuthor string `json:"link_author,omitempty"` + // This doesn't appear consistently. + PostNumComments *int `json:"num_comments,omitempty"` + IsSubmitter bool `json:"is_submitter"` + ScoreHidden bool `json:"score_hidden"` + Saved bool `json:"saved"` + Stickied bool `json:"stickied"` + Locked bool `json:"locked"` + CanGild bool `json:"can_gild"` + NSFW bool `json:"over_18"` + Replies Replies `json:"replies"` } -type ListingSecureMedia struct { - RedditVideo RedditVideo `json:"reddit_video"` +func (c *Comment) Soup() soup.Root { + return soup.HTMLParse(html.UnescapeString(c.BodyHTML)) +} + +// HasMore determines whether the comment has more replies to load in its reply tree. +func (c *Comment) HasMore() bool { + return c.Replies.More != nil && len(c.Replies.More.Children) > 0 +} + +// addCommentToReplies traverses the comment tree to find the one +// that the 2nd comment is replying to. It then adds it to its replies. +func (c *Comment) addCommentToReplies(comment *Comment) { + if c.FullID == comment.ParentID { + c.Replies.Comments = append(c.Replies.Comments, comment) + return + } + + for _, reply := range c.Replies.Comments { + reply.addCommentToReplies(comment) + } } -type ListingSecureMediaEmbed struct { +func (c *Comment) addMoreToReplies(more *More) { + if c.FullID == more.ParentID { + c.Replies.More = more + return + } + + for _, reply := range c.Replies.Comments { + reply.addMoreToReplies(more) + } +} + +func (c *Comment) String() string { + return fmt.Sprintf( + `{ID: %v, FullID: %v, Created: %v, Edited: %v, ParentID: %v, Permalink: %v, Body: %v, BodyHTML: %v, Author: %v, AuthorID: %v, AuthorFlairText: %v, AuthorFlairID: %v, SubredditName: %v, SubredditNamePrefixed: %v, SubredditID: %v, Likes: %v, Score: %v, Controversiality: %v, PostID: %v, PostTitle: %v, PostPermalink: %v, PostAuthor: %v, PostNumComments: %v, IsSubmitter: %v, ScoreHidden: %v, Saved: %v, Stickied: %v, Locked: %v, CanGild: %v, NSFW: %v, Replies: %v}`, + c.ID, + c.FullID, + c.Created, + c.Edited, + c.ParentID, + c.Permalink, + c.Body, + c.BodyHTML, + c.Author, + c.AuthorID, + c.AuthorFlairText, + c.AuthorFlairID, + c.SubredditName, + c.SubredditNamePrefixed, + c.SubredditID, + c.Likes, + c.Score, + c.Controversiality, + c.PostID, + c.PostTitle, + c.PostPermalink, + c.PostAuthor, + c.PostNumComments, + c.IsSubmitter, + c.ScoreHidden, + c.Saved, + c.Stickied, + c.Locked, + c.CanGild, + c.NSFW, + c.Replies, + ) +} + +// Replies holds replies to a comment. +// It contains both comments and "more" comments, which are entrypoints to other +// comments that were left out. +type Replies struct { + Comments []*Comment `json:"comments,omitempty"` + More *More `json:"-"` +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (r *Replies) UnmarshalJSON(data []byte) error { + // if a comment has no replies, its "replies" field is set to "" + if string(data) == `""` { + r = nil + return nil + } + + root := new(Thing) + err := json.Unmarshal(data, root) + if err != nil { + return err + } + + listing, _ := root.Data.(*Listing) + + r.Comments = listing.Children.Comments + if len(listing.Children.Mores) > 0 { + r.More = listing.Children.Mores[0] + } + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (r *Replies) MarshalJSON() ([]byte, error) { + if r == nil || len(r.Comments) == 0 { + return []byte(`null`), nil + } + return json.Marshal(r.Comments) } -type ListingData struct { - AllAwardings []interface{} `json:"all_awardings"` - AllowLiveComments bool `json:"allow_live_comments"` - ApprovedAtUtc interface{} `json:"approved_at_utc"` - ApprovedBy interface{} `json:"approved_by"` - Archived bool `json:"archived"` - Author string `json:"author"` - AuthorFlairBackgroundColor interface{} `json:"author_flair_background_color"` - AuthorFlairCSSClass interface{} `json:"author_flair_css_class"` - AuthorFlairRichtext []interface{} `json:"author_flair_richtext"` - AuthorFlairTemplateID interface{} `json:"author_flair_template_id"` - AuthorFlairText interface{} `json:"author_flair_text"` - AuthorFlairTextColor interface{} `json:"author_flair_text_color"` - AuthorFlairType string `json:"author_flair_type"` - AuthorFullname string `json:"author_fullname"` - AuthorIsBlocked bool `json:"author_is_blocked"` - AuthorPatreonFlair bool `json:"author_patreon_flair"` - AuthorPremium bool `json:"author_premium"` - Awarders []interface{} `json:"awarders"` - BannedAtUtc interface{} `json:"banned_at_utc"` - BannedBy interface{} `json:"banned_by"` - CanGild bool `json:"can_gild"` - CanModPost bool `json:"can_mod_post"` - Category interface{} `json:"category"` - Clicked bool `json:"clicked"` - ContentCategories interface{} `json:"content_categories"` - ContestMode bool `json:"contest_mode"` - Created float64 `json:"created"` - CreatedUtc float64 `json:"created_utc"` - DiscussionType interface{} `json:"discussion_type"` - Distinguished interface{} `json:"distinguished"` - Domain string `json:"domain"` - Downs int `json:"downs"` - Edited bool `json:"edited"` - Gilded int `json:"gilded"` - Gildings ListingGildings `json:"gildings"` - Hidden bool `json:"hidden"` - HideScore bool `json:"hide_score"` - ID string `json:"id"` - IsCreatedFromAdsUI bool `json:"is_created_from_ads_ui"` - IsCrosspostable bool `json:"is_crosspostable"` - IsMeta bool `json:"is_meta"` - IsOriginalContent bool `json:"is_original_content"` - IsRedditMediaDomain bool `json:"is_reddit_media_domain"` - IsRobotIndexable bool `json:"is_robot_indexable"` - IsSelf bool `json:"is_self"` - IsVideo bool `json:"is_video"` - Likes interface{} `json:"likes"` - LinkFlairBackgroundColor string `json:"link_flair_background_color"` - LinkFlairCSSClass interface{} `json:"link_flair_css_class"` - LinkFlairRichtext []interface{} `json:"link_flair_richtext"` - LinkFlairText interface{} `json:"link_flair_text"` - LinkFlairTextColor string `json:"link_flair_text_color"` - LinkFlairType string `json:"link_flair_type"` - Locked bool `json:"locked"` - Media ListingMedia `json:"media"` - MediaEmbed ListingMediaEmbed `json:"media_embed"` - MediaOnly bool `json:"media_only"` - ModNote interface{} `json:"mod_note"` - ModReasonBy interface{} `json:"mod_reason_by"` - ModReasonTitle interface{} `json:"mod_reason_title"` - ModReports []interface{} `json:"mod_reports"` - Name string `json:"name"` - NoFollow bool `json:"no_follow"` - NumComments int `json:"num_comments"` - NumCrossposts int `json:"num_crossposts"` - NumReports interface{} `json:"num_reports"` - Over18 bool `json:"over_18"` - ParentWhitelistStatus string `json:"parent_whitelist_status"` - Permalink string `json:"permalink"` - Pinned bool `json:"pinned"` - Pwls int `json:"pwls"` - Quarantine bool `json:"quarantine"` - RemovalReason interface{} `json:"removal_reason"` - RemovedBy interface{} `json:"removed_by"` - RemovedByCategory interface{} `json:"removed_by_category"` - ReportReasons interface{} `json:"report_reasons"` - Saved bool `json:"saved"` - Score int `json:"score"` - SecureMedia ListingSecureMedia `json:"secure_media"` - SecureMediaEmbed ListingSecureMediaEmbed `json:"secure_media_embed"` - Selftext string `json:"selftext"` - SelftextHTML interface{} `json:"selftext_html"` - SendReplies bool `json:"send_replies"` - Spoiler bool `json:"spoiler"` - Stickied bool `json:"stickied"` - Subreddit string `json:"subreddit"` - SubredditID string `json:"subreddit_id"` - SubredditNamePrefixed string `json:"subreddit_name_prefixed"` - SubredditSubscribers int `json:"subreddit_subscribers"` - SubredditType string `json:"subreddit_type"` - SuggestedSort interface{} `json:"suggested_sort"` - Thumbnail string `json:"thumbnail"` - Title string `json:"title"` - TopAwardedType interface{} `json:"top_awarded_type"` - TotalAwardsReceived int `json:"total_awards_received"` - TreatmentTags []interface{} `json:"treatment_tags"` - Ups int `json:"ups"` - UpvoteRatio float64 `json:"upvote_ratio"` - URL string `json:"url"` - URLOverriddenByDest string `json:"url_overridden_by_dest"` - UserReports []interface{} `json:"user_reports"` - ViewCount interface{} `json:"view_count"` - Visited bool `json:"visited"` - WhitelistStatus string `json:"whitelist_status"` - Wls int `json:"wls"` +// More holds information used to retrieve additional comments omitted from a base comment tree. +type More struct { + ID string `json:"id"` + FullID string `json:"name"` + ParentID string `json:"parent_id"` + // Count is the total number of replies to the parent + replies to those replies (recursively). + Count int `json:"count"` + // Depth is the number of comment nodes from the parent down to the furthest comment node. + Depth int `json:"depth"` + Children []string `json:"children"` }