Skip to content

Commit

Permalink
Fix round tripper
Browse files Browse the repository at this point in the history
  • Loading branch information
atanasdinov committed Dec 12, 2022
1 parent b96503e commit 9ee719c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 28 deletions.
7 changes: 1 addition & 6 deletions cmd/content-rw-elasticsearch/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,8 @@ func main() {
if err != nil {
log.WithError(err).Fatal("Failed to load AWS config")
}
credentials, err := cfg.Credentials.Retrieve(context.TODO())
if err != nil {
log.WithError(err).Fatal("Failed to obtain AWS credentials values")
}
log.Infof("Obtaining AWS credentials by using [%s] as provider", credentials.Source)

accessConfig.Credentials = credentials
accessConfig.AWSConfig = cfg
}

httpClient := pkghttp.NewHTTPClient()
Expand Down
48 changes: 26 additions & 22 deletions pkg/es/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package es

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
Expand All @@ -23,15 +24,15 @@ type Client interface {
}

type AccessConfig struct {
Credentials aws.Credentials
Endpoint string
Region string
AWSConfig aws.Config
Endpoint string
Region string
}

type AWSSigningTransport struct {
HTTPClient *http.Client
Credentials aws.Credentials
Region string
HTTPClient *http.Client
AWSConfig aws.Config
Region string
}

func (a AWSSigningTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
Expand All @@ -49,35 +50,38 @@ func (a AWSSigningTransport) RoundTrip(req *http.Request) (resp *http.Response,
}
}()

hasher := sha256.New()
payload := []byte("")
credentials, err := a.AWSConfig.Credentials.Retrieve(req.Context())
if err != nil {
return nil, err
}

if req.Body != nil {
payload, err = io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("reading request body: %w", err)
}
internalReq := req.Clone(req.Context())

defer req.Body.Close()
bodyReader, err := req.GetBody()
if err != nil {
return nil, fmt.Errorf("reading request body: %w", err)
}

hash := hex.EncodeToString(hasher.Sum(payload))
hash := sha256.New()

if _, err := io.Copy(hash, bodyReader); err != nil {
return nil, fmt.Errorf("copying request body: %w", err)
}

err = signer.
if err := signer.
NewSigner().
SignHTTP(req.Context(), a.Credentials, req, hash, "es", a.Region, time.Now().UTC())
if err != nil {
SignHTTP(context.Background(), credentials, internalReq, hex.EncodeToString(hash.Sum(nil)), "es", a.Region, time.Now()); err != nil {
return nil, fmt.Errorf("signing request: %w", err)
}

return a.HTTPClient.Do(req)
return a.HTTPClient.Do(internalReq)
}

func NewClient(config AccessConfig, client *http.Client, log *logger.UPPLogger) (Client, error) {
signingTransport := AWSSigningTransport{
Credentials: config.Credentials,
HTTPClient: client,
Region: config.Region,
AWSConfig: config.AWSConfig,
HTTPClient: client,
Region: config.Region,
}
signingClient := &http.Client{Transport: http.RoundTripper(signingTransport)}

Expand Down

0 comments on commit 9ee719c

Please sign in to comment.