Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 70 additions & 23 deletions internal/services/osinstaller/linux-mac-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,77 @@ import (
"compress/gzip"
"fmt"
"io"
"io/fs"
"log"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"

"github.com/checkmarx/ast-cli/internal/logger"
)

const dirDefault int = 0755
// dirDefault is the permission bits applied to directories created during extraction.
const dirDefault os.FileMode = 0755

// fileNonExec and fileExec are permission bits applied to extracted regular files.
// Only the executable bit is taken from the archive; other bits are normalized.
const fileNonExec os.FileMode = 0644
const fileExec os.FileMode = 0755

// maxExtractBytes caps how many bytes a single tar entry may expand to,
// preventing decompression-bomb (tar-bomb) attacks.
const maxExtractBytes int64 = 500 * 1024 * 1024 // 500 MB

// UnzipOrExtractFiles Extracts all the files from the tar.gz file
func UnzipOrExtractFiles(installationConfiguration *InstallationConfiguration) error {
func UnzipOrExtractFiles(installationConfiguration *InstallationConfiguration) (err error) {
logger.PrintIfVerbose("Extracting files in: " + installationConfiguration.WorkingDir())
filePath := filepath.Join(installationConfiguration.WorkingDir(), installationConfiguration.FileName)

gzipStream, err := os.Open(filePath)
if err != nil {
fmt.Println("error when open file ", filePath, err)
return err
}
defer func() {
if cerr := gzipStream.Close(); cerr != nil && err == nil {
err = cerr
}
}()

uncompressedStream, err := gzip.NewReader(gzipStream)
if err != nil {
log.Println("ExtractTarGz: NewReader failed ", err)
return err
}
defer func() {
if cerr := uncompressedStream.Close(); cerr != nil && err == nil {
err = cerr
}
}()

tarReader := tar.NewReader(uncompressedStream)
return extractFiles(installationConfiguration, tar.NewReader(uncompressedStream))
}

err = extractFiles(installationConfiguration, tarReader)
if err != nil {
return err
// safeJoin validates that name is a relative path and that the resolved
// destination stays inside workingDir, preventing path traversal attacks.
func safeJoin(workingDir, name string) (string, error) {
if name == "" || name == "." {
return "", fmt.Errorf("illegal file path (empty or dot): %s", name)
}
return nil
if filepath.IsAbs(name) {
return "", fmt.Errorf("illegal file path (absolute): %s", name)
}
dst := filepath.Join(workingDir, name)
cleanBase := filepath.Clean(workingDir) + string(os.PathSeparator)
if !strings.HasPrefix(dst, cleanBase) {
return "", fmt.Errorf("illegal file path (traversal): %s", name)
}
return dst, nil
}

func extractFiles(installationConfiguration *InstallationConfiguration, tarReader *tar.Reader) error {
workingDir := installationConfiguration.WorkingDir()
for {
header, err := tarReader.Next()

Expand All @@ -52,36 +86,49 @@ func extractFiles(installationConfiguration *InstallationConfiguration, tarReade
}

if err != nil {
log.Fatalf("ExtractTarGz: Next() failed: %s", err.Error())
return fmt.Errorf("ExtractTarGz: Next() failed: %w", err)
}

switch header.Typeflag {
case tar.TypeDir:
if err := os.Mkdir(header.Name, os.FileMode(dirDefault)); err != nil {
log.Fatalf("ExtractTarGz: Mkdir() failed: %s", err.Error())
dst, err := safeJoin(workingDir, header.Name)
if err != nil {
return err
}
if err := os.MkdirAll(dst, dirDefault); err != nil {
return fmt.Errorf("ExtractTarGz: Mkdir() failed: %w", err)
}

case tar.TypeReg:
extractedFilePath := filepath.Join(installationConfiguration.WorkingDir(), header.Name)
outFile, err := os.Create(extractedFilePath)
extractedFilePath, err := safeJoin(workingDir, header.Name)
if err != nil {
log.Fatalf("ExtractTarGz: Create() failed: %s", err.Error())
return err
}
if _, err = io.Copy(outFile, tarReader); err != nil {
log.Fatalf("ExtractTarGz: Copy() failed: %s", err.Error())
if err := os.MkdirAll(filepath.Dir(extractedFilePath), dirDefault); err != nil {
return fmt.Errorf("ExtractTarGz: MkdirAll() failed: %w", err)
}
err = outFile.Close()
outFile, err := os.Create(extractedFilePath)
if err != nil {
return fmt.Errorf("ExtractTarGz: Create() failed: %w", err)
}
if _, err = io.Copy(outFile, io.LimitReader(tarReader, maxExtractBytes)); err != nil {
_ = outFile.Close()
return fmt.Errorf("ExtractTarGz: Copy() failed: %w", err)
}
if err := outFile.Close(); err != nil {
return err
}
err = os.Chmod(extractedFilePath, fs.ModePerm)
if err != nil {
// Preserve only the executable bit from the archive entry; never grant world-write.
mode := fileNonExec
if header.FileInfo().Mode()&0111 != 0 {
mode = fileExec
}
if err := os.Chmod(extractedFilePath, mode); err != nil {
return err
}

default:
log.Fatalf(
"ExtractTarGz: uknown type: %v in %s",
header.Typeflag,
header.Name)
log.Printf("ExtractTarGz: unknown type: %v in %s", header.Typeflag, header.Name)
}
}
return nil
Expand Down
Loading
Loading