diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml index 3d33034..a6ffa4c 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml @@ -72,8 +72,8 @@ spec: valueFrom: fieldRef: fieldPath: spec.nodeName - - name: HOST_TOKEN_PATH - value: {{ trimSuffix "/" .Values.node.kubeletPath }}/plugins/s3.csi.aws.com/token + - name: HOST_PLUGIN_DIR + value: {{ trimSuffix "/" .Values.node.kubeletPath }}/plugins/s3.csi.aws.com/ {{- with .Values.awsAccessSecret }} - name: AWS_ACCESS_KEY_ID valueFrom: diff --git a/deploy/kubernetes/base/node-daemonset.yaml b/deploy/kubernetes/base/node-daemonset.yaml index a1101ce..f7f3a70 100644 --- a/deploy/kubernetes/base/node-daemonset.yaml +++ b/deploy/kubernetes/base/node-daemonset.yaml @@ -72,8 +72,8 @@ spec: valueFrom: fieldRef: fieldPath: spec.nodeName - - name: HOST_TOKEN_PATH - value: /var/lib/kubelet/plugins/s3.csi.aws.com/token + - name: HOST_PLUGIN_DIR + value: /var/lib/kubelet/plugins/s3.csi.aws.com/ volumeMounts: - name: kubelet-dir mountPath: /var/lib/kubelet diff --git a/go.mod b/go.mod index e1954d1..1673969 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,12 @@ go 1.21 require ( github.com/container-storage-interface/spec v1.9.0 - github.com/coreos/go-systemd/v22 v22.5.0 github.com/godbus/dbus/v5 v5.1.0 github.com/golang/mock v1.6.0 github.com/kubernetes-csi/csi-test v2.2.0+incompatible github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.27.6 - github.com/stretchr/testify v1.8.2 + github.com/onsi/gomega v1.29.0 + github.com/stretchr/testify v1.8.4 google.golang.org/grpc v1.59.0 k8s.io/klog/v2 v2.110.1 k8s.io/mount-utils v0.28.4 @@ -27,17 +26,17 @@ require ( github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.4.0 github.com/moby/sys/mountinfo v0.7.1 // indirect github.com/nxadm/tail v1.4.8 // indirect - golang.org/x/net v0.18.0 - golang.org/x/sys v0.14.0 + golang.org/x/net v0.19.0 + golang.org/x/sys v0.15.0 golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.31.0 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/apimachinery v0.28.4 + k8s.io/apimachinery v0.29.1 k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect ) diff --git a/go.sum b/go.sum index 0e26b08..277bf81 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/container-storage-interface/spec v1.9.0 h1:zKtX4STsq31Knz3gciCYCi1SXtO2HJDecIjDVboYavY= github.com/container-storage-interface/spec v1.9.0/go.mod h1:ZfDu+3ZRyeVqxZM0Ds19MVLkN2d1XJ5MAfi1L3VjlT0= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -14,7 +12,6 @@ github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4 github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -33,8 +30,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= @@ -55,24 +52,20 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= +github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= +github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= +github.com/onsi/gomega v1.29.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -86,8 +79,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -105,8 +98,8 @@ golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -116,8 +109,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -147,11 +140,10 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/apimachinery v0.28.4 h1:zOSJe1mc+GxuMnFzD4Z/U1wst50X28ZNsn5bhgIIao8= -k8s.io/apimachinery v0.28.4/go.mod h1:wI37ncBvfAoswfq626yPTe6Bz1c22L7uaJ8dho83mgg= +k8s.io/apimachinery v0.29.1 h1:KY4/E6km/wLBguvCZv8cKTeOwwOBqFNjwJIdMkMbbRc= +k8s.io/apimachinery v0.29.1/go.mod h1:6HVkd1FwxIagpYrHSwJlQqZI3G9LfYWRPAkUvLnXTKU= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= k8s.io/mount-utils v0.28.4 h1:5GOZLm2dXi2fr+MKY8hS6kdV5reXrZBiK7848O5MVD0= diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 6c103d7..41796be 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -25,7 +25,6 @@ import ( "os" "time" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" "k8s.io/klog/v2" @@ -51,9 +50,9 @@ var ( type Driver struct { Endpoint string Srv *grpc.Server + NodeID string - NodeID string - Mounter Mounter + NodeServer *S3NodeServer } func NewDriver(endpoint string, mpVersion string, nodeID string) *Driver { @@ -66,9 +65,9 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) *Driver { } return &Driver{ - Endpoint: endpoint, - NodeID: nodeID, - Mounter: mounter, + Endpoint: endpoint, + NodeID: nodeID, + NodeServer: &S3NodeServer{NodeID: nodeID, Mounter: mounter}, } } @@ -81,7 +80,7 @@ func (d *Driver) Run() error { go tokenFileTender(ctx, tokenFile, "/csi/token") } - scheme, addr, err := util.ParseEndpoint(d.Endpoint) + scheme, addr, err := ParseEndpoint(d.Endpoint) if err != nil { return err } @@ -105,7 +104,7 @@ func (d *Driver) Run() error { csi.RegisterIdentityServer(d.Srv, d) csi.RegisterControllerServer(d.Srv, d) - csi.RegisterNodeServer(d.Srv, d) + csi.RegisterNodeServer(d.Srv, d.NodeServer) klog.Infof("Listening for connections on address: %#v", listener.Addr()) return d.Srv.Serve(listener) diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 4b0f4f2..517720d 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -1,38 +1,16 @@ -/* -Copyright 2019 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package driver -import ( - "k8s.io/mount-utils" -) +type FakeMounter struct{} + +func (m *FakeMounter) Mount(bucketName string, target string, + credentials *MountCredentials, options []string) error { + return nil +} -func NewFakeMounter() Mounter { - return &S3Mounter{ - Interface: &mount.FakeMounter{ - MountPoints: []mount.MountPoint{}, - }, - } +func (m *FakeMounter) Unmount(target string) error { + return nil } -// NewFakeDriver creates a new mock driver used for testing -func NewFakeDriver(endpoint string) *Driver { - return &Driver{ - Endpoint: endpoint, - NodeID: "fake_id", - Mounter: NewFakeMounter(), - } +func (m *FakeMounter) IsMountPoint(target string) (bool, error) { + return false, nil } diff --git a/pkg/driver/mocks/mock_mount.go b/pkg/driver/mocks/mock_mount.go index b3f269d..2d3aa4a 100644 --- a/pkg/driver/mocks/mock_mount.go +++ b/pkg/driver/mocks/mock_mount.go @@ -5,218 +5,235 @@ package mock_driver import ( + context "context" + os "os" reflect "reflect" + driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver" + system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" mount "k8s.io/mount-utils" ) -// MockMounter is a mock of Mounter interface. -type MockMounter struct { +// MockFs is a mock of Fs interface. +type MockFs struct { ctrl *gomock.Controller - recorder *MockMounterMockRecorder + recorder *MockFsMockRecorder } -// MockMounterMockRecorder is the mock recorder for MockMounter. -type MockMounterMockRecorder struct { - mock *MockMounter +// MockFsMockRecorder is the mock recorder for MockFs. +type MockFsMockRecorder struct { + mock *MockFs } -// NewMockMounter creates a new mock instance. -func NewMockMounter(ctrl *gomock.Controller) *MockMounter { - mock := &MockMounter{ctrl: ctrl} - mock.recorder = &MockMounterMockRecorder{mock} +// NewMockFs creates a new mock instance. +func NewMockFs(ctrl *gomock.Controller) *MockFs { + mock := &MockFs{ctrl: ctrl} + mock.recorder = &MockFsMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMounter) EXPECT() *MockMounterMockRecorder { +func (m *MockFs) EXPECT() *MockFsMockRecorder { return m.recorder } -// CanSafelySkipMountPointCheck mocks base method. -func (m *MockMounter) CanSafelySkipMountPointCheck() bool { +// MkdirAll mocks base method. +func (m *MockFs) MkdirAll(path string, perm os.FileMode) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSafelySkipMountPointCheck") - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "MkdirAll", path, perm) + ret0, _ := ret[0].(error) return ret0 } -// CanSafelySkipMountPointCheck indicates an expected call of CanSafelySkipMountPointCheck. -func (mr *MockMounterMockRecorder) CanSafelySkipMountPointCheck() *gomock.Call { +// MkdirAll indicates an expected call of MkdirAll. +func (mr *MockFsMockRecorder) MkdirAll(path, perm interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSafelySkipMountPointCheck", reflect.TypeOf((*MockMounter)(nil).CanSafelySkipMountPointCheck)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockFs)(nil).MkdirAll), path, perm) } -// GetMountRefs mocks base method. -func (m *MockMounter) GetMountRefs(pathname string) ([]string, error) { +// Remove mocks base method. +func (m *MockFs) Remove(name string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMountRefs", pathname) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Remove", name) + ret0, _ := ret[0].(error) + return ret0 } -// GetMountRefs indicates an expected call of GetMountRefs. -func (mr *MockMounterMockRecorder) GetMountRefs(pathname interface{}) *gomock.Call { +// Remove indicates an expected call of Remove. +func (mr *MockFsMockRecorder) Remove(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMountRefs", reflect.TypeOf((*MockMounter)(nil).GetMountRefs), pathname) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockFs)(nil).Remove), name) } -// IsCorruptedMnt mocks base method. -func (m *MockMounter) IsCorruptedMnt(err error) bool { +// Stat mocks base method. +func (m *MockFs) Stat(name string) (os.FileInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsCorruptedMnt", err) - ret0, _ := ret[0].(bool) - return ret0 + ret := m.ctrl.Call(m, "Stat", name) + ret0, _ := ret[0].(os.FileInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// IsCorruptedMnt indicates an expected call of IsCorruptedMnt. -func (mr *MockMounterMockRecorder) IsCorruptedMnt(err interface{}) *gomock.Call { +// Stat indicates an expected call of Stat. +func (mr *MockFsMockRecorder) Stat(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCorruptedMnt", reflect.TypeOf((*MockMounter)(nil).IsCorruptedMnt), err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockFs)(nil).Stat), name) } -// IsLikelyNotMountPoint mocks base method. -func (m *MockMounter) IsLikelyNotMountPoint(file string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsLikelyNotMountPoint", file) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 +// MockMounter is a mock of Mounter interface. +type MockMounter struct { + ctrl *gomock.Controller + recorder *MockMounterMockRecorder } -// IsLikelyNotMountPoint indicates an expected call of IsLikelyNotMountPoint. -func (mr *MockMounterMockRecorder) IsLikelyNotMountPoint(file interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLikelyNotMountPoint", reflect.TypeOf((*MockMounter)(nil).IsLikelyNotMountPoint), file) +// MockMounterMockRecorder is the mock recorder for MockMounter. +type MockMounterMockRecorder struct { + mock *MockMounter } -// IsMountPoint mocks base method. -func (m *MockMounter) IsMountPoint(file string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsMountPoint", file) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 +// NewMockMounter creates a new mock instance. +func NewMockMounter(ctrl *gomock.Controller) *MockMounter { + mock := &MockMounter{ctrl: ctrl} + mock.recorder = &MockMounterMockRecorder{mock} + return mock } -// IsMountPoint indicates an expected call of IsMountPoint. -func (mr *MockMounterMockRecorder) IsMountPoint(file interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMountPoint", reflect.TypeOf((*MockMounter)(nil).IsMountPoint), file) +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMounter) EXPECT() *MockMounterMockRecorder { + return m.recorder } -// List mocks base method. -func (m *MockMounter) List() ([]mount.MountPoint, error) { +// IsMountPoint mocks base method. +func (m *MockMounter) IsMountPoint(target string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List") - ret0, _ := ret[0].([]mount.MountPoint) + ret := m.ctrl.Call(m, "IsMountPoint", target) + ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// List indicates an expected call of List. -func (mr *MockMounterMockRecorder) List() *gomock.Call { +// IsMountPoint indicates an expected call of IsMountPoint. +func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockMounter)(nil).List)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMountPoint", reflect.TypeOf((*MockMounter)(nil).IsMountPoint), target) } -// MakeDir mocks base method. -func (m *MockMounter) MakeDir(pathname string) error { +// Mount mocks base method. +func (m *MockMounter) Mount(bucketName, target string, credentials *driver.MountCredentials, options []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeDir", pathname) + ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, options) ret0, _ := ret[0].(error) return ret0 } -// MakeDir indicates an expected call of MakeDir. -func (mr *MockMounterMockRecorder) MakeDir(pathname interface{}) *gomock.Call { +// Mount indicates an expected call of Mount. +func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, options interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeDir", reflect.TypeOf((*MockMounter)(nil).MakeDir), pathname) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, options) } -// Mount mocks base method. -func (m *MockMounter) Mount(source, target, fstype string, options []string) error { +// Unmount mocks base method. +func (m *MockMounter) Unmount(target string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", source, target, fstype, options) + ret := m.ctrl.Call(m, "Unmount", target) ret0, _ := ret[0].(error) return ret0 } -// Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(source, target, fstype, options interface{}) *gomock.Call { +// Unmount indicates an expected call of Unmount. +func (mr *MockMounterMockRecorder) Unmount(target interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), source, target, fstype, options) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target) } -// MountSensitive mocks base method. -func (m *MockMounter) MountSensitive(source, target, fstype string, options, sensitiveOptions []string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitive", source, target, fstype, options, sensitiveOptions) - ret0, _ := ret[0].(error) - return ret0 +// MockServiceRunner is a mock of ServiceRunner interface. +type MockServiceRunner struct { + ctrl *gomock.Controller + recorder *MockServiceRunnerMockRecorder } -// MountSensitive indicates an expected call of MountSensitive. -func (mr *MockMounterMockRecorder) MountSensitive(source, target, fstype, options, sensitiveOptions interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitive", reflect.TypeOf((*MockMounter)(nil).MountSensitive), source, target, fstype, options, sensitiveOptions) +// MockServiceRunnerMockRecorder is the mock recorder for MockServiceRunner. +type MockServiceRunnerMockRecorder struct { + mock *MockServiceRunner } -// MountSensitiveWithoutSystemd mocks base method. -func (m *MockMounter) MountSensitiveWithoutSystemd(source, target, fstype string, options, sensitiveOptions []string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitiveWithoutSystemd", source, target, fstype, options, sensitiveOptions) - ret0, _ := ret[0].(error) - return ret0 +// NewMockServiceRunner creates a new mock instance. +func NewMockServiceRunner(ctrl *gomock.Controller) *MockServiceRunner { + mock := &MockServiceRunner{ctrl: ctrl} + mock.recorder = &MockServiceRunnerMockRecorder{mock} + return mock } -// MountSensitiveWithoutSystemd indicates an expected call of MountSensitiveWithoutSystemd. -func (mr *MockMounterMockRecorder) MountSensitiveWithoutSystemd(source, target, fstype, options, sensitiveOptions interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitiveWithoutSystemd", reflect.TypeOf((*MockMounter)(nil).MountSensitiveWithoutSystemd), source, target, fstype, options, sensitiveOptions) +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockServiceRunner) EXPECT() *MockServiceRunnerMockRecorder { + return m.recorder } -// MountSensitiveWithoutSystemdWithMountFlags mocks base method. -func (m *MockMounter) MountSensitiveWithoutSystemdWithMountFlags(source, target, fstype string, options, sensitiveOptions, mountFlags []string) error { +// RunOneshot mocks base method. +func (m *MockServiceRunner) RunOneshot(ctx context.Context, config *system.ExecConfig) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitiveWithoutSystemdWithMountFlags", source, target, fstype, options, sensitiveOptions, mountFlags) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "RunOneshot", ctx, config) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// MountSensitiveWithoutSystemdWithMountFlags indicates an expected call of MountSensitiveWithoutSystemdWithMountFlags. -func (mr *MockMounterMockRecorder) MountSensitiveWithoutSystemdWithMountFlags(source, target, fstype, options, sensitiveOptions, mountFlags interface{}) *gomock.Call { +// RunOneshot indicates an expected call of RunOneshot. +func (mr *MockServiceRunnerMockRecorder) RunOneshot(ctx, config interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitiveWithoutSystemdWithMountFlags", reflect.TypeOf((*MockMounter)(nil).MountSensitiveWithoutSystemdWithMountFlags), source, target, fstype, options, sensitiveOptions, mountFlags) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunOneshot", reflect.TypeOf((*MockServiceRunner)(nil).RunOneshot), ctx, config) } -// PathExists mocks base method. -func (m *MockMounter) PathExists(path string) (bool, error) { +// StartService mocks base method. +func (m *MockServiceRunner) StartService(ctx context.Context, config *system.ExecConfig) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PathExists", path) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "StartService", ctx, config) + ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } -// PathExists indicates an expected call of PathExists. -func (mr *MockMounterMockRecorder) PathExists(path interface{}) *gomock.Call { +// StartService indicates an expected call of StartService. +func (mr *MockServiceRunnerMockRecorder) StartService(ctx, config interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PathExists", reflect.TypeOf((*MockMounter)(nil).PathExists), path) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartService", reflect.TypeOf((*MockServiceRunner)(nil).StartService), ctx, config) } -// Unmount mocks base method. -func (m *MockMounter) Unmount(target string) error { +// MockMountLister is a mock of MountLister interface. +type MockMountLister struct { + ctrl *gomock.Controller + recorder *MockMountListerMockRecorder +} + +// MockMountListerMockRecorder is the mock recorder for MockMountLister. +type MockMountListerMockRecorder struct { + mock *MockMountLister +} + +// NewMockMountLister creates a new mock instance. +func NewMockMountLister(ctrl *gomock.Controller) *MockMountLister { + mock := &MockMountLister{ctrl: ctrl} + mock.recorder = &MockMountListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMountLister) EXPECT() *MockMountListerMockRecorder { + return m.recorder +} + +// ListMounts mocks base method. +func (m *MockMountLister) ListMounts() ([]mount.MountPoint, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Unmount", target) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "ListMounts") + ret0, _ := ret[0].([]mount.MountPoint) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Unmount indicates an expected call of Unmount. -func (mr *MockMounterMockRecorder) Unmount(target interface{}) *gomock.Call { +// ListMounts indicates an expected call of ListMounts. +func (mr *MockMountListerMockRecorder) ListMounts() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListMounts", reflect.TypeOf((*MockMountLister)(nil).ListMounts)) } diff --git a/pkg/driver/mount.go b/pkg/driver/mount.go index b3e8f76..3a2a59c 100644 --- a/pkg/driver/mount.go +++ b/pkg/driver/mount.go @@ -39,27 +39,95 @@ const ( defaultRegionEnv = "AWS_DEFAULT_REGION" stsEndpointsEnv = "AWS_STS_REGIONAL_ENDPOINTS" MountS3PathEnv = "MOUNT_S3_PATH" - hostTokenPathEnv = "HOST_TOKEN_PATH" + hostPluginDirEnv = "HOST_TOKEN_DIR" defaultMountS3Path = "/usr/bin/mount-s3" procMounts = "/host/proc/mounts" userAgentPrefix = "--user-agent-prefix" csiDriverPrefix = "s3-csi-driver/" ) +type MountCredentials struct { + AccessKeyID string + SecretAccessKey string + Region string + DefaultRegion string + WebTokenPath string + StsEndpoints string + AwsRoleArn string +} + +// Get environment variables to pass to mount-s3 for authentication. +func (mc *MountCredentials) Env() []string { + env := []string{} + + if mc.AccessKeyID != "" && mc.SecretAccessKey != "" { + env = append(env, keyIdEnv+"="+mc.AccessKeyID) + env = append(env, accessKeyEnv+"="+mc.SecretAccessKey) + } + if mc.WebTokenPath != "" { + env = append(env, webIdentityTokenEnv+"="+mc.WebTokenPath) + env = append(env, roleArnEnv+"="+mc.AwsRoleArn) + } + if mc.Region != "" { + env = append(env, regionEnv+"="+mc.Region) + } + if mc.DefaultRegion != "" { + env = append(env, defaultRegionEnv+"="+mc.DefaultRegion) + } + if mc.StsEndpoints != "" { + env = append(env, stsEndpointsEnv+"="+mc.StsEndpoints) + } + + return env +} + +type Fs interface { + Stat(name string) (os.FileInfo, error) + MkdirAll(path string, perm os.FileMode) error + Remove(name string) error +} + +type OsFs struct{} + +func (OsFs) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +func (OsFs) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} + +func (OsFs) Remove(path string) error { + return os.Remove(path) +} + // Mounter is an interface for mount operations type Mounter interface { - mount.Interface - IsCorruptedMnt(err error) bool - PathExists(path string) (bool, error) - MakeDir(pathname string) error + Mount(bucketName string, target string, credentials *MountCredentials, options []string) error + Unmount(target string) error + IsMountPoint(target string) (bool, error) +} + +type ServiceRunner interface { + StartService(ctx context.Context, config *system.ExecConfig) (string, error) + RunOneshot(ctx context.Context, config *system.ExecConfig) (string, error) +} + +type MountLister interface { + ListMounts() ([]mount.MountPoint, error) +} + +type ProcMountLister struct { + ProcMountPath string } type S3Mounter struct { - mount.Interface - ctx context.Context - supervisor *system.SystemdSupervisor - mpVersion string - mountS3Path string + Ctx context.Context + Runner ServiceRunner + Fs Fs + MountLister MountLister + MpVersion string + MountS3Path string } func MountS3Path() string { @@ -72,74 +140,110 @@ func MountS3Path() string { func NewS3Mounter(mpVersion string) (*S3Mounter, error) { ctx := context.Background() - supervisor, err := system.StartOsSystemdSupervisor() + runner, err := system.StartOsSystemdSupervisor() if err != nil { return nil, fmt.Errorf("failed to start systemd supervisor: %w", err) } return &S3Mounter{ - Interface: mount.New(""), - ctx: ctx, - supervisor: supervisor, - mpVersion: mpVersion, - mountS3Path: MountS3Path(), + Ctx: ctx, + Runner: runner, + Fs: &OsFs{}, + MountLister: &ProcMountLister{ProcMountPath: procMounts}, + MpVersion: mpVersion, + MountS3Path: MountS3Path(), }, nil } -func (m *S3Mounter) MakeDir(pathname string) error { - err := os.MkdirAll(pathname, os.FileMode(0755)) - if err != nil { - if !os.IsExist(err) { - return err - } - } - return nil -} - -// IsCorruptedMnt return true if err is about corrupted mount point -func (m *S3Mounter) IsCorruptedMnt(err error) bool { - return mount.IsCorruptedMnt(err) -} - -func (m *S3Mounter) List() ([]mount.MountPoint, error) { - mounts, err := os.ReadFile(procMounts) +func (pml *ProcMountLister) ListMounts() ([]mount.MountPoint, error) { + mounts, err := os.ReadFile(pml.ProcMountPath) if err != nil { return nil, fmt.Errorf("Failed to read %s: %w", procMounts, err) } return parseProcMounts(mounts) } -func (m *S3Mounter) IsMountPoint(file string) (bool, error) { - mountPoints, err := m.List() +func (m *S3Mounter) IsMountPoint(target string) (bool, error) { + if _, err := m.Fs.Stat(target); os.IsNotExist(err) { + return false, err + } + + mountPoints, err := m.MountLister.ListMounts() if err != nil { - return false, fmt.Errorf("Failed to cat /proc/mounts: %w", err) + return false, fmt.Errorf("Failed to list mounts: %w", err) } for _, mp := range mountPoints { - if mp.Path == file { + if mp.Path == target { return true, nil } } return false, nil } -func (m *S3Mounter) PathExists(path string) (bool, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return false, nil - } else if err != nil { - return false, err - } - return true, nil -} +// Mount the given bucket at the target path. Options will be passed through mostly unchanged, +// with the exception of the user agent prefix which will be added to the Mountpoint headers. +// This method will create the target path if it does not exist and if there is an existing corrupt +// mount, it will attempt an unmount before attempting the mount. +func (m *S3Mounter) Mount(bucketName string, target string, + credentials *MountCredentials, options []string) error { -func (m *S3Mounter) Mount(source string, target string, _ string, options []string) error { - timeoutCtx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + if bucketName == "" { + return fmt.Errorf("bucket name is empty") + } + if target == "" { + return fmt.Errorf("target is empty") + } + timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - env := passthroughEnv() - output, err := m.supervisor.StartService(timeoutCtx, &system.ExecConfig{ - Name: "mount-s3-" + m.mpVersion + "-" + uuid.New().String() + ".service", + cleanupDir := false + + // check if the target path exists + _, statErr := m.Fs.Stat(target) + if statErr != nil { + // does not exist, create the directory + if os.IsNotExist(statErr) { + if err := m.Fs.MkdirAll(target, 0755); err != nil { + return fmt.Errorf("Failed to create target directory: %w", err) + } + cleanupDir = true + defer func() { + if cleanupDir { + if err := m.Fs.Remove(target); err != nil { + klog.V(4).Infof("Mount: Failed to delete target dir: %s.", target) + } + } + }() + } + // Corrupted mount, try unmounting + if mount.IsCorruptedMnt(statErr) { + klog.V(4).Infof("Mount: Target path %q is a corrupted mount. Trying to unmount.", target) + if mntErr := m.Unmount(target); mntErr != nil { + return fmt.Errorf("Unable to unmount the target %q : %v, %v", target, statErr, mntErr) + } + } + } + + mounts, err := m.MountLister.ListMounts() + if err != nil { + return fmt.Errorf("Could not check if %q is a mount point: %v, %v", target, statErr, err) + } + for _, m := range mounts { + if m.Path == target { + klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) + return nil + } + } + + env := []string{} + if credentials != nil { + env = credentials.Env() + } + + output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ + Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver FUSE daemon", - ExecPath: m.mountS3Path, - Args: append(addUserAgentToOptions(options), source, target), + ExecPath: m.MountS3Path, + Args: append(addUserAgentToOptions(options), bucketName, target), Env: env, }) @@ -149,6 +253,7 @@ func (m *S3Mounter) Mount(source string, target string, _ string, options []stri if output != "" { klog.V(5).Infof("mount-s3 output: %s", output) } + cleanupDir = false return nil } @@ -166,10 +271,10 @@ func addUserAgentToOptions(options []string) []string { } func (m *S3Mounter) Unmount(target string) error { - timeoutCtx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - output, err := m.supervisor.RunOneshot(timeoutCtx, &system.ExecConfig{ + output, err := m.Runner.RunOneshot(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-umount-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver unmount", ExecPath: "/usr/bin/umount", @@ -184,42 +289,6 @@ func (m *S3Mounter) Unmount(target string) error { return nil } -func passthroughEnv() []string { - env := []string{} - - keyId := os.Getenv(keyIdEnv) - accessKey := os.Getenv(accessKeyEnv) - if keyId != "" && accessKey != "" { - env = append(env, keyIdEnv+"="+keyId) - env = append(env, accessKeyEnv+"="+accessKey) - } - webIdentityFile := os.Getenv(webIdentityTokenEnv) - awsRoleArn := os.Getenv(roleArnEnv) - hostTokenPath := os.Getenv(hostTokenPathEnv) - if hostTokenPath == "" { - // set the default in case the env variable isn't found - hostTokenPath = "/var/lib/kubelet/plugins/s3.csi.aws.com/token" - } - if webIdentityFile != "" { - env = append(env, webIdentityTokenEnv+"="+hostTokenPath) - env = append(env, roleArnEnv+"="+awsRoleArn) - } - region := os.Getenv(regionEnv) - if region != "" { - env = append(env, regionEnv+"="+region) - } - defaultRegion := os.Getenv(defaultRegionEnv) - if defaultRegion != "" { - env = append(env, defaultRegionEnv+"="+defaultRegion) - } - stsEndpoints := os.Getenv(stsEndpointsEnv) - if stsEndpoints != "" { - env = append(env, stsEndpointsEnv+"="+stsEndpoints) - } - - return env -} - func parseProcMounts(data []byte) ([]mount.MountPoint, error) { var mounts []mount.MountPoint diff --git a/pkg/driver/mount_test.go b/pkg/driver/mount_test.go index 0ccd604..1e1ad10 100644 --- a/pkg/driver/mount_test.go +++ b/pkg/driver/mount_test.go @@ -1,31 +1,182 @@ -package driver +package driver_test import ( + "context" + "errors" + "strings" "testing" - "github.com/stretchr/testify/assert" + driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver" + mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/system" + "github.com/golang/mock/gomock" + "k8s.io/mount-utils" ) -func TestUserAgentPrefix(t *testing.T) { +type TestMountLister struct { + Mounts []mount.MountPoint + Err error +} + +func (l *TestMountLister) ListMounts() ([]mount.MountPoint, error) { + return l.Mounts, l.Err +} + +type mounterTestEnv struct { + ctx context.Context + mockCtl *gomock.Controller + mockRunner *mock_driver.MockServiceRunner + mockFs *mock_driver.MockFs + mockMountLister *mock_driver.MockMountLister + mounter *driver.S3Mounter +} + +func initMounterTestEnv(t *testing.T) *mounterTestEnv { + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockRunner := mock_driver.NewMockServiceRunner(mockCtl) + mockFs := mock_driver.NewMockFs(mockCtl) + mockMountLister := mock_driver.NewMockMountLister(mockCtl) + mountpointVersion := "TEST_MP_VERSION-v1.1" + + return &mounterTestEnv{ + ctx: ctx, + mockCtl: mockCtl, + mockRunner: mockRunner, + mockFs: mockFs, + mockMountLister: mockMountLister, + mounter: &driver.S3Mounter{ + Ctx: ctx, + Runner: mockRunner, + Fs: mockFs, + MountLister: mockMountLister, + MpVersion: mountpointVersion, + MountS3Path: driver.MountS3Path(), + }, + } +} + +func TestS3MounterMount(t *testing.T) { + testBucketName := "test-bucket" + testTargetPath := "/mnt/my-mountpoint/bucket/" + testCredentials := &driver.MountCredentials{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + Region: "test-region", + DefaultRegion: "test-region", + WebTokenPath: "test-web-token-path", + StsEndpoints: "test-sts-endpoint", + AwsRoleArn: "test-aws-role", + } + testCases := []struct { - name string - input []string - expected []string + name string + bucketName string + targetPath string + credentials *driver.MountCredentials + options []string + expectedErr bool + before func(*testing.T, *mounterTestEnv) }{ { - name: "success: add user agent prefix to mount call", - input: []string{"--read-only"}, - expected: []string{"--read-only", userAgentPrefix + "=" + csiDriverPrefix + GetVersion().DriverVersion}, + name: "success: mounts without empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, }, { - name: "success: replacing customer user agent prefix", - input: []string{"--read-only", "--user-agent-prefix testing"}, - expected: []string{"--read-only", userAgentPrefix + "=" + csiDriverPrefix + GetVersion().DriverVersion}, + name: "success: mounts with nil credentials", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, + }, + { + name: "success: replaces user agent prefix", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{"--user-agent-prefix=mycustomuseragent"}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { + for _, a := range config.Args { + if strings.Contains(a, "mycustomuseragent") { + t.Fatal("Bad user agent") + } + } + return "success", nil + }) + }, + }, + { + name: "failure: fails on mount failure", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{}, + expectedErr: true, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("fail", errors.New("test failure")) + }, + }, + { + name: "failure: won't mount empty bucket name", + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + expectedErr: true, + }, + { + name: "failure: won't mount empty target", + bucketName: testBucketName, + credentials: testCredentials, + options: []string{}, + expectedErr: true, + }, + { + name: "failure: mounts without empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, addUserAgentToOptions(tc.input), tc.expected) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + env := initMounterTestEnv(t) + if testCase.before != nil { + testCase.before(t, env) + } + err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, + testCase.credentials, testCase.options) + env.mockCtl.Finish() + if err != nil && !testCase.expectedErr { + t.Fatal(err) + } }) } } + +func TestS3MounterIsMountPoint(t *testing.T) { +} diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 3cf6342..6d7a5d5 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -19,13 +19,16 @@ package driver import ( "context" "os" + "path" "strings" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" + "k8s.io/mount-utils" ) const ( @@ -37,16 +40,28 @@ var ( nodeCaps = []csi.NodeServiceCapability_RPC_Type{} ) -func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +// S3NodeServer is the implementation of the csi.NodeServer interface +type S3NodeServer struct { + NodeID string + BaseCredentials *MountCredentials + Mounter Mounter +} + +type Token struct { + Token string `json:"token"` + ExpirationTimestamp time.Time `json:"expirationTimestamp"` +} + +func (ns *S3NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (ns *S3NodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { - klog.V(4).Infof("NodePublishVolume: called with args %+v", req) +func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { + klog.V(4).Infof("NodePublishVolume: req: %+v", req) volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -68,21 +83,15 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu return nil, status.Error(codes.InvalidArgument, "Volume capability not provided") } - if !d.isValidVolumeCapabilities([]*csi.VolumeCapability{volCap}) { + if !ns.isValidVolumeCapabilities([]*csi.VolumeCapability{volCap}) { return nil, status.Error(codes.InvalidArgument, "Volume capability not supported") } mountpointArgs := []string{} - if req.GetReadonly() || volCap.GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { mountpointArgs = append(mountpointArgs, "--read-only") } - klog.V(4).Infof("NodePublishVolume: creating dir %s", target) - if err := d.Mounter.MakeDir(target); err != nil { - return nil, status.Errorf(codes.Internal, "Could not create dir %q: %v", target, err) - } - // get the mount(point) options (yaml list) if capMount := volCap.GetMount(); capMount != nil { mountFlags := capMount.GetMountFlags() @@ -99,19 +108,27 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu mountpointArgs = compileMountOptions(mountpointArgs, mountFlags) } - //Checking if the target directory is already mounted with a volume. - mounted, err := d.isMounted(target) - if err != nil { - return nil, status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) + klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) + hostPluginDirEnv := os.Getenv(hostPluginDirEnv) + if hostPluginDirEnv == "" { + // set the default in case the env variable isn't found + hostPluginDirEnv = "/var/lib/kubelet/plugins/s3.csi.aws.com/" } - if !mounted { - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) - if err := d.Mounter.Mount(bucket, target, fstype, mountpointArgs); err != nil { - os.Remove(target) - return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) - } - klog.V(4).Infof("NodePublishVolume: %s was mounted", target) + hostTokenPath := path.Join(hostPluginDirEnv, "token") + credentials := &MountCredentials{ + AccessKeyID: os.Getenv(keyIdEnv), + SecretAccessKey: os.Getenv(accessKeyEnv), + Region: os.Getenv(regionEnv), + DefaultRegion: os.Getenv(defaultRegionEnv), + WebTokenPath: hostTokenPath, + StsEndpoints: os.Getenv(stsEndpointsEnv), + AwsRoleArn: os.Getenv(roleArnEnv), + } + if err := ns.Mounter.Mount(bucket, target, credentials, mountpointArgs); err != nil { + os.Remove(target) + return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } + klog.V(4).Infof("NodePublishVolume: %s was mounted", target) return &csi.NodePublishVolumeResponse{}, nil } @@ -140,7 +157,7 @@ func compileMountOptions(currentOptions []string, newOptions []string) []string return allMountOptions.List() } -func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { klog.V(4).Infof("NodeUnpublishVolume: called with args %+v", req) volumeID := req.GetVolumeId() @@ -152,11 +169,11 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish return nil, status.Error(codes.InvalidArgument, "Target path not provided") } - mounted, err := d.Mounter.IsMountPoint(target) + mounted, err := ns.Mounter.IsMountPoint(target) if err != nil && os.IsNotExist(err) { klog.V(4).Infof("NodeUnpublishVolume: target path %s does not exist, skipping unmount", target) return &csi.NodeUnpublishVolumeResponse{}, nil - } else if err != nil && d.Mounter.IsCorruptedMnt(err) { + } else if err != nil && mount.IsCorruptedMnt(err) { klog.V(4).Infof("NodeUnpublishVolume: target path %s is corrupted: %v, will try to unmount", target, err) mounted = true } else if err != nil { @@ -168,7 +185,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish } klog.V(4).Infof("NodeUnpublishVolume: unmounting %s", target) - err = d.Mounter.Unmount(target) + err = ns.Mounter.Unmount(target) if err != nil { return nil, status.Errorf(codes.Internal, "Could not unmount %q: %v", target, err) } @@ -176,15 +193,15 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish return &csi.NodeUnpublishVolumeResponse{}, nil } -func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { +func (ns *S3NodeServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (ns *S3NodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { +func (ns *S3NodeServer) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { klog.V(4).Infof("NodeGetCapabilities: called with args %+v", req) var caps []*csi.NodeServiceCapability for _, cap := range nodeCaps { @@ -200,44 +217,15 @@ func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabi return &csi.NodeGetCapabilitiesResponse{Capabilities: caps}, nil } -func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { +func (ns *S3NodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { klog.V(4).Infof("NodeGetInfo: called with args %+v", req) return &csi.NodeGetInfoResponse{ - NodeId: d.NodeID, + NodeId: ns.NodeID, }, nil } -// isMounted checks if target is a valid mountpoint -// inexistent target directory is NOT an error -// method will try to unmount the directory if it was detected to be corrupted -func (d *Driver) isMounted(target string) (bool, error) { - notMnt, err := d.Mounter.IsLikelyNotMountPoint(target) - if err != nil && !os.IsNotExist(err) { - _, pathErr := d.Mounter.PathExists(target) - if pathErr != nil && d.Mounter.IsCorruptedMnt(pathErr) { - klog.V(4).Infof("NodePublishVolume: Target path %q is a corrupted mount. Trying to unmount.", target) - if mntErr := d.Mounter.Unmount(target); mntErr != nil { - return false, status.Errorf(codes.Internal, "Unable to unmount the target %q : %v", target, mntErr) - } - return false, nil - } - return false, status.Errorf(codes.Internal, "Could not check if %q is a mount point: %v, %v", target, err, pathErr) - } - - if err != nil && os.IsNotExist(err) { - klog.V(5).Infof("[Debug] NodePublishVolume: Target path %q does not exist", target) - return false, nil - } - - if !notMnt { - klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) - } - - return !notMnt, nil -} - -func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { +func (ns *S3NodeServer) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { hasSupport := func(cap *csi.VolumeCapability) bool { for _, c := range volumeCaps { if c.GetMode() == cap.AccessMode.GetMode() { diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index e02a482..7cf3832 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -15,22 +15,22 @@ import ( type nodeServerTestEnv struct { mockCtl *gomock.Controller mockMounter *mock_driver.MockMounter - driver *driver.Driver + server *driver.S3NodeServer } func initNodeServerTestEnv(t *testing.T) *nodeServerTestEnv { mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockMounter := mock_driver.NewMockMounter(mockCtl) - driver := &driver.Driver{ - Endpoint: "unix://tmp/csi.sock", - NodeID: "test-nodeID", - Mounter: mockMounter, + server := &driver.S3NodeServer{ + NodeID: "test-nodeID", + BaseCredentials: &driver.MountCredentials{}, + Mounter: mockMounter, } return &nodeServerTestEnv{ mockCtl: mockCtl, mockMounter: mockMounter, - driver: driver, + server: server, } } @@ -64,10 +64,8 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(false, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Any()) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Any()) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -94,10 +92,8 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(false, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Eq([]string{"--read-only"})) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only"})) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -127,10 +123,8 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(true, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -160,12 +154,10 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Mount( - gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), + gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only", "--test=123"})).Return(nil) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -184,7 +176,7 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err == nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -219,29 +211,7 @@ func TestNodeUnpublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) - if err != nil { - t.Fatalf("NodePublishVolume failed: %v", err) - } - - nodeTestEnv.mockCtl.Finish() - }, - }, - { - name: "success: corrupted volume", - testFunc: func(t *testing.T) { - nodeTestEnv := initNodeServerTestEnv(t) - ctx := context.Background() - req := &csi.NodeUnpublishVolumeRequest{ - VolumeId: volumeId, - TargetPath: targetPath, - } - - expectedErr := errors.New("") - nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, expectedErr) - nodeTestEnv.mockMounter.EXPECT().IsCorruptedMnt(expectedErr).Return(true) - nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -260,7 +230,7 @@ func TestNodeUnpublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -280,7 +250,7 @@ func TestNodeUnpublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(errors.New("")) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err == nil { t.Fatalf("NodePublishVolume must fail") } @@ -300,7 +270,7 @@ func TestNodeUnpublishVolume(t *testing.T) { expectedError := fs.ErrNotExist nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, expectedError) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -314,3 +284,21 @@ func TestNodeUnpublishVolume(t *testing.T) { t.Run(tc.name, tc.testFunc) } } + +func TestNodeGetCapabilities(t *testing.T) { + nodeTestEnv := initNodeServerTestEnv(t) + ctx := context.Background() + req := &csi.NodeGetCapabilitiesRequest{} + + resp, err := nodeTestEnv.server.NodeGetCapabilities(ctx, req) + if err != nil { + t.Fatalf("NodeGetCapabilities failed: %v", err) + } + + capabilities := resp.GetCapabilities() + if len(capabilities) != 0 { + t.Fatalf("NodeGetCapabilities failed: capabilities not empty") + } + + nodeTestEnv.mockCtl.Finish() +} diff --git a/pkg/util/util.go b/pkg/driver/server.go similarity index 55% rename from pkg/util/util.go rename to pkg/driver/server.go index ab71adc..3886a9b 100644 --- a/pkg/util/util.go +++ b/pkg/driver/server.go @@ -1,20 +1,4 @@ -/* -Copyright 2022 The Kubernetes Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package util +package driver import ( "fmt" diff --git a/pkg/system/systemd_test.go b/pkg/system/systemd_test.go index 00605c2..ed39b94 100644 --- a/pkg/system/systemd_test.go +++ b/pkg/system/systemd_test.go @@ -165,11 +165,9 @@ func TestSystemdConnection(t *testing.T) { } }() }) - _, err := conn.ListUnits(ctx) - if err == nil { + if _, err := conn.ListUnits(ctx); err == nil { t.Fatalf("Expected error, got nil") } - }, }, } diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index acd6580..94b5618 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -28,7 +28,6 @@ import ( sanity "github.com/kubernetes-csi/csi-test/pkg/sanity" "github.com/awslabs/aws-s3-csi-driver/pkg/driver" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" ) const ( @@ -66,7 +65,15 @@ func waitDriverIsUp(endpoint string) { } var _ = BeforeSuite(func() { - s3Driver = driver.NewFakeDriver(endpoint) + s3Driver = &driver.Driver{ + Endpoint: endpoint, + NodeID: "fake_id", + NodeServer: &driver.S3NodeServer{ + NodeID: "fake_id", + BaseCredentials: &driver.MountCredentials{}, + Mounter: &driver.FakeMounter{}, + }, + } go func() { Expect(s3Driver.Run()).NotTo(HaveOccurred()) }() @@ -84,7 +91,7 @@ var _ = Describe("Amazon S3 CSI Driver", func() { Address: endpoint, TargetPath: mountPath, StagingPath: stagePath, - TestVolumeSize: 2000 * util.GiB, + TestVolumeSize: 2000 * driver.GiB, IDGen: &sanity.DefaultIDGenerator{}, } sanity.GinkgoTest(config)