Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream policy updates to GRPC clients #556

Merged
merged 2 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/libs/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"bytes"
"crypto/rand"
"flag"
"fmt"
"math/big"
"net"
"os"
"os/exec"
"os/signal"
"reflect"
"sort"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -584,3 +586,56 @@ func ConvertStrToUnixTime(strTime string) int64 {
t, _ := time.Parse(TimeFormSimple, strTime)
return t.UTC().Unix()
}

// IsLabelMapSubset check whether m2 is a subset of m1
func IsLabelMapSubset(m1, m2 types.LabelMap) bool {
match := true
for k, v := range m2 {
if m1[k] != v {
match = false
break
}
}
return match
}

// LabelMapFromLabelArray converts []string to map[string]string
func LabelMapFromLabelArray(labels []string) types.LabelMap {
labelMap := types.LabelMap{}
for _, label := range labels {
kvPair := strings.FieldsFunc(label, labelKVSplitter)
if len(kvPair) != 2 {
continue
}
labelMap[kvPair[0]] = kvPair[1]
}
return labelMap
}

// LabelMapToLabelArray converts map[string]string to sorted []string
func LabelMapToLabelArray(labelMap types.LabelMap) (labels []string) {
for k, v := range labelMap {
labels = append(labels, fmt.Sprintf("%s=%s", k, v))
}

sort.Strings(labels)
return
}

// LabelMapToString converts map[string]string to string
func LabelMapToString(lm types.LabelMap) string {
return strings.Join(LabelMapToLabelArray(lm), ",")
}

// LabelMapFromString converts string to map[string]string
func LabelMapFromString(labels string) types.LabelMap {
return LabelMapFromLabelArray(strings.FieldsFunc(labels, labelArrSplitter))
}

func labelKVSplitter(r rune) bool {
return r == ':' || r == '='
}

func labelArrSplitter(r rune) bool {
return r == ',' || r == ';'
}
180 changes: 180 additions & 0 deletions src/libs/consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package libs

import (
"sync"

"github.com/accuknox/auto-policy-discovery/src/types"
"google.golang.org/grpc"

dpb "github.com/accuknox/auto-policy-discovery/src/protobuf/v1/discovery"
)

// PolicyConsumer stores filter information provided in v1.Discovery.GetFlow RPC request
type PolicyConsumer struct {
policyType []string
Kind []string
Filter types.PolicyFilter
Events chan *types.PolicyYaml
}

func (pc *PolicyConsumer) IsTypeNetwork() bool {
return ContainsElement(pc.policyType, types.PolicyTypeNetwork)
}

func (pc *PolicyConsumer) IsTypeSystem() bool {
return ContainsElement(pc.policyType, types.PolicyTypeSystem)
}

func NewPolicyConsumer(req *dpb.GetPolicyRequest) *PolicyConsumer {
kind := req.GetKind()
return &PolicyConsumer{
Kind: kind,
policyType: getPolicyTypeFromKind(kind),
Filter: convertGrpcRequestToPolicyFilter(req),
Events: make(chan *types.PolicyYaml, 64),
}
}

func getPolicyTypeFromKind(kind []string) []string {
isTypeNetwork := false
isTypeSystem := false

for _, k := range kind {
switch k {
case types.KindCiliumNetworkPolicy,
types.KindK8sNetworkPolicy,
types.KindCiliumClusterwideNetworkPolicy:
isTypeNetwork = true
case types.KindKubeArmorPolicy,
types.KindKubeArmorHostPolicy:
isTypeSystem = true
}
}

var res []string
if isTypeNetwork {
res = append(res, types.PolicyTypeNetwork)
}
if isTypeSystem {
res = append(res, types.PolicyTypeSystem)
}

return res
}

// PolicyStore is used for support v1.Discovery.GetFlow RPC requests
type PolicyStore struct {
Consumers map[*PolicyConsumer]struct{}
Mutex sync.Mutex
}

// AddConsumer adds a new PolicyConsumer to the store
func (pc *PolicyStore) AddConsumer(c *PolicyConsumer) {
pc.Mutex.Lock()
defer pc.Mutex.Unlock()

pc.Consumers[c] = struct{}{}
return
}

// RemoveConsumer removes a PolicyConsumer from the store
func (pc *PolicyStore) RemoveConsumer(c *PolicyConsumer) {
pc.Mutex.Lock()
defer pc.Mutex.Unlock()

delete(pc.Consumers, c)
}

