Skip to content

Commit

Permalink
Handle error for deferred close calls (#1326)
Browse files Browse the repository at this point in the history
Handle error for deferred close()
  • Loading branch information
skotambkar committed Jul 29, 2021
1 parent 4b1cbb9 commit c698c9b
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 95 deletions.
14 changes: 14 additions & 0 deletions .changelog/b18b4ffbdf9a45239616f48de1ef2096.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"id": "b18b4ffb-df9a-4523-9616-f48de1ef2096",
"type": "feature",
"collapse": true,
"description": "adds error handling for defered close calls",
"modules": [
".",
"config",
"feature/cloudfront/sign",
"feature/ec2/imds",
"feature/s3/manager",
"internal/ini"
]
}
14 changes: 11 additions & 3 deletions aws/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@ func main() {
}
}

func generateFile(filename string, tmplName string, types ptr.Scalars) error {
func generateFile(filename string, tmplName string, types ptr.Scalars) (err error) {
f, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create %s file, %v", filename, err)
}
defer f.Close()

if err := ptrTmpl.ExecuteTemplate(f, tmplName, types); err != nil {
defer func() {
closeErr := f.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
}
}()

if err = ptrTmpl.ExecuteTemplate(f, tmplName, types); err != nil {
return fmt.Errorf("failed to generate %s file, %v", filename, err)
}

Expand Down
12 changes: 10 additions & 2 deletions config/resolve_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func setupCredentialsEndpoints(t *testing.T) (aws.EndpointResolver, func()) {
}
}

