Skip to content

Commit

Permalink
Fix e2e tests (#25851)
Browse files Browse the repository at this point in the history
* Update downloadPolicies to reflect latest policy download endpoint

* remove ioutil

* Consider other custom policies files
  • Loading branch information
mftoure committed May 24, 2024
1 parent d89b0de commit 451bf20
Showing 1 changed file with 70 additions and 3 deletions.
73 changes: 70 additions & 3 deletions cmd/security-agent/subcommands/runtime/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package runtime

import (
"archive/zip"
"bytes"
"context"
"encoding/json"
"errors"
Expand All @@ -17,6 +19,7 @@ import (
"os"
"path"
"runtime"
"strings"
"time"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -186,6 +189,7 @@ type downloadPolicyCliParams struct {

check bool
outputPath string
source string
}

func downloadPolicyCommands(globalParams *command.GlobalParams) []*cobra.Command {
Expand All @@ -210,6 +214,7 @@ func downloadPolicyCommands(globalParams *command.GlobalParams) []*cobra.Command

downloadPolicyCmd.Flags().BoolVar(&downloadPolicyArgs.check, "check", false, "Check policies after downloading")
downloadPolicyCmd.Flags().StringVar(&downloadPolicyArgs.outputPath, "output-path", "", "Output path for downloaded policies")
downloadPolicyCmd.Flags().StringVar(&downloadPolicyArgs.source, "source", "all", `Specify wether should download the custom, default or all policies. allowed: "all", "default", "custom"`)

return []*cobra.Command{downloadPolicyCmd}
}
Expand Down Expand Up @@ -765,26 +770,88 @@ func downloadPolicy(log log.Component, config config.Component, _ secrets.Compon
if err != nil {
return err
}

// Unzip the downloaded file containing both default and custom policies
resBytes := []byte(res)
reader, err := zip.NewReader(bytes.NewReader(resBytes), int64(len(resBytes)))
if err != nil {
return err
}

var defaultPolicy []byte
var customPolicies []string

for _, file := range reader.File {
if strings.HasSuffix(file.Name, ".policy") {
pf, err := file.Open()
if err != nil {
return err
}
policyData, err := io.ReadAll(pf)
pf.Close()
if err != nil {
return err
}

if file.Name == "default.policy" {
defaultPolicy = policyData
} else {
customPolicies = append(customPolicies, string(policyData))
}
}
}

tempDir, err := os.MkdirTemp("", "policy_check")
if err != nil {
return err
}
defer os.RemoveAll(tempDir)

tempOutputPath := path.Join(tempDir, "check.policy")
if err := os.WriteFile(tempOutputPath, resBytes, 0644); err != nil {
if err := os.WriteFile(path.Join(tempDir, "default.policy"), defaultPolicy, 0644); err != nil {
return err
}
for i, customPolicy := range customPolicies {
if err := os.WriteFile(path.Join(tempDir, fmt.Sprintf("custom%d.policy", i+1)), []byte(customPolicy), 0644); err != nil {
return err
}
}

if downloadPolicyArgs.check {
if err := checkPolicies(log, config, &checkPoliciesCliParams{dir: tempDir}); err != nil {
return err
}
}

_, err = outputWriter.Write(resBytes)
// Extract and merge rules from custom policies
var customRules string
for _, customPolicy := range customPolicies {
customPolicyLines := strings.Split(customPolicy, "\n")
rulesIndex := -1
for i, line := range customPolicyLines {
if strings.TrimSpace(line) == "rules:" {
rulesIndex = i
break
}
}
if rulesIndex != -1 && rulesIndex+1 < len(customPolicyLines) {
customRules += "\n" + strings.Join(customPolicyLines[rulesIndex+1:], "\n")
}
}

// Output depending on user's specification
var outputContent string
switch downloadPolicyArgs.source {
case "all":
outputContent = string(defaultPolicy) + customRules
case "default":
outputContent = string(defaultPolicy)
case "custom":
outputContent = string(customRules)
default:
return errors.New("invalid source specified")
}

_, err = outputWriter.Write([]byte(outputContent))
if err != nil {
return err
}
Expand Down

0 comments on commit 451bf20

Please sign in to comment.