Skip to content

Commit

Permalink
Merge pull request #2384 from OffchainLabs/gligneul/init-checksum
Browse files Browse the repository at this point in the history
[NIT-2467][Configuration Changes] Add option to enable the checksum validation
  • Loading branch information
gligneul committed Jun 17, 2024
2 parents ad9ac77 + 7b9cc29 commit 6319948
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 18 deletions.
3 changes: 3 additions & 0 deletions cmd/conf/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type InitConfig struct {
Url string `koanf:"url"`
Latest string `koanf:"latest"`
LatestBase string `koanf:"latest-base"`
ValidateChecksum bool `koanf:"validate-checksum"`
DownloadPath string `koanf:"download-path"`
DownloadPoll time.Duration `koanf:"download-poll"`
DevInit bool `koanf:"dev-init"`
Expand All @@ -39,6 +40,7 @@ var InitConfigDefault = InitConfig{
Url: "",
Latest: "",
LatestBase: "https://snapshot.arbitrum.foundation/",
ValidateChecksum: true,
DownloadPath: "/tmp/",
DownloadPoll: time.Minute,
DevInit: false,
Expand All @@ -62,6 +64,7 @@ func InitConfigAddOptions(prefix string, f *pflag.FlagSet) {
f.String(prefix+".url", InitConfigDefault.Url, "url to download initialization data - will poll if download fails")
f.String(prefix+".latest", InitConfigDefault.Latest, "if set, searches for the latest snapshot of the given kind "+acceptedSnapshotKindsStr)
f.String(prefix+".latest-base", InitConfigDefault.LatestBase, "base url used when searching for the latest")
f.Bool(prefix+".validate-checksum", InitConfigDefault.ValidateChecksum, "if true: validate the checksum after downloading the snapshot")
f.String(prefix+".download-path", InitConfigDefault.DownloadPath, "path to save temp downloaded file")
f.Duration(prefix+".download-poll", InitConfigDefault.DownloadPoll, "how long to wait between polling attempts")
f.Bool(prefix+".dev-init", InitConfigDefault.DevInit, "init with dev data (1 account with balance) instead of file import")
Expand Down
18 changes: 16 additions & 2 deletions cmd/nitro/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func downloadInit(ctx context.Context, initConfig *conf.InitConfig) (string, err
return initFile, nil
}
log.Info("Downloading initial database", "url", initConfig.Url)
if !initConfig.ValidateChecksum {
file, err := downloadFile(ctx, initConfig, initConfig.Url, nil)
if err != nil && errors.Is(err, notFoundError) {
return downloadInitInParts(ctx, initConfig)
}
return file, err
}
checksum, err := fetchChecksum(ctx, initConfig.Url+".sha256")
if err != nil {
if errors.Is(err, notFoundError) {
Expand All @@ -100,7 +107,10 @@ func downloadFile(ctx context.Context, initConfig *conf.InitConfig, url string,
if err != nil {
panic(err)
}
req.SetChecksum(sha256.New(), checksum, false)
if checksum != nil {
const deleteOnError = true
req.SetChecksum(sha256.New(), checksum, deleteOnError)
}
resp := grabclient.Do(req.WithContext(ctx))
firstPrintTime := time.Now().Add(time.Second * 2)
updateLoop:
Expand Down Expand Up @@ -235,7 +245,11 @@ func downloadInitInParts(ctx context.Context, initConfig *conf.InitConfig) (stri
for i, partName := range partNames {
log.Info("Downloading database part", "part", partName)
partUrl := url.JoinPath("..", partName).String()
partFile, err := downloadFile(ctx, initConfig, partUrl, checksums[i])
var checksum []byte
if initConfig.ValidateChecksum {
checksum = checksums[i]
}
partFile, err := downloadFile(ctx, initConfig, partUrl, checksum)
if err != nil {
return "", fmt.Errorf("error downloading part \"%s\": %w", partName, err)
}
Expand Down
101 changes: 85 additions & 16 deletions cmd/nitro/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,55 @@ import (
"os"
"path"
"path/filepath"
"strings"
"testing"
"time"

"github.com/offchainlabs/nitro/cmd/conf"
"github.com/offchainlabs/nitro/util/testhelpers"
)

func TestDownloadInit(t *testing.T) {
const (
archiveName = "random_data.tar.gz"
dataSize = 1024 * 1024
filePerm = 0600
)
const (
archiveName = "random_data.tar.gz"
numParts = 3
partSize = 1024 * 1024
dataSize = numParts * partSize
filePerm = 0600
dirPerm = 0700
)

func TestDownloadInitWithoutChecksum(t *testing.T) {
// Create archive with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)

// Write archive file
archiveFile := fmt.Sprintf("%s/%s", serverDir, archiveName)
err := os.WriteFile(archiveFile, data, filePerm)
Require(t, err, "failed to write archive")

// Start HTTP server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr := startFileServer(t, ctx, serverDir)

// Download file
initConfig := conf.InitConfigDefault
initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName)
initConfig.DownloadPath = t.TempDir()
initConfig.ValidateChecksum = false
receivedArchive, err := downloadInit(ctx, &initConfig)
Require(t, err, "failed to download")

// Check archive contents
receivedData, err := os.ReadFile(receivedArchive)
Require(t, err, "failed to read received archive")
if !bytes.Equal(receivedData, data) {
t.Error("downloaded archive is different from generated one")
}
}

func TestDownloadInitWithChecksum(t *testing.T) {
// Create archive with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)
Expand Down Expand Up @@ -65,15 +100,51 @@ func TestDownloadInit(t *testing.T) {
}
}

func TestDownloadInitInParts(t *testing.T) {
const (
archiveName = "random_data.tar.gz"
numParts = 3
partSize = 1024 * 1024
dataSize = numParts * partSize
filePerm = 0600
)
func TestDownloadInitInPartsWithoutChecksum(t *testing.T) {
// Create parts with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)
manifest := bytes.NewBuffer(nil)
for i := 0; i < numParts; i++ {
partData := data[partSize*i : partSize*(i+1)]
partName := fmt.Sprintf("%s.part%d", archiveName, i)
fmt.Fprintf(manifest, "%s %s\n", strings.Repeat("0", 64), partName)
err := os.WriteFile(path.Join(serverDir, partName), partData, filePerm)
Require(t, err, "failed to write part")
}
manifestFile := fmt.Sprintf("%s/%s.manifest.txt", serverDir, archiveName)
err := os.WriteFile(manifestFile, manifest.Bytes(), filePerm)
Require(t, err, "failed to write manifest file")

// Start HTTP server
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr := startFileServer(t, ctx, serverDir)

// Download file
initConfig := conf.InitConfigDefault
initConfig.Url = fmt.Sprintf("http://%s/%s", addr, archiveName)
initConfig.DownloadPath = t.TempDir()
initConfig.ValidateChecksum = false
receivedArchive, err := downloadInit(ctx, &initConfig)
Require(t, err, "failed to download")

// check database contents
receivedData, err := os.ReadFile(receivedArchive)
Require(t, err, "failed to read received archive")
if !bytes.Equal(receivedData, data) {
t.Error("downloaded archive is different from generated one")
}

// Check if the function deleted the temporary files
entries, err := os.ReadDir(initConfig.DownloadPath)
Require(t, err, "failed to read temp dir")
if len(entries) != 1 {
t.Error("download function did not delete temp files")
}
}

func TestDownloadInitInPartsWithChecksum(t *testing.T) {
// Create parts with random data
serverDir := t.TempDir()
data := testhelpers.RandomSlice(dataSize)
Expand Down Expand Up @@ -126,8 +197,6 @@ func TestSetLatestSnapshotUrl(t *testing.T) {
snapshotKind = "archive"
latestDate = "2024/21"
latestFile = "latest-" + snapshotKind + ".txt"
dirPerm = 0700
filePerm = 0600
)

// Create latest file
Expand Down

0 comments on commit 6319948

Please sign in to comment.