func ssoTestSetup() (func(), error) {
func ssoTestSetup() (fn func(), err error) {
dir, err := ioutil.TempDir("", "sso-test")
if err != nil {
return nil, err
Expand All @@ -139,7 +139,15 @@ func ssoTestSetup() (func(), error) {
os.RemoveAll(dir)
return nil, err
}
defer tokenFile.Close()

defer func() {
closeErr := tokenFile.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
}
}()

_, err = tokenFile.WriteString(fmt.Sprintf(ssoTokenCacheFile, time.Now().
Add(15*time.Minute).
Expand Down
12 changes: 10 additions & 2 deletions feature/cloudfront/sign/privkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ import (

// LoadPEMPrivKeyFile reads a PEM encoded RSA private key from the file name.
// A new RSA private key will be returned if no error.
func LoadPEMPrivKeyFile(name string) (*rsa.PrivateKey, error) {
func LoadPEMPrivKeyFile(name string) (key *rsa.PrivateKey, err error) {
file, err := os.Open(name)
if err != nil {
return nil, err
}
defer file.Close()

defer func() {
closeErr := file.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
}
}()

return LoadPEMPrivKey(file)
}
Expand Down
13 changes: 10 additions & 3 deletions feature/ec2/imds/api_op_GetIAMInfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,22 @@ func buildGetIAMInfoPath(params interface{}) (string, error) {
return getIAMInfoPath, nil
}

func buildGetIAMInfoOutput(resp *smithyhttp.Response) (interface{}, error) {
defer resp.Body.Close()
func buildGetIAMInfoOutput(resp *smithyhttp.Response) (v interface{}, err error) {
defer func() {
closeErr := resp.Body.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("response body close error: %v, original error: %w", closeErr, err)
}
}()

var buff [1024]byte
ringBuffer := smithyio.NewRingBuffer(buff[:])
body := io.TeeReader(resp.Body, ringBuffer)

imdsResult := &GetIAMInfoOutput{}
if err := json.NewDecoder(body).Decode(&imdsResult.IAMInfo); err != nil {
if err = json.NewDecoder(body).Decode(&imdsResult.IAMInfo); err != nil {
return nil, &smithy.DeserializationError{
Err: fmt.Errorf("failed to decode instance identity document, %w", err),
Snapshot: ringBuffer.Bytes(),
Expand Down
13 changes: 10 additions & 3 deletions feature/ec2/imds/api_op_GetInstanceIdentityDocument.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,22 @@ func buildGetInstanceIdentityDocumentPath(params interface{}) (string, error) {
return getInstanceIdentityDocumentPath, nil
}

func buildGetInstanceIdentityDocumentOutput(resp *smithyhttp.Response) (interface{}, error) {
defer resp.Body.Close()
func buildGetInstanceIdentityDocumentOutput(resp *smithyhttp.Response) (v interface{}, err error) {
defer func() {
closeErr := resp.Body.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("response body close error: %v, original error: %w", closeErr, err)
}
}()

var buff [1024]byte
ringBuffer := smithyio.NewRingBuffer(buff[:])
body := io.TeeReader(resp.Body, ringBuffer)

output := &GetInstanceIdentityDocumentOutput{}
if err := json.NewDecoder(body).Decode(&output.InstanceIdentityDocument); err != nil {
if err = json.NewDecoder(body).Decode(&output.InstanceIdentityDocument); err != nil {
return nil, &smithy.DeserializationError{
Err: fmt.Errorf("failed to decode instance identity document, %w", err),
Snapshot: ringBuffer.Bytes(),
Expand Down
13 changes: 10 additions & 3 deletions feature/ec2/imds/api_op_GetToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ func buildGetTokenPath(interface{}) (string, error) {
return getTokenPath, nil
}

func buildGetTokenOutput(resp *smithyhttp.Response) (interface{}, error) {
defer resp.Body.Close()
func buildGetTokenOutput(resp *smithyhttp.Response) (v interface{}, err error) {
defer func() {
closeErr := resp.Body.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("response body close error: %v, original error: %w", closeErr, err)
}
}()

ttlHeader := resp.Header.Get(tokenTTLHeader)
tokenTTL, err := strconv.ParseInt(ttlHeader, 10, 64)
Expand All @@ -77,7 +84,7 @@ func buildGetTokenOutput(resp *smithyhttp.Response) (interface{}, error) {
}

var token strings.Builder
if _, err := io.Copy(&token, resp.Body); err != nil {
if _, err = io.Copy(&token, resp.Body); err != nil {
return nil, fmt.Errorf("unable to read API token, %w", err)
}

Expand Down
24 changes: 21 additions & 3 deletions feature/s3/manager/upload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,13 @@ func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3Upl
}

func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()

_, hasUploads := r.URL.Query()["uploads"]

Expand Down Expand Up @@ -1091,7 +1097,13 @@ type successPartHandler struct {
}

func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()

n, err := io.Copy(ioutil.Discard, r.Body)
if err != nil {
Expand Down Expand Up @@ -1128,7 +1140,13 @@ type failPartHandler struct {
}

func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()

if h.failsRemaining == 0 && h.successHandler != nil {
h.successHandler.ServeHTTP(w, r)
Expand Down
24 changes: 20 additions & 4 deletions internal/awstesting/custom_ca_bundle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package awstesting

import (
"fmt"
"io/ioutil"
"net"
"net/http"
Expand All @@ -9,12 +10,19 @@ import (
"time"
)

func availableLocalAddr(ip string) (string, error) {
func availableLocalAddr(ip string) (v string, err error) {
l, err := net.Listen("tcp", ip+":0")
if err != nil {
return "", err
}
defer l.Close()
defer func() {
closeErr := l.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("ip listener close error: %v, original error: %w", closeErr, err)
}
}()

return l.Addr().String(), nil
}
Expand Down Expand Up @@ -82,7 +90,7 @@ func CleanupTLSBundleFiles(files ...string) error {
return nil
}

func createTmpFile(b []byte) (string, error) {
func createTmpFile(b []byte) (v string, err error) {
bundleFile, err := ioutil.TempFile(os.TempDir(), "aws-sdk-go-session-test")
if err != nil {
return "", err
Expand All @@ -93,7 +101,15 @@ func createTmpFile(b []byte) (string, error) {
return "", err
}

defer bundleFile.Close()
defer func() {
closeErr := bundleFile.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("file close error: %v, original error: %w", closeErr, err)
}
}()

return bundleFile.Name(), nil
}

Expand Down
19 changes: 14 additions & 5 deletions internal/ini/ini.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package ini

import (
"fmt"
"io"
"os"
)

// OpenFile takes a path to a given file, and will open and parse
// that file.
func OpenFile(path string) (Sections, error) {
f, err := os.Open(path)
if err != nil {
return Sections{}, &UnableToReadFile{Err: err}
func OpenFile(path string) (sections Sections, err error) {
f, oerr := os.Open(path)
if oerr != nil {
return Sections{}, &UnableToReadFile{Err: oerr}
}
defer f.Close()

defer func() {
closeErr := f.Close()
if err == nil {
err = closeErr
} else if closeErr != nil {
err = fmt.Errorf("close error: %v, original error: %w", closeErr, err)
}
}()

return Parse(f, path)
}
Expand Down

0 comments on commit c698c9b

Please sign in to comment.