// Publish converts the given KnoxPolicy to PolicyYaml and pushes to consumer's channels
func (pc *PolicyStore) Publish(policy *types.PolicyYaml) {
pc.Mutex.Lock()
defer pc.Mutex.Unlock()

for consumer := range pc.Consumers {
if matchPolicyYaml(policy, consumer) {
consumer.Events <- policy
}
}
}

func FilterPolicyYamls(policyYamls []types.PolicyYaml, consumer *PolicyConsumer) []types.PolicyYaml {
result := []types.PolicyYaml{}

for i := range policyYamls {
if matchPolicyYaml(&policyYamls[i], consumer) {
result = append(result, policyYamls[i])
}
}

return result
}

func matchPolicyYaml(p *types.PolicyYaml, c *PolicyConsumer) bool {
filter := c.Filter

if filter.Cluster != "" && filter.Cluster != p.Cluster {
return false
}

if filter.Namespace != "" && filter.Cluster != p.Namespace {
return false
}

if len(filter.Labels) != 0 && !IsLabelMapSubset(p.Labels, filter.Labels) {
return false
}

if !ContainsElement(c.Kind, p.Kind) {
return false
}

return true
}

func convertGrpcRequestToPolicyFilter(req *dpb.GetPolicyRequest) types.PolicyFilter {
return types.PolicyFilter{
Cluster: req.GetCluster(),
Namespace: req.GetNamespace(),
Labels: LabelMapFromLabelArray(req.GetLabel()),
}
}

func convertPolicyYamlToGrpcResponse(p *types.PolicyYaml) *dpb.GetPolicyResponse {
return &dpb.GetPolicyResponse{
Kind: p.Kind,
Name: p.Name,
Cluster: p.Cluster,
Namespace: p.Namespace,
Label: LabelMapToLabelArray(p.Labels),
Yaml: p.Yaml,
}
}

func SendPolicyYamlInGrpcStream(stream grpc.ServerStream, policy *types.PolicyYaml) error {
resp := convertPolicyYamlToGrpcResponse(policy)
err := stream.SendMsg(resp)
if err != nil {
log.Error().Msgf("sending network policy yaml in grpc stream failed err=%v", err.Error())
return err
}
return nil
}

func RelayPolicyEventToGrpcStream(stream grpc.ServerStream, consumer *PolicyConsumer) error {
for {
select {
case <-stream.Context().Done():
// client disconnected
return nil
case policy, ok := <-consumer.Events:
if !ok {
// channel closed and all items are consumed
return nil
}
err := SendPolicyYamlInGrpcStream(stream, policy)
if err != nil {
return err
}
}
}
}
29 changes: 26 additions & 3 deletions src/libs/dbHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ func UpdateOutdatedNetworkPolicy(cfg types.ConfigDB, outdatedPolicy string, late
}
}

func UpdateNetworkPolicies(cfg types.ConfigDB, policies []types.KnoxNetworkPolicy) {
for _, policy := range policies {
UpdateNetworkPolicy(cfg, policy)
}
}

func UpdateNetworkPolicy(cfg types.ConfigDB, policy types.KnoxNetworkPolicy) {
if cfg.DBDriver == "mysql" {
if err := UpdateNetworkPolicyToMySQL(cfg, policy); err != nil {
Expand Down Expand Up @@ -349,13 +355,30 @@ func GetPodNames(cfg types.ConfigDB, filter types.ObsPodDetail) ([]string, error
// =============== //
// == Policy DB == //
// =============== //
func GetPolicyYamls(cfg types.ConfigDB, policyType string) ([]types.PolicyYaml, error) {
var err error
var results []types.PolicyYaml

if cfg.DBDriver == "mysql" {
results, err = GetPolicyYamlsMySQL(cfg, policyType)
if err != nil {
return nil, err
}
} else if cfg.DBDriver == "sqlite3" {
results, err = GetPolicyYamlsSQLite(cfg, policyType)
if err != nil {
return nil, err
}
}
return results, nil
}

func UpdateOrInsertPolicies(cfg types.ConfigDB, policies []types.Policy) error {
func UpdateOrInsertPolicyYamls(cfg types.ConfigDB, policies []types.PolicyYaml) error {
var err = errors.New("unknown db driver")
if cfg.DBDriver == "mysql" {
err = UpdateOrInsertPoliciesMySQL(cfg, policies)
err = UpdateOrInsertPolicyYamlsMySQL(cfg, policies)
} else if cfg.DBDriver == "sqlite3" {
err = UpdateOrInsertPoliciesSQLite(cfg, policies)
err = UpdateOrInsertPolicyYamlsSQLite(cfg, policies)
}
return err
}