diff --git a/README.md b/README.md index 74051a3493fe..5f1917b67176 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Note that as network usage increases, hardware requirements may change. - CPU: Equivalent of 8 AWS vCPU - RAM: 16 GB - Storage: 200 GB -- OS: Ubuntu 18.04/20.04 or MacOS >= Catalina +- OS: Ubuntu 18.04/20.04 or macOS >= 10.15 (Catalina) - Network: Reliable IPv4 or IPv6 network connection, with an open public port. - Software Dependencies: - [Go](https://golang.org/doc/install) version >= 1.16.8 and set up [`$GOPATH`](https://github.com/golang/go/wiki/SettingGOPATH). diff --git a/RELEASES.md b/RELEASES.md index c7422d73c34b..7c4213cf1f5c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,37 @@ # Release Notes +## [v1.7.2](https://github.com/ava-labs/avalanchego/releases/tag/v1.7.2) + +This version is backwards compatible to [v1.7.0](https://github.com/ava-labs/avalanchego/releases/tag/v1.7.0). It is optional, but encouraged. + +### Coreth + +- Fixed memory leak in the estimate gas API. +- Reduced the default RPC gas limit to 50,000,000 gas. +- Improved RPC logging. +- Removed pre-AP5 legacy code. + +### PlatformVM + +- Optimized validator set change calculations. +- Removed storage of non-decided blocks. +- Simplified error handling. +- Removed pre-AP5 legacy code. + +### Networking + +- Explicitly fail requests with responses that failed to be parsed. +- Removed pre-AP5 legacy code. + +### Configs + +- Introduced the ability for a delayed graceful node shutdown. +- Added the ability to take all configs as environment variables for containerized deployments. + +### Utils + +- Fixed panic bug in logging library when importing from external projects. + ## [v1.7.1](https://github.com/ava-labs/avalanchego/releases/tag/v1.7.1) This update is backwards compatible with [v1.7.0](https://github.com/ava-labs/avalanchego/releases/tag/v1.7.0). Please see the expected update times in the v1.7.0 release. diff --git a/api/auth/auth.go b/api/auth/auth.go index 45eccf513008..c4574e9528b8 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -261,7 +261,7 @@ func (a *auth) CreateHandler() (http.Handler, error) { server.RegisterCodec(codec, "application/json") server.RegisterCodec(codec, "application/json;charset=UTF-8") return server, server.RegisterService( - &service{auth: a}, + &Service{auth: a}, "auth", ) } diff --git a/api/auth/service.go b/api/auth/service.go index ee1b681f6323..d9ed5d84eef6 100644 --- a/api/auth/service.go +++ b/api/auth/service.go @@ -9,8 +9,8 @@ import ( "github.com/ava-labs/avalanchego/api" ) -// service that serves the Auth API functionality. -type service struct { +// Service that serves the Auth API functionality. +type Service struct { auth *auth } @@ -32,7 +32,7 @@ type Token struct { Token string `json:"token"` // The new token. Expires in [TokenLifespan]. } -func (s *service) NewToken(_ *http.Request, args *NewTokenArgs, reply *Token) error { +func (s *Service) NewToken(_ *http.Request, args *NewTokenArgs, reply *Token) error { s.auth.log.Debug("Auth: NewToken called") var err error @@ -45,7 +45,7 @@ type RevokeTokenArgs struct { Token } -func (s *service) RevokeToken(_ *http.Request, args *RevokeTokenArgs, reply *api.SuccessResponse) error { +func (s *Service) RevokeToken(_ *http.Request, args *RevokeTokenArgs, reply *api.SuccessResponse) error { s.auth.log.Debug("Auth: RevokeToken called") reply.Success = true @@ -57,7 +57,7 @@ type ChangePasswordArgs struct { NewPassword string `json:"newPassword"` // New authorization password } -func (s *service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *api.SuccessResponse) error { +func (s *Service) ChangePassword(_ *http.Request, args *ChangePasswordArgs, reply *api.SuccessResponse) error { s.auth.log.Debug("Auth: ChangePassword called") reply.Success = true diff --git a/api/health/health.go b/api/health/health.go new file mode 100644 index 000000000000..a7c4ee59653c --- /dev/null +++ b/api/health/health.go @@ -0,0 +1,92 @@ +// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package health + +import ( + "net/http" + "time" + + stdjson "encoding/json" + + "github.com/gorilla/rpc/v2" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/ava-labs/avalanchego/snow/engine/common" + "github.com/ava-labs/avalanchego/utils/json" + "github.com/ava-labs/avalanchego/utils/logging" + + healthlib "github.com/ava-labs/avalanchego/health" +) + +var _ Health = &health{} + +// Health wraps a [healthlib.Service]. Handler() returns a handler that handles +// incoming HTTP API requests. We have this in a separate package from +// [healthlib] to avoid a circular import where this service imports +// snow/engine/common but that package imports [healthlib].Checkable +type Health interface { + healthlib.Service + + Handler() (*common.HTTPHandler, error) +} + +func New(checkFreq time.Duration, log logging.Logger, namespace string, registry prometheus.Registerer) (Health, error) { + service, err := healthlib.NewService(checkFreq, log, namespace, registry) + return &health{ + Service: service, + log: log, + }, err +} + +type health struct { + healthlib.Service + log logging.Logger +} + +func (h *health) Handler() (*common.HTTPHandler, error) { + newServer := rpc.NewServer() + codec := json.NewCodec() + newServer.RegisterCodec(codec, "application/json") + newServer.RegisterCodec(codec, "application/json;charset=UTF-8") + + // If a GET request is sent, we respond with a 200 if the node is healthy or + // a 503 if the node isn't healthy. + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + newServer.ServeHTTP(w, r) + return + } + + // Make sure the content type is set before writing the header. + w.Header().Set("Content-Type", "application/json") + + checks, healthy := h.Results() + if !healthy { + // If a health check has failed, we should return a 503. + w.WriteHeader(http.StatusServiceUnavailable) + } + // The encoder will call write on the writer, which will write the + // header with a 200. + err := stdjson.NewEncoder(w).Encode(APIHealthServerReply{ + Checks: checks, + Healthy: healthy, + }) + if err != nil { + h.log.Debug("failed to encode the health check response due to %s", err) + } + }) + + err := newServer.RegisterService( + &Service{ + log: h.log, + health: h.Service, + }, + "health", + ) + return &common.HTTPHandler{ + LockOptions: common.NoLock, + Handler: handler, + }, err +} diff --git a/api/health/no_op.go b/api/health/no_op.go new file mode 100644 index 000000000000..69203e63c11b --- /dev/null +++ b/api/health/no_op.go @@ -0,0 +1,30 @@ +// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package health + +import ( + healthback "github.com/AppsFlyer/go-sundheit" + + "github.com/ava-labs/avalanchego/snow/engine/common" + + healthlib "github.com/ava-labs/avalanchego/health" +) + +var _ Health = &noOp{} + +type noOp struct{} + +// NewNoOp returns a noop version of the health interface that does nothing for +// when the Health API is disabled +func NewNoOp() Health { return &noOp{} } + +func (n *noOp) Results() (map[string]healthback.Result, bool) { + return map[string]healthback.Result{}, true +} + +func (n *noOp) RegisterCheck(string, healthlib.Check) error { return nil } + +func (n *noOp) RegisterMonotonicCheck(string, healthlib.Check) error { return nil } + +func (n *noOp) Handler() (*common.HTTPHandler, error) { return nil, nil } diff --git a/api/health/service.go b/api/health/service.go index 14b6dec0e0ea..78320f5a5a36 100644 --- a/api/health/service.go +++ b/api/health/service.go @@ -5,87 +5,19 @@ package health import ( "net/http" - "time" stdjson "encoding/json" - health "github.com/AppsFlyer/go-sundheit" + healthback "github.com/AppsFlyer/go-sundheit" - "github.com/gorilla/rpc/v2" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/utils/json" "github.com/ava-labs/avalanchego/utils/logging" healthlib "github.com/ava-labs/avalanchego/health" ) -var _ Service = &apiServer{} - -// Service wraps a [healthlib.Service]. Handler() returns a handler -// that handles incoming HTTP API requests. We have this in a separate -// package from [healthlib] to avoid a circular import where this service -// imports snow/engine/common but that package imports [healthlib].Checkable -type Service interface { - healthlib.Service - Handler() (*common.HTTPHandler, error) -} - -func NewService(checkFreq time.Duration, log logging.Logger, namespace string, registry prometheus.Registerer) (Service, error) { - service, err := healthlib.NewService(checkFreq, log, namespace, registry) - if err != nil { - return nil, err - } - return &apiServer{ - Service: service, - log: log, - }, nil -} - -// APIServer serves HTTP for a health service -type apiServer struct { - healthlib.Service - log logging.Logger -} - -func (as *apiServer) Handler() (*common.HTTPHandler, error) { - newServer := rpc.NewServer() - codec := json.NewCodec() - newServer.RegisterCodec(codec, "application/json") - newServer.RegisterCodec(codec, "application/json;charset=UTF-8") - if err := newServer.RegisterService(as, "health"); err != nil { - return nil, err - } - - // If a GET request is sent, we respond with a 200 if the node is healthy or - // a 503 if the node isn't healthy. - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - newServer.ServeHTTP(w, r) - return - } - - // Make sure the content type is set before writing the header. - w.Header().Set("Content-Type", "application/json") - - checks, healthy := as.Results() - if !healthy { - // If a health check has failed, we should return a 503. - w.WriteHeader(http.StatusServiceUnavailable) - } - // The encoder will call write on the writer, which will write the - // header with a 200. - err := stdjson.NewEncoder(w).Encode(APIHealthServerReply{ - Checks: checks, - Healthy: healthy, - }) - if err != nil { - as.log.Debug("failed to encode the health check response due to %s", err) - } - }) - return &common.HTTPHandler{LockOptions: common.NoLock, Handler: handler}, nil +type Service struct { + log logging.Logger + health healthlib.Service } // APIHealthArgs are the arguments for Health @@ -93,46 +25,18 @@ type APIHealthArgs struct{} // APIHealthReply is the response for Health type APIHealthServerReply struct { - Checks map[string]health.Result `json:"checks"` - Healthy bool `json:"healthy"` + Checks map[string]healthback.Result `json:"checks"` + Healthy bool `json:"healthy"` } // Health returns a summation of the health of the node -func (as *apiServer) Health(_ *http.Request, _ *APIHealthArgs, reply *APIHealthServerReply) error { - as.log.Debug("Health.health called") - reply.Checks, reply.Healthy = as.Results() +func (s *Service) Health(_ *http.Request, _ *APIHealthArgs, reply *APIHealthServerReply) error { + s.log.Debug("Health.health called") + reply.Checks, reply.Healthy = s.health.Results() if reply.Healthy { return nil } replyStr, err := stdjson.Marshal(reply.Checks) - as.log.Warn("Health.health is returning an error: %s", string(replyStr)) + s.log.Warn("Health.health is returning an error: %s", string(replyStr)) return err } - -type noOp struct{} - -// NewNoOpService returns a NoOp version of health check -// for when the Health API is disabled -func NewNoOpService() Service { - return &noOp{} -} - -// RegisterCheck implements the Service interface -func (n *noOp) Results() (map[string]health.Result, bool) { - return map[string]health.Result{}, true -} - -// RegisterCheck implements the Service interface -func (n *noOp) Handler() (_ *common.HTTPHandler, _ error) { - return nil, nil -} - -// RegisterCheckFn implements the Service interface -func (n *noOp) RegisterCheck(_ string, _ healthlib.Check) error { - return nil -} - -// RegisterMonotonicCheckFn implements the Service interface -func (n *noOp) RegisterMonotonicCheck(_ string, _ healthlib.Check) error { - return nil -} diff --git a/api/server/server.go b/api/server/server.go index 7b1e7dfd1135..fcff57da4418 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -5,6 +5,7 @@ package server import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -28,14 +29,12 @@ import ( "github.com/ava-labs/avalanchego/utils/logging" ) -const ( - baseURL = "/ext" - serverShutdownTimeout = 10 * time.Second -) +const baseURL = "/ext" var ( - errUnknownLockOption = errors.New("invalid lock options") - _ RouteAdder = &Server{} + errUnknownLockOption = errors.New("invalid lock options") + + _ RouteAdder = &Server{} ) type RouteAdder interface { @@ -44,21 +43,21 @@ type RouteAdder interface { // Server maintains the HTTP router type Server struct { - // This node's ID - nodeID ids.ShortID // log this server writes to log logging.Logger // generates new logs for chains to write to factory logging.Factory - // Maps endpoints to handlers - router *router // points the the router handlers handler http.Handler // Listens for HTTP traffic on this address listenHost string listenPort uint16 - // http server + shutdownTimeout time.Duration + + // Maps endpoints to handlers + router *router + srv *http.Server } @@ -69,6 +68,7 @@ func (s *Server) Initialize( host string, port uint16, allowedOrigins []string, + shutdownTimeout time.Duration, nodeID ids.ShortID, wrappers ...Wrapper, ) { @@ -76,8 +76,8 @@ func (s *Server) Initialize( s.factory = factory s.listenHost = host s.listenPort = port + s.shutdownTimeout = shutdownTimeout s.router = newRouter() - s.nodeID = nodeID s.log.Info("API created with allowed origins: %v", allowedOrigins) @@ -119,9 +119,18 @@ func (s *Server) Dispatch() error { } // DispatchTLS starts the API server with the provided TLS certificate -func (s *Server) DispatchTLS(certFile, keyFile string) error { +func (s *Server) DispatchTLS(certBytes, keyBytes []byte) error { listenAddress := fmt.Sprintf("%s:%d", s.listenHost, s.listenPort) - listener, err := net.Listen("tcp", listenAddress) + cert, err := tls.X509KeyPair(certBytes, keyBytes) + if err != nil { + return err + } + config := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + + listener, err := tls.Listen("tcp", listenAddress, config) if err != nil { return err } @@ -133,7 +142,8 @@ func (s *Server) DispatchTLS(certFile, keyFile string) error { s.log.Info("HTTPS API server listening on \"%s:%d\"", s.listenHost, ipDesc.Port) } - return http.ServeTLS(listener, s.handler, certFile, keyFile) + s.srv = &http.Server{Addr: listenAddress, Handler: s.handler} + return s.srv.Serve(listener) } // RegisterChain registers the API endpoints associated with this chain. That is, @@ -274,39 +284,17 @@ func (s *Server) AddAliasesWithReadLock(endpoint string, aliases ...string) erro return s.AddAliases(endpoint, aliases...) } -func (s *Server) Call( - writer http.ResponseWriter, - method, - base, - endpoint string, - body io.Reader, - headers map[string]string, -) error { - url := fmt.Sprintf("%s/vm/%s", baseURL, base) - - handler, err := s.router.GetHandler(url, endpoint) - if err != nil { - return err - } - req, err := http.NewRequest("POST", "*", body) - if err != nil { - return err - } - for key, value := range headers { - req.Header.Set(key, value) - } - - handler.ServeHTTP(writer, req) - - return nil -} - // Shutdown this server func (s *Server) Shutdown() error { if s.srv == nil { return nil } - ctx, cancel := context.WithTimeout(context.Background(), serverShutdownTimeout) - defer cancel() - return s.srv.Shutdown(ctx) + + ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout) + err := s.srv.Shutdown(ctx) + cancel() + + // If shutdown times out, make sure the server is still shutdown. + _ = s.srv.Close() + return err } diff --git a/api/server/server_test.go b/api/server/server_test.go deleted file mode 100644 index dd02ab067d26..000000000000 --- a/api/server/server_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package server - -import ( - "bytes" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/gorilla/rpc/v2" - "github.com/gorilla/rpc/v2/json2" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/snow/engine/common" - "github.com/ava-labs/avalanchego/utils/logging" -) - -type Service struct{ called bool } - -type Args struct{} - -type Reply struct{} - -func (s *Service) Call(_ *http.Request, args *Args, reply *Reply) error { - s.called = true - return nil -} - -func TestCall(t *testing.T) { - s := Server{} - s.Initialize( - logging.NoLog{}, - logging.NoFactory{}, - "localhost", - 8080, - []string{"*"}, - ids.GenerateTestShortID(), - ) - - serv := &Service{} - newServer := rpc.NewServer() - newServer.RegisterCodec(json2.NewCodec(), "application/json") - newServer.RegisterCodec(json2.NewCodec(), "application/json;charset=UTF-8") - if err := newServer.RegisterService(serv, "test"); err != nil { - t.Fatal(err) - } - - err := s.AddRoute( - &common.HTTPHandler{Handler: newServer}, - new(sync.RWMutex), - "vm/lol", - "", - logging.NoLog{}, - ) - if err != nil { - t.Fatal(err) - } - - buf, err := json2.EncodeClientRequest("test.Call", &Args{}) - if err != nil { - t.Fatal(err) - } - - writer := httptest.NewRecorder() - body := bytes.NewBuffer(buf) - headers := map[string]string{ - "Content-Type": "application/json", - } - err = s.Call(writer, "POST", "lol", "", body, headers) - if err != nil { - t.Fatal(err) - } - - if !serv.called { - t.Fatalf("Should have been called") - } -} diff --git a/chains/manager.go b/chains/manager.go index 32f8364276b2..8876fa5f3204 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -148,7 +148,7 @@ type ManagerConfig struct { CriticalChains ids.Set // Chains that can't exit gracefully WhitelistedSubnets ids.Set // Subnets to validate TimeoutManager *timeout.Manager // Manages request timeouts when sending messages to other validators - HealthService health.Service + HealthService health.Health RetryBootstrap bool // Should Bootstrap be retried RetryBootstrapWarnFrequency int // Max number of times to retry bootstrap before warning the node operator SubnetConfigs map[ids.ID]SubnetConfig // ID -> SubnetConfig diff --git a/config/config.go b/config/config.go index 1633e853d691..a1dd288e408d 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ package config import ( "crypto/tls" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -66,6 +67,8 @@ var ( errMinStakeDurationAboveMax = errors.New("max stake duration can't be less than min stake duration") errStakeMintingPeriodBelowMin = errors.New("stake minting period can't be less than max stake duration") errCannotWhitelistPrimaryNetwork = errors.New("cannot whitelist primary network") + errStakingKeyContentUnset = fmt.Errorf("%s key not set but %s set", StakingKeyContentKey, StakingCertContentKey) + errStakingCertContentUnset = fmt.Errorf("%s key set but %s not set", StakingKeyContentKey, StakingCertContentKey) ) func GetRunnerConfig(v *viper.Viper) (runner.Config, error) { @@ -160,12 +163,18 @@ func getAPIAuthConfig(v *viper.Viper) (node.APIAuthConfig, error) { if !config.APIRequireAuthToken { return config, nil } - passwordFilePath := v.GetString(APIAuthPasswordFileKey) - pwBytes, err := ioutil.ReadFile(passwordFilePath) - if err != nil { - return node.APIAuthConfig{}, fmt.Errorf("API auth password file %q failed to be read: %w", passwordFilePath, err) + + if v.IsSet(APIAuthPasswordKey) { + config.APIAuthPassword = v.GetString(APIAuthPasswordKey) + } else { + passwordFilePath := v.GetString(APIAuthPasswordFileKey) // picks flag value or default + passwordBytes, err := ioutil.ReadFile(passwordFilePath) + if err != nil { + return node.APIAuthConfig{}, fmt.Errorf("API auth password file %q failed to be read: %w", passwordFilePath, err) + } + config.APIAuthPassword = strings.TrimSpace(string(passwordBytes)) } - config.APIAuthPassword = strings.TrimSpace(string(pwBytes)) + if !password.SufficientlyStrong(config.APIAuthPassword, password.OK) { return node.APIAuthConfig{}, errAuthPasswordTooWeak } @@ -187,6 +196,39 @@ func getIPCConfig(v *viper.Viper) node.IPCConfig { } func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) { + var ( + httpsKey []byte + httpsCert []byte + err error + ) + switch { + case v.IsSet(HTTPSKeyContentKey): + rawContent := v.GetString(HTTPSKeyContentKey) + httpsKey, err = base64.StdEncoding.DecodeString(rawContent) + if err != nil { + return node.HTTPConfig{}, fmt.Errorf("unable to decode base64 content: %w", err) + } + case v.IsSet(HTTPSKeyFileKey): + httpsKeyFilepath := os.ExpandEnv(v.GetString(HTTPSKeyFileKey)) + if httpsKey, err = ioutil.ReadFile(filepath.Clean(httpsKeyFilepath)); err != nil { + return node.HTTPConfig{}, err + } + } + + switch { + case v.IsSet(HTTPSCertContentKey): + rawContent := v.GetString(HTTPSCertContentKey) + httpsCert, err = base64.StdEncoding.DecodeString(rawContent) + if err != nil { + return node.HTTPConfig{}, fmt.Errorf("unable to decode base64 content: %w", err) + } + case v.IsSet(HTTPSCertFileKey): + httpsKeyFilepath := os.ExpandEnv(v.GetString(HTTPSCertFileKey)) + if httpsCert, err = ioutil.ReadFile(filepath.Clean(httpsKeyFilepath)); err != nil { + return node.HTTPConfig{}, err + } + } + config := node.HTTPConfig{ APIConfig: node.APIConfig{ APIIndexerConfig: node.APIIndexerConfig{ @@ -202,11 +244,14 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) { HTTPHost: v.GetString(HTTPHostKey), HTTPPort: uint16(v.GetUint(HTTPPortKey)), HTTPSEnabled: v.GetBool(HTTPSEnabledKey), - HTTPSKeyFile: os.ExpandEnv(v.GetString(HTTPSKeyFileKey)), - HTTPSCertFile: os.ExpandEnv(v.GetString(HTTPSCertFileKey)), + HTTPSKey: httpsKey, + HTTPSCert: httpsCert, APIAllowedOrigins: v.GetStringSlice(HTTPAllowedOrigins), + + ShutdownTimeout: v.GetDuration(HTTPShutdownTimeoutKey), + ShutdownWait: v.GetDuration(HTTPShutdownWaitKey), } - var err error + config.APIAuthConfig, err = getAPIAuthConfig(v) if err != nil { return node.HTTPConfig{}, err @@ -498,16 +543,28 @@ func getProfilerConfig(v *viper.Viper) (profiler.Config, error) { return config, nil } -func getStakingTLSCert(v *viper.Viper) (tls.Certificate, error) { - if v.GetBool(StakingEphemeralCertEnabledKey) { - // Use an ephemeral staking key/cert - cert, err := staking.NewTLSCert() - if err != nil { - return tls.Certificate{}, fmt.Errorf("couldn't generate ephemeral staking key/cert: %w", err) - } - return *cert, nil +func getStakingTLSCertFromFlag(v *viper.Viper) (tls.Certificate, error) { + stakingKeyRawContent := v.GetString(StakingKeyContentKey) + stakingKeyContent, err := base64.StdEncoding.DecodeString(stakingKeyRawContent) + if err != nil { + return tls.Certificate{}, fmt.Errorf("unable to decode base64 content: %w", err) + } + + stakingCertRawContent := v.GetString(StakingCertContentKey) + stakingCertContent, err := base64.StdEncoding.DecodeString(stakingCertRawContent) + if err != nil { + return tls.Certificate{}, fmt.Errorf("unable to decode base64 content: %w", err) + } + + cert, err := staking.LoadTLSCertFromBytes(stakingKeyContent, stakingCertContent) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed creating cert: %w", err) } + return *cert, nil +} + +func getStakingTLSCertFromFile(v *viper.Viper) (tls.Certificate, error) { // Parse the staking key/cert paths and expand environment variables stakingKeyPath := os.ExpandEnv(v.GetString(StakingKeyPathKey)) stakingCertPath := os.ExpandEnv(v.GetString(StakingCertPathKey)) @@ -527,13 +584,35 @@ func getStakingTLSCert(v *viper.Viper) (tls.Certificate, error) { } // Load and parse the staking key/cert - cert, err := staking.LoadTLSCert(stakingKeyPath, stakingCertPath) + cert, err := staking.LoadTLSCertFromFiles(stakingKeyPath, stakingCertPath) if err != nil { return tls.Certificate{}, fmt.Errorf("couldn't read staking certificate: %w", err) } return *cert, nil } +func getStakingTLSCert(v *viper.Viper) (tls.Certificate, error) { + if v.GetBool(StakingEphemeralCertEnabledKey) { + // Use an ephemeral staking key/cert + cert, err := staking.NewTLSCert() + if err != nil { + return tls.Certificate{}, fmt.Errorf("couldn't generate ephemeral staking key/cert: %w", err) + } + return *cert, nil + } + + switch { + case v.IsSet(StakingKeyContentKey) && !v.IsSet(StakingCertContentKey): + return tls.Certificate{}, errStakingCertContentUnset + case !v.IsSet(StakingKeyContentKey) && v.IsSet(StakingCertContentKey): + return tls.Certificate{}, errStakingKeyContentUnset + case v.IsSet(StakingKeyContentKey) && v.IsSet(StakingCertContentKey): + return getStakingTLSCertFromFlag(v) + default: + return getStakingTLSCertFromFile(v) + } +} + func getStakingConfig(v *viper.Viper, networkID uint32) (node.StakingConfig, error) { config := node.StakingConfig{ EnableStaking: v.GetBool(StakingEnabledKey), @@ -591,6 +670,24 @@ func getTxFeeConfig(v *viper.Viper, networkID uint32) genesis.TxFeeConfig { return genesis.GetTxFeeConfig(networkID) } +func getGenesisData(v *viper.Viper, networkID uint32) ([]byte, ids.ID, error) { + // try first loading genesis content directly from flag/env-var + if v.IsSet(GenesisConfigContentKey) { + genesisData := v.GetString(GenesisConfigContentKey) + return genesis.FromFlag(networkID, genesisData) + } + + // if content is not specified go for the file + if v.IsSet(GenesisConfigFileKey) { + genesisFileName := os.ExpandEnv(v.GetString(GenesisConfigFileKey)) + return genesis.FromFile(networkID, genesisFileName) + } + + // finally if file is not specified/readable go for the predefined config + config := genesis.GetConfig(networkID) + return genesis.FromConfig(config) +} + func getWhitelistedSubnets(v *viper.Viper) (ids.Set, error) { whitelistedSubnetIDs := ids.Set{} for _, subnet := range strings.Split(v.GetString(WhitelistedSubnetsKey), ",") { @@ -610,10 +707,18 @@ func getWhitelistedSubnets(v *viper.Viper) (ids.Set, error) { } func getDatabaseConfig(v *viper.Viper, networkID uint32) (node.DatabaseConfig, error) { - var configBytes []byte - if v.IsSet(DBConfigFileKey) { + var ( + configBytes []byte + err error + ) + if v.IsSet(DBConfigContentKey) { + dbConfigContent := v.GetString(DBConfigContentKey) + configBytes, err = base64.StdEncoding.DecodeString(dbConfigContent) + if err != nil { + return node.DatabaseConfig{}, fmt.Errorf("unable to decode base64 content: %w", err) + } + } else if v.IsSet(DBConfigFileKey) { path := os.ExpandEnv(v.GetString(DBConfigFileKey)) - var err error configBytes, err = ioutil.ReadFile(path) if err != nil { return node.DatabaseConfig{}, err @@ -631,22 +736,32 @@ func getDatabaseConfig(v *viper.Viper, networkID uint32) (node.DatabaseConfig, e } func getVMAliases(v *viper.Viper) (map[ids.ID][]string, error) { - aliasFilePath := filepath.Clean(v.GetString(VMAliasesFileKey)) - exists, err := storage.FileExists(aliasFilePath) - if err != nil { - return nil, err - } + var fileBytes []byte + if v.IsSet(VMAliasesContentKey) { + var err error + aliasFlagContent := v.GetString(VMAliasesContentKey) + fileBytes, err = base64.StdEncoding.DecodeString(aliasFlagContent) + if err != nil { + return nil, fmt.Errorf("unable to decode base64 content: %w", err) + } + } else { + aliasFilePath := filepath.Clean(v.GetString(VMAliasesFileKey)) + exists, err := storage.FileExists(aliasFilePath) + if err != nil { + return nil, err + } - if !exists { - if v.IsSet(VMAliasesFileKey) { - return nil, fmt.Errorf("vm alias file does not exist in %v", aliasFilePath) + if !exists { + if v.IsSet(VMAliasesFileKey) { + return nil, fmt.Errorf("vm alias file does not exist in %v", aliasFilePath) + } + return nil, nil } - return nil, nil - } - fileBytes, err := ioutil.ReadFile(aliasFilePath) - if err != nil { - return nil, err + fileBytes, err = ioutil.ReadFile(aliasFilePath) + if err != nil { + return nil, err + } } vmAliasMap := make(map[ids.ID][]string) @@ -691,8 +806,21 @@ func getPathFromDirKey(v *viper.Viper, configKey string) (string, error) { return "", nil } -// getChainConfigs reads & puts chainConfigs to node config -func getChainConfigs(v *viper.Viper) (map[string]chains.ChainConfig, error) { +func getChainConfigsFromFlag(v *viper.Viper) (map[string]chains.ChainConfig, error) { + chainConfigContentB64 := v.GetString(ChainConfigContentKey) + chainConfigContent, err := base64.StdEncoding.DecodeString(chainConfigContentB64) + if err != nil { + return nil, fmt.Errorf("unable to decode base64 content: %w", err) + } + + chainConfigs := make(map[string]chains.ChainConfig) + if err := json.Unmarshal(chainConfigContent, &chainConfigs); err != nil { + return nil, fmt.Errorf("could not unmarshal JSON: %w", err) + } + return chainConfigs, nil +} + +func getChainConfigsFromDir(v *viper.Viper) (map[string]chains.ChainConfig, error) { chainConfigPath, err := getPathFromDirKey(v, ChainConfigDirKey) if err != nil { return nil, err @@ -709,6 +837,14 @@ func getChainConfigs(v *viper.Viper) (map[string]chains.ChainConfig, error) { return chainConfigs, nil } +// getChainConfigs reads & puts chainConfigs to node config +func getChainConfigs(v *viper.Viper) (map[string]chains.ChainConfig, error) { + if v.IsSet(ChainConfigContentKey) { + return getChainConfigsFromFlag(v) + } + return getChainConfigsFromDir(v) +} + // ReadsChainConfigs reads chain config files from static directories and returns map with contents, // if successful. func readChainConfigPath(chainConfigPath string) (map[string]chains.ChainConfig, error) { @@ -747,8 +883,34 @@ func readChainConfigPath(chainConfigPath string) (map[string]chains.ChainConfig, return chainConfigMap, nil } +func getSubnetConfigsFromFlags(v *viper.Viper, subnetIDs []ids.ID) (map[ids.ID]chains.SubnetConfig, error) { + subnetConfigContentB64 := v.GetString(SubnetConfigContentKey) + subnetConfigContent, err := base64.StdEncoding.DecodeString(subnetConfigContentB64) + if err != nil { + return nil, fmt.Errorf("unable to decode base64 content: %w", err) + } + + // Note: no default values are loaded here. All Subnet parameters must be + // explicitly defined. + subnetConfigs := make(map[ids.ID]chains.SubnetConfig, len(subnetIDs)) + if err := json.Unmarshal(subnetConfigContent, &subnetConfigs); err != nil { + return nil, fmt.Errorf("could not unmarshal JSON: %w", err) + } + + res := make(map[ids.ID]chains.SubnetConfig) + for _, subnetID := range subnetIDs { + if subnetConfig, ok := subnetConfigs[subnetID]; ok { + if err := subnetConfig.ConsensusParameters.Valid(); err != nil { + return nil, err + } + res[subnetID] = subnetConfig + } + } + return res, nil +} + // getSubnetConfigs reads SubnetConfigs to node config map -func getSubnetConfigs(v *viper.Viper, subnetIDs []ids.ID) (map[ids.ID]chains.SubnetConfig, error) { +func getSubnetConfigsFromDir(v *viper.Viper, subnetIDs []ids.ID) (map[ids.ID]chains.SubnetConfig, error) { subnetConfigPath, err := getPathFromDirKey(v, SubnetConfigDirKey) if err != nil { return nil, err @@ -765,6 +927,13 @@ func getSubnetConfigs(v *viper.Viper, subnetIDs []ids.ID) (map[ids.ID]chains.Sub return subnetConfigs, nil } +func getSubnetConfigs(v *viper.Viper, subnetIDs []ids.ID) (map[ids.ID]chains.SubnetConfig, error) { + if v.IsSet(SubnetConfigContentKey) { + return getSubnetConfigsFromFlags(v, subnetIDs) + } + return getSubnetConfigsFromDir(v, subnetIDs) +} + // readSubnetConfigs reads subnet config files from a path and given subnetIDs and returns a map. func readSubnetConfigs(subnetConfigPath string, subnetIDs []ids.ID, defaultSubnetConfig chains.SubnetConfig) (map[ids.ID]chains.SubnetConfig, error) { subnetConfigs := make(map[ids.ID]chains.SubnetConfig) @@ -921,10 +1090,8 @@ func GetNodeConfig(v *viper.Viper, buildDir string) (node.Config, error) { nodeConfig.TxFeeConfig = getTxFeeConfig(v, nodeConfig.NetworkID) // Genesis Data - nodeConfig.GenesisBytes, nodeConfig.AvaxAssetID, err = genesis.Genesis( - nodeConfig.NetworkID, - os.ExpandEnv(v.GetString(GenesisConfigFileKey)), - ) + + nodeConfig.GenesisBytes, nodeConfig.AvaxAssetID, err = getGenesisData(v, nodeConfig.NetworkID) if err != nil { return node.Config{}, fmt.Errorf("unable to load genesis file: %w", err) } diff --git a/config/config_test.go b/config/config_test.go index 5007e2093da3..37c44fd5e91d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,6 +4,8 @@ package config import ( + "encoding/base64" + "encoding/json" "fmt" "io/ioutil" "log" @@ -18,9 +20,11 @@ import ( "github.com/ava-labs/avalanchego/chains" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/avalanche" + "github.com/ava-labs/avalanchego/snow/consensus/snowball" ) -func TestSetChainConfigs(t *testing.T) { +func TestGetChainConfigsFromFiles(t *testing.T) { tests := map[string]struct { configs map[string]string upgrades map[string]string @@ -96,7 +100,7 @@ func TestSetChainConfigs(t *testing.T) { } } -func TestSetChainConfigsDirNotExist(t *testing.T) { +func TestGetChainConfigsDirNotExist(t *testing.T) { tests := map[string]struct { structure string file map[string]string @@ -178,7 +182,84 @@ func TestSetChainConfigDefaultDir(t *testing.T) { assert.Equal(expected, chainConfigs) } -func TestGetVMAliases(t *testing.T) { +func TestGetChainConfigsFromFlags(t *testing.T) { + tests := map[string]struct { + fullConfigs map[string]chains.ChainConfig + errMessage string + expected map[string]chains.ChainConfig + }{ + "no chain configs": { + fullConfigs: map[string]chains.ChainConfig{}, + expected: map[string]chains.ChainConfig{}, + }, + "valid chain-id": { + fullConfigs: func() map[string]chains.ChainConfig { + m := map[string]chains.ChainConfig{} + id1, err := ids.FromString("yH8D7ThNJkxmtkuv2jgBa4P1Rn3Qpr4pPr7QYNfcdoS6k6HWp") + assert.NoError(t, err) + m[id1.String()] = chains.ChainConfig{Config: []byte("hello"), Upgrade: []byte("helloUpgrades")} + + id2, err := ids.FromString("2JVSBoinj9C2J33VntvzYtVJNZdN2NKiwwKjcumHUWEb5DbBrm") + assert.NoError(t, err) + m[id2.String()] = chains.ChainConfig{Config: []byte("world"), Upgrade: []byte(nil)} + + return m + }(), + expected: func() map[string]chains.ChainConfig { + m := map[string]chains.ChainConfig{} + id1, err := ids.FromString("yH8D7ThNJkxmtkuv2jgBa4P1Rn3Qpr4pPr7QYNfcdoS6k6HWp") + assert.NoError(t, err) + m[id1.String()] = chains.ChainConfig{Config: []byte("hello"), Upgrade: []byte("helloUpgrades")} + + id2, err := ids.FromString("2JVSBoinj9C2J33VntvzYtVJNZdN2NKiwwKjcumHUWEb5DbBrm") + assert.NoError(t, err) + m[id2.String()] = chains.ChainConfig{Config: []byte("world"), Upgrade: []byte(nil)} + + return m + }(), + }, + "valid alias": { + fullConfigs: map[string]chains.ChainConfig{ + "C": {Config: []byte("hello"), Upgrade: []byte("upgradess")}, + "X": {Config: []byte("world"), Upgrade: []byte(nil)}, + }, + expected: func() map[string]chains.ChainConfig { + m := map[string]chains.ChainConfig{} + m["C"] = chains.ChainConfig{Config: []byte("hello"), Upgrade: []byte("upgradess")} + m["X"] = chains.ChainConfig{Config: []byte("world"), Upgrade: []byte(nil)} + + return m + }(), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + jsonMaps, err := json.Marshal(test.fullConfigs) + assert.NoError(err) + encodedFileContent := base64.StdEncoding.EncodeToString(jsonMaps) + + // build viper config + v := setupViperFlags() + v.Set(ChainConfigContentKey, encodedFileContent) + + // Parse config + chainConfigs, err := getChainConfigs(v) + if len(test.errMessage) > 0 { + assert.Error(err) + if err != nil { + assert.Contains(err.Error(), test.errMessage) + } + } else { + assert.NoError(err) + } + assert.Equal(test.expected, chainConfigs) + }) + } +} + +func TestGetVMAliasesFromFile(t *testing.T) { tests := map[string]struct { givenJSON string expected map[ids.ID][]string @@ -225,6 +306,53 @@ func TestGetVMAliases(t *testing.T) { } } +func TestGetVMAliasesFromFlag(t *testing.T) { + tests := map[string]struct { + givenJSON string + expected map[ids.ID][]string + errMessage string + }{ + "wrong vm id": { + givenJSON: `{"wrongVmId": ["vm1","vm2"]}`, + expected: nil, + errMessage: "problem unmarshaling vmAliases", + }, + "vm id": { + givenJSON: `{"2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i": ["vm1","vm2"], + "Gmt4fuNsGJAd2PX86LBvycGaBpgCYKbuULdCLZs3SEs1Jx1LU": ["vm3", "vm4"] }`, + expected: func() map[ids.ID][]string { + m := map[ids.ID][]string{} + id1, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + id2, _ := ids.FromString("Gmt4fuNsGJAd2PX86LBvycGaBpgCYKbuULdCLZs3SEs1Jx1LU") + m[id1] = []string{"vm1", "vm2"} + m[id2] = []string{"vm3", "vm4"} + return m + }(), + errMessage: "", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + encodedFileContent := base64.StdEncoding.EncodeToString([]byte(test.givenJSON)) + + // build viper config + v := setupViperFlags() + v.Set(VMAliasesContentKey, encodedFileContent) + + vmAliases, err := getVMAliases(v) + if len(test.errMessage) > 0 { + assert.Error(err) + assert.Contains(err.Error(), test.errMessage) + } else { + assert.NoError(err) + assert.Equal(test.expected, vmAliases) + } + }) + } +} + func TestGetVMAliasesDefaultDir(t *testing.T) { assert := assert.New(t) root := t.TempDir() @@ -267,7 +395,7 @@ func TestGetVMAliasesDirNotExists(t *testing.T) { assert.NoError(err) } -func TestGetSubnetConfigs(t *testing.T) { +func TestGetSubnetConfigsFromFile(t *testing.T) { tests := map[string]struct { givenJSON string testF func(*assert.Assertions, map[ids.ID]chains.SubnetConfig) @@ -345,6 +473,136 @@ func TestGetSubnetConfigs(t *testing.T) { } } +func TestGetSubnetConfigsFromFlags(t *testing.T) { + tests := map[string]struct { + cfgsMap map[ids.ID]chains.SubnetConfig + testF func(*assert.Assertions, map[ids.ID]chains.SubnetConfig) + errMessage string + }{ + "no configs": { + cfgsMap: func() map[ids.ID]chains.SubnetConfig { + res := make(map[ids.ID]chains.SubnetConfig) + return res + }(), + testF: func(assert *assert.Assertions, given map[ids.ID]chains.SubnetConfig) { + assert.Empty(given) + }, + errMessage: "", + }, + "entry with no config": { + cfgsMap: func() map[ids.ID]chains.SubnetConfig { + res := make(map[ids.ID]chains.SubnetConfig) + id, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + res[id] = chains.SubnetConfig{} + return res + }(), + testF: func(assert *assert.Assertions, given map[ids.ID]chains.SubnetConfig) { + assert.True(len(given) == 1) + id, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + _, ok := given[id] + assert.True(ok) + }, + errMessage: "Fails the condition that: 1 < Parents", + }, + "subnet is not whitelisted": { + cfgsMap: func() map[ids.ID]chains.SubnetConfig { + res := make(map[ids.ID]chains.SubnetConfig) + id, _ := ids.FromString("Gmt4fuNsGJAd2PX86LBvycGaBpgCYKbuULdCLZs3SEs1Jx1LU") + res[id] = chains.SubnetConfig{ValidatorOnly: true} + return res + }(), + testF: func(assert *assert.Assertions, given map[ids.ID]chains.SubnetConfig) { + assert.Empty(given) + }, + }, + "invalid consensus parameters": { + cfgsMap: func() map[ids.ID]chains.SubnetConfig { + res := make(map[ids.ID]chains.SubnetConfig) + id, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + res[id] = chains.SubnetConfig{ + ConsensusParameters: avalanche.Parameters{ + Parents: 2, + BatchSize: 1, + Parameters: snowball.Parameters{ + K: 111, + Alpha: 1234, + BetaVirtuous: 1, + }, + }, + } + return res + }(), + testF: func(assert *assert.Assertions, given map[ids.ID]chains.SubnetConfig) { + assert.Empty(given) + }, + errMessage: "fails the condition that: alpha <= k", + }, + "correct config": { + cfgsMap: func() map[ids.ID]chains.SubnetConfig { + res := make(map[ids.ID]chains.SubnetConfig) + id, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + res[id] = chains.SubnetConfig{ + ValidatorOnly: true, + ConsensusParameters: avalanche.Parameters{ + Parents: 111, + BatchSize: 1, + Parameters: snowball.Parameters{ + Alpha: 20, + K: 30, + BetaVirtuous: 5, + BetaRogue: 6, + ConcurrentRepolls: 6, + OptimalProcessing: 2, + MaxOutstandingItems: 2, + MaxItemProcessingTime: 2, + }, + }, + } + return res + }(), + testF: func(assert *assert.Assertions, given map[ids.ID]chains.SubnetConfig) { + id, _ := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + config, ok := given[id] + assert.True(ok) + assert.Equal(true, config.ValidatorOnly) + assert.Equal(111, config.ConsensusParameters.Parents) + assert.Equal(20, config.ConsensusParameters.Alpha) + // must still respect defaults + assert.Equal(30, config.ConsensusParameters.K) + }, + errMessage: "", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + subnetID, err := ids.FromString("2Ctt6eGAeo4MLqTmGa7AdRecuVMPGWEX9wSsCLBYrLhX4a394i") + assert.NoError(err) + cfgsMapBytes, err := json.Marshal(test.cfgsMap) + assert.NoError(err) + encodedFileContent := base64.StdEncoding.EncodeToString(cfgsMapBytes) + + // build viper config + v := setupViperFlags() + v.Set(SubnetConfigContentKey, encodedFileContent) + + // setup configs for default + v.Set(SnowAvalancheNumParentsKey, 1) + v.Set(SnowAvalancheBatchSizeKey, 1) + + subnetConfigs, err := getSubnetConfigs(v, []ids.ID{subnetID}) + if len(test.errMessage) > 0 { + assert.Error(err) + assert.Contains(err.Error(), test.errMessage) + } else { + assert.NoError(err) + test.testF(assert, subnetConfigs) + } + }) + } +} + // setups config json file and writes content func setupConfigJSON(t *testing.T, rootPath string, value string) string { configFilePath := filepath.Join(rootPath, "config.json") @@ -359,7 +617,7 @@ func setupFile(t *testing.T, path string, fileName string, value string) { assert.NoError(t, ioutil.WriteFile(filePath, []byte(value), 0o600)) } -func setupViper(configFilePath string) *viper.Viper { +func setupViperFlags() *viper.Viper { v := viper.New() fs := BuildFlagSet() pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.PanicOnError) // flags are now reset @@ -368,6 +626,11 @@ func setupViper(configFilePath string) *viper.Viper { if err := v.BindPFlags(pflag.CommandLine); err != nil { log.Fatal(err) } + return v +} + +func setupViper(configFilePath string) *viper.Viper { + v := setupViperFlags() // need to set it since in tests executable dir is somewhere /var/tmp/ (or wherever is designated by go) // thus it searches buildDir in /var/tmp/ // but actual buildDir resides under project_root/build diff --git a/config/flags.go b/config/flags.go index 8db8804dc573..ae4277f77d53 100644 --- a/config/flags.go +++ b/config/flags.go @@ -76,10 +76,14 @@ func addNodeFlags(fs *flag.FlagSet) { fs.Uint64(FdLimitKey, ulimit.DefaultFDLimit, "Attempts to raise the process file descriptor limit to at least this value.") // Config File - fs.String(ConfigFileKey, "", "Specifies a config file") + fs.String(ConfigFileKey, "", fmt.Sprintf("Specifies a config file. Ignored if %s is specified.", ConfigContentKey)) + fs.String(ConfigContentKey, "", "Specifies base64 encoded config content") + fs.String(ConfigContentTypeKey, "", "Specifies the format of the base64 encoded config content") - // Genesis Config File - fs.String(GenesisConfigFileKey, "", "Specifies a genesis config file (ignored when running standard networks)") + // Genesis + fs.String(GenesisConfigFileKey, "", fmt.Sprintf("Specifies a genesis config file (ignored when running standard networks or if %s is specified).", + GenesisConfigContentKey)) + fs.String(GenesisConfigContentKey, "", "Specifies base64 encoded genesis content") // Network ID fs.String(NetworkNameKey, defaultNetworkName, "Network ID this node will connect to") @@ -93,7 +97,8 @@ func addNodeFlags(fs *flag.FlagSet) { // Database fs.String(DBTypeKey, leveldb.Name, fmt.Sprintf("Database type to use. Should be one of {%s, %s, %s}", leveldb.Name, rocksdb.Name, memdb.Name)) fs.String(DBPathKey, defaultDBDir, "Path to database directory") - fs.String(DBConfigFileKey, "", "Path to database config file") + fs.String(DBConfigFileKey, "", fmt.Sprintf("Path to database config file. Ignored if %s is specified.", DBConfigContentKey)) + fs.String(DBConfigContentKey, "", "Specifies base64 encoded database config content") // Logging fs.String(LogsDirKey, "", "Logging directory for Avalanche") @@ -147,8 +152,7 @@ func addNodeFlags(fs *flag.FlagSet) { fs.Bool(NetworkAllowPrivateIPsKey, true, "Allows the node to connect peers with private IPs") fs.Bool(NetworkRequireValidatorToConnectKey, false, "If true, this node will only maintain a connection with another node if this node is a validator, the other node is a validator, or the other node is a beacon") // Peer alias configuration - fs.Duration(PeerAliasTimeoutKey, 10*time.Minute, "How often the node will attempt to connect "+ - "to an IP address previously associated with a peer (i.e. a peer alias).") + fs.Duration(PeerAliasTimeoutKey, 10*time.Minute, "How often the node will attempt to connect to an IP address previously associated with a peer (i.e. a peer alias).") // Benchlist fs.Int(BenchlistFailThresholdKey, 10, "Number of consecutive failed queries before benchlisting a node.") @@ -181,11 +185,19 @@ func addNodeFlags(fs *flag.FlagSet) { fs.String(HTTPHostKey, "127.0.0.1", "Address of the HTTP server") fs.Uint(HTTPPortKey, 9650, "Port of the HTTP server") fs.Bool(HTTPSEnabledKey, false, "Upgrade the HTTP server to HTTPs") - fs.String(HTTPSKeyFileKey, "", "TLS private key file for the HTTPs server") - fs.String(HTTPSCertFileKey, "", "TLS certificate file for the HTTPs server") + fs.String(HTTPSKeyFileKey, "", fmt.Sprintf("TLS private key file for the HTTPs server. Ignored if %s is specified.", HTTPSKeyContentKey)) + fs.String(HTTPSKeyContentKey, "", "Specifies base64 encoded TLS private key for the HTTPs server.") + fs.String(HTTPSCertFileKey, "", fmt.Sprintf("TLS certificate file for the HTTPs server. Ignored if %s is specified.", HTTPSCertContentKey)) + fs.String(HTTPSCertContentKey, "", "Specifies base64 encoded TLS certificate for the HTTPs server.") fs.String(HTTPAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network") + fs.Duration(HTTPShutdownWaitKey, 0, "Duration to wait after receiving SIGTERM or SIGINT before initiating shutdown. The /health endpoint will return unhealthy during this duration.") + fs.Duration(HTTPShutdownTimeoutKey, 10*time.Second, "Maximum duration to wait for existing connections to complete during node shutdown.") fs.Bool(APIAuthRequiredKey, false, "Require authorization token to call HTTP APIs") - fs.String(APIAuthPasswordFileKey, "", "Password file used to initially create/validate API authorization tokens. Leading and trailing whitespace is removed from the password. Can be changed via API call.") + fs.String(APIAuthPasswordFileKey, "", + fmt.Sprintf("Password file used to initially create/validate API authorization tokens. Ignored if %s is specified. Leading and trailing whitespace is removed from the password. Can be changed via API call.", + APIAuthPasswordKey)) + fs.String(APIAuthPasswordKey, "", "Specifies password for API authorization tokens.") + // Enable/Disable APIs fs.Bool(AdminAPIEnabledKey, false, "If true, this node exposes the Admin API") fs.Bool(InfoAPIEnabledKey, true, "If true, this node exposes the Info API") @@ -212,8 +224,10 @@ func addNodeFlags(fs *flag.FlagSet) { fs.Uint(StakingPortKey, 9651, "Port of the consensus server") fs.Bool(StakingEnabledKey, true, "Enable staking. If enabled, Network TLS is required.") fs.Bool(StakingEphemeralCertEnabledKey, false, "If true, the node uses an ephemeral staking key and certificate, and has an ephemeral node ID.") - fs.String(StakingKeyPathKey, defaultStakingKeyPath, "Path to the TLS private key for staking") - fs.String(StakingCertPathKey, defaultStakingCertPath, "Path to the TLS certificate for staking") + fs.String(StakingKeyPathKey, defaultStakingKeyPath, fmt.Sprintf("Path to the TLS private key for staking. Ignored if %s is specified.", StakingKeyContentKey)) + fs.String(StakingKeyContentKey, "", "Specifies base64 encoded TLS private key for staking.") + fs.String(StakingCertPathKey, defaultStakingCertPath, fmt.Sprintf("Path to the TLS certificate for staking. Ignored if %s is specified.", StakingCertContentKey)) + fs.String(StakingCertContentKey, "", "Specifies base64 encoded TLS certificate for staking.") fs.Uint64(StakingDisabledWeightKey, 100, "Weight to provide to each peer when staking is disabled") // Uptime Requirement fs.Float64(UptimeRequirementKey, genesis.LocalParams.UptimeRequirement, "Fraction of time a validator must be online to receive rewards") @@ -268,15 +282,18 @@ func addNodeFlags(fs *flag.FlagSet) { fs.Bool(IndexAllowIncompleteKey, false, "If true, allow running the node in such a way that could cause an index to miss transactions. Ignored if index is disabled.") // Config Directories - fs.String(ChainConfigDirKey, defaultChainConfigDir, "Chain specific configurations parent directory. Defaults to $HOME/.avalanchego/configs/chains/") - fs.String(SubnetConfigDirKey, defaultSubnetConfigDir, "Subnet specific configurations parent directory. Defaults to $HOME/.avalanchego/configs/subnets/") + fs.String(ChainConfigDirKey, defaultChainConfigDir, fmt.Sprintf("Chain specific configurations parent directory. Ignored if %s is specified. Defaults to $HOME/.avalanchego/configs/chains/0", ChainConfigContentKey)) + fs.String(ChainConfigContentKey, "", "Specifies base64 encoded chains configurations.") + fs.String(SubnetConfigDirKey, defaultSubnetConfigDir, fmt.Sprintf("Subnet specific configurations parent directory. Ignored if %s is specified. Defaults to $HOME/.avalanchego/configs/subnets/", SubnetConfigContentKey)) + fs.String(SubnetConfigContentKey, "", "Specifies base64 encoded subnets configurations.") // Profiles fs.String(ProfileDirKey, defaultProfileDir, "Path to the profile directory") fs.Bool(ProfileContinuousEnabledKey, false, "Whether the app should continuously produce performance profiles") fs.Duration(ProfileContinuousFreqKey, 15*time.Minute, "How frequently to rotate performance profiles") fs.Int(ProfileContinuousMaxFilesKey, 5, "Maximum number of historical profiles to keep") - fs.String(VMAliasesFileKey, defaultVMAliasFilePath, "Specifies a JSON file that maps vmIDs with custom aliases.") + fs.String(VMAliasesFileKey, defaultVMAliasFilePath, fmt.Sprintf("Specifies a JSON file that maps vmIDs with custom aliases. Ignored if %s is specified.", VMAliasesContentKey)) + fs.String(VMAliasesContentKey, "", "Specifies base64 encoded maps vmIDs with custom aliases.") // Delays fs.Duration(NetworkInitialReconnectDelayKey, time.Second, "Initial delay duration must be waited before attempting to reconnect a peer.") diff --git a/config/keys.go b/config/keys.go index 0170e5c45313..5044fc0ac71f 100644 --- a/config/keys.go +++ b/config/keys.go @@ -6,8 +6,11 @@ package config // #nosec G101 const ( ConfigFileKey = "config-file" + ConfigContentKey = "config-file-content" + ConfigContentTypeKey = "config-file-content-type" VersionKey = "version" GenesisConfigFileKey = "genesis" + GenesisConfigContentKey = "genesis-content" NetworkNameKey = "network-id" TxFeeKey = "tx-fee" CreateAssetTxFeeKey = "create-asset-tx-fee" @@ -26,6 +29,7 @@ const ( DBTypeKey = "db-type" DBPathKey = "db-dir" DBConfigFileKey = "db-config-file" + DBConfigContentKey = "db-config-file-content" PublicIPKey = "public-ip" DynamicUpdateDurationKey = "dynamic-update-duration" DynamicPublicIPResolverKey = "dynamic-public-ip" @@ -38,9 +42,14 @@ const ( HTTPPortKey = "http-port" HTTPSEnabledKey = "http-tls-enabled" HTTPSKeyFileKey = "http-tls-key-file" + HTTPSKeyContentKey = "http-tls-key-file-content" HTTPSCertFileKey = "http-tls-cert-file" + HTTPSCertContentKey = "http-tls-cert-file-content" HTTPAllowedOrigins = "http-allowed-origins" + HTTPShutdownTimeoutKey = "http-shutdown-timeout" + HTTPShutdownWaitKey = "http-shutdown-wait" APIAuthRequiredKey = "api-auth-required" + APIAuthPasswordKey = "api-auth-password" APIAuthPasswordFileKey = "api-auth-password-file" BootstrapIPsKey = "bootstrap-ips" BootstrapIDsKey = "bootstrap-ids" @@ -48,7 +57,9 @@ const ( StakingEnabledKey = "staking-enabled" StakingEphemeralCertEnabledKey = "staking-ephemeral-cert-enabled" StakingKeyPathKey = "staking-tls-key-file" + StakingKeyContentKey = "staking-tls-key-file-content" StakingCertPathKey = "staking-tls-cert-file" + StakingCertContentKey = "staking-tls-cert-file-content" StakingDisabledWeightKey = "staking-disabled-weight" NetworkInitialTimeoutKey = "network-initial-timeout" NetworkMinimumTimeoutKey = "network-minimum-timeout" @@ -126,7 +137,9 @@ const ( BootstrapMultiputMaxContainersSentKey = "bootstrap-multiput-max-containers-sent" BootstrapMultiputMaxContainersReceivedKey = "bootstrap-multiput-max-containers-received" ChainConfigDirKey = "chain-config-dir" + ChainConfigContentKey = "chain-config-content" SubnetConfigDirKey = "subnet-config-dir" + SubnetConfigContentKey = "subnet-config-content" ProfileDirKey = "profile-dir" ProfileContinuousEnabledKey = "profile-continuous-enabled" ProfileContinuousFreqKey = "profile-continuous-freq" @@ -142,4 +155,5 @@ const ( OutboundThrottlerNodeMaxAtLargeBytesKey = "throttler-outbound-node-max-at-large-bytes" UptimeMetricFreqKey = "uptime-metric-freq" VMAliasesFileKey = "vm-aliases-file" + VMAliasesContentKey = "vm-aliases-file-content" ) diff --git a/config/viper.go b/config/viper.go index cf22d62b73a2..3b99d06839fe 100644 --- a/config/viper.go +++ b/config/viper.go @@ -4,6 +4,9 @@ package config import ( + "bytes" + "encoding/base64" + "errors" "flag" "fmt" "io" @@ -13,6 +16,8 @@ import ( "github.com/spf13/viper" ) +var errMissingConfigFormat = errors.New("config content format not specified") + // BuildViper returns the viper environment from parsing config file from // default search paths and any parsed command line flags func BuildViper(fs *flag.FlagSet, args []string) (*viper.Viper, error) { @@ -31,8 +36,28 @@ func BuildViper(fs *flag.FlagSet, args []string) (*viper.Viper, error) { if err := v.BindPFlags(pfs); err != nil { return nil, err } - if v.IsSet(ConfigFileKey) { - v.SetConfigFile(os.ExpandEnv(v.GetString(ConfigFileKey))) + + // load node configs from flags or file, depending on which flags are set + switch { + case v.IsSet(ConfigContentKey): + if !v.IsSet(ConfigContentTypeKey) { + return nil, errMissingConfigFormat + } + + configContentB64 := v.GetString(ConfigContentKey) + configBytes, err := base64.StdEncoding.DecodeString(configContentB64) + if err != nil { + return nil, fmt.Errorf("unable to decode base64 content: %w", err) + } + + v.SetConfigType(v.GetString(ConfigContentTypeKey)) + if err := v.ReadConfig(bytes.NewBuffer(configBytes)); err != nil { + return nil, err + } + + case v.IsSet(ConfigFileKey): + filename := os.ExpandEnv(v.GetString(ConfigFileKey)) + v.SetConfigFile(filename) if err := v.ReadInConfig(); err != nil { return nil, err } diff --git a/database/corruptabledb/db.go b/database/corruptabledb/db.go new file mode 100644 index 000000000000..0a423b6c512d --- /dev/null +++ b/database/corruptabledb/db.go @@ -0,0 +1,112 @@ +// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package corruptabledb + +import ( + "sync/atomic" + + "github.com/ava-labs/avalanchego/database" +) + +var ( + _ database.Database = &Database{} + _ database.Batch = &batch{} +) + +// CorruptableDB is a wrapper around Database +// it prevents any future calls in case of a corruption occurs +type Database struct { + database.Database + // 1 if there was previously an error other than "not found" or "closed" + // while performing a db operation. If [errored] == 1, Has, Get, Put, + // Delete and batch writes fail with ErrAvoidCorruption. + errored uint64 +} + +// New returns a new prefixed database +func New(db database.Database) *Database { + return &Database{Database: db} +} + +// Has returns if the key is set in the database +func (db *Database) Has(key []byte) (bool, error) { + if db.corrupted() { + return false, database.ErrAvoidCorruption + } + has, err := db.Database.Has(key) + return has, db.handleError(err) +} + +// Get returns the value the key maps to in the database +func (db *Database) Get(key []byte) ([]byte, error) { + if db.corrupted() { + return nil, database.ErrAvoidCorruption + } + value, err := db.Database.Get(key) + return value, db.handleError(err) +} + +// Put sets the value of the provided key to the provided value +func (db *Database) Put(key []byte, value []byte) error { + if db.corrupted() { + return database.ErrAvoidCorruption + } + return db.handleError(db.Database.Put(key, value)) +} + +// Delete removes the key from the database +func (db *Database) Delete(key []byte) error { + if db.corrupted() { + return database.ErrAvoidCorruption + } + return db.handleError(db.Database.Delete(key)) +} + +// Stat returns a particular internal stat of the database. +func (db *Database) Stat(property string) (string, error) { + stat, err := db.Database.Stat(property) + return stat, db.handleError(err) +} + +func (db *Database) Compact(start []byte, limit []byte) error { + return db.handleError(db.Database.Compact(start, limit)) +} + +func (db *Database) Close() error { return db.handleError(db.Database.Close()) } + +func (db *Database) NewBatch() database.Batch { + return &batch{ + Batch: db.Database.NewBatch(), + db: db, + } +} + +func (db *Database) corrupted() bool { + return atomic.LoadUint64(&db.errored) == 1 +} + +func (db *Database) handleError(err error) error { + switch err { + case nil, database.ErrNotFound, database.ErrClosed: + // If we get an error other than "not found" or "closed", disallow future + // database operations to avoid possible corruption + default: + atomic.StoreUint64(&db.errored, 1) + } + return err +} + +// batch is a wrapper around the batch to contain sizes. +type batch struct { + database.Batch + db *Database +} + +// Write flushes any accumulated data to disk. +func (b *batch) Write() error { + if b.db.corrupted() { + return database.ErrAvoidCorruption + } + return b.db.handleError(b.Batch.Write()) +} diff --git a/database/corruptabledb/db_test.go b/database/corruptabledb/db_test.go new file mode 100644 index 000000000000..bf25b390beea --- /dev/null +++ b/database/corruptabledb/db_test.go @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package corruptabledb + +import ( + "errors" + "testing" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/stretchr/testify/assert" +) + +func TestInterface(t *testing.T) { + for _, test := range database.Tests { + baseDB := memdb.New() + db := New(baseDB) + test(t, db) + } +} + +// TestCorruption tests to make sure corruptabledb wrapper works as expected. +func TestCorruption(t *testing.T) { + key := []byte("hello") + value := []byte("world") + tests := map[string]func(db database.Database) error{ + "corrupted has": func(db database.Database) error { + _, err := db.Has(key) + return err + }, + "corrupted get": func(db database.Database) error { + _, err := db.Get(key) + return err + }, + "corrupted put": func(db database.Database) error { + return db.Put(key, value) + }, + "corrupted delete": func(db database.Database) error { + return db.Delete(key) + }, + "corrupted batch": func(db database.Database) error { + corruptableBatch := db.NewBatch() + assert.NotNil(t, corruptableBatch) + + err := corruptableBatch.Put(key, value) + assert.NoError(t, err) + + return corruptableBatch.Write() + }, + } + baseDB := memdb.New() + // wrap this db + corruptableDB := New(baseDB) + _ = corruptableDB.handleError(errors.New("corruption error")) + assert.True(t, corruptableDB.corrupted()) + for name, testFn := range tests { + t.Run(name, func(tt *testing.T) { + err := testFn(corruptableDB) + assert.ErrorIsf(tt, err, database.ErrAvoidCorruption, "not received the corruption error") + }) + } +} diff --git a/database/leveldb/db.go b/database/leveldb/db.go index ff16c9753945..c64ecd8892d9 100644 --- a/database/leveldb/db.go +++ b/database/leveldb/db.go @@ -7,7 +7,6 @@ import ( "bytes" "encoding/json" "fmt" - "sync/atomic" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" @@ -54,12 +53,6 @@ var ( // in binary-alphabetical order. type Database struct { *leveldb.DB - log logging.Logger - - // 1 if there was previously an error other than "not found" or "closed" - // while performing a db operation. If [errored] == 1, Has, Get, Put, - // Delete and batch writes fail with ErrAvoidCorruption. - errored uint64 } type config struct { @@ -146,43 +139,30 @@ func New(file string, configBytes []byte, log logging.Logger) (database.Database db, err = leveldb.RecoverFile(file, nil) } return &Database{ - DB: db, - log: log, + DB: db, }, err } // Has returns if the key is set in the database func (db *Database) Has(key []byte) (bool, error) { - if db.corrupted() { - return false, database.ErrAvoidCorruption - } has, err := db.DB.Has(key, nil) - return has, db.handleError(err) + return has, updateError(err) } // Get returns the value the key maps to in the database func (db *Database) Get(key []byte) ([]byte, error) { - if db.corrupted() { - return nil, database.ErrAvoidCorruption - } value, err := db.DB.Get(key, nil) - return value, db.handleError(err) + return value, updateError(err) } // Put sets the value of the provided key to the provided value func (db *Database) Put(key []byte, value []byte) error { - if db.corrupted() { - return database.ErrAvoidCorruption - } - return db.handleError(db.DB.Put(key, value, nil)) + return updateError(db.DB.Put(key, value, nil)) } // Delete removes the key from the database func (db *Database) Delete(key []byte) error { - if db.corrupted() { - return database.ErrAvoidCorruption - } - return db.handleError(db.DB.Delete(key, nil)) + return updateError(db.DB.Delete(key, nil)) } // NewBatch creates a write/delete-only buffer that is atomically committed to @@ -220,7 +200,7 @@ func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database // Stat returns a particular internal stat of the database. func (db *Database) Stat(property string) (string, error) { stat, err := db.DB.GetProperty(property) - return stat, db.handleError(err) + return stat, updateError(err) } // This comment is basically copy pasted from the underlying levelDB library: @@ -235,28 +215,11 @@ func (db *Database) Stat(property string) (string, error) { // And a nil limit is treated as a key after all keys in the DB. // Therefore if both are nil then it will compact entire DB. func (db *Database) Compact(start []byte, limit []byte) error { - return db.handleError(db.DB.CompactRange(util.Range{Start: start, Limit: limit})) + return updateError(db.DB.CompactRange(util.Range{Start: start, Limit: limit})) } // Close implements the Database interface -func (db *Database) Close() error { return db.handleError(db.DB.Close()) } - -func (db *Database) corrupted() bool { - return atomic.LoadUint64(&db.errored) == 1 -} - -func (db *Database) handleError(err error) error { - err = updateError(err) - switch err { - case nil, database.ErrNotFound, database.ErrClosed: - // If we get an error other than "not found" or "closed", disallow future - // database operations to avoid possible corruption - default: - db.log.Fatal("leveldb error: %s", err) - atomic.StoreUint64(&db.errored, 1) - } - return err -} +func (db *Database) Close() error { return updateError(db.DB.Close()) } // batch is a wrapper around a levelDB batch to contain sizes. type batch struct { @@ -284,10 +247,7 @@ func (b *batch) Size() int { return b.size } // Write flushes any accumulated data to disk. func (b *batch) Write() error { - if b.db.corrupted() { - return database.ErrAvoidCorruption - } - return b.db.handleError(b.db.DB.Write(&b.Batch, nil)) + return updateError(b.db.DB.Write(&b.Batch, nil)) } // Reset resets the batch for reuse. diff --git a/database/manager/manager.go b/database/manager/manager.go index be1fbefdb6d3..efb6d8e5f3d1 100644 --- a/database/manager/manager.go +++ b/database/manager/manager.go @@ -13,6 +13,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/corruptabledb" "github.com/ava-labs/avalanchego/database/leveldb" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/database/meterdb" @@ -130,10 +131,12 @@ func new( return nil, fmt.Errorf("couldn't create db at %s: %w", currentDBPath, err) } + wrappedDB := corruptabledb.New(currentDB) + manager := &manager{ databases: []*VersionedDatabase{ { - Database: currentDB, + Database: wrappedDB, Version: currentVersion, }, }, @@ -177,8 +180,10 @@ func new( return fmt.Errorf("couldn't create db at %s: %w", path, err) } + wrappedDB := corruptabledb.New(db) + manager.databases = append(manager.databases, &VersionedDatabase{ - Database: db, + Database: wrappedDB, Version: version, }) diff --git a/database/rocksdb/db.go b/database/rocksdb/db.go index 746533b6f878..d79966a34d10 100644 --- a/database/rocksdb/db.go +++ b/database/rocksdb/db.go @@ -12,7 +12,6 @@ import ( "os" "runtime" "sync" - "sync/atomic" "github.com/linxGnu/grocksdb" @@ -52,13 +51,6 @@ type Database struct { readOptions *grocksdb.ReadOptions iteratorOptions *grocksdb.ReadOptions writeOptions *grocksdb.WriteOptions - - log logging.Logger - - // 1 if there was previously an error other than "not found" or "closed" - // while performing a db operation. If [errored] == 1, Has, Get, Put, - // Delete and batch writes fail with ErrAvoidCorruption. - errored uint64 } // New returns a wrapped RocksDB object. @@ -93,7 +85,6 @@ func New(file string, configBytes []byte, log logging.Logger) (database.Database readOptions: grocksdb.NewDefaultReadOptions(), iteratorOptions: iteratorOptions, writeOptions: grocksdb.NewDefaultWriteOptions(), - log: log, }, nil } @@ -115,16 +106,12 @@ func (db *Database) Get(key []byte) ([]byte, error) { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return nil, database.ErrClosed - case db.corrupted(): - return nil, database.ErrAvoidCorruption } value, err := db.db.GetBytes(db.readOptions, key) if err != nil { - atomic.StoreUint64(&db.errored, 1) return nil, err } if value != nil { @@ -138,18 +125,10 @@ func (db *Database) Put(key []byte, value []byte) error { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return database.ErrClosed - case db.corrupted(): - return database.ErrAvoidCorruption - } - - err := db.db.Put(db.writeOptions, key, value) - if err != nil { - atomic.StoreUint64(&db.errored, 1) } - return err + return db.db.Put(db.writeOptions, key, value) } // Delete removes the key from the database @@ -157,18 +136,10 @@ func (db *Database) Delete(key []byte) error { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return database.ErrClosed - case db.corrupted(): - return database.ErrAvoidCorruption } - - err := db.db.Delete(db.writeOptions, key) - if err != nil { - atomic.StoreUint64(&db.errored, 1) - } - return err + return db.db.Delete(db.writeOptions, key) } // NewBatch creates a write/delete-only buffer that is atomically committed to @@ -192,11 +163,8 @@ func (db *Database) NewIterator() database.Iterator { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return &nodb.Iterator{Err: database.ErrClosed} - case db.corrupted(): - return &nodb.Iterator{Err: database.ErrAvoidCorruption} } it := db.db.NewIterator(db.iteratorOptions) @@ -216,11 +184,8 @@ func (db *Database) NewIteratorWithStart(start []byte) database.Iterator { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return &nodb.Iterator{Err: database.ErrClosed} - case db.corrupted(): - return &nodb.Iterator{Err: database.ErrAvoidCorruption} } it := db.db.NewIterator(db.iteratorOptions) @@ -240,11 +205,8 @@ func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return &nodb.Iterator{Err: database.ErrClosed} - case db.corrupted(): - return &nodb.Iterator{Err: database.ErrAvoidCorruption} } it := db.db.NewIterator(db.iteratorOptions) @@ -266,11 +228,8 @@ func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return &nodb.Iterator{Err: database.ErrClosed} - case db.corrupted(): - return &nodb.Iterator{Err: database.ErrAvoidCorruption} } it := db.db.NewIterator(db.iteratorOptions) @@ -307,11 +266,8 @@ func (db *Database) Compact(start []byte, limit []byte) error { db.lock.RLock() defer db.lock.RUnlock() - switch { - case db.db == nil: + if db.db == nil { return database.ErrClosed - case db.corrupted(): - return database.ErrAvoidCorruption } db.db.CompactRange(grocksdb.Range{Start: start, Limit: limit}) @@ -336,10 +292,6 @@ func (db *Database) Close() error { return nil } -func (db *Database) corrupted() bool { - return atomic.LoadUint64(&db.errored) == 1 -} - // batch is a wrapper around a levelDB batch to contain sizes. type batch struct { batch *grocksdb.WriteBatch @@ -369,12 +321,10 @@ func (b *batch) Write() error { b.db.lock.RLock() defer b.db.lock.RUnlock() - switch { - case b.db.db == nil: + if b.db.db == nil { return database.ErrClosed - case b.db.corrupted(): - return database.ErrAvoidCorruption } + return b.db.db.Write(b.db.writeOptions, b.batch) } diff --git a/genesis/config.go b/genesis/config.go index 7f7e0824846f..9cc2a60c8c66 100644 --- a/genesis/config.go +++ b/genesis/config.go @@ -4,6 +4,7 @@ package genesis import ( + "encoding/base64" "encoding/hex" "encoding/json" "fmt" @@ -201,16 +202,27 @@ func GetConfig(networkID uint32) *Config { } } -// GetConfigFile loads a *Config from a provided -// filepath. +// GetConfigFile loads a *Config from a provided filepath. func GetConfigFile(fp string) (*Config, error) { - b, err := ioutil.ReadFile(filepath.Clean(fp)) + bytes, err := ioutil.ReadFile(filepath.Clean(fp)) if err != nil { return nil, fmt.Errorf("unable to load file %s: %w", fp, err) } + return parseGenesisJSONBytesToConfig(bytes) +} + +// GetConfigContent loads a *Config from a provided environment variable +func GetConfigContent(genesisContent string) (*Config, error) { + bytes, err := base64.StdEncoding.DecodeString(genesisContent) + if err != nil { + return nil, fmt.Errorf("unable to decode base64 content: %w", err) + } + return parseGenesisJSONBytesToConfig(bytes) +} +func parseGenesisJSONBytesToConfig(bytes []byte) (*Config, error) { var unparsedConfig UnparsedConfig - if err := json.Unmarshal(b, &unparsedConfig); err != nil { + if err := json.Unmarshal(bytes, &unparsedConfig); err != nil { return nil, fmt.Errorf("could not unmarshal JSON: %w", err) } @@ -218,6 +230,5 @@ func GetConfigFile(fp string) (*Config, error) { if err != nil { return nil, fmt.Errorf("unable to parse config: %w", err) } - return &config, nil } diff --git a/genesis/genesis.go b/genesis/genesis.go index b061f7219780..4dfa3d475c60 100644 --- a/genesis/genesis.go +++ b/genesis/genesis.go @@ -165,43 +165,38 @@ func validateConfig(networkID uint32, config *Config) error { return nil } -// Genesis returns the genesis data of the Platform Chain. +// FromFile returns the genesis data of the Platform Chain. // // Since an Avalanche network has exactly one Platform Chain, and the Platform // Chain defines the genesis state of the network (who is staking, which chains // exist, etc.), defining the genesis state of the Platform Chain is the same as // defining the genesis state of the network. // -// Genesis accepts: +// FromFile accepts: // 1) The ID of the new network. [networkID] // 2) The location of a custom genesis config to load. [filepath] // -// If [filepath] is empty or the given network ID is Mainnet, Testnet, or Local, loads the -// network genesis state from predefined configs. If [filepath] is non-empty and networkID -// isn't Mainnet, Testnet, or Local, loads the network genesis data from the config at [filepath]. +// If [filepath] is empty or the given network ID is Mainnet, Testnet, or Local, returns error. +// If [filepath] is non-empty and networkID isn't Mainnet, Testnet, or Local, +// loads the network genesis data from the config at [filepath]. // -// Genesis returns: +// FromFile returns: // 1) The byte representation of the genesis state of the platform chain // (ie the genesis state of the network) // 2) The asset ID of AVAX -func Genesis(networkID uint32, filepath string) ([]byte, ids.ID, error) { - config := GetConfig(networkID) - if len(filepath) > 0 { - switch networkID { - case constants.MainnetID, constants.TestnetID, constants.LocalID: - return nil, ids.ID{}, fmt.Errorf( - "cannot override genesis config for standard network %s (%d)", - constants.NetworkName(networkID), - networkID, - ) - } - - customConfig, err := GetConfigFile(filepath) - if err != nil { - return nil, ids.ID{}, fmt.Errorf("unable to load provided genesis config at %s: %w", filepath, err) - } +func FromFile(networkID uint32, filepath string) ([]byte, ids.ID, error) { + switch networkID { + case constants.MainnetID, constants.TestnetID, constants.LocalID: + return nil, ids.ID{}, fmt.Errorf( + "cannot override genesis config for standard network %s (%d)", + constants.NetworkName(networkID), + networkID, + ) + } - config = customConfig + config, err := GetConfigFile(filepath) + if err != nil { + return nil, ids.ID{}, fmt.Errorf("unable to load provided genesis config at %s: %w", filepath, err) } if err := validateConfig(networkID, config); err != nil { @@ -211,6 +206,47 @@ func Genesis(networkID uint32, filepath string) ([]byte, ids.ID, error) { return FromConfig(config) } +// FromFlag returns the genesis data of the Platform Chain. +// +// Since an Avalanche network has exactly one Platform Chain, and the Platform +// Chain defines the genesis state of the network (who is staking, which chains +// exist, etc.), defining the genesis state of the Platform Chain is the same as +// defining the genesis state of the network. +// +// FromFlag accepts: +// 1) The ID of the new network. [networkID] +// 2) The content of a custom genesis config to load. [genesisContent] +// +// If [genesisContent] is empty or the given network ID is Mainnet, Testnet, or Local, returns error. +// If [genesisContent] is non-empty and networkID isn't Mainnet, Testnet, or Local, +// loads the network genesis data from [genesisContent]. +// +// FromFlag returns: +// 1) The byte representation of the genesis state of the platform chain +// (ie the genesis state of the network) +// 2) The asset ID of AVAX +func FromFlag(networkID uint32, genesisContent string) ([]byte, ids.ID, error) { + switch networkID { + case constants.MainnetID, constants.TestnetID, constants.LocalID: + return nil, ids.ID{}, fmt.Errorf( + "cannot override genesis config for standard network %s (%d)", + constants.NetworkName(networkID), + networkID, + ) + } + + customConfig, err := GetConfigContent(genesisContent) + if err != nil { + return nil, ids.ID{}, fmt.Errorf("unable to load genesis content from flag: %w", err) + } + + if err := validateConfig(networkID, customConfig); err != nil { + return nil, ids.ID{}, fmt.Errorf("genesis config validation failed: %w", err) + } + + return FromConfig(customConfig) +} + // FromConfig returns: // 1) The byte representation of the genesis state of the platform chain // (ie the genesis state of the network) diff --git a/genesis/genesis_test.go b/genesis/genesis_test.go index c8f57cb07f52..3cfb14248973 100644 --- a/genesis/genesis_test.go +++ b/genesis/genesis_test.go @@ -4,6 +4,8 @@ package genesis import ( + "encoding/base64" + "encoding/json" "fmt" "path/filepath" "testing" @@ -230,7 +232,7 @@ var ( }` ) -func TestGenesis(t *testing.T) { +func TestGenesisFromFile(t *testing.T) { tests := map[string]struct { networkID uint32 customConfig string @@ -239,12 +241,14 @@ func TestGenesis(t *testing.T) { expected string }{ "mainnet": { - networkID: constants.MainnetID, - expected: "3e6662fdbd88bcf4c7dd82cb4699c0807f1d7315d493bc38532697e11b226276", + networkID: constants.MainnetID, + customConfig: customGenesisConfigJSON, + err: "cannot override genesis config for standard network mainnet (1)", }, "fuji": { - networkID: constants.FujiID, - expected: "2e6b699298a664793bff42dae9c1af8d9c54645d8b376fd331e0b67475578e0a", + networkID: constants.FujiID, + customConfig: customGenesisConfigJSON, + err: "cannot override genesis config for standard network fuji (5)", }, "fuji (with custom specified)": { networkID: constants.FujiID, @@ -252,8 +256,9 @@ func TestGenesis(t *testing.T) { err: "cannot override genesis config for standard network fuji (5)", }, "local": { - networkID: constants.LocalID, - expected: "0495fd22c09aa8551664f0874abea2d90628c28ca897091e69188ed6052dc768", + networkID: constants.LocalID, + customConfig: customGenesisConfigJSON, + err: "cannot override genesis config for standard network local (12345)", }, "local (with custom specified)": { networkID: constants.LocalID, @@ -284,8 +289,9 @@ func TestGenesis(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - assert := assert.New(t) + // test loading of genesis from file + assert := assert.New(t) var customFile string if len(test.customConfig) > 0 { customFile = filepath.Join(t.TempDir(), "config.json") @@ -296,7 +302,97 @@ func TestGenesis(t *testing.T) { customFile = test.missingFilepath } - genesisBytes, _, err := Genesis(test.networkID, customFile) + genesisBytes, _, err := FromFile(test.networkID, customFile) + if len(test.err) > 0 { + assert.Error(err) + assert.Contains(err.Error(), test.err) + return + } + assert.NoError(err) + + genesisHash := fmt.Sprintf("%x", hashing.ComputeHash256(genesisBytes)) + assert.Equal(test.expected, genesisHash, "genesis hash mismatch") + + genesis := platformvm.Genesis{} + _, err = platformvm.GenesisCodec.Unmarshal(genesisBytes, &genesis) + assert.NoError(err) + }) + } +} + +func TestGenesisFromFlag(t *testing.T) { + tests := map[string]struct { + networkID uint32 + customConfig string + err string + expected string + }{ + "mainnet": { + networkID: constants.MainnetID, + err: "cannot override genesis config for standard network mainnet (1)", + }, + "fuji": { + networkID: constants.FujiID, + err: "cannot override genesis config for standard network fuji (5)", + }, + "local": { + networkID: constants.LocalID, + err: "cannot override genesis config for standard network local (12345)", + }, + "local (with custom specified)": { + networkID: constants.LocalID, + customConfig: customGenesisConfigJSON, + err: "cannot override genesis config for standard network local (12345)", + }, + "custom": { + networkID: 9999, + customConfig: customGenesisConfigJSON, + expected: "a1d1838586db85fe94ab1143560c3356df9ba2445794b796bba050be89f4fcb4", + }, + "custom (networkID mismatch)": { + networkID: 9999, + customConfig: localGenesisConfigJSON, + err: "networkID 9999 specified but genesis config contains networkID 12345", + }, + "custom (invalid format)": { + networkID: 9999, + customConfig: invalidGenesisConfigJSON, + err: "unable to load genesis content from flag", + }, + "custom (missing content)": { + networkID: 9999, + err: "unable to load genesis content from flag", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + // test loading of genesis content from flag/env-var + + assert := assert.New(t) + var genBytes []byte + if len(test.customConfig) == 0 { + // try loading a default config + var err error + switch test.networkID { + case constants.MainnetID: + genBytes, err = json.Marshal(&MainnetConfig) + assert.NoError(err) + case constants.TestnetID: + genBytes, err = json.Marshal(&FujiConfig) + assert.NoError(err) + case constants.LocalID: + genBytes, err = json.Marshal(&LocalConfig) + assert.NoError(err) + default: + genBytes = make([]byte, 0) + } + } else { + genBytes = []byte(test.customConfig) + } + content := base64.StdEncoding.EncodeToString(genBytes) + + genesisBytes, _, err := FromFlag(test.networkID, content) if len(test.err) > 0 { assert.Error(err) assert.Contains(err.Error(), test.err) @@ -373,7 +469,8 @@ func TestVMGenesis(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - genesisBytes, _, err := Genesis(test.networkID, "") + config := GetConfig(test.networkID) + genesisBytes, _, err := FromConfig(config) assert.NoError(err) genesisTx, err := VMGenesis(genesisBytes, vmTest.vmID) @@ -414,7 +511,8 @@ func TestAVAXAssetID(t *testing.T) { t.Run(constants.NetworkIDToNetworkName[test.networkID], func(t *testing.T) { assert := assert.New(t) - _, avaxAssetID, err := Genesis(test.networkID, "") + config := GetConfig(test.networkID) + _, avaxAssetID, err := FromConfig(config) assert.NoError(err) assert.Equal( diff --git a/health/service.go b/health/service.go index c228d2ad617b..b13cb281107e 100644 --- a/health/service.go +++ b/health/service.go @@ -53,7 +53,7 @@ type service struct { checkFreq time.Duration } -// RegisterCheckFn adds a check that calls [checkFn] to evaluate health +// RegisterCheck adds a check that calls [checkFn] to evaluate health func (s *service) RegisterCheck(name string, checkFn Check) error { check := &check{ name: name, diff --git a/message/codec_test.go b/message/codec_test.go index c901bf6caf10..bc0d4c43013f 100644 --- a/message/codec_test.go +++ b/message/codec_test.go @@ -108,11 +108,7 @@ func TestCodecPackParseGzip(t *testing.T) { fields: map[Field]interface{}{}, }, { - op: Pong, - fields: map[Field]interface{}{}, - }, - { - op: UptimePong, + op: Pong, fields: map[Field]interface{}{ Uptime: uint8(80), }, diff --git a/message/ops.go b/message/ops.go index 64b032c1b8d8..b2678e3ae12d 100644 --- a/message/ops.go +++ b/message/ops.go @@ -14,11 +14,9 @@ const ( GetVersion Op = iota _ GetPeerList - // TODO: NetworkUpgrade/Rename this to Pong - UptimePong - Ping - // TODO: NetworkUpgrade/delete this in favor of UptimePong Pong + Ping + _ // Bootstrapping: GetAcceptedFrontier AcceptedFrontier @@ -63,7 +61,6 @@ var ( PeerList, Ping, Pong, - UptimePong, } // List of all consensus request message types @@ -154,8 +151,7 @@ var ( GetPeerList: {}, PeerList: {SignedPeers}, Ping: {}, - Pong: {}, - UptimePong: {Uptime}, + Pong: {Uptime}, // Bootstrapping: GetAcceptedFrontier: {ChainID, RequestID, Deadline}, AcceptedFrontier: {ChainID, RequestID, ContainerIDs}, @@ -199,8 +195,6 @@ func (op Op) String() string { return "ping" case Pong: return "pong" - case UptimePong: - return "uptime_pong" case GetAcceptedFrontier: return "get_accepted_frontier" case AcceptedFrontier: diff --git a/message/outbound_msg_builder.go b/message/outbound_msg_builder.go index 75ff0616f5ca..451f76d10f3e 100644 --- a/message/outbound_msg_builder.go +++ b/message/outbound_msg_builder.go @@ -35,9 +35,7 @@ type OutboundMsgBuilder interface { Ping() (OutboundMessage, error) - Pong() (OutboundMessage, error) - - UptimePong(uptimePercentage uint8) (OutboundMessage, error) + Pong(uptimePercentage uint8) (OutboundMessage, error) GetAcceptedFrontier( chainID ids.ID, @@ -208,21 +206,13 @@ func (b *outMsgBuilder) Ping() (OutboundMessage, error) { ) } -func (b *outMsgBuilder) Pong() (OutboundMessage, error) { +func (b *outMsgBuilder) Pong(uptimePercentage uint8) (OutboundMessage, error) { return b.c.Pack( Pong, - nil, - Pong.Compressable(), // Pong messages can't be compressed - ) -} - -func (b *outMsgBuilder) UptimePong(uptimePercentage uint8) (OutboundMessage, error) { - return b.c.Pack( - UptimePong, map[Field]interface{}{ Uptime: uptimePercentage, }, - UptimePong.Compressable(), // UptimePong messages can't be compressed + Pong.Compressable(), // Pong messages can't be compressed ) } diff --git a/network/network.go b/network/network.go index 34f8d243f04c..cd1dae54b40c 100644 --- a/network/network.go +++ b/network/network.go @@ -728,17 +728,8 @@ func (n *network) NodeUptime() (UptimeResult, bool) { continue } - weightFloat := float64(weight) - - peerVersion := peer.versionStruct.GetValue().(version.Application) - if peerVersion.Before(version.MinUptimeVersion) { - // If the peer is running an earlier version, then ignore their - // stake - totalWeight -= weightFloat - continue - } - percent := float64(peer.observedUptime) + weightFloat := float64(weight) totalWeightedPercent += percent * weightFloat // if this peer thinks we're above requirement add the weight diff --git a/network/peer.go b/network/peer.go index 32d7807da1f8..4c58d26990ae 100644 --- a/network/peer.go +++ b/network/peer.go @@ -436,10 +436,6 @@ func (p *peer) handle(msg message.InboundMessage, msgLen float64) { p.handlePong(msg) msg.OnFinishedHandling() return - case message.UptimePong: - p.handleUptimePong(msg) - msg.OnFinishedHandling() - return case message.GetPeerList: p.handleGetPeerList(msg) msg.OnFinishedHandling() @@ -569,14 +565,6 @@ func (p *peer) sendPing() { // assumes the [stateLock] is not held func (p *peer) sendPong() { - msg, err := p.net.mc.Pong() - p.net.log.AssertNoError(err) - - p.net.send(msg, false, []*peer{p}) -} - -// assumes the [stateLock] is not held -func (p *peer) sendUptimePong() { uptimePercent, err := p.net.config.UptimeCalculator.CalculateUptimePercent(p.nodeID) if err != nil { uptimePercent = 0 @@ -585,7 +573,7 @@ func (p *peer) sendUptimePong() { // with this way we can pack it into a single byte flooredPercentage := math.Floor(uptimePercent * 100) percentage := uint8(flooredPercentage) - msg, err := p.net.mc.UptimePong(percentage) + msg, err := p.net.mc.Pong(percentage) p.net.log.AssertNoError(err) p.net.send(msg, false, []*peer{p}) @@ -824,21 +812,10 @@ func (p *peer) handlePeerList(msg message.InboundMessage) { // assumes the [stateLock] is not held func (p *peer) handlePing(_ message.InboundMessage) { p.sendPong() - p.sendUptimePong() } // assumes the [stateLock] is not held func (p *peer) handlePong(msg message.InboundMessage) { - p.pongHandle(msg, false) -} - -// assumes the [stateLock] is not held -func (p *peer) handleUptimePong(msg message.InboundMessage) { - p.pongHandle(msg, true) -} - -// assumes the [stateLock] is not held -func (p *peer) pongHandle(msg message.InboundMessage, isUptime bool) { if !p.net.shouldHoldConnection(p.nodeID) { p.net.log.Debug("disconnecting from peer %s%s at %s because the peer is not a validator", constants.NodeIDPrefix, p.nodeID, p.getIP()) p.discardIP() @@ -855,9 +832,9 @@ func (p *peer) pongHandle(msg message.InboundMessage, isUptime bool) { p.net.log.Debug("disconnecting from peer %s%s at %s version (%s) not compatible: %s", constants.NodeIDPrefix, p.nodeID, p.getIP(), peerVersion, err) p.discardIP() } - if isUptime && - // if the peer or this node is not a validator, we don't need their uptime. - p.net.config.Validators.Contains(constants.PrimaryNetworkID, p.nodeID) && + + // if the peer or this node is not a validator, we don't need their uptime. + if p.net.config.Validators.Contains(constants.PrimaryNetworkID, p.nodeID) && p.net.config.Validators.Contains(constants.PrimaryNetworkID, p.net.config.MyNodeID) { uptime := msg.Get(message.Uptime).(uint8) if uptime <= 100 { diff --git a/node/config.go b/node/config.go index be82b6f3e88a..d02a20dc845d 100644 --- a/node/config.go +++ b/node/config.go @@ -44,11 +44,14 @@ type HTTPConfig struct { HTTPHost string `json:"httpHost"` HTTPPort uint16 `json:"httpPort"` - HTTPSEnabled bool `json:"httpsEnabled"` - HTTPSKeyFile string `json:"httpsKeyFile"` - HTTPSCertFile string `json:"httpsCertFile"` + HTTPSEnabled bool `json:"httpsEnabled"` + HTTPSKey []byte `json:"-"` + HTTPSCert []byte `json:"-"` APIAllowedOrigins []string `json:"apiAllowedOrigins"` + + ShutdownTimeout time.Duration `json:"shutdownTimeout"` + ShutdownWait time.Duration `json:"shutdownWait"` } type APIConfig struct { diff --git a/node/node.go b/node/node.go index 74f7dd8d79f8..433c37cb5415 100644 --- a/node/node.go +++ b/node/node.go @@ -10,8 +10,8 @@ import ( "fmt" "net" "path/filepath" - "strings" "sync" + "time" "github.com/hashicorp/go-plugin" @@ -68,11 +68,12 @@ var ( genesisHashKey = []byte("genesisID") indexerDBPrefix = []byte{0x00} - errInvalidTLSKey = errors.New("invalid TLS key") - errPNotCreated = errors.New("P-Chain not created") - errXNotCreated = errors.New("X-Chain not created") - errCNotCreated = errors.New("C-Chain not created") - errFailedToRegisterHealthCheck = errors.New("couldn't register network health check") + errInvalidTLSKey = errors.New("invalid TLS key") + errPNotCreated = errors.New("P-Chain not created") + errXNotCreated = errors.New("X-Chain not created") + errCNotCreated = errors.New("C-Chain not created") + errNotBootstrapped = errors.New("primary subnet has not finished bootstrapping") + errShuttingDown = errors.New("server shutting down") ) // Node is an instance of an Avalanche node. @@ -102,7 +103,7 @@ type Node struct { sharedMemory atomic.Memory // Monitors node health and runs health checks - healthService health.Service + healthService health.Health // Build and parse messages, for both network layer and chain manager msgCreator message.Creator @@ -334,7 +335,7 @@ func (n *Node) Dispatch() error { var err error if n.Config.HTTPSEnabled { n.Log.Debug("initializing API server with TLS") - err = n.APIServer.DispatchTLS(n.Config.HTTPSCertFile, n.Config.HTTPSKeyFile) + err = n.APIServer.DispatchTLS(n.Config.HTTPSCert, n.Config.HTTPSKey) } else { n.Log.Debug("initializing API server without TLS") err = n.APIServer.Dispatch() @@ -495,6 +496,7 @@ func (n *Node) initAPIServer() error { n.Config.HTTPHost, n.Config.HTTPPort, n.Config.APIAllowedOrigins, + n.Config.ShutdownTimeout, n.ID, ) return nil @@ -511,6 +513,7 @@ func (n *Node) initAPIServer() error { n.Config.HTTPHost, n.Config.HTTPPort, n.Config.APIAllowedOrigins, + n.Config.ShutdownTimeout, n.ID, a, ) @@ -913,13 +916,13 @@ func (n *Node) initInfoAPI() error { // Assumes n.Log, n.Net, n.APIServer, n.HTTPLog already initialized func (n *Node) initHealthAPI() error { if !n.Config.HealthAPIEnabled { - n.healthService = health.NewNoOpService() + n.healthService = health.NewNoOp() n.Log.Info("skipping health API initialization because it has been disabled") return nil } n.Log.Info("initializing Health API") - healthService, err := health.NewService( + healthService, err := health.New( n.Config.HealthCheckFreq, n.Log, "health", @@ -931,7 +934,7 @@ func (n *Node) initHealthAPI() error { n.healthService = healthService chainsNotBootstrapped := func(pChainID ids.ID, xChainID ids.ID, cChainID ids.ID) []string { - var chains []string + chains := make([]string, 0, 3) if !n.chainManager.IsBootstrapped(pChainID) { chains = append(chains, "'P'") } @@ -944,35 +947,43 @@ func (n *Node) initHealthAPI() error { return chains } - isBootstrappedFunc := func() (interface{}, error) { - if pChainID, err := n.chainManager.Lookup("P"); err != nil { + // Passes if the P, X and C chains are finished bootstrapping + isNotBootstrappedFunc := func() (interface{}, error) { + pChainID, err := n.chainManager.Lookup("P") + if err != nil { return nil, errPNotCreated - } else if xChainID, err := n.chainManager.Lookup("X"); err != nil { + } + + xChainID, err := n.chainManager.Lookup("X") + if err != nil { return nil, errXNotCreated - } else if cChainID, err := n.chainManager.Lookup("C"); err != nil { + } + + cChainID, err := n.chainManager.Lookup("C") + if err != nil { return nil, errCNotCreated - } else if chains := chainsNotBootstrapped(pChainID, xChainID, cChainID); len(chains) != 0 { - return nil, fmt.Errorf("primary subnet has not finished bootstrapping %s", strings.Join(chains, ", ")) } - return nil, nil + chains := chainsNotBootstrapped(pChainID, xChainID, cChainID) + if len(chains) != 0 { + return chains, errNotBootstrapped + } + return chains, nil } - // Passes if the P, X and C chains are finished bootstrapping - err = n.healthService.RegisterMonotonicCheck("isBootstrapped", isBootstrappedFunc) + + err = n.healthService.RegisterMonotonicCheck("isNotBootstrapped", isNotBootstrappedFunc) if err != nil { - return fmt.Errorf("couldn't register isBootstrapped health check: %w", err) + return fmt.Errorf("couldn't register isNotBootstrapped health check: %w", err) } - // Register the network layer with the health service err = n.healthService.RegisterCheck("network", n.Net.HealthCheck) if err != nil { - return errFailedToRegisterHealthCheck + return fmt.Errorf("couldn't register network health check: %w", err) } - // Register the router with the health service err = n.healthService.RegisterCheck("router", n.Config.ConsensusRouter.HealthCheck) if err != nil { - return errFailedToRegisterHealthCheck + return fmt.Errorf("couldn't register router health check: %w", err) } handler, err := n.healthService.Handler() @@ -1151,6 +1162,23 @@ func (n *Node) Shutdown(exitCode int) { func (n *Node) shutdown() { n.Log.Info("shutting down node with exit code %d", n.ExitCode()) + + if n.healthService != nil { + // Passes if the node is not shutting down + shuttingDownCheckFunc := func() (interface{}, error) { + return map[string]interface{}{ + "isShuttingDown": true, + }, errShuttingDown + } + + err := n.healthService.RegisterCheck("shuttingDown", shuttingDownCheckFunc) + if err != nil { + n.Log.Debug("couldn't register shuttingDown health check: %s", err) + } + + time.Sleep(n.Config.ShutdownWait) + } + if n.IPCs != nil { if err := n.IPCs.Shutdown(); err != nil { n.Log.Debug("error during IPC shutdown: %s", err) diff --git a/scripts/versions.sh b/scripts/versions.sh index 16e9a8c254d6..a12fa388ead3 100644 --- a/scripts/versions.sh +++ b/scripts/versions.sh @@ -7,4 +7,4 @@ # Set up the versions to be used # Don't export them as their used in the context of other calls -coreth_version=${CORETH_VERSION:-'v0.8.1-rc.0'} +coreth_version=${CORETH_VERSION:-'v0.8.2-rc.0'} diff --git a/snow/engine/snowman/transitive.go b/snow/engine/snowman/transitive.go index aa473d25fa29..34e33a75c93b 100644 --- a/snow/engine/snowman/transitive.go +++ b/snow/engine/snowman/transitive.go @@ -342,7 +342,7 @@ func (t *Transitive) Chits(vdr ids.ShortID, requestID uint32, votes []ids.ID) er if err != nil { return err } - // Wait until [blkID] has been issued to consensus before for applying this chit. + // Wait until [blkID] has been issued to consensus before applying this chit. if !added { v.deps.Add(blkID) } diff --git a/snow/networking/router/handler.go b/snow/networking/router/handler.go index 30f7d9ca4f3f..5289ccb1ec2a 100644 --- a/snow/networking/router/handler.go +++ b/snow/networking/router/handler.go @@ -69,7 +69,7 @@ func (h *Handler) Initialize( h.validators = validators var lock sync.Mutex h.unprocessedMsgsCond = sync.NewCond(&lock) - h.cpuTracker = tracker.NewCPUTracker(uptime.IntervalFactory{}, defaultCPUInterval) + h.cpuTracker = tracker.NewCPUTracker(uptime.ContinuousFactory{}, defaultCPUInterval) var err error h.unprocessedMsgs, err = newUnprocessedMsgs(h.ctx.Log, h.validators, h.cpuTracker, "handler", h.ctx.Registerer) return err @@ -231,7 +231,7 @@ func (h *Handler) handleConsensusMsg(msg message.InboundMessage) error { if err != nil { h.ctx.Log.Debug("Malformed message %s from (%s, %s, %d) dropped. Error: %s", msg.Op(), nodeID, h.engine.Context().ChainID, reqID, err) - return nil + return h.engine.GetAcceptedFrontierFailed(nodeID, reqID) } return h.engine.AcceptedFrontier(nodeID, reqID, containerIDs) @@ -255,7 +255,7 @@ func (h *Handler) handleConsensusMsg(msg message.InboundMessage) error { if err != nil { h.ctx.Log.Debug("Malformed message %s from (%s, %s, %d) dropped. Error: %s", msg.Op(), nodeID, h.engine.Context().ChainID, reqID, err) - return nil + return h.engine.GetAcceptedFailed(nodeID, reqID) } return h.engine.Accepted(nodeID, reqID, containerIDs) @@ -332,7 +332,7 @@ func (h *Handler) handleConsensusMsg(msg message.InboundMessage) error { if err != nil { h.ctx.Log.Debug("Malformed message %s from (%s, %s, %d) dropped. Error: %s", msg.Op(), nodeID, h.engine.Context().ChainID, reqID, err) - return nil + return h.engine.QueryFailed(nodeID, reqID) } return h.engine.Chits(nodeID, reqID, votes) @@ -362,7 +362,7 @@ func (h *Handler) handleConsensusMsg(msg message.InboundMessage) error { if !ok { h.ctx.Log.Debug("Malformed message %s from (%s, %s, %d) dropped. Error: could not parse AppBytes", msg.Op(), nodeID, h.engine.Context().ChainID, reqID) - return nil + return h.engine.AppRequestFailed(nodeID, reqID) } return h.engine.AppResponse(nodeID, reqID, appBytes) diff --git a/snow/networking/tracker/cpu_tracker_test.go b/snow/networking/tracker/cpu_tracker_test.go index ca0e6f686f56..87d52e1728af 100644 --- a/snow/networking/tracker/cpu_tracker_test.go +++ b/snow/networking/tracker/cpu_tracker_test.go @@ -13,7 +13,7 @@ import ( func TestCPUTracker(t *testing.T) { halflife := time.Second - cpuTracker := NewCPUTracker(uptime.IntervalFactory{}, halflife) + cpuTracker := NewCPUTracker(uptime.ContinuousFactory{}, halflife) vdr1 := ids.ShortID{1} vdr2 := ids.ShortID{2} diff --git a/staking/tls.go b/staking/tls.go index 69ca99c9c681..f230a10f5c88 100644 --- a/staking/tls.go +++ b/staking/tls.go @@ -70,20 +70,26 @@ func InitNodeStakingKeyPair(keyPath, certPath string) error { if err := os.Chmod(keyPath, perms.ReadOnly); err != nil { // Make key read-only return fmt.Errorf("couldn't change permissions on key: %w", err) } - return nil } -func LoadTLSCert(keyPath, certPath string) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) +func LoadTLSCertFromBytes(keyBytes, certBytes []byte) (*tls.Certificate, error) { + cert, err := tls.X509KeyPair(certBytes, keyBytes) if err != nil { - return nil, err + return nil, fmt.Errorf("failed creating cert: %w", err) } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + return &cert, err +} + +func LoadTLSCertFromFiles(keyPath, certPath string) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return nil, err } - return &cert, nil + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + return &cert, err } func NewTLSCert() (*tls.Certificate, error) { @@ -96,10 +102,7 @@ func NewTLSCert() (*tls.Certificate, error) { return nil, err } cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return nil, err - } - return &cert, nil + return &cert, err } // Creates a new staking private key / staking certificate pair. @@ -137,6 +140,5 @@ func NewCertAndKeyBytes() ([]byte, []byte, error) { if err := pem.Encode(&keyBuff, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return nil, nil, fmt.Errorf("couldn't write private key: %w", err) } - return certBuff.Bytes(), keyBuff.Bytes(), nil } diff --git a/utils/logging/log.go b/utils/logging/log.go index e359eb3440e5..8bd179945a9c 100644 --- a/utils/logging/log.go +++ b/utils/logging/log.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "sync" "time" @@ -194,7 +195,7 @@ func (l *Log) log(level Level, format string, args ...interface{}) { func (l *Log) format(level Level, format string, args ...interface{}) string { loc := "?" if _, file, no, ok := runtime.Caller(3); ok { - localFile := file[len(filePrefix):] + localFile := strings.TrimPrefix(file, filePrefix) loc = fmt.Sprintf("%s#%d", localFile, no) } text := fmt.Sprintf("%s: %s", loc, fmt.Sprintf(format, args...)) diff --git a/utils/uptime/interval_meter.go b/utils/uptime/interval_meter.go deleted file mode 100644 index 6565524518a3..000000000000 --- a/utils/uptime/interval_meter.go +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package uptime - -import ( - "time" -) - -const ( - maxSkippedIntervals = 32 -) - -var ( - _ Meter = &intervalMeter{} - _ Factory = &IntervalFactory{} -) - -// IntervalFactory implements the Factory interface by returning an interval -// based meter. -type IntervalFactory struct{} - -// New implements the Factory interface. -func (IntervalFactory) New(halflife time.Duration) Meter { - return NewIntervalMeter(halflife) -} - -type intervalMeter struct { - running bool - - halflife time.Duration - previousValues time.Duration - currentValue time.Duration - nextHalvening time.Time - lastUpdated time.Time -} - -// NewIntervalMeter returns a new Meter with the provided halflife -func NewIntervalMeter(halflife time.Duration) Meter { - return &intervalMeter{halflife: halflife} -} - -func (a *intervalMeter) Start(currentTime time.Time) { - a.update(currentTime, true) -} - -func (a *intervalMeter) Stop(currentTime time.Time) { - a.update(currentTime, false) -} - -func (a *intervalMeter) update(currentTime time.Time, running bool) { - if a.running == running { - return - } - a.Read(currentTime) - a.running = running -} - -func (a *intervalMeter) Read(currentTime time.Time) float64 { - if currentTime.After(a.lastUpdated) { - // try to finish the current round - if !currentTime.Before(a.nextHalvening) { - if a.running { - a.currentValue += a.nextHalvening.Sub(a.lastUpdated) - } - a.lastUpdated = a.nextHalvening - a.nextHalvening = a.nextHalvening.Add(a.halflife) - a.previousValues += a.currentValue >> 1 - a.currentValue = 0 - a.previousValues >>= 1 - - // try to skip future rounds - if totalTime := currentTime.Sub(a.lastUpdated); totalTime >= a.halflife { - numSkippedPeriods := totalTime / a.halflife - if numSkippedPeriods > maxSkippedIntervals { - a.lastUpdated = currentTime - a.nextHalvening = currentTime.Add(a.halflife) - - // If this meter hasn't been read in a long time, avoid - // potential shifting overflow issues and just jump to a - // reasonable value. - a.currentValue = 0 - if a.running { - a.previousValues = a.halflife >> 1 - return 1 - } - a.previousValues = 0 - return 0 - } - - if numSkippedPeriods > 0 { - a.previousValues >>= numSkippedPeriods - if a.running { - additionalRunningTime := a.halflife - (a.halflife >> numSkippedPeriods) - a.previousValues += additionalRunningTime >> 1 - } - skippedDuration := numSkippedPeriods * a.halflife - a.lastUpdated = a.lastUpdated.Add(skippedDuration) - a.nextHalvening = a.nextHalvening.Add(skippedDuration) - } - } - } - - // increment the value for the current round - if a.running { - a.currentValue += currentTime.Sub(a.lastUpdated) - } - a.lastUpdated = currentTime - } - - spentTime := a.halflife - a.nextHalvening.Sub(a.lastUpdated) - if spentTime == 0 { - return float64(2*a.previousValues) / float64(a.halflife) - } - spentTime <<= 1 - expectedValue := float64(a.currentValue) / float64(spentTime) - return float64(a.previousValues)/float64(a.halflife) + expectedValue -} diff --git a/utils/uptime/meter_test.go b/utils/uptime/meter_test.go index 872a34fd724d..43ffe7d4368c 100644 --- a/utils/uptime/meter_test.go +++ b/utils/uptime/meter_test.go @@ -22,10 +22,6 @@ var ( name: "continuous", factory: ContinuousFactory{}, }, - { - name: "interval", - factory: IntervalFactory{}, - }, } meterTests = []struct { @@ -131,14 +127,14 @@ func StandardUsageTest(t *testing.T, factory Factory) { t.Fatalf("Wrong uptime value. Expected %f got %f", .625, uptime) } - currentTime = currentTime.Add((maxSkippedIntervals + 2) * halflife) + currentTime = currentTime.Add(34 * halflife) if uptime := m.Read(currentTime); math.Abs(uptime-1) > epsilon { t.Fatalf("Wrong uptime value. Expected %d got %f", 1, uptime) } m.Stop(currentTime) - currentTime = currentTime.Add((maxSkippedIntervals + 2) * halflife) + currentTime = currentTime.Add(34 * halflife) if uptime := m.Read(currentTime); math.Abs(uptime-0) > epsilon { t.Fatalf("Wrong uptime value. Expected %d got %f", 0, uptime) } diff --git a/version/constants.go b/version/constants.go index a0797ae8c5ae..bc9fbf26c7e0 100644 --- a/version/constants.go +++ b/version/constants.go @@ -11,7 +11,7 @@ import ( // These are globals that describe network upgrades and node versions var ( - Current = NewDefaultVersion(1, 7, 1) + Current = NewDefaultVersion(1, 7, 2) CurrentApp = NewDefaultApplication(constants.PlatformName, Current.Major(), Current.Minor(), Current.Patch()) MinimumCompatibleVersion = NewDefaultApplication(constants.PlatformName, 1, 7, 0) PrevMinimumCompatibleVersion = NewDefaultApplication(constants.PlatformName, 1, 6, 0) @@ -19,8 +19,6 @@ var ( PrevMinimumUnmaskedVersion = NewDefaultApplication(constants.PlatformName, 1, 0, 0) VersionParser = NewDefaultApplicationParser() - MinUptimeVersion = NewDefaultApplication(constants.PlatformName, 1, 6, 5) - CurrentDatabase = DatabaseVersion1_4_5 PrevDatabase = DatabaseVersion1_0_0 diff --git a/vms/platformvm/abort_block.go b/vms/platformvm/abort_block.go index dea4b3c711ca..07c67cfbaff5 100644 --- a/vms/platformvm/abort_block.go +++ b/vms/platformvm/abort_block.go @@ -43,10 +43,6 @@ func (a *AbortBlock) Verify() error { blkID := a.ID() if err := a.DoubleDecisionBlock.Verify(); err != nil { - a.vm.ctx.Log.Trace("rejecting block %s due to a failed verification: %s", blkID, err) - if err := a.Reject(); err != nil { - a.vm.ctx.Log.Error("failed to reject abort block %s due to %s", blkID, err) - } return err } @@ -58,14 +54,10 @@ func (a *AbortBlock) Verify() error { // The parent of an Abort block should always be a proposal parent, ok := parentIntf.(*ProposalBlock) if !ok { - a.vm.ctx.Log.Trace("rejecting block %s due to an incorrect parent type", blkID) - if err := a.Reject(); err != nil { - a.vm.ctx.Log.Error("failed to reject abort block %s due to %s", blkID, err) - } return errInvalidBlockType } - a.onAcceptState, a.onAcceptFunc = parent.onAbort() + a.onAcceptState = parent.onAbortState a.timestamp = a.onAcceptState.GetTimestamp() a.vm.currentBlocks[blkID] = a diff --git a/vms/platformvm/add_delegator_tx.go b/vms/platformvm/add_delegator_tx.go index a47cd28e3159..172bf8803beb 100644 --- a/vms/platformvm/add_delegator_tx.go +++ b/vms/platformvm/add_delegator_tx.go @@ -110,7 +110,7 @@ func (tx *UnsignedAddDelegatorTx) SyntacticVerify(ctx *snow.Context) error { // Attempts to verify this transaction with the provided state. func (tx *UnsignedAddDelegatorTx) SemanticVerify(vm *VM, parentState MutableState, stx *Tx) error { - _, _, _, _, err := tx.Execute(vm, parentState, stx) + _, _, err := tx.Execute(vm, parentState, stx) // We ignore [errFutureStakeTime] here because an advanceTimeTx will be // issued before this transaction is issued. if errors.Is(err, errFutureStakeTime) { @@ -127,24 +127,22 @@ func (tx *UnsignedAddDelegatorTx) Execute( ) ( VersionedState, VersionedState, - func() error, - func() error, - TxError, + error, ) { // Verify the tx is well-formed if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } duration := tx.Validator.Duration() switch { case duration < vm.MinStakeDuration: // Ensure staking length is not too short - return nil, nil, nil, nil, permError{errStakeTooShort} + return nil, nil, errStakeTooShort case duration > vm.MaxStakeDuration: // Ensure staking length is not too long - return nil, nil, nil, nil, permError{errStakeTooLong} + return nil, nil, errStakeTooLong case tx.Validator.Wght < vm.MinDelegatorStake: // Ensure validator is staking at least the minimum amount - return nil, nil, nil, nil, permError{errWeightTooSmall} + return nil, nil, errWeightTooSmall } outs := make([]*avax.TransferableOutput, len(tx.Outs)+len(tx.Stake)) @@ -159,24 +157,20 @@ func (tx *UnsignedAddDelegatorTx) Execute( // Ensure the proposed validator starts after the current timestamp validatorStartTime := tx.StartTime() if !currentTimestamp.Before(validatorStartTime) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "chain timestamp (%s) not before validator's start time (%s)", - currentTimestamp, - validatorStartTime, - ), - } + return nil, nil, fmt.Errorf( + "chain timestamp (%s) not before validator's start time (%s)", + currentTimestamp, + validatorStartTime, + ) } currentValidator, err := currentStakers.GetValidator(tx.Validator.NodeID) if err != nil && err != database.ErrNotFound { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), - } + return nil, nil, fmt.Errorf( + "failed to find whether %s is a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } pendingValidator := pendingStakers.GetValidator(tx.Validator.NodeID) @@ -199,22 +193,20 @@ func (tx *UnsignedAddDelegatorTx) Execute( vdrTx, err = pendingStakers.GetValidatorTx(tx.Validator.NodeID) if err != nil { if err == database.ErrNotFound { - return nil, nil, nil, nil, permError{errDelegatorSubset} - } - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), + return nil, nil, errDelegatorSubset } + return nil, nil, fmt.Errorf( + "failed to find whether %s is a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } } // Ensure that the period this delegator delegates is a subset of the // time the validator validates. if !tx.Validator.BoundedBy(vdrTx.StartTime(), vdrTx.EndTime()) { - return nil, nil, nil, nil, permError{errDelegatorSubset} + return nil, nil, errDelegatorSubset } // Ensure that the period this delegator delegates wouldn't become over @@ -222,12 +214,12 @@ func (tx *UnsignedAddDelegatorTx) Execute( vdrWeight := vdrTx.Weight() currentWeight, err := math.Add64(vdrWeight, currentDelegatorWeight) if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } maximumWeight, err := math.Mul64(MaxValidatorWeightFactor, vdrWeight) if err != nil { - return nil, nil, nil, nil, permError{errStakeOverflow} + return nil, nil, errStakeOverflow } if !currentTimestamp.Before(vm.ApricotPhase3Time) { @@ -242,24 +234,15 @@ func (tx *UnsignedAddDelegatorTx) Execute( maximumWeight, ) if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } if !canDelegate { - return nil, nil, nil, nil, permError{errOverDelegated} + return nil, nil, errOverDelegated } // Verify the flowcheck if err := vm.semanticVerifySpend(parentState, tx, tx.Ins, outs, stx.Creds, vm.AddStakerTxFee, vm.ctx.AVAXAssetID); err != nil { - switch err.(type) { - case permError: - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - default: - return nil, nil, nil, nil, tempError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - } + return nil, nil, fmt.Errorf("failed semanticVerifySpend: %w", err) } // Make sure the tx doesn't start too far in the future. This is done @@ -267,7 +250,7 @@ func (tx *UnsignedAddDelegatorTx) Execute( // error. maxStartTime := currentTimestamp.Add(maxFutureStartTime) if validatorStartTime.After(maxStartTime) { - return nil, nil, nil, nil, permError{errFutureStakeTime} + return nil, nil, errFutureStakeTime } } @@ -288,7 +271,7 @@ func (tx *UnsignedAddDelegatorTx) Execute( // Produce the UTXOS produceOutputs(onAbortState, txID, vm.ctx.AVAXAssetID, outs) - return onCommitState, onAbortState, nil, nil, nil + return onCommitState, onAbortState, nil } // InitiallyPrefersCommit returns true if the proposed validators start time is diff --git a/vms/platformvm/add_delegator_tx_test.go b/vms/platformvm/add_delegator_tx_test.go index ed9b704b009a..fca1fb486a9e 100644 --- a/vms/platformvm/add_delegator_tx_test.go +++ b/vms/platformvm/add_delegator_tx_test.go @@ -329,7 +329,7 @@ func TestAddDelegatorTxExecute(t *testing.T) { if tt.setup != nil { tt.setup(vm) } - if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil && !tt.shouldErr { + if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil && !tt.shouldErr { t.Fatalf("shouldn't have errored but got %s", err) } else if err == nil && tt.shouldErr { t.Fatalf("expected test to error but got none") diff --git a/vms/platformvm/add_subnet_validator_tx.go b/vms/platformvm/add_subnet_validator_tx.go index 7da1cfdf53f2..0b36e24afcde 100644 --- a/vms/platformvm/add_subnet_validator_tx.go +++ b/vms/platformvm/add_subnet_validator_tx.go @@ -72,7 +72,7 @@ func (tx *UnsignedAddSubnetValidatorTx) SyntacticVerify(ctx *snow.Context) error // Attempts to verify this transaction with the provided state. func (tx *UnsignedAddSubnetValidatorTx) SemanticVerify(vm *VM, parentState MutableState, stx *Tx) error { - _, _, _, _, err := tx.Execute(vm, parentState, stx) + _, _, err := tx.Execute(vm, parentState, stx) // We ignore [errFutureStakeTime] here because an advanceTimeTx will be // issued before this transaction is issued. if errors.Is(err, errFutureStakeTime) { @@ -89,23 +89,21 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( ) ( VersionedState, VersionedState, - func() error, - func() error, - TxError, + error, ) { // Verify the tx is well-formed if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } duration := tx.Validator.Duration() switch { case duration < vm.MinStakeDuration: // Ensure staking length is not too short - return nil, nil, nil, nil, permError{errStakeTooShort} + return nil, nil, errStakeTooShort case duration > vm.MaxStakeDuration: // Ensure staking length is not too long - return nil, nil, nil, nil, permError{errStakeTooLong} + return nil, nil, errStakeTooLong case len(stx.Creds) == 0: - return nil, nil, nil, nil, permError{errWrongNumberOfCredentials} + return nil, nil, errWrongNumberOfCredentials } currentStakers := parentState.CurrentStakerChainState() @@ -116,24 +114,20 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( // Ensure the proposed validator starts after the current timestamp validatorStartTime := tx.StartTime() if !currentTimestamp.Before(validatorStartTime) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "validator's start time (%s) is at or after current chain timestamp (%s)", - currentTimestamp, - validatorStartTime, - ), - } + return nil, nil, fmt.Errorf( + "validator's start time (%s) is at or after current chain timestamp (%s)", + currentTimestamp, + validatorStartTime, + ) } currentValidator, err := currentStakers.GetValidator(tx.Validator.NodeID) if err != nil && err != database.ErrNotFound { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), - } + return nil, nil, fmt.Errorf( + "failed to find whether %s is a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } var vdrTx *UnsignedAddValidatorTx @@ -145,12 +139,10 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( // Ensure that this transaction isn't a duplicate add validator tx. subnets := currentValidator.SubnetValidators() if _, validates := subnets[tx.Validator.Subnet]; validates { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "already validating subnet %s", - tx.Validator.Subnet, - ), - } + return nil, nil, fmt.Errorf( + "already validating subnet %s", + tx.Validator.Subnet, + ) } } else { // This validator is attempting to validate with a node that hasn't @@ -158,34 +150,30 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( vdrTx, err = pendingStakers.GetValidatorTx(tx.Validator.NodeID) if err != nil { if err == database.ErrNotFound { - return nil, nil, nil, nil, permError{errDSValidatorSubset} - } - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), + return nil, nil, errDSValidatorSubset } + return nil, nil, fmt.Errorf( + "failed to find whether %s is a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } } // Ensure that the period this validator validates the specified subnet // is a subset of the time they validate the primary network. if !tx.Validator.BoundedBy(vdrTx.StartTime(), vdrTx.EndTime()) { - return nil, nil, nil, nil, permError{errDSValidatorSubset} + return nil, nil, errDSValidatorSubset } // Ensure that this transaction isn't a duplicate add validator tx. pendingValidator := pendingStakers.GetValidator(tx.Validator.NodeID) subnets := pendingValidator.SubnetValidators() if _, validates := subnets[tx.Validator.Subnet]; validates { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "already validating subnet %s", - tx.Validator.Subnet, - ), - } + return nil, nil, fmt.Errorf( + "already validating subnet %s", + tx.Validator.Subnet, + ) } baseTxCredsLen := len(stx.Creds) - 1 @@ -195,34 +183,30 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( subnetIntf, _, err := parentState.GetTx(tx.Validator.Subnet) if err != nil { if err == database.ErrNotFound { - return nil, nil, nil, nil, permError{errDSValidatorSubset} - } - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "couldn't find subnet %s with %w", - tx.Validator.Subnet, - err, - ), + return nil, nil, errDSValidatorSubset } + return nil, nil, fmt.Errorf( + "couldn't find subnet %s with %w", + tx.Validator.Subnet, + err, + ) } subnet, ok := subnetIntf.UnsignedTx.(*UnsignedCreateSubnetTx) if !ok { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "%s is not a subnet", - tx.Validator.Subnet, - ), - } + return nil, nil, fmt.Errorf( + "%s is not a subnet", + tx.Validator.Subnet, + ) } if err := vm.fx.VerifyPermission(tx, tx.SubnetAuth, subnetCred, subnet.Owner); err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } // Verify the flowcheck if err := vm.semanticVerifySpend(parentState, tx, tx.Ins, tx.Outs, baseTxCreds, vm.TxFee, vm.ctx.AVAXAssetID); err != nil { - return nil, nil, nil, nil, err + return nil, nil, err } // Make sure the tx doesn't start too far in the future. This is done @@ -230,7 +214,7 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( // error. maxStartTime := currentTimestamp.Add(maxFutureStartTime) if validatorStartTime.After(maxStartTime) { - return nil, nil, nil, nil, permError{errFutureStakeTime} + return nil, nil, errFutureStakeTime } } @@ -251,7 +235,7 @@ func (tx *UnsignedAddSubnetValidatorTx) Execute( // Produce the UTXOS produceOutputs(onAbortState, txID, vm.ctx.AVAXAssetID, tx.Outs) - return onCommitState, onAbortState, nil, nil, nil + return onCommitState, onAbortState, nil } // InitiallyPrefersCommit returns true if the proposed validators start time is diff --git a/vms/platformvm/add_subnet_validator_tx_test.go b/vms/platformvm/add_subnet_validator_tx_test.go index bb9dc092cd35..4253007bc423 100644 --- a/vms/platformvm/add_subnet_validator_tx_test.go +++ b/vms/platformvm/add_subnet_validator_tx_test.go @@ -153,7 +153,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because validator stops validating primary network earlier than subnet") } @@ -171,7 +171,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil { t.Fatal(err) } @@ -212,7 +212,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because validator not in the current or pending validator sets of the primary network") } @@ -239,7 +239,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because validator starts validating primary " + "network before starting to validate primary network") } @@ -256,7 +256,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because validator stops validating primary " + "network after stops validating primary network") } @@ -273,7 +273,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err != nil { t.Fatalf("should have passed verification") } @@ -292,7 +292,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed verification because starts validating at current timestamp") } @@ -337,7 +337,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { t.Fatal(err) } - if _, _, _, _, err := duplicateSubnetTx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, duplicateSubnetTx); err == nil { + if _, _, err := duplicateSubnetTx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, duplicateSubnetTx); err == nil { t.Fatal("should have failed verification because validator already validating the specified subnet") } @@ -362,7 +362,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { if err != nil { t.Fatal(err) } - if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed verification because tx has 3 signatures but only 2 needed") } @@ -384,7 +384,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { tx.UnsignedTx.(*UnsignedAddSubnetValidatorTx).SubnetAuth.(*secp256k1fx.Input).SigIndices[1:] // This tx was syntactically verified when it was created...pretend it wasn't so we don't use cache tx.UnsignedTx.(*UnsignedAddSubnetValidatorTx).syntacticallyVerified = false - if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed verification because not enough control sigs") } @@ -407,7 +407,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { t.Fatal(err) } copy(tx.Creds[0].(*secp256k1fx.Credential).Sigs[0][:], sig) - if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed verification because a control sig is invalid") } @@ -435,7 +435,7 @@ func TestAddSubnetValidatorTxExecute(t *testing.T) { t.Fatal(err) } - if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed verification because validator already in pending validator set of the specified subnet") } } diff --git a/vms/platformvm/add_validator_tx.go b/vms/platformvm/add_validator_tx.go index efb99619206d..fb1ed81be6f4 100644 --- a/vms/platformvm/add_validator_tx.go +++ b/vms/platformvm/add_validator_tx.go @@ -120,7 +120,7 @@ func (tx *UnsignedAddValidatorTx) SyntacticVerify(ctx *snow.Context) error { // Attempts to verify this transaction with the provided state. func (tx *UnsignedAddValidatorTx) SemanticVerify(vm *VM, parentState MutableState, stx *Tx) error { - _, _, _, _, err := tx.Execute(vm, parentState, stx) + _, _, err := tx.Execute(vm, parentState, stx) // We ignore [errFutureStakeTime] here because an advanceTimeTx will be // issued before this transaction is issued. if errors.Is(err, errFutureStakeTime) { @@ -137,30 +137,28 @@ func (tx *UnsignedAddValidatorTx) Execute( ) ( VersionedState, VersionedState, - func() error, - func() error, - TxError, + error, ) { // Verify the tx is well-formed if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } switch { case tx.Validator.Wght < vm.MinValidatorStake: // Ensure validator is staking at least the minimum amount - return nil, nil, nil, nil, permError{errWeightTooSmall} + return nil, nil, errWeightTooSmall case tx.Validator.Wght > vm.MaxValidatorStake: // Ensure validator isn't staking too much - return nil, nil, nil, nil, permError{errWeightTooLarge} + return nil, nil, errWeightTooLarge case tx.Shares < vm.MinDelegationFee: - return nil, nil, nil, nil, permError{errInsufficientDelegationFee} + return nil, nil, errInsufficientDelegationFee } duration := tx.Validator.Duration() switch { case duration < vm.MinStakeDuration: // Ensure staking length is not too short - return nil, nil, nil, nil, permError{errStakeTooShort} + return nil, nil, errStakeTooShort case duration > vm.MaxStakeDuration: // Ensure staking length is not too long - return nil, nil, nil, nil, permError{errStakeTooLong} + return nil, nil, errStakeTooLong } currentStakers := parentState.CurrentStakerChainState() @@ -175,67 +173,48 @@ func (tx *UnsignedAddValidatorTx) Execute( // Ensure the proposed validator starts after the current time startTime := tx.StartTime() if !currentTimestamp.Before(startTime) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "validator's start time (%s) at or before current timestamp (%s)", - startTime, - currentTimestamp, - ), - } + return nil, nil, fmt.Errorf( + "validator's start time (%s) at or before current timestamp (%s)", + startTime, + currentTimestamp, + ) } // Ensure this validator isn't currently a validator. _, err := currentStakers.GetValidator(tx.Validator.NodeID) if err == nil { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "%s is already a primary network validator", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - ), - } + return nil, nil, fmt.Errorf( + "%s is already a primary network validator", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + ) } if err != database.ErrNotFound { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), - } + return nil, nil, fmt.Errorf( + "failed to find whether %s is a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } // Ensure this validator isn't about to become a validator. _, err = pendingStakers.GetValidatorTx(tx.Validator.NodeID) if err == nil { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "%s is about to become a primary network validator", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - ), - } + return nil, nil, fmt.Errorf( + "%s is about to become a primary network validator", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + ) } if err != database.ErrNotFound { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to find whether %s is about to become a validator: %w", - tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), - err, - ), - } + return nil, nil, fmt.Errorf( + "failed to find whether %s is about to become a validator: %w", + tx.Validator.NodeID.PrefixedString(constants.NodeIDPrefix), + err, + ) } // Verify the flowcheck if err := vm.semanticVerifySpend(parentState, tx, tx.Ins, outs, stx.Creds, vm.AddStakerTxFee, vm.ctx.AVAXAssetID); err != nil { - switch err.(type) { - case permError: - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - default: - return nil, nil, nil, nil, tempError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - } + return nil, nil, fmt.Errorf("failed semanticVerifySpend: %w", err) } // Make sure the tx doesn't start too far in the future. This is done @@ -243,7 +222,7 @@ func (tx *UnsignedAddValidatorTx) Execute( // error. maxStartTime := currentTimestamp.Add(maxFutureStartTime) if startTime.After(maxStartTime) { - return nil, nil, nil, nil, permError{errFutureStakeTime} + return nil, nil, errFutureStakeTime } } @@ -264,7 +243,7 @@ func (tx *UnsignedAddValidatorTx) Execute( // Produce the UTXOS produceOutputs(onAbortState, txID, vm.ctx.AVAXAssetID, outs) - return onCommitState, onAbortState, nil, nil, nil + return onCommitState, onAbortState, nil } // InitiallyPrefersCommit returns true if the proposed validators start time is diff --git a/vms/platformvm/add_validator_tx_test.go b/vms/platformvm/add_validator_tx_test.go index c2b86dd4e2f6..304cde3a31b4 100644 --- a/vms/platformvm/add_validator_tx_test.go +++ b/vms/platformvm/add_validator_tx_test.go @@ -185,7 +185,7 @@ func TestAddValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've errored because start time too early") } @@ -201,7 +201,7 @@ func TestAddValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've errored because start time too far in the future") } @@ -217,7 +217,7 @@ func TestAddValidatorTxExecute(t *testing.T) { ids.ShortEmpty, // change addr ); err != nil { t.Fatal(err) - } else if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've errored because validator already validating") } @@ -250,7 +250,7 @@ func TestAddValidatorTxExecute(t *testing.T) { t.Fatal(err) } - if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because validator in pending validator set") } @@ -276,7 +276,7 @@ func TestAddValidatorTxExecute(t *testing.T) { vm.internalState.DeleteUTXO(utxoID) } // Now keys[0] has no funds - if _, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + if _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should have failed because tx fee paying key has no funds") } } diff --git a/vms/platformvm/advance_time_tx.go b/vms/platformvm/advance_time_tx.go index 0fd91e1b8c39..2ab80fd83414 100644 --- a/vms/platformvm/advance_time_tx.go +++ b/vms/platformvm/advance_time_tx.go @@ -46,7 +46,7 @@ func (tx *UnsignedAdvanceTimeTx) SyntacticVerify(*snow.Context) error { // Attempts to verify this transaction with the provided state. func (tx *UnsignedAdvanceTimeTx) SemanticVerify(vm *VM, parentState MutableState, stx *Tx) error { - _, _, _, _, err := tx.Execute(vm, parentState, stx) + _, _, err := tx.Execute(vm, parentState, stx) return err } @@ -58,55 +58,47 @@ func (tx *UnsignedAdvanceTimeTx) Execute( ) ( VersionedState, VersionedState, - func() error, - func() error, - TxError, + error, ) { switch { case tx == nil: - return nil, nil, nil, nil, tempError{errNilTx} + return nil, nil, errNilTx case len(stx.Creds) != 0: - return nil, nil, nil, nil, permError{errWrongNumberOfCredentials} + return nil, nil, errWrongNumberOfCredentials } txTimestamp := tx.Timestamp() localTimestamp := vm.clock.Time() localTimestampPlusSync := localTimestamp.Add(syncBound) if localTimestampPlusSync.Before(txTimestamp) { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "proposed time (%s) is too far in the future relative to local time (%s)", - txTimestamp, - localTimestamp, - ), - } + return nil, nil, fmt.Errorf( + "proposed time (%s) is too far in the future relative to local time (%s)", + txTimestamp, + localTimestamp, + ) } if chainTimestamp := parentState.GetTimestamp(); !txTimestamp.After(chainTimestamp) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "proposed timestamp (%s), not after current timestamp (%s)", - txTimestamp, - chainTimestamp, - ), - } + return nil, nil, fmt.Errorf( + "proposed timestamp (%s), not after current timestamp (%s)", + txTimestamp, + chainTimestamp, + ) } // Only allow timestamp to move forward as far as the time of next staker // set change time nextStakerChangeTime, err := vm.nextStakerChangeTime(parentState) if err != nil { - return nil, nil, nil, nil, tempError{err} + return nil, nil, err } if txTimestamp.After(nextStakerChangeTime) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "proposed timestamp (%s) later than next staker change time (%s)", - txTimestamp, - nextStakerChangeTime, - ), - } + return nil, nil, fmt.Errorf( + "proposed timestamp (%s) later than next staker change time (%s)", + txTimestamp, + nextStakerChangeTime, + ) } currentSupply := parentState.GetCurrentSupply() @@ -136,7 +128,7 @@ pendingStakerLoop: ) currentSupply, err = safemath.Add64(currentSupply, r) if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } toAddDelegatorsWithRewardToCurrent = append(toAddDelegatorsWithRewardToCurrent, &validatorReward{ @@ -157,7 +149,7 @@ pendingStakerLoop: ) currentSupply, err = safemath.Add64(currentSupply, r) if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } toAddValidatorsWithRewardToCurrent = append(toAddValidatorsWithRewardToCurrent, &validatorReward{ @@ -177,9 +169,7 @@ pendingStakerLoop: } numToRemoveFromPending++ default: - return nil, nil, nil, nil, permError{ - fmt.Errorf("expected validator but got %T", tx.UnsignedTx), - } + return nil, nil, fmt.Errorf("expected validator but got %T", tx.UnsignedTx) } } newlyPendingStakers := pendingStakers.DeleteStakers(numToRemoveFromPending) @@ -202,7 +192,7 @@ currentStakerLoop: // We shouldn't be removing any primary network validators here break currentStakerLoop default: - return nil, nil, nil, nil, permError{errWrongTxType} + return nil, nil, errWrongTxType } } newlyCurrentStakers, err := currentStakers.UpdateStakers( @@ -212,7 +202,7 @@ currentStakerLoop: numToRemoveFromCurrent, ) if err != nil { - return nil, nil, nil, nil, tempError{err} + return nil, nil, err } onCommitState := newVersionedState(parentState, newlyCurrentStakers, newlyPendingStakers) @@ -222,15 +212,7 @@ currentStakerLoop: // State doesn't change if this proposal is aborted onAbortState := newVersionedState(parentState, currentStakers, pendingStakers) - // If this block is committed, update the validator sets. - // onCommitDB will be committed to vm.DB before this is called. - onCommitFunc := func() error { - // For each Subnet, update the node's validator manager to reflect - // current Subnet membership - return vm.updateValidators(false) - } - - return onCommitState, onAbortState, onCommitFunc, nil, nil + return onCommitState, onAbortState, nil } // InitiallyPrefersCommit returns true if the proposed time is at diff --git a/vms/platformvm/advance_time_tx_test.go b/vms/platformvm/advance_time_tx_test.go index 1677238e6f5f..9d54fc5eab48 100644 --- a/vms/platformvm/advance_time_tx_test.go +++ b/vms/platformvm/advance_time_tx_test.go @@ -4,11 +4,14 @@ package platformvm import ( + "fmt" "testing" "time" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto" + "github.com/stretchr/testify/assert" ) // Ensure semantic verification fails when proposed timestamp is at or before current timestamp @@ -24,7 +27,7 @@ func TestAdvanceTimeTxTimestampTooEarly(t *testing.T) { if tx, err := vm.newAdvanceTimeTx(defaultGenesisTime); err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've failed verification because proposed timestamp same as current timestamp") } } @@ -40,33 +43,13 @@ func TestAdvanceTimeTxTimestampTooLate(t *testing.T) { pendingValidatorEndTime := pendingValidatorStartTime.Add(defaultMinStakingDuration) nodeIDKey, _ := vm.factory.NewPrivateKey() nodeID := nodeIDKey.PublicKey().Address() - addPendingValidatorTx, err := vm.newAddValidatorTx( - vm.MinValidatorStake, - uint64(pendingValidatorStartTime.Unix()), - uint64(pendingValidatorEndTime.Unix()), - nodeID, - nodeID, - PercentDenominator, - []*crypto.PrivateKeySECP256K1R{keys[0]}, - ids.ShortEmpty, // change addr - ) - if err != nil { - t.Fatal(err) - } - - vm.internalState.AddPendingStaker(addPendingValidatorTx) - vm.internalState.AddTx(addPendingValidatorTx, Committed) - if err := vm.internalState.Commit(); err != nil { - t.Fatal(err) - } - if err := vm.internalState.(*internalStateImpl).loadPendingValidators(); err != nil { - t.Fatal(err) - } + _, err := addPendingValidator(vm, pendingValidatorStartTime, pendingValidatorEndTime, nodeID, []*crypto.PrivateKeySECP256K1R{keys[0]}) + assert.NoError(t, err) tx, err := vm.newAdvanceTimeTx(pendingValidatorStartTime.Add(1 * time.Second)) if err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've failed verification because proposed timestamp is after pending validator start time") } if err := vm.Shutdown(); err != nil { @@ -90,7 +73,7 @@ func TestAdvanceTimeTxTimestampTooLate(t *testing.T) { // Proposes advancing timestamp to 1 second after genesis validators stop validating if tx, err := vm.newAdvanceTimeTx(defaultValidateEndTime.Add(1 * time.Second)); err != nil { t.Fatal(err) - } else if _, _, _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx); err == nil { t.Fatal("should've failed verification because proposed timestamp is after pending validator start time") } } @@ -113,34 +96,14 @@ func TestAdvanceTimeTxUpdatePrimaryNetworkStakers(t *testing.T) { pendingValidatorEndTime := pendingValidatorStartTime.Add(defaultMinStakingDuration) nodeIDKey, _ := vm.factory.NewPrivateKey() nodeID := nodeIDKey.PublicKey().Address() - addPendingValidatorTx, err := vm.newAddValidatorTx( - vm.MinValidatorStake, - uint64(pendingValidatorStartTime.Unix()), - uint64(pendingValidatorEndTime.Unix()), - nodeID, - nodeID, - PercentDenominator, - []*crypto.PrivateKeySECP256K1R{keys[0]}, - ids.ShortEmpty, // change addr - ) - if err != nil { - t.Fatal(err) - } - - vm.internalState.AddPendingStaker(addPendingValidatorTx) - vm.internalState.AddTx(addPendingValidatorTx, Committed) - if err := vm.internalState.Commit(); err != nil { - t.Fatal(err) - } - if err := vm.internalState.(*internalStateImpl).loadPendingValidators(); err != nil { - t.Fatal(err) - } + addPendingValidatorTx, err := addPendingValidator(vm, pendingValidatorStartTime, pendingValidatorEndTime, nodeID, []*crypto.PrivateKeySECP256K1R{keys[0]}) + assert.NoError(t, err) tx, err := vm.newAdvanceTimeTx(pendingValidatorStartTime) if err != nil { t.Fatal(err) } - onCommit, onAbort, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + onCommit, onAbort, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) if err != nil { t.Fatal(err) } @@ -180,22 +143,34 @@ func TestAdvanceTimeTxUpdatePrimaryNetworkStakers(t *testing.T) { if vdr.ID() != addPendingValidatorTx.ID() { t.Fatalf("Added the wrong tx to the pending validator set") } + + // Test VM validators + onCommit.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + assert.True(t, vm.Validators.Contains(constants.PrimaryNetworkID, nodeID)) } // Ensure semantic verification updates the current and pending staker sets correctly. // Namely, it should add pending stakers whose start time is at or before the timestamp. // It will not remove primary network stakers; that happens in rewardTxs. -func TestAdvanceTimeTxUpdatePrimaryNetworkStakers2(t *testing.T) { +func TestAdvanceTimeTxUpdateStakers(t *testing.T) { + type stakerStatus uint + const ( + pending stakerStatus = iota + current + ) + type staker struct { nodeID ids.ShortID startTime, endTime time.Time } type test struct { - description string - stakers []staker - advanceTimeTo []time.Time - expectedCurrent []ids.ShortID - expectedPending []ids.ShortID + description string + stakers []staker + subnetStakers []staker + advanceTimeTo []time.Time + expectedStakers map[ids.ShortID]stakerStatus + expectedSubnetStakers map[ids.ShortID]stakerStatus } // Chronological order: staker1 start, staker2 start, staker3 start and staker 4 start, @@ -215,6 +190,11 @@ func TestAdvanceTimeTxUpdatePrimaryNetworkStakers2(t *testing.T) { startTime: staker2.startTime.Add(1 * time.Minute), endTime: staker2.endTime.Add(1 * time.Minute), } + staker3Sub := staker{ + nodeID: staker3.nodeID, + startTime: staker3.startTime.Add(1 * time.Minute), + endTime: staker3.endTime.Add(-1 * time.Minute), + } staker4 := staker{ nodeID: ids.GenerateTestShortID(), startTime: staker3.startTime, @@ -228,136 +208,157 @@ func TestAdvanceTimeTxUpdatePrimaryNetworkStakers2(t *testing.T) { tests := []test{ { - description: "advance time to before staker1 start", - stakers: []staker{ - staker1, - staker2, - staker3, - staker4, - staker5, + description: "advance time to before staker1 start with subnet", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + subnetStakers: []staker{staker1, staker2, staker3, staker4, staker5}, + advanceTimeTo: []time.Time{staker1.startTime.Add(-1 * time.Second)}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker1.nodeID: pending, staker2.nodeID: pending, staker3.nodeID: pending, staker4.nodeID: pending, staker5.nodeID: pending, + }, + expectedSubnetStakers: map[ids.ShortID]stakerStatus{ + staker1.nodeID: pending, staker2.nodeID: pending, staker3.nodeID: pending, staker4.nodeID: pending, staker5.nodeID: pending, + }, + }, + { + description: "advance time to staker 1 start with subnet", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + subnetStakers: []staker{staker1}, + advanceTimeTo: []time.Time{staker1.startTime}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker2.nodeID: pending, staker3.nodeID: pending, staker4.nodeID: pending, staker5.nodeID: pending, + staker1.nodeID: current, + }, + expectedSubnetStakers: map[ids.ShortID]stakerStatus{ + staker2.nodeID: pending, staker3.nodeID: pending, staker4.nodeID: pending, staker5.nodeID: pending, + staker1.nodeID: current, }, - advanceTimeTo: []time.Time{staker1.startTime.Add(-1 * time.Second)}, - expectedPending: []ids.ShortID{staker1.nodeID, staker2.nodeID, staker3.nodeID, staker4.nodeID, staker5.nodeID}, }, { - description: "advance time to staker 1 start", - stakers: []staker{ - staker1, - staker2, - staker3, - staker4, - staker5, + description: "advance time to the staker2 start", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker3.nodeID: pending, staker4.nodeID: pending, staker5.nodeID: pending, + staker1.nodeID: current, staker2.nodeID: current, }, - advanceTimeTo: []time.Time{staker1.startTime}, - expectedCurrent: []ids.ShortID{staker1.nodeID}, - expectedPending: []ids.ShortID{staker2.nodeID, staker3.nodeID, staker4.nodeID, staker5.nodeID}, }, { - description: "advance time to the staker2 start", - stakers: []staker{ - staker1, - staker2, - staker3, - staker4, - staker5, + description: "staker3 should validate only primary network", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + subnetStakers: []staker{staker1, staker2, staker3Sub, staker4, staker5}, + advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime, staker3.startTime}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker5.nodeID: pending, + staker1.nodeID: current, staker2.nodeID: current, staker3.nodeID: current, staker4.nodeID: current, + }, + expectedSubnetStakers: map[ids.ShortID]stakerStatus{ + staker5.nodeID: pending, staker3Sub.nodeID: pending, + staker1.nodeID: current, staker2.nodeID: current, staker4.nodeID: current, }, - advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime}, - expectedCurrent: []ids.ShortID{staker1.nodeID}, - expectedPending: []ids.ShortID{staker3.nodeID, staker4.nodeID, staker5.nodeID}, }, { - description: "advance time to staker3 and staker4 start", - stakers: []staker{ - staker1, - staker2, - staker3, - staker4, - staker5, + description: "advance time to staker3 start with subnet", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + subnetStakers: []staker{staker1, staker2, staker3Sub, staker4, staker5}, + advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime, staker3.startTime, staker3Sub.startTime}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker5.nodeID: pending, + staker1.nodeID: current, staker2.nodeID: current, staker3.nodeID: current, staker4.nodeID: current, + }, + expectedSubnetStakers: map[ids.ShortID]stakerStatus{ + staker5.nodeID: pending, + staker1.nodeID: current, staker2.nodeID: current, staker3.nodeID: current, staker4.nodeID: current, }, - advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime, staker3.startTime}, - expectedCurrent: []ids.ShortID{staker2.nodeID, staker3.nodeID, staker4.nodeID}, - expectedPending: []ids.ShortID{staker5.nodeID}, }, { - description: "advance time to staker5 start", - stakers: []staker{ - staker1, - staker2, - staker3, - staker4, - staker5, + description: "advance time to staker5 end", + stakers: []staker{staker1, staker2, staker3, staker4, staker5}, + advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime, staker3.startTime, staker5.startTime}, + expectedStakers: map[ids.ShortID]stakerStatus{ + staker1.nodeID: current, staker2.nodeID: current, staker3.nodeID: current, staker4.nodeID: current, staker5.nodeID: current, }, - advanceTimeTo: []time.Time{staker1.startTime, staker2.startTime, staker3.startTime, staker5.startTime}, - expectedCurrent: []ids.ShortID{staker3.nodeID, staker4.nodeID, staker5.nodeID}, }, } - for _, tt := range tests { - vm, _, _ := defaultVM() - vm.ctx.Lock.Lock() - defer func() { - if err := vm.Shutdown(); err != nil { - t.Fatal(err) - } - vm.ctx.Lock.Unlock() - }() - - for _, staker := range tt.stakers { - tx, err := vm.newAddValidatorTx( - vm.MinValidatorStake, - uint64(staker.startTime.Unix()), - uint64(staker.endTime.Unix()), - staker.nodeID, // validator ID - ids.ShortEmpty, // reward address - PercentDenominator, - []*crypto.PrivateKeySECP256K1R{keys[0]}, - ids.ShortEmpty, // change addr - ) - if err != nil { - t.Fatal(err) + for _, test := range tests { + t.Run(test.description, func(ts *testing.T) { + assert := assert.New(ts) + vm, _, _ := defaultVM() + vm.ctx.Lock.Lock() + defer func() { + if err := vm.Shutdown(); err != nil { + t.Fatal(err) + } + vm.ctx.Lock.Unlock() + }() + vm.WhitelistedSubnets.Add(testSubnet1.ID()) + + for _, staker := range test.stakers { + _, err := addPendingValidator(vm, staker.startTime, staker.endTime, staker.nodeID, []*crypto.PrivateKeySECP256K1R{keys[0]}) + assert.NoError(err) } - vm.internalState.AddPendingStaker(tx) - vm.internalState.AddTx(tx, Committed) + for _, staker := range test.subnetStakers { + tx, err := vm.newAddSubnetValidatorTx( + 10, // Weight + uint64(staker.startTime.Unix()), + uint64(staker.endTime.Unix()), + staker.nodeID, // validator ID + testSubnet1.ID(), // Subnet ID + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, // Keys + ids.ShortEmpty, // reward address + ) + assert.NoError(err) + vm.internalState.AddPendingStaker(tx) + vm.internalState.AddTx(tx, Committed) + } if err := vm.internalState.Commit(); err != nil { t.Fatal(err) } if err := vm.internalState.(*internalStateImpl).loadPendingValidators(); err != nil { t.Fatal(err) } - } - for _, newTime := range tt.advanceTimeTo { - vm.clock.Set(newTime) - tx, err := vm.newAdvanceTimeTx(newTime) - if err != nil { - t.Fatal(err) - } + for _, newTime := range test.advanceTimeTo { + vm.clock.Set(newTime) + tx, err := vm.newAdvanceTimeTx(newTime) + if err != nil { + t.Fatal(err) + } - onCommitState, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) - if err != nil { - t.Fatalf("failed test '%s': %s", tt.description, err) + onCommitState, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + assert.NoError(err) + onCommitState.Apply(vm.internalState) } - onCommitState.Apply(vm.internalState) - } - // Check that the validators we expect to be in the current staker set are there - currentStakers := vm.internalState.CurrentStakerChainState() - for _, stakerNodeID := range tt.expectedCurrent { - _, err := currentStakers.GetValidator(stakerNodeID) - if err != nil { - t.Fatalf("failed test '%s': expected validator to be in current validator set but it isn't", tt.description) + assert.NoError(vm.internalState.Commit()) + + // Check that the validators we expect to be in the current staker set are there + currentStakers := vm.internalState.CurrentStakerChainState() + // Check that the validators we expect to be in the pending staker set are there + pendingStakers := vm.internalState.PendingStakerChainState() + for stakerNodeID, status := range test.expectedStakers { + switch status { + case pending: + _, err := pendingStakers.GetValidatorTx(stakerNodeID) + assert.NoError(err) + assert.False(vm.Validators.Contains(constants.PrimaryNetworkID, stakerNodeID)) + case current: + _, err := currentStakers.GetValidator(stakerNodeID) + assert.NoError(err) + assert.True(vm.Validators.Contains(constants.PrimaryNetworkID, stakerNodeID)) + } } - } - // Check that the validators we expect to be in the pending staker set are there - pendingStakers := vm.internalState.PendingStakerChainState() - for _, stakerNodeID := range tt.expectedPending { - _, err := pendingStakers.GetValidatorTx(stakerNodeID) - if err != nil { - t.Fatalf("failed test '%s': expected validator to be in pending validator set but it isn't", tt.description) + for stakerNodeID, status := range test.expectedSubnetStakers { + switch status { + case pending: + assert.False(vm.Validators.Contains(testSubnet1.ID(), stakerNodeID)) + case current: + assert.True(vm.Validators.Contains(testSubnet1.ID(), stakerNodeID)) + } } - } + }) } } @@ -374,7 +375,7 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { } vm.ctx.Lock.Unlock() }() - + vm.WhitelistedSubnets.Add(testSubnet1.ID()) // Add a subnet validator to the staker set subnetValidatorNodeID := keys[0].PublicKey().Address() // Starts after the corre @@ -405,13 +406,14 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { // The above validator is now part of the staking set // Queue a staker that joins the staker set after the above validator leaves + subnetVdr2NodeID := keys[1].PublicKey().Address() tx, err = vm.newAddSubnetValidatorTx( 1, // Weight uint64(subnetVdr1EndTime.Add(time.Second).Unix()), // Start time uint64(subnetVdr1EndTime.Add(time.Second).Add(defaultMinStakingDuration).Unix()), // end time - keys[1].PublicKey().Address(), // Node ID - testSubnet1.ID(), // Subnet ID - []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, // Keys + subnetVdr2NodeID, // Node ID + testSubnet1.ID(), // Subnet ID + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, // Keys ids.ShortEmpty, // reward address ) if err != nil { @@ -435,7 +437,7 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { if err != nil { t.Fatal(err) } - onCommitState, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + onCommitState, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) if err != nil { t.Fatal(err) } @@ -451,6 +453,197 @@ func TestAdvanceTimeTxRemoveSubnetValidator(t *testing.T) { if exists { t.Fatal("should have been removed from validator set") } + // Check VM Validators are removed successfully + onCommitState.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + assert.False(t, vm.Validators.Contains(testSubnet1.ID(), subnetVdr2NodeID)) + assert.False(t, vm.Validators.Contains(testSubnet1.ID(), subnetValidatorNodeID)) +} + +func TestWhitelistedSubnet(t *testing.T) { + for _, whitelist := range []bool{true, false} { + t.Run(fmt.Sprintf("whitelisted %t", whitelist), func(ts *testing.T) { + vm, _, _ := defaultVM() + vm.ctx.Lock.Lock() + defer func() { + if err := vm.Shutdown(); err != nil { + t.Fatal(err) + } + vm.ctx.Lock.Unlock() + }() + + if whitelist { + vm.WhitelistedSubnets.Add(testSubnet1.ID()) + } + // Add a subnet validator to the staker set + subnetValidatorNodeID := keys[0].PublicKey().Address() + + subnetVdr1StartTime := defaultGenesisTime.Add(1 * time.Minute) + subnetVdr1EndTime := defaultGenesisTime.Add(10 * defaultMinStakingDuration).Add(1 * time.Minute) + tx, err := vm.newAddSubnetValidatorTx( + 1, // Weight + uint64(subnetVdr1StartTime.Unix()), // Start time + uint64(subnetVdr1EndTime.Unix()), // end time + subnetValidatorNodeID, // Node ID + testSubnet1.ID(), // Subnet ID + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1]}, // Keys + ids.ShortEmpty, // reward address + ) + if err != nil { + t.Fatal(err) + } + + vm.internalState.AddPendingStaker(tx) + vm.internalState.AddTx(tx, Committed) + if err := vm.internalState.Commit(); err != nil { + t.Fatal(err) + } + if err := vm.internalState.(*internalStateImpl).loadPendingValidators(); err != nil { + t.Fatal(err) + } + + // Advance time to the staker's start time. + vm.clock.Set(subnetVdr1StartTime) + tx, err = vm.newAdvanceTimeTx(subnetVdr1StartTime) + if err != nil { + t.Fatal(err) + } + onCommitState, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + if err != nil { + t.Fatal(err) + } + + onCommitState.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + assert.Equal(t, whitelist, vm.Validators.Contains(testSubnet1.ID(), subnetValidatorNodeID)) + }) + } +} + +func TestAdvanceTimeTxDelegatorStakerWeight(t *testing.T) { + vm, _, _ := defaultVM() + vm.ctx.Lock.Lock() + defer func() { + if err := vm.Shutdown(); err != nil { + t.Fatal(err) + } + vm.ctx.Lock.Unlock() + }() + + // Case: Timestamp is after next validator start time + // Add a pending validator + pendingValidatorStartTime := defaultGenesisTime.Add(1 * time.Second) + pendingValidatorEndTime := pendingValidatorStartTime.Add(defaultMaxStakingDuration) + nodeIDKey, _ := vm.factory.NewPrivateKey() + nodeID := nodeIDKey.PublicKey().Address() + _, err := addPendingValidator(vm, pendingValidatorStartTime, pendingValidatorEndTime, nodeID, []*crypto.PrivateKeySECP256K1R{keys[0]}) + assert.NoError(t, err) + + tx, err := vm.newAdvanceTimeTx(pendingValidatorStartTime) + assert.NoError(t, err) + onCommit, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + assert.NoError(t, err) + onCommit.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + + // Test validator weight before delegation + primarySet, ok := vm.Validators.GetValidators(constants.PrimaryNetworkID) + assert.True(t, ok) + vdrWeight, _ := primarySet.GetWeight(nodeID) + assert.Equal(t, vm.MinValidatorStake, vdrWeight) + + // Add delegator + pendingDelegatorStartTime := pendingValidatorStartTime.Add(1 * time.Second) + pendingDelegatorEndTime := pendingDelegatorStartTime.Add(1 * time.Second) + addDelegatorTx, err := vm.newAddDelegatorTx( + vm.MinDelegatorStake, + uint64(pendingDelegatorStartTime.Unix()), + uint64(pendingDelegatorEndTime.Unix()), + nodeID, + keys[0].PublicKey().Address(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1], keys[4]}, + ids.ShortEmpty, // change addr + ) + assert.NoError(t, err) + vm.internalState.AddPendingStaker(addDelegatorTx) + vm.internalState.AddTx(addDelegatorTx, Committed) + assert.NoError(t, vm.internalState.Commit()) + assert.NoError(t, vm.internalState.(*internalStateImpl).loadPendingValidators()) + + // Advance Time + tx, err = vm.newAdvanceTimeTx(pendingDelegatorStartTime) + assert.NoError(t, err) + onCommit, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + assert.NoError(t, err) + onCommit.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + + // Test validator weight after delegation + vdrWeight, _ = primarySet.GetWeight(nodeID) + assert.Equal(t, vm.MinDelegatorStake+vm.MinValidatorStake, vdrWeight) +} + +func TestAdvanceTimeTxDelegatorStakers(t *testing.T) { + vm, _, _ := defaultVM() + vm.ctx.Lock.Lock() + defer func() { + if err := vm.Shutdown(); err != nil { + t.Fatal(err) + } + vm.ctx.Lock.Unlock() + }() + + // Case: Timestamp is after next validator start time + // Add a pending validator + pendingValidatorStartTime := defaultGenesisTime.Add(1 * time.Second) + pendingValidatorEndTime := pendingValidatorStartTime.Add(defaultMinStakingDuration) + nodeIDKey, _ := vm.factory.NewPrivateKey() + nodeID := nodeIDKey.PublicKey().Address() + _, err := addPendingValidator(vm, pendingValidatorStartTime, pendingValidatorEndTime, nodeID, []*crypto.PrivateKeySECP256K1R{keys[0]}) + assert.NoError(t, err) + + tx, err := vm.newAdvanceTimeTx(pendingValidatorStartTime) + assert.NoError(t, err) + onCommit, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + assert.NoError(t, err) + onCommit.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + + // Test validator weight before delegation + primarySet, ok := vm.Validators.GetValidators(constants.PrimaryNetworkID) + assert.True(t, ok) + vdrWeight, _ := primarySet.GetWeight(nodeID) + assert.Equal(t, vm.MinValidatorStake, vdrWeight) + + // Add delegator + pendingDelegatorStartTime := pendingValidatorStartTime.Add(1 * time.Second) + pendingDelegatorEndTime := pendingDelegatorStartTime.Add(defaultMinStakingDuration) + addDelegatorTx, err := vm.newAddDelegatorTx( + vm.MinDelegatorStake, + uint64(pendingDelegatorStartTime.Unix()), + uint64(pendingDelegatorEndTime.Unix()), + nodeID, + keys[0].PublicKey().Address(), + []*crypto.PrivateKeySECP256K1R{keys[0], keys[1], keys[4]}, + ids.ShortEmpty, // change addr + ) + assert.NoError(t, err) + vm.internalState.AddPendingStaker(addDelegatorTx) + vm.internalState.AddTx(addDelegatorTx, Committed) + assert.NoError(t, vm.internalState.Commit()) + assert.NoError(t, vm.internalState.(*internalStateImpl).loadPendingValidators()) + + // Advance Time + tx, err = vm.newAdvanceTimeTx(pendingDelegatorStartTime) + assert.NoError(t, err) + onCommit, _, err = tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + assert.NoError(t, err) + onCommit.Apply(vm.internalState) + assert.NoError(t, vm.internalState.Commit()) + + // Test validator weight after delegation + vdrWeight, _ = primarySet.GetWeight(nodeID) + assert.Equal(t, vm.MinDelegatorStake+vm.MinValidatorStake, vdrWeight) } // Test method InitiallyPrefersCommit @@ -511,3 +704,29 @@ func TestAdvanceTimeTxUnmarshal(t *testing.T) { t.Fatal("should have same timestamp") } } + +func addPendingValidator(vm *VM, startTime time.Time, endTime time.Time, nodeID ids.ShortID, keys []*crypto.PrivateKeySECP256K1R) (*Tx, error) { + addPendingValidatorTx, err := vm.newAddValidatorTx( + vm.MinValidatorStake, + uint64(startTime.Unix()), + uint64(endTime.Unix()), + nodeID, + nodeID, + PercentDenominator, + keys, + ids.ShortEmpty, // change addr + ) + if err != nil { + return nil, err + } + + vm.internalState.AddPendingStaker(addPendingValidatorTx) + vm.internalState.AddTx(addPendingValidatorTx, Committed) + if err := vm.internalState.Commit(); err != nil { + return nil, err + } + if err := vm.internalState.(*internalStateImpl).loadPendingValidators(); err != nil { + return nil, err + } + return addPendingValidatorTx, err +} diff --git a/vms/platformvm/atomic_block.go b/vms/platformvm/atomic_block.go index b80eb580ed02..118603f6248e 100644 --- a/vms/platformvm/atomic_block.go +++ b/vms/platformvm/atomic_block.go @@ -72,14 +72,6 @@ func (ab *AtomicBlock) Verify() error { blkID := ab.ID() if err := ab.CommonDecisionBlock.Verify(); err != nil { - ab.vm.ctx.Log.Trace("rejecting block %s due to a failed verification: %s", blkID, err) - if err := ab.Reject(); err != nil { - ab.vm.ctx.Log.Error( - "failed to reject atomic block %s due to %s", - blkID, - err, - ) - } return err } @@ -133,7 +125,7 @@ func (ab *AtomicBlock) Verify() error { ab.onAcceptState = onAccept ab.timestamp = onAccept.GetTimestamp() - ab.vm.blockBuilder.RemoveAtomicTx(&ab.Tx) + ab.vm.blockBuilder.RemoveDecisionTxs([]*Tx{&ab.Tx}) ab.vm.currentBlocks[blkID] = ab parentIntf.addChild(ab) return nil diff --git a/vms/platformvm/block_builder.go b/vms/platformvm/block_builder.go index 3c9c74537d74..079d198272be 100644 --- a/vms/platformvm/block_builder.go +++ b/vms/platformvm/block_builder.go @@ -125,6 +125,7 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { m.dropIncoming = true defer func() { m.dropIncoming = false + m.ResetTimer() }() m.vm.ctx.Log.Debug("in BuildBlock") @@ -151,55 +152,10 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { return nil, errEndOfTime } - // TODO: remove after AP5. - enabledAP5 := !currentChainTimestamp.Before(m.vm.ApricotPhase5Time) - // If there are pending decision txs, build a block with a batch of them - if m.HasDecisionTxs() || (enabledAP5 && m.HasAtomicTx()) { - txs := make([]*Tx, 0, BatchSize) - if m.HasDecisionTxs() { - decisionTxs := m.PopDecisionTxs(BatchSize) - txs = append(txs, decisionTxs...) - } - if enabledAP5 && m.HasAtomicTx() { - atomicTxs := m.PopAtomicTxs(BatchSize - len(txs)) - txs = append(txs, atomicTxs...) - } - - blk, err := m.vm.newStandardBlock(preferredID, nextHeight, txs) - if err != nil { - m.ResetTimer() - return nil, err - } - m.vm.ctx.Log.Debug("Built Standard Block %s: %s", blk.ID(), jsonFormatter{obj: blk}) - - if err := blk.Verify(); err != nil { - m.ResetTimer() - return nil, err - } - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() - } - - // If there is a pending atomic tx, build a block with it - if !enabledAP5 && m.HasAtomicTx() { - tx := m.PopAtomicTx() - - blk, err := m.vm.newAtomicBlock(preferredID, nextHeight, *tx) - if err != nil { - m.ResetTimer() - return nil, err - } - m.vm.ctx.Log.Debug("Built Atomic Block %s: %s", blk.ID(), jsonFormatter{obj: blk}) - - if err := blk.Verify(); err != nil { - m.ResetTimer() - return nil, err - } - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() + if m.HasDecisionTxs() { + txs := m.PopDecisionTxs(BatchSize) + return m.vm.newStandardBlock(preferredID, nextHeight, txs) } currentStakers := preferredState.CurrentStakerChainState() @@ -221,14 +177,7 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { if err != nil { return nil, err } - blk, err := m.vm.newProposalBlock(preferredID, nextHeight, *rewardValidatorTx) - if err != nil { - return nil, err - } - m.vm.ctx.Log.Debug("Built Proposal Block %s: %s", blk.ID(), jsonFormatter{obj: blk}) - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() + return m.vm.newProposalBlock(preferredID, nextHeight, *rewardValidatorTx) } // If local time is >= time of the next staker set change, @@ -245,13 +194,7 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { if err != nil { return nil, err } - blk, err := m.vm.newProposalBlock(preferredID, nextHeight, *advanceTimeTx) - if err != nil { - return nil, err - } - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() + return m.vm.newProposalBlock(preferredID, nextHeight, *advanceTimeTx) } // Propose adding a new validator but only if their start time is in the @@ -291,29 +234,9 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { if err != nil { return nil, err } - blk, err := m.vm.newProposalBlock(preferredID, nextHeight, *advanceTimeTx) - if err != nil { - return nil, err - } - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() - } - - // Attempt to issue the transaction - blk, err := m.vm.newProposalBlock(preferredID, nextHeight, *tx) - if err != nil { - m.ResetTimer() - return nil, err + return m.vm.newProposalBlock(preferredID, nextHeight, *advanceTimeTx) } - - if err := blk.Verify(); err != nil { - m.ResetTimer() - return nil, err - } - - m.vm.internalState.AddBlock(blk) - return blk, m.vm.internalState.Commit() + return m.vm.newProposalBlock(preferredID, nextHeight, *tx) } m.vm.ctx.Log.Debug("BuildBlock returning error (no blocks)") @@ -325,7 +248,7 @@ func (m *blockBuilder) BuildBlock() (snowman.Block, error) { func (m *blockBuilder) ResetTimer() { // If there is a pending transaction trigger building of a block with that // transaction - if m.HasDecisionTxs() || m.HasAtomicTx() { + if m.HasDecisionTxs() { m.vm.NotifyBlockReady() return } diff --git a/vms/platformvm/cache_internal_state.go b/vms/platformvm/cache_internal_state.go index 4679cd29d763..f5a1a376b35d 100644 --- a/vms/platformvm/cache_internal_state.go +++ b/vms/platformvm/cache_internal_state.go @@ -1017,6 +1017,19 @@ func (st *internalStateImpl) writeCurrentStakers() error { delete(nodeUpdates, nodeID) continue } + + if subnetID == constants.PrimaryNetworkID || st.vm.WhitelistedSubnets.Contains(subnetID) { + var err error + if nodeDiff.Decrease { + err = st.vm.Validators.RemoveWeight(subnetID, nodeID, nodeDiff.Amount) + } else { + err = st.vm.Validators.AddWeight(subnetID, nodeID, nodeDiff.Amount) + } + if err != nil { + return err + } + } + nodeDiffBytes, err := GenesisCodec.Marshal(CodecVersion, nodeDiff) if err != nil { return err @@ -1030,6 +1043,15 @@ func (st *internalStateImpl) writeCurrentStakers() error { } st.validatorDiffsCache.Put(string(prefixBytes), nodeUpdates) } + + // Attempt to update the stake metrics + primaryValidators, ok := st.vm.Validators.GetValidators(constants.PrimaryNetworkID) + if !ok { + return nil + } + weight, _ := primaryValidators.GetWeight(st.vm.ctx.NodeID) + st.vm.localStake.Set(float64(weight)) + st.vm.totalStake.Set(float64(primaryValidators.Weight())) return nil } diff --git a/vms/platformvm/commit_block.go b/vms/platformvm/commit_block.go index b366cec6b6f9..30389037b604 100644 --- a/vms/platformvm/commit_block.go +++ b/vms/platformvm/commit_block.go @@ -43,10 +43,6 @@ func (c *CommitBlock) Verify() error { blkID := c.ID() if err := c.DoubleDecisionBlock.Verify(); err != nil { - c.vm.ctx.Log.Trace("rejecting block %s due to a failed verification: %s", blkID, err) - if err := c.Reject(); err != nil { - c.vm.ctx.Log.Error("failed to reject commit block %s due to %s", blkID, err) - } return err } @@ -58,14 +54,10 @@ func (c *CommitBlock) Verify() error { // The parent of a Commit block should always be a proposal parent, ok := parentIntf.(*ProposalBlock) if !ok { - c.vm.ctx.Log.Trace("rejecting block %s due to an incorrect parent type", blkID) - if err := c.Reject(); err != nil { - c.vm.ctx.Log.Error("failed to reject commit block %s due to %s", blkID, err) - } return errInvalidBlockType } - c.onAcceptState, c.onAcceptFunc = parent.onCommit() + c.onAcceptState = parent.onCommitState c.timestamp = c.onAcceptState.GetTimestamp() c.vm.currentBlocks[blkID] = c diff --git a/vms/platformvm/create_chain_tx.go b/vms/platformvm/create_chain_tx.go index fc706d7c82b1..dacf79f9fd14 100644 --- a/vms/platformvm/create_chain_tx.go +++ b/vms/platformvm/create_chain_tx.go @@ -112,15 +112,15 @@ func (tx *UnsignedCreateChainTx) Execute( stx *Tx, ) ( func() error, - TxError, + error, ) { // Make sure this transaction is well formed. if len(stx.Creds) == 0 { - return nil, permError{errWrongNumberOfCredentials} + return nil, errWrongNumberOfCredentials } if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, permError{err} + return nil, err } // Select the credentials for each purpose @@ -137,24 +137,20 @@ func (tx *UnsignedCreateChainTx) Execute( subnetIntf, _, err := vs.GetTx(tx.SubnetID) if err == database.ErrNotFound { - return nil, permError{ - fmt.Errorf("%s isn't a known subnet", tx.SubnetID), - } + return nil, fmt.Errorf("%s isn't a known subnet", tx.SubnetID) } if err != nil { - return nil, tempError{err} + return nil, err } subnet, ok := subnetIntf.UnsignedTx.(*UnsignedCreateSubnetTx) if !ok { - return nil, permError{ - fmt.Errorf("%s isn't a subnet", tx.SubnetID), - } + return nil, fmt.Errorf("%s isn't a subnet", tx.SubnetID) } // Verify that this chain is authorized by the subnet if err := vm.fx.VerifyPermission(tx, tx.SubnetAuth, subnetCred, subnet.Owner); err != nil { - return nil, permError{err} + return nil, err } // Consume the UTXOS diff --git a/vms/platformvm/create_subnet_tx.go b/vms/platformvm/create_subnet_tx.go index 07519c9606bd..1b4e7b640112 100644 --- a/vms/platformvm/create_subnet_tx.go +++ b/vms/platformvm/create_subnet_tx.go @@ -78,11 +78,11 @@ func (tx *UnsignedCreateSubnetTx) Execute( stx *Tx, ) ( func() error, - TxError, + error, ) { // Make sure this transaction is well formed. if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, permError{err} + return nil, err } // Verify the flowcheck diff --git a/vms/platformvm/error.go b/vms/platformvm/error.go deleted file mode 100644 index ab3c85e16a55..000000000000 --- a/vms/platformvm/error.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package platformvm - -var ( - _ TxError = &tempError{} - _ TxError = &permError{} -) - -// TxError provides the ability for errors to be distinguished as permanent or -// temporary -type TxError interface { - error - Unwrap() error - Temporary() bool -} - -type tempError struct{ error } - -func (e tempError) Unwrap() error { return e.error } -func (tempError) Temporary() bool { return true } - -type permError struct{ error } - -func (e permError) Unwrap() error { return e.error } -func (permError) Temporary() bool { return false } diff --git a/vms/platformvm/export_tx.go b/vms/platformvm/export_tx.go index a9851d81076b..b2fdcebaf4b7 100644 --- a/vms/platformvm/export_tx.go +++ b/vms/platformvm/export_tx.go @@ -14,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/vms/components/avax" + "github.com/ava-labs/avalanchego/vms/components/verify" "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) @@ -21,7 +22,6 @@ var ( errNoExportOutputs = errors.New("no export outputs") errOutputsNotSorted = errors.New("outputs not sorted") errOverflowExport = errors.New("overflow when computing export amount + txFee") - errWrongChainID = errors.New("tx has wrong chain ID") _ UnsignedAtomicTx = &UnsignedExportTx{} ) @@ -93,9 +93,9 @@ func (tx *UnsignedExportTx) Execute( vm *VM, vs VersionedState, stx *Tx, -) (func() error, TxError) { +) (func() error, error) { if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, permError{err} + return nil, err } outs := make([]*avax.TransferableOutput, len(tx.Outs)+len(tx.ExportedOutputs)) @@ -103,23 +103,14 @@ func (tx *UnsignedExportTx) Execute( copy(outs[len(tx.Outs):], tx.ExportedOutputs) if vm.bootstrapped.GetValue() { - if err := vm.isValidCrossChainID(vs, tx.DestinationChain); err != nil { + if err := verify.SameSubnet(vm.ctx, tx.DestinationChain); err != nil { return nil, err } } // Verify the flowcheck if err := vm.semanticVerifySpend(vs, tx, tx.Ins, outs, stx.Creds, vm.TxFee, vm.ctx.AVAXAssetID); err != nil { - switch err.(type) { - case permError: - return nil, permError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - default: - return nil, tempError{ - fmt.Errorf("failed semanticVerifySpend: %w", err), - } - } + return nil, fmt.Errorf("failed semanticVerifySpend: %w", err) } // Consume the UTXOS @@ -168,7 +159,7 @@ func (tx *UnsignedExportTx) AtomicExecute( vm *VM, parentState MutableState, stx *Tx, -) (VersionedState, TxError) { +) (VersionedState, error) { // Set up the state if this tx is committed newState := newVersionedState( parentState, diff --git a/vms/platformvm/export_tx_test.go b/vms/platformvm/export_tx_test.go index aeccc2e9a939..a359c62dda37 100644 --- a/vms/platformvm/export_tx_test.go +++ b/vms/platformvm/export_tx_test.go @@ -44,15 +44,7 @@ func TestNewExportTx(t *testing.T) { shouldVerify: true, }, { - description: "P->C export before AP5", - destinationChainID: cChainID, - sourceKeys: []*crypto.PrivateKeySECP256K1R{sourceKey}, - timestamp: defaultValidateStartTime, - shouldErr: false, - shouldVerify: false, - }, - { - description: "P->C export after AP5", + description: "P->C export", destinationChainID: cChainID, sourceKeys: []*crypto.PrivateKeySECP256K1R{sourceKey}, timestamp: vm.ApricotPhase5Time, diff --git a/vms/platformvm/import_tx.go b/vms/platformvm/import_tx.go index 5be1ed0688dd..f286f3516a35 100644 --- a/vms/platformvm/import_tx.go +++ b/vms/platformvm/import_tx.go @@ -14,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/vms/components/avax" + "github.com/ava-labs/avalanchego/vms/components/verify" "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) @@ -102,24 +103,22 @@ func (tx *UnsignedImportTx) Execute( vm *VM, vs VersionedState, stx *Tx, -) (func() error, TxError) { +) (func() error, error) { if err := tx.SyntacticVerify(vm.ctx); err != nil { - return nil, permError{err} + return nil, err } utxos := make([]*avax.UTXO, len(tx.Ins)+len(tx.ImportedInputs)) for index, input := range tx.Ins { utxo, err := vs.GetUTXO(input.InputID()) if err != nil { - return nil, tempError{ - fmt.Errorf("failed to get UTXO %s: %w", &input.UTXOID, err), - } + return nil, fmt.Errorf("failed to get UTXO %s: %w", &input.UTXOID, err) } utxos[index] = utxo } if vm.bootstrapped.GetValue() { - if err := vm.isValidCrossChainID(vs, tx.SourceChain); err != nil { + if err := verify.SameSubnet(vm.ctx, tx.SourceChain); err != nil { return nil, err } @@ -130,17 +129,13 @@ func (tx *UnsignedImportTx) Execute( } allUTXOBytes, err := vm.ctx.SharedMemory.Get(tx.SourceChain, utxoIDs) if err != nil { - return nil, tempError{ - fmt.Errorf("failed to get shared memory: %w", err), - } + return nil, fmt.Errorf("failed to get shared memory: %w", err) } for i, utxoBytes := range allUTXOBytes { utxo := &avax.UTXO{} if _, err := Codec.Unmarshal(utxoBytes, utxo); err != nil { - return nil, tempError{ - fmt.Errorf("failed to unmarshal UTXO: %w", err), - } + return nil, fmt.Errorf("failed to unmarshal UTXO: %w", err) } utxos[i+len(tx.Ins)] = utxo } @@ -177,7 +172,7 @@ func (tx *UnsignedImportTx) AtomicExecute( vm *VM, parentState MutableState, stx *Tx, -) (VersionedState, TxError) { +) (VersionedState, error) { // Set up the state if this tx is committed newState := newVersionedState( parentState, diff --git a/vms/platformvm/import_tx_test.go b/vms/platformvm/import_tx_test.go index a39ed3437c45..225fce2995c7 100644 --- a/vms/platformvm/import_tx_test.go +++ b/vms/platformvm/import_tx_test.go @@ -112,16 +112,7 @@ func TestNewImportTx(t *testing.T) { shouldVerify: true, }, { - description: "attempting to import from C-chain before AP5", - sourceChainID: cChainID, - sharedMemory: fundedSharedMemory(cChainID, vm.TxFee), - sourceKeys: []*crypto.PrivateKeySECP256K1R{sourceKey}, - timestamp: defaultValidateStartTime, - shouldErr: false, - shouldVerify: false, - }, - { - description: "attempting to import from C-chain after AP5", + description: "attempting to import from C-chain", sourceChainID: cChainID, sharedMemory: fundedSharedMemory(cChainID, vm.TxFee), sourceKeys: []*crypto.PrivateKeySECP256K1R{sourceKey}, diff --git a/vms/platformvm/json.go b/vms/platformvm/json.go deleted file mode 100644 index bf5481bf33b5..000000000000 --- a/vms/platformvm/json.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (C) 2019-2021, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package platformvm - -import ( - "encoding/json" - "fmt" -) - -type jsonFormatter struct { - obj interface{} -} - -func (f jsonFormatter) String() string { - jsonBytes, err := json.Marshal(f.obj) - if err != nil { - return fmt.Sprintf("error marshalling: %s", err) - } - return string(jsonBytes) -} diff --git a/vms/platformvm/mempool.go b/vms/platformvm/mempool.go index 2e908bca5aad..679846442dfe 100644 --- a/vms/platformvm/mempool.go +++ b/vms/platformvm/mempool.go @@ -39,21 +39,15 @@ type Mempool interface { Get(txID ids.ID) *Tx AddDecisionTx(tx *Tx) - AddAtomicTx(tx *Tx) AddProposalTx(tx *Tx) HasDecisionTxs() bool - HasAtomicTx() bool HasProposalTx() bool RemoveDecisionTxs(txs []*Tx) - RemoveAtomicTx(tx *Tx) - RemoveAtomicTxs(tx []*Tx) RemoveProposalTx(tx *Tx) PopDecisionTxs(numTxs int) []*Tx - PopAtomicTx() *Tx - PopAtomicTxs(numTxs int) []*Tx PopProposalTx() *Tx MarkDropped(txID ids.ID) @@ -67,7 +61,6 @@ type mempool struct { bytesAvailable int unissuedDecisionTxs TxHeap - unissuedAtomicTxs TxHeap unissuedProposalTxs TxHeap unknownTxs prometheus.Counter @@ -95,15 +88,6 @@ func NewMempool(namespace string, registerer prometheus.Registerer) (Mempool, er return nil, err } - unissuedAtomicTxs, err := NewTxHeapWithMetrics( - NewTxHeapByAge(), - fmt.Sprintf("%s_atomic_txs", namespace), - registerer, - ) - if err != nil { - return nil, err - } - unissuedProposalTxs, err := NewTxHeapWithMetrics( NewTxHeapByStartTime(), fmt.Sprintf("%s_proposal_txs", namespace), @@ -127,7 +111,6 @@ func NewMempool(namespace string, registerer prometheus.Registerer) (Mempool, er bytesAvailableMetric: bytesAvailableMetric, bytesAvailable: maxMempoolSize, unissuedDecisionTxs: unissuedDecisionTxs, - unissuedAtomicTxs: unissuedAtomicTxs, unissuedProposalTxs: unissuedProposalTxs, unknownTxs: unknownTxs, droppedTxIDs: &cache.LRU{Size: droppedTxIDsCacheSize}, @@ -153,8 +136,6 @@ func (m *mempool) Add(tx *Tx) error { switch tx.UnsignedTx.(type) { case TimedTx: m.AddProposalTx(tx) - case UnsignedAtomicTx: - m.AddAtomicTx(tx) case UnsignedDecisionTx: m.AddDecisionTx(tx) default: @@ -178,9 +159,6 @@ func (m *mempool) Get(txID ids.ID) *Tx { if tx := m.unissuedDecisionTxs.Get(txID); tx != nil { return tx } - if tx := m.unissuedAtomicTxs.Get(txID); tx != nil { - return tx - } return m.unissuedProposalTxs.Get(txID) } @@ -189,11 +167,6 @@ func (m *mempool) AddDecisionTx(tx *Tx) { m.register(tx) } -func (m *mempool) AddAtomicTx(tx *Tx) { - m.unissuedAtomicTxs.Add(tx) - m.register(tx) -} - func (m *mempool) AddProposalTx(tx *Tx) { m.unissuedProposalTxs.Add(tx) m.register(tx) @@ -201,8 +174,6 @@ func (m *mempool) AddProposalTx(tx *Tx) { func (m *mempool) HasDecisionTxs() bool { return m.unissuedDecisionTxs.Len() > 0 } -func (m *mempool) HasAtomicTx() bool { return m.unissuedAtomicTxs.Len() > 0 } - func (m *mempool) HasProposalTx() bool { return m.unissuedProposalTxs.Len() > 0 } func (m *mempool) RemoveDecisionTxs(txs []*Tx) { @@ -214,22 +185,6 @@ func (m *mempool) RemoveDecisionTxs(txs []*Tx) { } } -func (m *mempool) RemoveAtomicTxs(txs []*Tx) { - for _, tx := range txs { - txID := tx.ID() - if m.unissuedAtomicTxs.Remove(txID) != nil { - m.deregister(tx) - } - } -} - -func (m *mempool) RemoveAtomicTx(tx *Tx) { - txID := tx.ID() - if m.unissuedAtomicTxs.Remove(txID) != nil { - m.deregister(tx) - } -} - func (m *mempool) RemoveProposalTx(tx *Tx) { txID := tx.ID() if m.unissuedProposalTxs.Remove(txID) != nil { @@ -251,27 +206,6 @@ func (m *mempool) PopDecisionTxs(numTxs int) []*Tx { return txs } -func (m *mempool) PopAtomicTx() *Tx { - tx := m.unissuedAtomicTxs.RemoveTop() - m.deregister(tx) - return tx -} - -// Pops a batch of atomic txs -func (m *mempool) PopAtomicTxs(numTxs int) []*Tx { - if maxLen := m.unissuedAtomicTxs.Len(); numTxs > maxLen { - numTxs = maxLen - } - - txs := make([]*Tx, numTxs) - for i := range txs { - tx := m.unissuedAtomicTxs.RemoveTop() - m.deregister(tx) - txs[i] = tx - } - return txs -} - func (m *mempool) PopProposalTx() *Tx { tx := m.unissuedProposalTxs.RemoveTop() m.deregister(tx) diff --git a/vms/platformvm/proposal_block.go b/vms/platformvm/proposal_block.go index 6f8584b028d9..a67df02e39ea 100644 --- a/vms/platformvm/proposal_block.go +++ b/vms/platformvm/proposal_block.go @@ -32,10 +32,6 @@ type ProposalBlock struct { onCommitState VersionedState // The state that the chain will have if this block's proposal is aborted onAbortState VersionedState - // The function to execute if this block's proposal is committed - onCommitFunc func() error - // The function to execute if this block's proposal is aborted - onAbortFunc func() error } func (pb *ProposalBlock) free() { @@ -103,30 +99,6 @@ func (pb *ProposalBlock) setBaseState() { pb.onAbortState.SetBase(pb.vm.internalState) } -// onCommit should only be called after Verify is called. -// -// returns: -// 1. The state of the chain assuming this proposal is enacted. (That is, if -// this block is accepted and followed by an accepted Commit block.) -// 2. A function to be executed when this block's proposal is committed. This -// function should not write to state. This function should only be called -// after the state has been updated. -func (pb *ProposalBlock) onCommit() (VersionedState, func() error) { - return pb.onCommitState, pb.onCommitFunc -} - -// onAbort should only be called after Verify is called. -// -// returns: -// 1. The state of the chain assuming this proposal is not enacted. (That is, -// if this block is accepted and followed by an accepted Abort block.) -// 2. A function to be executed when this block's proposal is aborted. This -// function should not write to state. This function should only be called -// after the state has been updated. -func (pb *ProposalBlock) onAbort() (VersionedState, func() error) { - return pb.onAbortState, pb.onAbortFunc -} - // Verify this block is valid. // // The parent block must either be a Commit or an Abort block. @@ -136,14 +108,6 @@ func (pb *ProposalBlock) Verify() error { blkID := pb.ID() if err := pb.CommonBlock.Verify(); err != nil { - pb.vm.ctx.Log.Trace("rejecting block %s due to a failed verification: %s", blkID, err) - if err := pb.Reject(); err != nil { - pb.vm.ctx.Log.Error( - "failed to reject proposal block %s due to %s", - blkID, - err, - ) - } return err } @@ -160,38 +124,17 @@ func (pb *ProposalBlock) Verify() error { // The parent of a proposal block (ie this block) must be a decision block parent, ok := parentIntf.(decision) if !ok { - pb.vm.ctx.Log.Trace("rejecting block %s due to an incorrect parent type", blkID) - if err := pb.Reject(); err != nil { - pb.vm.ctx.Log.Error( - "failed to reject proposal block %s due to %s", - blkID, - err, - ) - } return errInvalidBlockType } // parentState is the state if this block's parent is accepted parentState := parent.onAccept() - var err TxError - pb.onCommitState, pb.onAbortState, pb.onCommitFunc, pb.onAbortFunc, err = tx.Execute(pb.vm, parentState, &pb.Tx) + var err error + pb.onCommitState, pb.onAbortState, err = tx.Execute(pb.vm, parentState, &pb.Tx) if err != nil { txID := tx.ID() pb.vm.droppedTxCache.Put(txID, err.Error()) // cache tx as dropped - // If this block's transaction proposes to advance the timestamp, the - // transaction may fail verification now but be valid in the future, so - // don't (permanently) mark the block as rejected. - if !err.Temporary() { - pb.vm.ctx.Log.Trace("rejecting block %s due to a permanent verification error: %s", blkID, err) - if err := pb.Reject(); err != nil { - pb.vm.ctx.Log.Error( - "failed to reject proposal block %s due to %s", - blkID, - err, - ) - } - } return err } pb.onCommitState.AddTx(&pb.Tx, Committed) diff --git a/vms/platformvm/reward_validator_tx.go b/vms/platformvm/reward_validator_tx.go index 78ceb2489618..882b36dc5cc9 100644 --- a/vms/platformvm/reward_validator_tx.go +++ b/vms/platformvm/reward_validator_tx.go @@ -55,7 +55,7 @@ func (tx *UnsignedRewardValidatorTx) SyntacticVerify(*snow.Context) error { // Attempts to verify this transaction with the provided state. func (tx *UnsignedRewardValidatorTx) SemanticVerify(vm *VM, parentState MutableState, stx *Tx) error { - _, _, _, _, err := tx.Execute(vm, parentState, stx) + _, _, err := tx.Execute(vm, parentState, stx) return err } @@ -72,60 +72,52 @@ func (tx *UnsignedRewardValidatorTx) Execute( ) ( VersionedState, VersionedState, - func() error, - func() error, - TxError, + error, ) { switch { case tx == nil: - return nil, nil, nil, nil, tempError{errNilTx} + return nil, nil, errNilTx case tx.TxID == ids.Empty: - return nil, nil, nil, nil, tempError{errInvalidID} + return nil, nil, errInvalidID case len(stx.Creds) != 0: - return nil, nil, nil, nil, permError{errWrongNumberOfCredentials} + return nil, nil, errWrongNumberOfCredentials } currentStakers := parentState.CurrentStakerChainState() stakerTx, stakerReward, err := currentStakers.GetNextStaker() if err == database.ErrNotFound { - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed to get next staker stop time: %w", err), - } + return nil, nil, fmt.Errorf("failed to get next staker stop time: %w", err) } if err != nil { - return nil, nil, nil, nil, tempError{err} + return nil, nil, err } stakerID := stakerTx.ID() if stakerID != tx.TxID { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "attempting to remove TxID: %s. Should be removing %s", - tx.TxID, - stakerID, - ), - } + return nil, nil, fmt.Errorf( + "attempting to remove TxID: %s. Should be removing %s", + tx.TxID, + stakerID, + ) } // Verify that the chain's timestamp is the validator's end time currentTime := parentState.GetTimestamp() staker, ok := stakerTx.UnsignedTx.(TimedTx) if !ok { - return nil, nil, nil, nil, permError{errWrongTxType} + return nil, nil, errWrongTxType } if endTime := staker.EndTime(); !endTime.Equal(currentTime) { - return nil, nil, nil, nil, permError{ - fmt.Errorf( - "attempting to remove TxID: %s before their end time %s", - tx.TxID, - endTime, - ), - } + return nil, nil, fmt.Errorf( + "attempting to remove TxID: %s before their end time %s", + tx.TxID, + endTime, + ) } newlyCurrentStakers, err := currentStakers.DeleteNextStaker() if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } pendingStakers := parentState.PendingStakerChainState() @@ -136,7 +128,7 @@ func (tx *UnsignedRewardValidatorTx) Execute( currentSupply := onAbortState.GetCurrentSupply() newSupply, err := math.Sub64(currentSupply, stakerReward) if err != nil { - return nil, nil, nil, nil, permError{err} + return nil, nil, err } onAbortState.SetCurrentSupply(newSupply) @@ -164,13 +156,11 @@ func (tx *UnsignedRewardValidatorTx) Execute( if stakerReward > 0 { outIntf, err := vm.fx.CreateOutput(stakerReward, uStakerTx.RewardsOwner) if err != nil { - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed to create output: %w", err), - } + return nil, nil, fmt.Errorf("failed to create output: %w", err) } out, ok := outIntf.(verify.State) if !ok { - return nil, nil, nil, nil, permError{errInvalidState} + return nil, nil, errInvalidState } utxo := &avax.UTXO{ @@ -208,13 +198,11 @@ func (tx *UnsignedRewardValidatorTx) Execute( // are delgated to. vdr, err := currentStakers.GetValidator(uStakerTx.Validator.NodeID) if err != nil { - return nil, nil, nil, nil, tempError{ - fmt.Errorf( - "failed to get whether %s is a validator: %w", - uStakerTx.Validator.NodeID, - err, - ), - } + return nil, nil, fmt.Errorf( + "failed to get whether %s is a validator: %w", + uStakerTx.Validator.NodeID, + err, + ) } vdrTx := vdr.AddValidatorTx() @@ -234,13 +222,11 @@ func (tx *UnsignedRewardValidatorTx) Execute( if delegatorReward > 0 { outIntf, err := vm.fx.CreateOutput(delegatorReward, uStakerTx.RewardsOwner) if err != nil { - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed to create output: %w", err), - } + return nil, nil, fmt.Errorf("failed to create output: %w", err) } out, ok := outIntf.(verify.State) if !ok { - return nil, nil, nil, nil, permError{errInvalidState} + return nil, nil, errInvalidState } utxo := &avax.UTXO{ UTXOID: avax.UTXOID{ @@ -261,13 +247,11 @@ func (tx *UnsignedRewardValidatorTx) Execute( if delegateeReward > 0 { outIntf, err := vm.fx.CreateOutput(delegateeReward, vdrTx.RewardsOwner) if err != nil { - return nil, nil, nil, nil, permError{ - fmt.Errorf("failed to create output: %w", err), - } + return nil, nil, fmt.Errorf("failed to create output: %w", err) } out, ok := outIntf.(verify.State) if !ok { - return nil, nil, nil, nil, permError{errInvalidState} + return nil, nil, errInvalidState } utxo := &avax.UTXO{ UTXOID: avax.UTXOID{ @@ -285,22 +269,16 @@ func (tx *UnsignedRewardValidatorTx) Execute( nodeID = uStakerTx.Validator.ID() startTime = vdrTx.StartTime() default: - return nil, nil, nil, nil, permError{errShouldBeDSValidator} + return nil, nil, errShouldBeDSValidator } uptime, err := vm.uptimeManager.CalculateUptimePercentFrom(nodeID, startTime) if err != nil { - return nil, nil, nil, nil, tempError{ - fmt.Errorf("failed to calculate uptime: %w", err), - } + return nil, nil, fmt.Errorf("failed to calculate uptime: %w", err) } tx.shouldPreferCommit = uptime >= vm.UptimePercentage - // Regardless of whether this tx is committed or aborted, update the - // validator set to remove the staker. onAbortDB or onCommitDB should commit - // (flush to vm.DB) before this is called - updateValidators := func() error { return vm.updateValidators(false) } - return onCommitState, onAbortState, updateValidators, updateValidators, nil + return onCommitState, onAbortState, nil } // InitiallyPrefersCommit returns true if this node thinks the validator diff --git a/vms/platformvm/reward_validator_tx_test.go b/vms/platformvm/reward_validator_tx_test.go index 8956f78883e9..4ce02f8c458f 100644 --- a/vms/platformvm/reward_validator_tx_test.go +++ b/vms/platformvm/reward_validator_tx_test.go @@ -15,6 +15,7 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/common" "github.com/ava-labs/avalanchego/snow/uptime" "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto" "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/version" @@ -41,7 +42,7 @@ func TestUnsignedRewardValidatorTxExecuteOnCommit(t *testing.T) { // Case 1: Chain timestamp is wrong if tx, err := vm.newRewardValidatorTx(toRemove.ID()); err != nil { t.Fatal(err) - } else if _, _, _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { t.Fatalf("should have failed because validator end time doesn't match chain timestamp") } @@ -51,7 +52,7 @@ func TestUnsignedRewardValidatorTxExecuteOnCommit(t *testing.T) { // Case 2: Wrong validator if tx, err := vm.newRewardValidatorTx(ids.GenerateTestID()); err != nil { t.Fatal(err) - } else if _, _, _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { t.Fatalf("should have failed because validator ID is wrong") } @@ -61,7 +62,7 @@ func TestUnsignedRewardValidatorTxExecuteOnCommit(t *testing.T) { t.Fatal(err) } - onCommitState, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + onCommitState, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) if err != nil { t.Fatal(err) } @@ -89,10 +90,6 @@ func TestUnsignedRewardValidatorTxExecuteOnCommit(t *testing.T) { t.Fatal(err) } - if err := vm.updateValidators(false); err != nil { - t.Fatal(err) - } - onCommitBalance, err := vm.getBalance(stakeOwners) if err != nil { t.Fatal(err) @@ -124,7 +121,7 @@ func TestUnsignedRewardValidatorTxExecuteOnAbort(t *testing.T) { // Case 1: Chain timestamp is wrong if tx, err := vm.newRewardValidatorTx(toRemove.ID()); err != nil { t.Fatal(err) - } else if _, _, _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { t.Fatalf("should have failed because validator end time doesn't match chain timestamp") } @@ -134,7 +131,7 @@ func TestUnsignedRewardValidatorTxExecuteOnAbort(t *testing.T) { // Case 2: Wrong validator if tx, err := vm.newRewardValidatorTx(ids.GenerateTestID()); err != nil { t.Fatal(err) - } else if _, _, _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { + } else if _, _, err := toRemove.Execute(vm, vm.internalState, tx); err == nil { t.Fatalf("should have failed because validator ID is wrong") } @@ -144,7 +141,7 @@ func TestUnsignedRewardValidatorTxExecuteOnAbort(t *testing.T) { t.Fatal(err) } - _, onAbortState, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + _, onAbortState, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) if err != nil { t.Fatal(err) } @@ -172,10 +169,6 @@ func TestUnsignedRewardValidatorTxExecuteOnAbort(t *testing.T) { t.Fatal(err) } - if err := vm.updateValidators(false); err != nil { - t.Fatal(err) - } - onAbortBalance, err := vm.getBalance(stakeOwners) if err != nil { t.Fatal(err) @@ -239,11 +232,17 @@ func TestRewardDelegatorTxExecuteOnCommit(t *testing.T) { assert.NoError(err) err = vm.internalState.(*internalStateImpl).loadCurrentValidators() assert.NoError(err) + // test validator stake + set, ok := vm.Validators.GetValidators(constants.PrimaryNetworkID) + assert.True(ok) + stake, ok := set.GetWeight(vdrNodeID) + assert.True(ok) + assert.Equal(vm.MinValidatorStake+vm.MinDelegatorStake, stake) tx, err := vm.newRewardValidatorTx(delTx.ID()) assert.NoError(err) - onCommitState, _, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + onCommitState, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) assert.NoError(err) vdrDestSet := ids.ShortSet{} @@ -278,6 +277,10 @@ func TestRewardDelegatorTxExecuteOnCommit(t *testing.T) { assert.Less(vdrReward, delReward, "the delegator's reward should be greater than the delegatee's because the delegatee's share is 25%") assert.Equal(expectedReward, delReward+vdrReward, "expected total reward to be %d but is %d", expectedReward, delReward+vdrReward) + + stake, ok = set.GetWeight(vdrNodeID) + assert.True(ok) + assert.Equal(vm.MinValidatorStake, stake) } func TestRewardDelegatorTxExecuteOnAbort(t *testing.T) { @@ -338,7 +341,7 @@ func TestRewardDelegatorTxExecuteOnAbort(t *testing.T) { tx, err := vm.newRewardValidatorTx(delTx.ID()) assert.NoError(err) - _, onAbortState, _, _, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) + _, onAbortState, err := tx.UnsignedTx.(UnsignedProposalTx).Execute(vm, vm.internalState, tx) assert.NoError(err) vdrDestSet := ids.ShortSet{} diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 94af3c1c1652..c73a4dddeeb2 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -279,8 +279,8 @@ func TestGetTxStatus(t *testing.T) { t.Fatal(err) } else if block, err := service.vm.BuildBlock(); err != nil { t.Fatal(err) - } else if blk, ok := block.(*AtomicBlock); !ok { - t.Fatalf("should be *AtomicBlock but is %T", block) + } else if blk, ok := block.(*StandardBlock); !ok { + t.Fatalf("should be *StandardBlock but is %T", block) } else if err := blk.Verify(); err != nil { t.Fatal(err) } else if err := blk.Accept(); err != nil { diff --git a/vms/platformvm/spend.go b/vms/platformvm/spend.go index b1bb33e231ca..77059eece95c 100644 --- a/vms/platformvm/spend.go +++ b/vms/platformvm/spend.go @@ -332,18 +332,16 @@ func (vm *VM) semanticVerifySpend( creds []verify.Verifiable, feeAmount uint64, feeAssetID ids.ID, -) TxError { +) error { utxos := make([]*avax.UTXO, len(ins)) for index, input := range ins { utxo, err := utxoDB.GetUTXO(input.InputID()) if err != nil { - return tempError{ - fmt.Errorf( - "failed to read consumed UTXO %s due to: %w", - &input.UTXOID, - err, - ), - } + return fmt.Errorf( + "failed to read consumed UTXO %s due to: %w", + &input.UTXOID, + err, + ) } utxos[index] = utxo } @@ -365,18 +363,24 @@ func (vm *VM) semanticVerifySpendUTXOs( creds []verify.Verifiable, feeAmount uint64, feeAssetID ids.ID, -) TxError { +) error { if len(ins) != len(creds) { - return permError{fmt.Errorf("there are %d inputs but %d credentials. Should be same number", - len(ins), len(creds))} + return fmt.Errorf( + "there are %d inputs but %d credentials. Should be same number", + len(ins), + len(creds), + ) } if len(ins) != len(utxos) { - return permError{fmt.Errorf("there are %d inputs but %d utxos. Should be same number", - len(ins), len(utxos))} + return fmt.Errorf( + "there are %d inputs but %d utxos. Should be same number", + len(ins), + len(utxos), + ) } for _, cred := range creds { // Verify credentials are well-formed. if err := cred.Verify(); err != nil { - return permError{err} + return err } } @@ -396,10 +400,10 @@ func (vm *VM) semanticVerifySpendUTXOs( utxo := utxos[index] // The UTXO consumed by [input] if assetID := utxo.AssetID(); assetID != feeAssetID { - return permError{errAssetIDMismatch} + return errAssetIDMismatch } if assetID := input.AssetID(); assetID != feeAssetID { - return permError{errAssetIDMismatch} + return errAssetIDMismatch } out := utxo.Out @@ -415,20 +419,18 @@ func (vm *VM) semanticVerifySpendUTXOs( // consumes it, is not locked even though [locktime] hasn't passed. This // is invalid. if inner, ok := in.(*StakeableLockIn); now < locktime && !ok { - return permError{errLockedFundsNotMarkedAsLocked} + return errLockedFundsNotMarkedAsLocked } else if ok { if inner.Locktime != locktime { // This input is locked, but its locktime is wrong - return permError{errWrongLocktime} + return errWrongLocktime } in = inner.TransferableIn } // Verify that this tx's credentials allow [in] to be spent if err := vm.fx.VerifyTransfer(tx, in, creds[index], out); err != nil { - return permError{ - fmt.Errorf("failed to verify transfer: %w", err), - } + return fmt.Errorf("failed to verify transfer: %w", err) } amount := in.Amount() @@ -436,7 +438,7 @@ func (vm *VM) semanticVerifySpendUTXOs( if now >= locktime { newUnlockedConsumed, err := math.Add64(unlockedConsumed, amount) if err != nil { - return permError{err} + return err } unlockedConsumed = newUnlockedConsumed continue @@ -444,14 +446,12 @@ func (vm *VM) semanticVerifySpendUTXOs( owned, ok := out.(Owned) if !ok { - return permError{errUnknownOwners} + return errUnknownOwners } owner := owned.Owners() ownerBytes, err := Codec.Marshal(CodecVersion, owner) if err != nil { - return tempError{ - fmt.Errorf("couldn't marshal owner: %w", err), - } + return fmt.Errorf("couldn't marshal owner: %w", err) } ownerID := hashing.ComputeHash256Array(ownerBytes) owners, ok := lockedConsumed[locktime] @@ -461,14 +461,14 @@ func (vm *VM) semanticVerifySpendUTXOs( } newAmount, err := math.Add64(owners[ownerID], amount) if err != nil { - return permError{err} + return err } owners[ownerID] = newAmount } for _, out := range outs { if assetID := out.AssetID(); assetID != feeAssetID { - return permError{errAssetIDMismatch} + return errAssetIDMismatch } output := out.Output() @@ -484,7 +484,7 @@ func (vm *VM) semanticVerifySpendUTXOs( if locktime == 0 { newUnlockedProduced, err := math.Add64(unlockedProduced, amount) if err != nil { - return permError{err} + return err } unlockedProduced = newUnlockedProduced continue @@ -492,14 +492,12 @@ func (vm *VM) semanticVerifySpendUTXOs( owned, ok := output.(Owned) if !ok { - return permError{errUnknownOwners} + return errUnknownOwners } owner := owned.Owners() ownerBytes, err := Codec.Marshal(CodecVersion, owner) if err != nil { - return tempError{ - fmt.Errorf("couldn't marshal owner: %w", err), - } + return fmt.Errorf("couldn't marshal owner: %w", err) } ownerID := hashing.ComputeHash256Array(ownerBytes) owners, ok := lockedProduced[locktime] @@ -509,7 +507,7 @@ func (vm *VM) semanticVerifySpendUTXOs( } newAmount, err := math.Add64(owners[ownerID], amount) if err != nil { - return permError{err} + return err } owners[ownerID] = newAmount } @@ -523,7 +521,13 @@ func (vm *VM) semanticVerifySpendUTXOs( if producedAmount > consumedAmount { increase := producedAmount - consumedAmount if increase > unlockedConsumed { - return permError{fmt.Errorf("address %s produces %d unlocked and consumes %d unlocked for locktime %d", ownerID, increase, unlockedConsumed, locktime)} + return fmt.Errorf( + "address %s produces %d unlocked and consumes %d unlocked for locktime %d", + ownerID, + increase, + unlockedConsumed, + locktime, + ) } unlockedConsumed -= increase } @@ -532,9 +536,12 @@ func (vm *VM) semanticVerifySpendUTXOs( // More unlocked tokens produced than consumed. Invalid. if unlockedProduced > unlockedConsumed { - return permError{fmt.Errorf("tx produces more unlocked (%d) than it consumes (%d)", unlockedProduced, unlockedConsumed)} + return fmt.Errorf( + "tx produces more unlocked (%d) than it consumes (%d)", + unlockedProduced, + unlockedConsumed, + ) } - return nil } diff --git a/vms/platformvm/standard_block.go b/vms/platformvm/standard_block.go index e1c31ee17e1b..2d60777b480a 100644 --- a/vms/platformvm/standard_block.go +++ b/vms/platformvm/standard_block.go @@ -69,13 +69,6 @@ func (sb *StandardBlock) Verify() error { blkID := sb.ID() if err := sb.CommonDecisionBlock.Verify(); err != nil { - if err := sb.Reject(); err != nil { - sb.vm.ctx.Log.Error( - "failed to reject standard block %s due to %s", - blkID, - err, - ) - } return err } @@ -88,13 +81,6 @@ func (sb *StandardBlock) Verify() error { // be a decision. parent, ok := parentIntf.(decision) if !ok { - if err := sb.Reject(); err != nil { - sb.vm.ctx.Log.Error( - "failed to reject standard block %s due to %s", - blkID, - err, - ) - } return errInvalidBlockType } @@ -108,9 +94,6 @@ func (sb *StandardBlock) Verify() error { // clear inputs so that multiple [Verify] calls can be made sb.inputs.Clear() - currentTimestamp := parentState.GetTimestamp() - enabledAP5 := !currentTimestamp.Before(sb.vm.ApricotPhase5Time) - funcs := make([]func() error, 0, len(sb.Txs)) for _, tx := range sb.Txs { txID := tx.ID() @@ -120,17 +103,8 @@ func (sb *StandardBlock) Verify() error { return errWrongTxType } - // TODO: remove after AP5. - if _, ok := tx.UnsignedTx.(UnsignedAtomicTx); ok && !enabledAP5 { - return fmt.Errorf( - "the chain timestamp (%d) is before the apricot phase 5 time (%d), hence atomic transactions should go through the atomic block", - currentTimestamp.Unix(), - sb.vm.ApricotPhase5Time.Unix(), - ) - } - inputUTXOs := utx.InputUTXOs() - // ensure it doesnt overlap with current input batch + // ensure it doesn't overlap with current input batch if sb.inputs.Overlaps(inputUTXOs) { return errConflictingBatchTxs } @@ -140,13 +114,6 @@ func (sb *StandardBlock) Verify() error { onAccept, err := utx.Execute(sb.vm, sb.onAcceptState, tx) if err != nil { sb.vm.droppedTxCache.Put(txID, err.Error()) // cache tx as dropped - if err := sb.Reject(); err != nil { - sb.vm.ctx.Log.Error( - "failed to reject standard block %s due to %s", - blkID, - err, - ) - } return err } @@ -183,7 +150,6 @@ func (sb *StandardBlock) Verify() error { sb.timestamp = sb.onAcceptState.GetTimestamp() sb.vm.blockBuilder.RemoveDecisionTxs(sb.Txs) - sb.vm.blockBuilder.RemoveAtomicTxs(sb.Txs) sb.vm.currentBlocks[blkID] = sb parentIntf.addChild(sb) return nil diff --git a/vms/platformvm/standard_block_test.go b/vms/platformvm/standard_block_test.go index 1be360d3950a..4299dcbfe96b 100644 --- a/vms/platformvm/standard_block_test.go +++ b/vms/platformvm/standard_block_test.go @@ -80,7 +80,7 @@ func TestAtomicTxImports(t *testing.T) { } vm.internalState.SetTimestamp(vm.ApricotPhase5Time.Add(100 * time.Second)) - vm.mempool.AddAtomicTx(tx) + vm.mempool.AddDecisionTx(tx) b, err := vm.BuildBlock() assert.NoError(err) // Test multiple verify calls work diff --git a/vms/platformvm/tx.go b/vms/platformvm/tx.go index eb422d6b0146..4a0705c5b2fb 100644 --- a/vms/platformvm/tx.go +++ b/vms/platformvm/tx.go @@ -54,7 +54,7 @@ type UnsignedDecisionTx interface { // Execute this transaction with the provided state. Execute(vm *VM, vs VersionedState, stx *Tx) ( onAcceptFunc func() error, - err TxError, + err error, ) // To maintain consistency with the Atomic txs @@ -72,9 +72,7 @@ type UnsignedProposalTx interface { Execute(vm *VM, state MutableState, stx *Tx) ( onCommitState VersionedState, onAbortState VersionedState, - onCommitFunc func() error, - onAbortFunc func() error, - err TxError, + err error, ) InitiallyPrefersCommit(vm *VM) bool } @@ -84,7 +82,7 @@ type UnsignedAtomicTx interface { UnsignedDecisionTx // Execute this transaction with the provided state. - AtomicExecute(vm *VM, parentState MutableState, stx *Tx) (VersionedState, TxError) + AtomicExecute(vm *VM, parentState MutableState, stx *Tx) (VersionedState, error) // Accept this transaction with the additionally provided state transitions. AtomicAccept(ctx *snow.Context, batch database.Batch) error diff --git a/vms/platformvm/vm.go b/vms/platformvm/vm.go index 3b49baf3adfc..a46b00a72445 100644 --- a/vms/platformvm/vm.go +++ b/vms/platformvm/vm.go @@ -36,7 +36,6 @@ import ( "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/vms/components/avax" - "github.com/ava-labs/avalanchego/vms/components/verify" "github.com/ava-labs/avalanchego/vms/secp256k1fx" safemath "github.com/ava-labs/avalanchego/utils/math" @@ -139,8 +138,6 @@ type VM struct { // Key: block ID // Value: the block currentBlocks map[ids.ID]Block - - lastVdrUpdate time.Time } // Initialize this blockchain. @@ -206,7 +203,7 @@ func (vm *VM) Initialize( vm.uptimeManager = uptime.NewManager(is) vm.UptimeLockedCalculator.SetCalculator(&vm.bootstrapped, &ctx.Lock, vm.uptimeManager) - if err := vm.updateValidators(true); err != nil { + if err := vm.updateValidators(); err != nil { return fmt.Errorf( "failed to initialize validator sets: %w", err, @@ -309,13 +306,8 @@ func (vm *VM) Bootstrapped() error { } vm.bootstrapped.SetValue(true) - errs := wrappers.Errs{} - errs.Add( - vm.updateValidators(false), - vm.fx.Bootstrapped(), - ) - if errs.Errored() { - return errs.Err + if err := vm.fx.Bootstrapped(); err != nil { + return err } primaryValidatorSet, exist := vm.Validators.GetValidators(constants.PrimaryNetworkID) @@ -582,13 +574,7 @@ func (vm *VM) GetCurrentHeight() (uint64, error) { return lastAccepted.Height(), nil } -func (vm *VM) updateValidators(force bool) error { - now := vm.clock.Time() - if !force && !vm.bootstrapped.GetValue() && now.Sub(vm.lastVdrUpdate) < 5*time.Second { - return nil - } - vm.lastVdrUpdate = now - +func (vm *VM) updateValidators() error { currentValidators := vm.internalState.CurrentStakerChainState() primaryValidators, err := currentValidators.ValidatorSet(constants.PrimaryNetworkID) if err != nil { @@ -677,17 +663,3 @@ func (vm *VM) getPercentConnected() (float64, error) { } return float64(connectedStake) / float64(vdrSet.Weight()), nil } - -// TODO: remove after AP5 -func (vm *VM) isValidCrossChainID(vs VersionedState, peerChainID ids.ID) TxError { - currentTimestamp := vs.GetTimestamp() - enabledAP5 := !currentTimestamp.Before(vm.ApricotPhase5Time) - if enabledAP5 { - if err := verify.SameSubnet(vm.ctx, peerChainID); err != nil { - return tempError{err} - } - } else if peerChainID != vm.ctx.XChainID { - return permError{errWrongChainID} - } - return nil -} diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 9e7adbdc655a..3364172b498d 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -734,12 +734,9 @@ func TestInvalidAddValidatorCommit(t *testing.T) { if err != nil { t.Fatal(err) } + preferredID := preferred.ID() preferredHeight := preferred.Height() - lastAcceptedID, err := vm.LastAccepted() - if err != nil { - t.Fatal(err) - } - blk, err := vm.newProposalBlock(lastAcceptedID, preferredHeight+1, *tx) + blk, err := vm.newProposalBlock(preferredID, preferredHeight+1, *tx) if err != nil { t.Fatal(err) } @@ -753,20 +750,9 @@ func TestInvalidAddValidatorCommit(t *testing.T) { if err := parsedBlock.Verify(); err == nil { t.Fatalf("Should have errored during verification") } - if status := parsedBlock.Status(); status != choices.Rejected { - t.Fatalf("Should have marked the block as rejected") - } if _, ok := vm.droppedTxCache.Get(blk.Tx.ID()); !ok { t.Fatal("tx should be in dropped tx cache") } - - parsedBlk, err := vm.GetBlock(blk.ID()) - if err != nil { - t.Fatal(err) - } - if status := parsedBlk.Status(); status != choices.Rejected { - t.Fatalf("Should have marked the block as rejected") - } } // Reject proposal to add validator to primary network diff --git a/vms/proposervm/vm_byzantine_test.go b/vms/proposervm/vm_byzantine_test.go index 691b158ec0cc..6000a2070d23 100644 --- a/vms/proposervm/vm_byzantine_test.go +++ b/vms/proposervm/vm_byzantine_test.go @@ -5,6 +5,7 @@ package proposervm import ( "bytes" + "encoding/hex" "testing" "time" @@ -618,3 +619,127 @@ func TestBlockVerify_InvalidPostForkOption(t *testing.T) { t.Fatal(err) } } + +func TestGetBlock_MutatedSignature(t *testing.T) { + coreVM, valState, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) + + // Make sure that we will be sampled to perform the proposals. + valState.GetValidatorSetF = func(height uint64, subnetID ids.ID) (map[ids.ShortID]uint64, error) { + res := make(map[ids.ShortID]uint64) + res[proVM.ctx.NodeID] = uint64(10) + return res, nil + } + + proVM.Set(coreGenBlk.Timestamp()) + + // Create valid core blocks to build our chain on. + coreBlk0 := &snowman.TestBlock{ + TestDecidable: choices.TestDecidable{ + IDV: ids.Empty.Prefix(1111), + StatusV: choices.Processing, + }, + BytesV: []byte{1}, + ParentV: coreGenBlk.ID(), + HeightV: coreGenBlk.Height() + 1, + TimestampV: coreGenBlk.Timestamp(), + } + + coreBlk1 := &snowman.TestBlock{ + TestDecidable: choices.TestDecidable{ + IDV: ids.Empty.Prefix(2222), + StatusV: choices.Processing, + }, + BytesV: []byte{2}, + ParentV: coreBlk0.ID(), + HeightV: coreBlk0.Height() + 1, + TimestampV: coreGenBlk.Timestamp(), + } + + coreVM.GetBlockF = func(blkID ids.ID) (snowman.Block, error) { + switch blkID { + case coreGenBlk.ID(): + return coreGenBlk, nil + case coreBlk0.ID(): + return coreBlk0, nil + case coreBlk1.ID(): + return coreBlk1, nil + default: + return nil, database.ErrNotFound + } + } + coreVM.ParseBlockF = func(b []byte) (snowman.Block, error) { + switch { + case bytes.Equal(b, coreGenBlk.Bytes()): + return coreGenBlk, nil + case bytes.Equal(b, coreBlk0.Bytes()): + return coreBlk0, nil + case bytes.Equal(b, coreBlk1.Bytes()): + return coreBlk1, nil + default: + return nil, errUnknownBlock + } + } + + // Build the first proposal block + coreVM.BuildBlockF = func() (snowman.Block, error) { return coreBlk0, nil } + + builtBlk0, err := proVM.BuildBlock() + if err != nil { + t.Fatalf("could not build post fork block %s", err) + } + + if err := builtBlk0.Verify(); err != nil { + t.Fatalf("failed to verify newly created block %s", err) + } + + if err := proVM.SetPreference(builtBlk0.ID()); err != nil { + t.Fatal(err) + } + + // The second propsal block will need to be signed because the timestamp + // hasn't moved forward + + // Craft what would be the next block, but with an invalid signature: + // ID: 2R3Uz98YmxHUJARWv6suApPdAbbZ7X7ipat1gZuZNNhC5wPwJW + // Valid Bytes: 000000000000fd81ce4f1ab2650176d46a3d1fbb593af5717a2ada7dabdcef19622325a8ce8400000000000003e800000000000006d0000004a13082049d30820285a003020102020100300d06092a864886f70d01010b050030003020170d3939313233313030303030305a180f32313231313132333130313030305a300030820222300d06092a864886f70d01010105000382020f003082020a0282020100b9c3615c42d501f3b9d21ed127b31855827dbe12652e6e6f278991a3ad1ca55e2241b1cac69a0aeeefdd913db8ae445ff847789fdcbc1cbe6cce0a63109d1c1fb9d441c524a6eb1412f9b8090f1507e3e50a725f9d0a9d5db424ea229a7c11d8b91c73fecbad31c7b216bb2ac5e4d5ff080a80fabc73b34beb8fa46513ab59d489ce3f273c0edab43ded4d4914e081e6e850f9e502c3c4a54afc8a3a89d889aec275b7162a7616d53a61cd3ee466394212e5bef307790100142ad9e0b6c95ad2424c6e84d06411ad066d0c37d4d14125bae22b49ad2a761a09507bbfe43d023696d278d9fbbaf06c4ff677356113d3105e248078c33caed144d85929b1dd994df33c5d3445675104659ca9642c269b5cfa39c7bad5e399e7ebce3b5e6661f989d5f388006ebd90f0e035d533f5662cb925df8744f61289e66517b51b9a2f54792dca9078d5e12bf8ad79e35a68d4d661d15f0d3029d6c5903c845323d5426e49deaa2be2bc261423a9cd77df9a2706afaca27f589cc2c8f53e2a1f90eb5a3f8bcee0769971db6bacaec265d86b39380f69e3e0e06072de986feede26fe856c55e24e88ee5ac342653ac55a04e21b8517310c717dff0e22825c0944c6ba263f8f060099ea6e44a57721c7aa54e2790a4421fb85e3347e4572cba44e62b2cad19c1623c1cab4a715078e56458554cef8442769e6d5dd7f99a6234653a46828804f0203010001a320301e300e0603551d0f0101ff0404030204b0300c0603551d130101ff04023000300d06092a864886f70d01010b050003820201004ee2229d354720a751e2d2821134994f5679997113192626cf61594225cfdf51e6479e2c17e1013ab9dceb713bc0f24649e5cab463a8cf8617816ed736ac5251a853ff35e859ac6853ebb314f967ff7867c53512d42e329659375682c854ca9150cfa4c3964680e7650beb93e8b4a0d6489a9ca0ce0104752ba4d9cf3e2dc9436b56ecd0bd2e33cbbeb5a107ec4fd6f41a943c8bee06c0b32f4291a3e3759a7984d919a97d5d6517b841053df6e795ed33b52ed5e41357c3e431beb725e4e4f2ef956c44fd1f76fa4d847602e491c3585a90cdccfff982405d388b83d6f32ea16da2f5e4595926a7d26078e32992179032d30831b1f1b42de1781c507536a49adb4c95bad04c171911eed30d63c73712873d1e8094355efb9aeee0c16f8599575fd7f8bb027024bad63b097d2230d8f0ba12a8ed23e618adc3d7cb6a63e02b82a6d4d74b21928dbcb6d3788c6fd45022d69f3ab94d914d97cd651db662e92918a5d891ef730a813f03aade2fe385b61f44840f8925ad3345df1c82c9de882bb7184b4cd0bbd9db8322aaedb4ff86e5be9635987e6c40455ab9b063cdb423bee2edcac47cf654487e9286f33bdbad10018f4db9564cee6e048570e1517a2e396501b5978a53d10a548aed26938c2f9aada3ae62d3fdae486deb9413dffb6524666453633d665c3712d0fec9f844632b2b3eaf0267ca495eb41dba8273862609de00000001020000020098147a41989d8626f63d0966b39376143e45ea6e21b62761a115660d88db9cba37be71d1e1153e7546eb075749122449f2f3f5984e51773f082700d847334da35babe72a66e5a49c9a96cd763bdd94258263ae92d30da65d7c606482d0afe9f4f884f4f6c33d6d8e1c0c71061244ebec6a9dbb9b78bfbb71dec572aa0c0d8e532bf779457e05412b75acf12f35c75917a3eda302aaa27c3090e93bf5de0c3e30968cf8ba025b91962118bbdb6612bf682ba6e87ae6cd1a5034c89559b76af870395dc17ec592e9dbb185633aa1604f8d648f82142a2d1a4dabd91f816b34e73120a70d061e64e6da62ba434fd0cdf7296aa67fd5e0432ef8cee67c1b59aee91c99288c17a8511d96ba7339fb4ae5da453289aa7a9fab00d37035accae24eef0eaf517148e67bdc76adaac2429508d642df3033ad6c9e3fb53057244c1295f2ed3ac66731f77178fccb7cc4fd40778ccb061e5d53cd0669371d8d355a4a733078a9072835b5564a52a50f5db8525d2ee00466124a8d40d9959281b86a789bd0769f3fb0deb89f0eb9cfe036ff8a0011f52ca551c30202f46680acfa656ccf32a4e8a7121ef52442128409dc40d21d61205839170c7b022f573c2cfdaa362df22e708e7572b9b77f4fb20fe56b122bcb003566e20caef289f9d7992c2f1ad0c8366f71e8889390e0d14e2e76c56b515933b0c337ac6bfcf76d33e2ba50cb62eb71 + // Invalid Bytes: 000000000000fd81ce4f1ab2650176d46a3d1fbb593af5717a2ada7dabdcef19622325a8ce8400000000000003e800000000000006d0000004a13082049d30820285a003020102020100300d06092a864886f70d01010b050030003020170d3939313233313030303030305a180f32313231313132333130313030305a300030820222300d06092a864886f70d01010105000382020f003082020a0282020100b9c3615c42d501f3b9d21ed127b31855827dbe12652e6e6f278991a3ad1ca55e2241b1cac69a0aeeefdd913db8ae445ff847789fdcbc1cbe6cce0a63109d1c1fb9d441c524a6eb1412f9b8090f1507e3e50a725f9d0a9d5db424ea229a7c11d8b91c73fecbad31c7b216bb2ac5e4d5ff080a80fabc73b34beb8fa46513ab59d489ce3f273c0edab43ded4d4914e081e6e850f9e502c3c4a54afc8a3a89d889aec275b7162a7616d53a61cd3ee466394212e5bef307790100142ad9e0b6c95ad2424c6e84d06411ad066d0c37d4d14125bae22b49ad2a761a09507bbfe43d023696d278d9fbbaf06c4ff677356113d3105e248078c33caed144d85929b1dd994df33c5d3445675104659ca9642c269b5cfa39c7bad5e399e7ebce3b5e6661f989d5f388006ebd90f0e035d533f5662cb925df8744f61289e66517b51b9a2f54792dca9078d5e12bf8ad79e35a68d4d661d15f0d3029d6c5903c845323d5426e49deaa2be2bc261423a9cd77df9a2706afaca27f589cc2c8f53e2a1f90eb5a3f8bcee0769971db6bacaec265d86b39380f69e3e0e06072de986feede26fe856c55e24e88ee5ac342653ac55a04e21b8517310c717dff0e22825c0944c6ba263f8f060099ea6e44a57721c7aa54e2790a4421fb85e3347e4572cba44e62b2cad19c1623c1cab4a715078e56458554cef8442769e6d5dd7f99a6234653a46828804f0203010001a320301e300e0603551d0f0101ff0404030204b0300c0603551d130101ff04023000300d06092a864886f70d01010b050003820201004ee2229d354720a751e2d2821134994f5679997113192626cf61594225cfdf51e6479e2c17e1013ab9dceb713bc0f24649e5cab463a8cf8617816ed736ac5251a853ff35e859ac6853ebb314f967ff7867c53512d42e329659375682c854ca9150cfa4c3964680e7650beb93e8b4a0d6489a9ca0ce0104752ba4d9cf3e2dc9436b56ecd0bd2e33cbbeb5a107ec4fd6f41a943c8bee06c0b32f4291a3e3759a7984d919a97d5d6517b841053df6e795ed33b52ed5e41357c3e431beb725e4e4f2ef956c44fd1f76fa4d847602e491c3585a90cdccfff982405d388b83d6f32ea16da2f5e4595926a7d26078e32992179032d30831b1f1b42de1781c507536a49adb4c95bad04c171911eed30d63c73712873d1e8094355efb9aeee0c16f8599575fd7f8bb027024bad63b097d2230d8f0ba12a8ed23e618adc3d7cb6a63e02b82a6d4d74b21928dbcb6d3788c6fd45022d69f3ab94d914d97cd651db662e92918a5d891ef730a813f03aade2fe385b61f44840f8925ad3345df1c82c9de882bb7184b4cd0bbd9db8322aaedb4ff86e5be9635987e6c40455ab9b063cdb423bee2edcac47cf654487e9286f33bdbad10018f4db9564cee6e048570e1517a2e396501b5978a53d10a548aed26938c2f9aada3ae62d3fdae486deb9413dffb6524666453633d665c3712d0fec9f844632b2b3eaf0267ca495eb41dba8273862609de00000001020000000101 + invalidBlkBytesHex := "000000000000fd81ce4f1ab2650176d46a3d1fbb593af5717a2ada7dabdcef19622325a8ce8400000000000003e800000000000006d0000004a13082049d30820285a003020102020100300d06092a864886f70d01010b050030003020170d3939313233313030303030305a180f32313231313132333130313030305a300030820222300d06092a864886f70d01010105000382020f003082020a0282020100b9c3615c42d501f3b9d21ed127b31855827dbe12652e6e6f278991a3ad1ca55e2241b1cac69a0aeeefdd913db8ae445ff847789fdcbc1cbe6cce0a63109d1c1fb9d441c524a6eb1412f9b8090f1507e3e50a725f9d0a9d5db424ea229a7c11d8b91c73fecbad31c7b216bb2ac5e4d5ff080a80fabc73b34beb8fa46513ab59d489ce3f273c0edab43ded4d4914e081e6e850f9e502c3c4a54afc8a3a89d889aec275b7162a7616d53a61cd3ee466394212e5bef307790100142ad9e0b6c95ad2424c6e84d06411ad066d0c37d4d14125bae22b49ad2a761a09507bbfe43d023696d278d9fbbaf06c4ff677356113d3105e248078c33caed144d85929b1dd994df33c5d3445675104659ca9642c269b5cfa39c7bad5e399e7ebce3b5e6661f989d5f388006ebd90f0e035d533f5662cb925df8744f61289e66517b51b9a2f54792dca9078d5e12bf8ad79e35a68d4d661d15f0d3029d6c5903c845323d5426e49deaa2be2bc261423a9cd77df9a2706afaca27f589cc2c8f53e2a1f90eb5a3f8bcee0769971db6bacaec265d86b39380f69e3e0e06072de986feede26fe856c55e24e88ee5ac342653ac55a04e21b8517310c717dff0e22825c0944c6ba263f8f060099ea6e44a57721c7aa54e2790a4421fb85e3347e4572cba44e62b2cad19c1623c1cab4a715078e56458554cef8442769e6d5dd7f99a6234653a46828804f0203010001a320301e300e0603551d0f0101ff0404030204b0300c0603551d130101ff04023000300d06092a864886f70d01010b050003820201004ee2229d354720a751e2d2821134994f5679997113192626cf61594225cfdf51e6479e2c17e1013ab9dceb713bc0f24649e5cab463a8cf8617816ed736ac5251a853ff35e859ac6853ebb314f967ff7867c53512d42e329659375682c854ca9150cfa4c3964680e7650beb93e8b4a0d6489a9ca0ce0104752ba4d9cf3e2dc9436b56ecd0bd2e33cbbeb5a107ec4fd6f41a943c8bee06c0b32f4291a3e3759a7984d919a97d5d6517b841053df6e795ed33b52ed5e41357c3e431beb725e4e4f2ef956c44fd1f76fa4d847602e491c3585a90cdccfff982405d388b83d6f32ea16da2f5e4595926a7d26078e32992179032d30831b1f1b42de1781c507536a49adb4c95bad04c171911eed30d63c73712873d1e8094355efb9aeee0c16f8599575fd7f8bb027024bad63b097d2230d8f0ba12a8ed23e618adc3d7cb6a63e02b82a6d4d74b21928dbcb6d3788c6fd45022d69f3ab94d914d97cd651db662e92918a5d891ef730a813f03aade2fe385b61f44840f8925ad3345df1c82c9de882bb7184b4cd0bbd9db8322aaedb4ff86e5be9635987e6c40455ab9b063cdb423bee2edcac47cf654487e9286f33bdbad10018f4db9564cee6e048570e1517a2e396501b5978a53d10a548aed26938c2f9aada3ae62d3fdae486deb9413dffb6524666453633d665c3712d0fec9f844632b2b3eaf0267ca495eb41dba8273862609de00000001020000000101" + invalidBlkBytes, err := hex.DecodeString(invalidBlkBytesHex) + if err != nil { + t.Fatal(err) + } + + invalidBlk, err := proVM.ParseBlock(invalidBlkBytes) + if err != nil { + // Not being able to parse an invalid block is fine. + t.Skip(err) + } + + if err := invalidBlk.Verify(); err == nil { + t.Fatalf("verified block without valid signature") + } + + // Note that the invalidBlk.ID() is the same as the correct blk ID because + // the signature isn't part of the blk ID. + blkID, err := ids.FromString("2R3Uz98YmxHUJARWv6suApPdAbbZ7X7ipat1gZuZNNhC5wPwJW") + if err != nil { + t.Fatal(err) + } + + if blkID != invalidBlk.ID() { + t.Fatalf("unexpected block ID; expected = %s , got = %s", blkID, invalidBlk.ID()) + } + + // GetBlock shouldn't really be able to succeed, as we don't have a valid + // representation of [blkID] + fetchedBlk, err := proVM.GetBlock(blkID) + if err != nil { + t.Skip(err) + } + + // GetBlock returned, so it must have somehow gotten a valid representation + // of [blkID]. + if err := fetchedBlk.Verify(); err != nil { + t.Fatalf("GetBlock returned an invalid block when the ID represented a potentially valid block: %s", err) + } +} diff --git a/vms/rpcchainvm/vm_server.go b/vms/rpcchainvm/vm_server.go index 2fc9e7a8ad94..661026f618df 100644 --- a/vms/rpcchainvm/vm_server.go +++ b/vms/rpcchainvm/vm_server.go @@ -18,6 +18,7 @@ import ( "github.com/ava-labs/avalanchego/api/metrics" "github.com/ava-labs/avalanchego/chains/atomic/gsharedmemory" "github.com/ava-labs/avalanchego/chains/atomic/gsharedmemory/gsharedmemoryproto" + "github.com/ava-labs/avalanchego/database/corruptabledb" "github.com/ava-labs/avalanchego/database/manager" "github.com/ava-labs/avalanchego/database/rpcdb" "github.com/ava-labs/avalanchego/database/rpcdb/rpcdbproto" @@ -106,9 +107,9 @@ func (vm *VMServer) Initialize(_ context.Context, req *vmproto.InitializeRequest return nil, err } vm.connCloser.Add(dbConn) - + db := rpcdb.NewClient(rpcdbproto.NewDatabaseClient(dbConn)) versionedDBs[i] = &manager.VersionedDatabase{ - Database: rpcdb.NewClient(rpcdbproto.NewDatabaseClient(dbConn)), + Database: corruptabledb.New(db), Version: version, } }