Skip to content

Commit

Permalink
refactor: implement native bao support
Browse files Browse the repository at this point in the history
  • Loading branch information
pcfreak30 committed Mar 30, 2024
1 parent ba67d19 commit 8d98f13
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 1,509 deletions.
30 changes: 16 additions & 14 deletions api/s5/s5.go
Expand Up @@ -473,7 +473,7 @@ func (rsnc readSeekNopCloser) Close() error {
func (s *S5API) smallFileUpload(jc jape.Context) {
user := middleware.GetUserFromContext(jc.Request.Context())

file, err := s.prepareFileUpload(jc)
file, size, err := s.prepareFileUpload(jc)
if err != nil {
s.sendErrorResponse(jc, err)
return
Expand All @@ -485,7 +485,7 @@ func (s *S5API) smallFileUpload(jc jape.Context) {
}
}(file)

newUpload, err2 := s.storage.UploadObject(jc.Request.Context(), s5.GetStorageProtocol(s.protocol), file, nil, nil)
newUpload, err2 := s.storage.UploadObject(jc.Request.Context(), s5.GetStorageProtocol(s.protocol), file, size, nil, nil)

if err2 != nil {
s.sendErrorResponse(jc, NewS5Error(ErrKeyFileUploadFailed, err2))
Expand Down Expand Up @@ -525,33 +525,35 @@ func (s *S5API) smallFileUpload(jc jape.Context) {
})
}

func (s *S5API) prepareFileUpload(jc jape.Context) (file io.ReadSeekCloser, s5Err *S5Error) {
func (s *S5API) prepareFileUpload(jc jape.Context) (file io.ReadSeekCloser, size uint64, s5Err *S5Error) {
r := jc.Request
contentType := r.Header.Get("Content-Type")

// Handle multipart form data uploads
if strings.HasPrefix(contentType, "multipart/form-data") {
if err := r.ParseMultipartForm(int64(s.config.Config().Core.PostUploadLimit)); err != nil {
return nil, NewS5Error(ErrKeyFileUploadFailed, err)
return nil, size, NewS5Error(ErrKeyFileUploadFailed, err)
}

multipartFile, _, err := r.FormFile("file")
if err != nil {
return nil, NewS5Error(ErrKeyFileUploadFailed, err)
return nil, size, NewS5Error(ErrKeyFileUploadFailed, err)
}

return multipartFile, nil
return multipartFile, size, nil
}

// Handle raw body uploads
data, err := io.ReadAll(r.Body)
if err != nil {
return nil, NewS5Error(ErrKeyFileUploadFailed, err)
return nil, size, NewS5Error(ErrKeyFileUploadFailed, err)
}

buffer := readSeekNopCloser{bytes.NewReader(data)}

return buffer, nil
size = uint64(len(data))

return buffer, size, nil
}

func (s *S5API) accountRegisterChallenge(jc jape.Context) {
Expand Down Expand Up @@ -1275,7 +1277,7 @@ func (s *S5API) pinEntity(ctx context.Context, userId uint, userIp string, cid *

data = append(data, dataCont...)

proof, err := s.storage.HashObject(ctx, bytes.NewReader(data))
proof, err := s.storage.HashObject(ctx, bytes.NewReader(data), uint64(len(data)))
if err != nil {
return nil, false
}
Expand Down Expand Up @@ -1444,7 +1446,7 @@ func (s *S5API) processMultipartFiles(r *http.Request) (map[string]*metadata.Upl
}
}(file)

upload, err := s.storage.UploadObject(r.Context(), s5.GetStorageProtocol(s.protocol), file, nil, nil)
upload, err := s.storage.UploadObject(r.Context(), s5.GetStorageProtocol(s.protocol), file, uint64(fileHeader.Size), nil, nil)
if err != nil {
return nil, NewS5Error(ErrKeyStorageOperationFailed, err)
}
Expand Down Expand Up @@ -1515,7 +1517,7 @@ func (s *S5API) uploadAppMetadata(appData *s5libmetadata.WebAppMetadata, r *http

file := bytes.NewReader(appDataRaw)

upload, err := s.storage.UploadObject(r.Context(), s5.GetStorageProtocol(s.protocol), file, nil, nil)
upload, err := s.storage.UploadObject(r.Context(), s5.GetStorageProtocol(s.protocol), file, uint64(len(appDataRaw)), nil, nil)
if err != nil {
return "", NewS5Error(ErrKeyStorageOperationFailed, err)
}
Expand Down Expand Up @@ -2163,7 +2165,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return err // Error logged in fetchAndProcess
}

hash, err := s.storage.HashObject(ctx, bytes.NewReader(fileData))
hash, err := s.storage.HashObject(ctx, bytes.NewReader(fileData), uint64(len(fileData)))
if err != nil {
s.logger.Error("error hashing object", zap.Error(err))
return err
Expand All @@ -2178,7 +2180,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return err
}

upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), bytes.NewReader(fileData), nil, hash)
upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), bytes.NewReader(fileData), parsedCid.Size, nil, hash)
if err != nil {
return err
}
Expand Down Expand Up @@ -2255,7 +2257,7 @@ func (s *S5API) pinImportCronJob(cid string, url string, proofUrl string, userId
return err
}

upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), nil, &renter.MultiPartUploadParams{
upload, err := s.storage.UploadObject(ctx, s5.GetStorageProtocol(s.protocol), nil, 0, &renter.MultiPartUploadParams{
ReaderFactory: func(start uint, end uint) (io.ReadCloser, error) {
rangeHeader := "bytes=%d-"
if end != 0 {
Expand Down
150 changes: 54 additions & 96 deletions bao/bao.go
@@ -1,36 +1,27 @@
package bao

import (
"bufio"
"bytes"
_ "embed"
"errors"
"io"
"math"
"os"
"os/exec"
"time"

"github.com/samber/lo"

"go.uber.org/zap"

"github.com/docker/go-units"
"github.com/hashicorp/go-plugin"
"lukechampine.com/blake3/bao"
)

//go:generate buf generate
//go:generate bash -c "cd rust && cargo build -r"
//go:embed rust/target/release/rust
var pluginBin []byte

var bao Bao
var client *plugin.Client

var _ io.ReadCloser = (*Verifier)(nil)
var _ io.WriterAt = (*proofWriter)(nil)

var ErrVerifyFailed = errors.New("verification failed")

const groupLog = 8
const groupChunks = 1 << groupLog

type Verifier struct {
r io.ReadCloser
proof Result
Expand All @@ -41,6 +32,12 @@ type Verifier struct {
verifyTime time.Duration
}

type Result struct {
Hash []byte
Proof []byte
Length uint
}

func (v *Verifier) Read(p []byte) (int, error) {
// Initial attempt to read from the buffer
n, err := v.buffer.Read(p)
Expand All @@ -52,7 +49,7 @@ func (v *Verifier) Read(p []byte) (int, error) {
return n, err
}

buf := make([]byte, VERIFY_CHUNK_SIZE)
buf := make([]byte, groupChunks)
// Continue reading from the source and verifying until we have enough data or hit an error
for v.buffer.Len() < len(p)-n {
readStart := time.Now()
Expand All @@ -68,7 +65,7 @@ func (v *Verifier) Read(p []byte) (int, error) {
timeStart := time.Now()

if bytesRead > 0 {
if status, err := bao.Verify(buf[:bytesRead], v.read, v.proof.Proof, v.proof.Hash); err != nil || !status {
if status := bao.VerifyChunk(buf[:bytesRead], v.proof.Proof, groupChunks, v.read, [32]byte(v.proof.Hash)); !status {
return n, errors.Join(ErrVerifyFailed, err)
}
v.read += uint64(bytesRead)
Expand All @@ -92,7 +89,7 @@ func (v *Verifier) Read(p []byte) (int, error) {
v.logger.Debug("Read time", zap.Duration("average", averageReadTime))
}

averageVerifyTime := v.verifyTime / time.Duration(v.read/VERIFY_CHUNK_SIZE)
averageVerifyTime := v.verifyTime / time.Duration(v.read/groupChunks)
v.logger.Debug("Verification time", zap.Duration("average", averageVerifyTime))

// Attempt to read the remainder of the data from the buffer
Expand All @@ -103,101 +100,62 @@ func (v *Verifier) Read(p []byte) (int, error) {
func (v *Verifier) Close() error {
return v.r.Close()
}
func Hash(r io.Reader, size uint64) (*Result, error) {
reader := newSizeReader(r)
writer := newProofWriter(int(size))

func init() {
temp, err := os.CreateTemp(os.TempDir(), "bao")
hash, err := bao.Encode(writer, reader, int64(size), groupLog, true)
if err != nil {
panic(err)
return nil, err
}

err = temp.Chmod(1755)
if err != nil {
panic(err)
}
return &Result{
Hash: hash[:],
Proof: writer.buf,
Length: uint(size),
}, nil
}

_, err = io.Copy(temp, bytes.NewReader(pluginBin))
if err != nil {
panic(err)
func NewVerifier(r io.ReadCloser, proof Result, logger *zap.Logger) *Verifier {
return &Verifier{
r: r,
proof: proof,
buffer: new(bytes.Buffer),
logger: logger,
}
}

err = temp.Close()
if err != nil {
panic(err)
}
type proofWriter struct {
buf []byte
}

clientInst := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: plugin.HandshakeConfig{
ProtocolVersion: 1,
},
Plugins: plugin.PluginSet{
"bao": &BaoPlugin{},
},
Cmd: exec.Command(temp.Name()),
AllowedProtocols: []plugin.Protocol{plugin.ProtocolGRPC},
})

rpcClient, err := clientInst.Client()
if err != nil {
panic(err)
func (p proofWriter) WriteAt(b []byte, off int64) (n int, err error) {
if copy(p.buf[off:], b) != len(b) {
panic("bad buffer size")
}
return len(b), nil
}

pluginInst, err := rpcClient.Dispense("bao")
if err != nil {
panic(err)
func newProofWriter(size int) *proofWriter {
return &proofWriter{
buf: make([]byte, bao.EncodedSize(size, groupLog, true)),
}

bao = pluginInst.(Bao)
}

func Shutdown() {
client.Kill()
type sizeReader struct {
reader io.Reader
read int64
}

func Hash(r io.Reader) (*Result, error) {
hasherId := bao.NewHasher()
initialSize := 4 * units.KiB
maxSize := 3.5 * units.MiB
bufSize := initialSize

reader := bufio.NewReaderSize(r, bufSize)
var totalReadSize int

buf := make([]byte, bufSize)
for {

n, err := reader.Read(buf)
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
totalReadSize += n

if !bao.Hash(hasherId, buf[:n]) {
return nil, errors.New("hashing failed")
}

// Adaptively adjust buffer size based on read patterns
if n == bufSize && float64(bufSize) < maxSize {
// If buffer was fully used, consider increasing buffer size
bufSize = int(math.Min(float64(bufSize*2), float64(maxSize))) // Double the buffer size, up to a maximum
buf = make([]byte, bufSize) // Apply new buffer size
reader = bufio.NewReaderSize(r, bufSize) // Apply new buffer size
}
}

result := bao.Finish(hasherId)
result.Length = uint(totalReadSize)

return &result, nil
func (s sizeReader) Read(p []byte) (int, error) {
n, err := s.reader.Read(p)
s.read += int64(n)
return n, err
}

func NewVerifier(r io.ReadCloser, proof Result, logger *zap.Logger) *Verifier {
return &Verifier{
r: r,
proof: proof,
buffer: new(bytes.Buffer),
logger: logger,
func newSizeReader(r io.Reader) *sizeReader {
return &sizeReader{
reader: r,
read: 0,
}
}
12 changes: 0 additions & 12 deletions bao/buf.gen.yaml

This file was deleted.

1 change: 0 additions & 1 deletion bao/buf.yaml

This file was deleted.

0 comments on commit 8d98f13

Please sign in to comment.