From 3fd288dc112b9777e244b633ea12cd3f14ca581b Mon Sep 17 00:00:00 2001 From: Joe Kramer Date: Wed, 14 Feb 2024 15:28:27 -0700 Subject: [PATCH] Refactor mounting (#149) Refactoring the node and mount code for better separation of concerns and testability. Some of the mount logic was moved from the node server to the mount module. The Mounter class is no longer based on the k8s library code. --- .../templates/node.yaml | 4 +- deploy/kubernetes/base/node-daemonset.yaml | 4 +- go.mod | 13 +- go.sum | 40 ++- pkg/driver/driver.go | 15 +- pkg/driver/fakes.go | 42 +-- pkg/driver/mocks/mock_mount.go | 255 ++++++++++-------- pkg/driver/mount.go | 248 ++++++++++------- pkg/driver/mount_test.go | 178 ++++++++++-- pkg/driver/node.go | 115 ++++---- pkg/driver/node_test.go | 86 +++--- pkg/{util/util.go => driver/server.go} | 18 +- pkg/system/systemd_test.go | 4 +- tests/sanity/sanity_test.go | 13 +- 14 files changed, 601 insertions(+), 434 deletions(-) rename pkg/{util/util.go => driver/server.go} (55%) diff --git a/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml b/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml index 7602221..b1ec58a 100644 --- a/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml +++ b/charts/aws-mountpoint-s3-csi-driver/templates/node.yaml @@ -75,8 +75,8 @@ spec: valueFrom: fieldRef: fieldPath: spec.nodeName - - name: HOST_TOKEN_PATH - value: {{ trimSuffix "/" .Values.node.kubeletPath }}/plugins/s3.csi.aws.com/token + - name: HOST_PLUGIN_DIR + value: {{ trimSuffix "/" .Values.node.kubeletPath }}/plugins/s3.csi.aws.com/ {{- with .Values.awsAccessSecret }} - name: AWS_ACCESS_KEY_ID valueFrom: diff --git a/deploy/kubernetes/base/node-daemonset.yaml b/deploy/kubernetes/base/node-daemonset.yaml index a1101ce..f7f3a70 100644 --- a/deploy/kubernetes/base/node-daemonset.yaml +++ b/deploy/kubernetes/base/node-daemonset.yaml @@ -72,8 +72,8 @@ spec: valueFrom: fieldRef: fieldPath: spec.nodeName - - name: HOST_TOKEN_PATH - value: /var/lib/kubelet/plugins/s3.csi.aws.com/token + - name: HOST_PLUGIN_DIR + value: /var/lib/kubelet/plugins/s3.csi.aws.com/ volumeMounts: - name: kubelet-dir mountPath: /var/lib/kubelet diff --git a/go.mod b/go.mod index e1954d1..1673969 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,12 @@ go 1.21 require ( github.com/container-storage-interface/spec v1.9.0 - github.com/coreos/go-systemd/v22 v22.5.0 github.com/godbus/dbus/v5 v5.1.0 github.com/golang/mock v1.6.0 github.com/kubernetes-csi/csi-test v2.2.0+incompatible github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.27.6 - github.com/stretchr/testify v1.8.2 + github.com/onsi/gomega v1.29.0 + github.com/stretchr/testify v1.8.4 google.golang.org/grpc v1.59.0 k8s.io/klog/v2 v2.110.1 k8s.io/mount-utils v0.28.4 @@ -27,17 +26,17 @@ require ( github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/go-cmp v0.5.9 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.4.0 github.com/moby/sys/mountinfo v0.7.1 // indirect github.com/nxadm/tail v1.4.8 // indirect - golang.org/x/net v0.18.0 - golang.org/x/sys v0.14.0 + golang.org/x/net v0.19.0 + golang.org/x/sys v0.15.0 golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.31.0 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/apimachinery v0.28.4 + k8s.io/apimachinery v0.29.1 k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect ) diff --git a/go.sum b/go.sum index 0e26b08..277bf81 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/container-storage-interface/spec v1.9.0 h1:zKtX4STsq31Knz3gciCYCi1SXtO2HJDecIjDVboYavY= github.com/container-storage-interface/spec v1.9.0/go.mod h1:ZfDu+3ZRyeVqxZM0Ds19MVLkN2d1XJ5MAfi1L3VjlT0= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -14,7 +12,6 @@ github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4 github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -33,8 +30,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= @@ -55,24 +52,20 @@ github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= +github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= +github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= +github.com/onsi/gomega v1.29.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -86,8 +79,8 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -105,8 +98,8 @@ golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -116,8 +109,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -147,11 +140,10 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/apimachinery v0.28.4 h1:zOSJe1mc+GxuMnFzD4Z/U1wst50X28ZNsn5bhgIIao8= -k8s.io/apimachinery v0.28.4/go.mod h1:wI37ncBvfAoswfq626yPTe6Bz1c22L7uaJ8dho83mgg= +k8s.io/apimachinery v0.29.1 h1:KY4/E6km/wLBguvCZv8cKTeOwwOBqFNjwJIdMkMbbRc= +k8s.io/apimachinery v0.29.1/go.mod h1:6HVkd1FwxIagpYrHSwJlQqZI3G9LfYWRPAkUvLnXTKU= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= k8s.io/mount-utils v0.28.4 h1:5GOZLm2dXi2fr+MKY8hS6kdV5reXrZBiK7848O5MVD0= diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 6c103d7..41796be 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -25,7 +25,6 @@ import ( "os" "time" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc" "k8s.io/klog/v2" @@ -51,9 +50,9 @@ var ( type Driver struct { Endpoint string Srv *grpc.Server + NodeID string - NodeID string - Mounter Mounter + NodeServer *S3NodeServer } func NewDriver(endpoint string, mpVersion string, nodeID string) *Driver { @@ -66,9 +65,9 @@ func NewDriver(endpoint string, mpVersion string, nodeID string) *Driver { } return &Driver{ - Endpoint: endpoint, - NodeID: nodeID, - Mounter: mounter, + Endpoint: endpoint, + NodeID: nodeID, + NodeServer: &S3NodeServer{NodeID: nodeID, Mounter: mounter}, } } @@ -81,7 +80,7 @@ func (d *Driver) Run() error { go tokenFileTender(ctx, tokenFile, "/csi/token") } - scheme, addr, err := util.ParseEndpoint(d.Endpoint) + scheme, addr, err := ParseEndpoint(d.Endpoint) if err != nil { return err } @@ -105,7 +104,7 @@ func (d *Driver) Run() error { csi.RegisterIdentityServer(d.Srv, d) csi.RegisterControllerServer(d.Srv, d) - csi.RegisterNodeServer(d.Srv, d) + csi.RegisterNodeServer(d.Srv, d.NodeServer) klog.Infof("Listening for connections on address: %#v", listener.Addr()) return d.Srv.Serve(listener) diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 4b0f4f2..517720d 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -1,38 +1,16 @@ -/* -Copyright 2019 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package driver -import ( - "k8s.io/mount-utils" -) +type FakeMounter struct{} + +func (m *FakeMounter) Mount(bucketName string, target string, + credentials *MountCredentials, options []string) error { + return nil +} -func NewFakeMounter() Mounter { - return &S3Mounter{ - Interface: &mount.FakeMounter{ - MountPoints: []mount.MountPoint{}, - }, - } +func (m *FakeMounter) Unmount(target string) error { + return nil } -// NewFakeDriver creates a new mock driver used for testing -func NewFakeDriver(endpoint string) *Driver { - return &Driver{ - Endpoint: endpoint, - NodeID: "fake_id", - Mounter: NewFakeMounter(), - } +func (m *FakeMounter) IsMountPoint(target string) (bool, error) { + return false, nil } diff --git a/pkg/driver/mocks/mock_mount.go b/pkg/driver/mocks/mock_mount.go index b3f269d..2d3aa4a 100644 --- a/pkg/driver/mocks/mock_mount.go +++ b/pkg/driver/mocks/mock_mount.go @@ -5,218 +5,235 @@ package mock_driver import ( + context "context" + os "os" reflect "reflect" + driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver" + system "github.com/awslabs/aws-s3-csi-driver/pkg/system" gomock "github.com/golang/mock/gomock" mount "k8s.io/mount-utils" ) -// MockMounter is a mock of Mounter interface. -type MockMounter struct { +// MockFs is a mock of Fs interface. +type MockFs struct { ctrl *gomock.Controller - recorder *MockMounterMockRecorder + recorder *MockFsMockRecorder } -// MockMounterMockRecorder is the mock recorder for MockMounter. -type MockMounterMockRecorder struct { - mock *MockMounter +// MockFsMockRecorder is the mock recorder for MockFs. +type MockFsMockRecorder struct { + mock *MockFs } -// NewMockMounter creates a new mock instance. -func NewMockMounter(ctrl *gomock.Controller) *MockMounter { - mock := &MockMounter{ctrl: ctrl} - mock.recorder = &MockMounterMockRecorder{mock} +// NewMockFs creates a new mock instance. +func NewMockFs(ctrl *gomock.Controller) *MockFs { + mock := &MockFs{ctrl: ctrl} + mock.recorder = &MockFsMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockMounter) EXPECT() *MockMounterMockRecorder { +func (m *MockFs) EXPECT() *MockFsMockRecorder { return m.recorder } -// CanSafelySkipMountPointCheck mocks base method. -func (m *MockMounter) CanSafelySkipMountPointCheck() bool { +// MkdirAll mocks base method. +func (m *MockFs) MkdirAll(path string, perm os.FileMode) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CanSafelySkipMountPointCheck") - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "MkdirAll", path, perm) + ret0, _ := ret[0].(error) return ret0 } -// CanSafelySkipMountPointCheck indicates an expected call of CanSafelySkipMountPointCheck. -func (mr *MockMounterMockRecorder) CanSafelySkipMountPointCheck() *gomock.Call { +// MkdirAll indicates an expected call of MkdirAll. +func (mr *MockFsMockRecorder) MkdirAll(path, perm interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CanSafelySkipMountPointCheck", reflect.TypeOf((*MockMounter)(nil).CanSafelySkipMountPointCheck)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockFs)(nil).MkdirAll), path, perm) } -// GetMountRefs mocks base method. -func (m *MockMounter) GetMountRefs(pathname string) ([]string, error) { +// Remove mocks base method. +func (m *MockFs) Remove(name string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMountRefs", pathname) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Remove", name) + ret0, _ := ret[0].(error) + return ret0 } -// GetMountRefs indicates an expected call of GetMountRefs. -func (mr *MockMounterMockRecorder) GetMountRefs(pathname interface{}) *gomock.Call { +// Remove indicates an expected call of Remove. +func (mr *MockFsMockRecorder) Remove(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMountRefs", reflect.TypeOf((*MockMounter)(nil).GetMountRefs), pathname) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockFs)(nil).Remove), name) } -// IsCorruptedMnt mocks base method. -func (m *MockMounter) IsCorruptedMnt(err error) bool { +// Stat mocks base method. +func (m *MockFs) Stat(name string) (os.FileInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsCorruptedMnt", err) - ret0, _ := ret[0].(bool) - return ret0 + ret := m.ctrl.Call(m, "Stat", name) + ret0, _ := ret[0].(os.FileInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// IsCorruptedMnt indicates an expected call of IsCorruptedMnt. -func (mr *MockMounterMockRecorder) IsCorruptedMnt(err interface{}) *gomock.Call { +// Stat indicates an expected call of Stat. +func (mr *MockFsMockRecorder) Stat(name interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCorruptedMnt", reflect.TypeOf((*MockMounter)(nil).IsCorruptedMnt), err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockFs)(nil).Stat), name) } -// IsLikelyNotMountPoint mocks base method. -func (m *MockMounter) IsLikelyNotMountPoint(file string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsLikelyNotMountPoint", file) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 +// MockMounter is a mock of Mounter interface. +type MockMounter struct { + ctrl *gomock.Controller + recorder *MockMounterMockRecorder } -// IsLikelyNotMountPoint indicates an expected call of IsLikelyNotMountPoint. -func (mr *MockMounterMockRecorder) IsLikelyNotMountPoint(file interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLikelyNotMountPoint", reflect.TypeOf((*MockMounter)(nil).IsLikelyNotMountPoint), file) +// MockMounterMockRecorder is the mock recorder for MockMounter. +type MockMounterMockRecorder struct { + mock *MockMounter } -// IsMountPoint mocks base method. -func (m *MockMounter) IsMountPoint(file string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsMountPoint", file) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 +// NewMockMounter creates a new mock instance. +func NewMockMounter(ctrl *gomock.Controller) *MockMounter { + mock := &MockMounter{ctrl: ctrl} + mock.recorder = &MockMounterMockRecorder{mock} + return mock } -// IsMountPoint indicates an expected call of IsMountPoint. -func (mr *MockMounterMockRecorder) IsMountPoint(file interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMountPoint", reflect.TypeOf((*MockMounter)(nil).IsMountPoint), file) +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMounter) EXPECT() *MockMounterMockRecorder { + return m.recorder } -// List mocks base method. -func (m *MockMounter) List() ([]mount.MountPoint, error) { +// IsMountPoint mocks base method. +func (m *MockMounter) IsMountPoint(target string) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List") - ret0, _ := ret[0].([]mount.MountPoint) + ret := m.ctrl.Call(m, "IsMountPoint", target) + ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } -// List indicates an expected call of List. -func (mr *MockMounterMockRecorder) List() *gomock.Call { +// IsMountPoint indicates an expected call of IsMountPoint. +func (mr *MockMounterMockRecorder) IsMountPoint(target interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockMounter)(nil).List)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsMountPoint", reflect.TypeOf((*MockMounter)(nil).IsMountPoint), target) } -// MakeDir mocks base method. -func (m *MockMounter) MakeDir(pathname string) error { +// Mount mocks base method. +func (m *MockMounter) Mount(bucketName, target string, credentials *driver.MountCredentials, options []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MakeDir", pathname) + ret := m.ctrl.Call(m, "Mount", bucketName, target, credentials, options) ret0, _ := ret[0].(error) return ret0 } -// MakeDir indicates an expected call of MakeDir. -func (mr *MockMounterMockRecorder) MakeDir(pathname interface{}) *gomock.Call { +// Mount indicates an expected call of Mount. +func (mr *MockMounterMockRecorder) Mount(bucketName, target, credentials, options interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeDir", reflect.TypeOf((*MockMounter)(nil).MakeDir), pathname) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), bucketName, target, credentials, options) } -// Mount mocks base method. -func (m *MockMounter) Mount(source, target, fstype string, options []string) error { +// Unmount mocks base method. +func (m *MockMounter) Unmount(target string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Mount", source, target, fstype, options) + ret := m.ctrl.Call(m, "Unmount", target) ret0, _ := ret[0].(error) return ret0 } -// Mount indicates an expected call of Mount. -func (mr *MockMounterMockRecorder) Mount(source, target, fstype, options interface{}) *gomock.Call { +// Unmount indicates an expected call of Unmount. +func (mr *MockMounterMockRecorder) Unmount(target interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Mount", reflect.TypeOf((*MockMounter)(nil).Mount), source, target, fstype, options) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target) } -// MountSensitive mocks base method. -func (m *MockMounter) MountSensitive(source, target, fstype string, options, sensitiveOptions []string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitive", source, target, fstype, options, sensitiveOptions) - ret0, _ := ret[0].(error) - return ret0 +// MockServiceRunner is a mock of ServiceRunner interface. +type MockServiceRunner struct { + ctrl *gomock.Controller + recorder *MockServiceRunnerMockRecorder } -// MountSensitive indicates an expected call of MountSensitive. -func (mr *MockMounterMockRecorder) MountSensitive(source, target, fstype, options, sensitiveOptions interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitive", reflect.TypeOf((*MockMounter)(nil).MountSensitive), source, target, fstype, options, sensitiveOptions) +// MockServiceRunnerMockRecorder is the mock recorder for MockServiceRunner. +type MockServiceRunnerMockRecorder struct { + mock *MockServiceRunner } -// MountSensitiveWithoutSystemd mocks base method. -func (m *MockMounter) MountSensitiveWithoutSystemd(source, target, fstype string, options, sensitiveOptions []string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitiveWithoutSystemd", source, target, fstype, options, sensitiveOptions) - ret0, _ := ret[0].(error) - return ret0 +// NewMockServiceRunner creates a new mock instance. +func NewMockServiceRunner(ctrl *gomock.Controller) *MockServiceRunner { + mock := &MockServiceRunner{ctrl: ctrl} + mock.recorder = &MockServiceRunnerMockRecorder{mock} + return mock } -// MountSensitiveWithoutSystemd indicates an expected call of MountSensitiveWithoutSystemd. -func (mr *MockMounterMockRecorder) MountSensitiveWithoutSystemd(source, target, fstype, options, sensitiveOptions interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitiveWithoutSystemd", reflect.TypeOf((*MockMounter)(nil).MountSensitiveWithoutSystemd), source, target, fstype, options, sensitiveOptions) +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockServiceRunner) EXPECT() *MockServiceRunnerMockRecorder { + return m.recorder } -// MountSensitiveWithoutSystemdWithMountFlags mocks base method. -func (m *MockMounter) MountSensitiveWithoutSystemdWithMountFlags(source, target, fstype string, options, sensitiveOptions, mountFlags []string) error { +// RunOneshot mocks base method. +func (m *MockServiceRunner) RunOneshot(ctx context.Context, config *system.ExecConfig) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MountSensitiveWithoutSystemdWithMountFlags", source, target, fstype, options, sensitiveOptions, mountFlags) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "RunOneshot", ctx, config) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// MountSensitiveWithoutSystemdWithMountFlags indicates an expected call of MountSensitiveWithoutSystemdWithMountFlags. -func (mr *MockMounterMockRecorder) MountSensitiveWithoutSystemdWithMountFlags(source, target, fstype, options, sensitiveOptions, mountFlags interface{}) *gomock.Call { +// RunOneshot indicates an expected call of RunOneshot. +func (mr *MockServiceRunnerMockRecorder) RunOneshot(ctx, config interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MountSensitiveWithoutSystemdWithMountFlags", reflect.TypeOf((*MockMounter)(nil).MountSensitiveWithoutSystemdWithMountFlags), source, target, fstype, options, sensitiveOptions, mountFlags) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunOneshot", reflect.TypeOf((*MockServiceRunner)(nil).RunOneshot), ctx, config) } -// PathExists mocks base method. -func (m *MockMounter) PathExists(path string) (bool, error) { +// StartService mocks base method. +func (m *MockServiceRunner) StartService(ctx context.Context, config *system.ExecConfig) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PathExists", path) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "StartService", ctx, config) + ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } -// PathExists indicates an expected call of PathExists. -func (mr *MockMounterMockRecorder) PathExists(path interface{}) *gomock.Call { +// StartService indicates an expected call of StartService. +func (mr *MockServiceRunnerMockRecorder) StartService(ctx, config interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PathExists", reflect.TypeOf((*MockMounter)(nil).PathExists), path) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartService", reflect.TypeOf((*MockServiceRunner)(nil).StartService), ctx, config) } -// Unmount mocks base method. -func (m *MockMounter) Unmount(target string) error { +// MockMountLister is a mock of MountLister interface. +type MockMountLister struct { + ctrl *gomock.Controller + recorder *MockMountListerMockRecorder +} + +// MockMountListerMockRecorder is the mock recorder for MockMountLister. +type MockMountListerMockRecorder struct { + mock *MockMountLister +} + +// NewMockMountLister creates a new mock instance. +func NewMockMountLister(ctrl *gomock.Controller) *MockMountLister { + mock := &MockMountLister{ctrl: ctrl} + mock.recorder = &MockMountListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMountLister) EXPECT() *MockMountListerMockRecorder { + return m.recorder +} + +// ListMounts mocks base method. +func (m *MockMountLister) ListMounts() ([]mount.MountPoint, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Unmount", target) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "ListMounts") + ret0, _ := ret[0].([]mount.MountPoint) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// Unmount indicates an expected call of Unmount. -func (mr *MockMounterMockRecorder) Unmount(target interface{}) *gomock.Call { +// ListMounts indicates an expected call of ListMounts. +func (mr *MockMountListerMockRecorder) ListMounts() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmount", reflect.TypeOf((*MockMounter)(nil).Unmount), target) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListMounts", reflect.TypeOf((*MockMountLister)(nil).ListMounts)) } diff --git a/pkg/driver/mount.go b/pkg/driver/mount.go index b3e8f76..37e1661 100644 --- a/pkg/driver/mount.go +++ b/pkg/driver/mount.go @@ -39,27 +39,94 @@ const ( defaultRegionEnv = "AWS_DEFAULT_REGION" stsEndpointsEnv = "AWS_STS_REGIONAL_ENDPOINTS" MountS3PathEnv = "MOUNT_S3_PATH" - hostTokenPathEnv = "HOST_TOKEN_PATH" defaultMountS3Path = "/usr/bin/mount-s3" procMounts = "/host/proc/mounts" userAgentPrefix = "--user-agent-prefix" csiDriverPrefix = "s3-csi-driver/" ) +type MountCredentials struct { + AccessKeyID string + SecretAccessKey string + Region string + DefaultRegion string + WebTokenPath string + StsEndpoints string + AwsRoleArn string +} + +// Get environment variables to pass to mount-s3 for authentication. +func (mc *MountCredentials) Env() []string { + env := []string{} + + if mc.AccessKeyID != "" && mc.SecretAccessKey != "" { + env = append(env, keyIdEnv+"="+mc.AccessKeyID) + env = append(env, accessKeyEnv+"="+mc.SecretAccessKey) + } + if mc.WebTokenPath != "" { + env = append(env, webIdentityTokenEnv+"="+mc.WebTokenPath) + env = append(env, roleArnEnv+"="+mc.AwsRoleArn) + } + if mc.Region != "" { + env = append(env, regionEnv+"="+mc.Region) + } + if mc.DefaultRegion != "" { + env = append(env, defaultRegionEnv+"="+mc.DefaultRegion) + } + if mc.StsEndpoints != "" { + env = append(env, stsEndpointsEnv+"="+mc.StsEndpoints) + } + + return env +} + +type Fs interface { + Stat(name string) (os.FileInfo, error) + MkdirAll(path string, perm os.FileMode) error + Remove(name string) error +} + +type OsFs struct{} + +func (OsFs) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +func (OsFs) MkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} + +func (OsFs) Remove(path string) error { + return os.Remove(path) +} + // Mounter is an interface for mount operations type Mounter interface { - mount.Interface - IsCorruptedMnt(err error) bool - PathExists(path string) (bool, error) - MakeDir(pathname string) error + Mount(bucketName string, target string, credentials *MountCredentials, options []string) error + Unmount(target string) error + IsMountPoint(target string) (bool, error) +} + +type ServiceRunner interface { + StartService(ctx context.Context, config *system.ExecConfig) (string, error) + RunOneshot(ctx context.Context, config *system.ExecConfig) (string, error) +} + +type MountLister interface { + ListMounts() ([]mount.MountPoint, error) +} + +type ProcMountLister struct { + ProcMountPath string } type S3Mounter struct { - mount.Interface - ctx context.Context - supervisor *system.SystemdSupervisor - mpVersion string - mountS3Path string + Ctx context.Context + Runner ServiceRunner + Fs Fs + MountLister MountLister + MpVersion string + MountS3Path string } func MountS3Path() string { @@ -72,74 +139,110 @@ func MountS3Path() string { func NewS3Mounter(mpVersion string) (*S3Mounter, error) { ctx := context.Background() - supervisor, err := system.StartOsSystemdSupervisor() + runner, err := system.StartOsSystemdSupervisor() if err != nil { return nil, fmt.Errorf("failed to start systemd supervisor: %w", err) } return &S3Mounter{ - Interface: mount.New(""), - ctx: ctx, - supervisor: supervisor, - mpVersion: mpVersion, - mountS3Path: MountS3Path(), + Ctx: ctx, + Runner: runner, + Fs: &OsFs{}, + MountLister: &ProcMountLister{ProcMountPath: procMounts}, + MpVersion: mpVersion, + MountS3Path: MountS3Path(), }, nil } -func (m *S3Mounter) MakeDir(pathname string) error { - err := os.MkdirAll(pathname, os.FileMode(0755)) - if err != nil { - if !os.IsExist(err) { - return err - } - } - return nil -} - -// IsCorruptedMnt return true if err is about corrupted mount point -func (m *S3Mounter) IsCorruptedMnt(err error) bool { - return mount.IsCorruptedMnt(err) -} - -func (m *S3Mounter) List() ([]mount.MountPoint, error) { - mounts, err := os.ReadFile(procMounts) +func (pml *ProcMountLister) ListMounts() ([]mount.MountPoint, error) { + mounts, err := os.ReadFile(pml.ProcMountPath) if err != nil { return nil, fmt.Errorf("Failed to read %s: %w", procMounts, err) } return parseProcMounts(mounts) } -func (m *S3Mounter) IsMountPoint(file string) (bool, error) { - mountPoints, err := m.List() +func (m *S3Mounter) IsMountPoint(target string) (bool, error) { + if _, err := m.Fs.Stat(target); os.IsNotExist(err) { + return false, err + } + + mountPoints, err := m.MountLister.ListMounts() if err != nil { - return false, fmt.Errorf("Failed to cat /proc/mounts: %w", err) + return false, fmt.Errorf("Failed to list mounts: %w", err) } for _, mp := range mountPoints { - if mp.Path == file { + if mp.Path == target { return true, nil } } return false, nil } -func (m *S3Mounter) PathExists(path string) (bool, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return false, nil - } else if err != nil { - return false, err - } - return true, nil -} +// Mount the given bucket at the target path. Options will be passed through mostly unchanged, +// with the exception of the user agent prefix which will be added to the Mountpoint headers. +// This method will create the target path if it does not exist and if there is an existing corrupt +// mount, it will attempt an unmount before attempting the mount. +func (m *S3Mounter) Mount(bucketName string, target string, + credentials *MountCredentials, options []string) error { -func (m *S3Mounter) Mount(source string, target string, _ string, options []string) error { - timeoutCtx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + if bucketName == "" { + return fmt.Errorf("bucket name is empty") + } + if target == "" { + return fmt.Errorf("target is empty") + } + timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - env := passthroughEnv() - output, err := m.supervisor.StartService(timeoutCtx, &system.ExecConfig{ - Name: "mount-s3-" + m.mpVersion + "-" + uuid.New().String() + ".service", + cleanupDir := false + + // check if the target path exists + _, statErr := m.Fs.Stat(target) + if statErr != nil { + // does not exist, create the directory + if os.IsNotExist(statErr) { + if err := m.Fs.MkdirAll(target, 0755); err != nil { + return fmt.Errorf("Failed to create target directory: %w", err) + } + cleanupDir = true + defer func() { + if cleanupDir { + if err := m.Fs.Remove(target); err != nil { + klog.V(4).Infof("Mount: Failed to delete target dir: %s.", target) + } + } + }() + } + // Corrupted mount, try unmounting + if mount.IsCorruptedMnt(statErr) { + klog.V(4).Infof("Mount: Target path %q is a corrupted mount. Trying to unmount.", target) + if mntErr := m.Unmount(target); mntErr != nil { + return fmt.Errorf("Unable to unmount the target %q : %v, %v", target, statErr, mntErr) + } + } + } + + mounts, err := m.MountLister.ListMounts() + if err != nil { + return fmt.Errorf("Could not check if %q is a mount point: %v, %v", target, statErr, err) + } + for _, m := range mounts { + if m.Path == target { + klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) + return nil + } + } + + env := []string{} + if credentials != nil { + env = credentials.Env() + } + + output, err := m.Runner.StartService(timeoutCtx, &system.ExecConfig{ + Name: "mount-s3-" + m.MpVersion + "-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver FUSE daemon", - ExecPath: m.mountS3Path, - Args: append(addUserAgentToOptions(options), source, target), + ExecPath: m.MountS3Path, + Args: append(addUserAgentToOptions(options), bucketName, target), Env: env, }) @@ -149,6 +252,7 @@ func (m *S3Mounter) Mount(source string, target string, _ string, options []stri if output != "" { klog.V(5).Infof("mount-s3 output: %s", output) } + cleanupDir = false return nil } @@ -166,10 +270,10 @@ func addUserAgentToOptions(options []string) []string { } func (m *S3Mounter) Unmount(target string) error { - timeoutCtx, cancel := context.WithTimeout(m.ctx, 30*time.Second) + timeoutCtx, cancel := context.WithTimeout(m.Ctx, 30*time.Second) defer cancel() - output, err := m.supervisor.RunOneshot(timeoutCtx, &system.ExecConfig{ + output, err := m.Runner.RunOneshot(timeoutCtx, &system.ExecConfig{ Name: "mount-s3-umount-" + uuid.New().String() + ".service", Description: "Mountpoint for Amazon S3 CSI driver unmount", ExecPath: "/usr/bin/umount", @@ -184,42 +288,6 @@ func (m *S3Mounter) Unmount(target string) error { return nil } -func passthroughEnv() []string { - env := []string{} - - keyId := os.Getenv(keyIdEnv) - accessKey := os.Getenv(accessKeyEnv) - if keyId != "" && accessKey != "" { - env = append(env, keyIdEnv+"="+keyId) - env = append(env, accessKeyEnv+"="+accessKey) - } - webIdentityFile := os.Getenv(webIdentityTokenEnv) - awsRoleArn := os.Getenv(roleArnEnv) - hostTokenPath := os.Getenv(hostTokenPathEnv) - if hostTokenPath == "" { - // set the default in case the env variable isn't found - hostTokenPath = "/var/lib/kubelet/plugins/s3.csi.aws.com/token" - } - if webIdentityFile != "" { - env = append(env, webIdentityTokenEnv+"="+hostTokenPath) - env = append(env, roleArnEnv+"="+awsRoleArn) - } - region := os.Getenv(regionEnv) - if region != "" { - env = append(env, regionEnv+"="+region) - } - defaultRegion := os.Getenv(defaultRegionEnv) - if defaultRegion != "" { - env = append(env, defaultRegionEnv+"="+defaultRegion) - } - stsEndpoints := os.Getenv(stsEndpointsEnv) - if stsEndpoints != "" { - env = append(env, stsEndpointsEnv+"="+stsEndpoints) - } - - return env -} - func parseProcMounts(data []byte) ([]mount.MountPoint, error) { var mounts []mount.MountPoint diff --git a/pkg/driver/mount_test.go b/pkg/driver/mount_test.go index 0ccd604..640218f 100644 --- a/pkg/driver/mount_test.go +++ b/pkg/driver/mount_test.go @@ -1,31 +1,179 @@ -package driver +package driver_test import ( + "context" + "errors" + "strings" "testing" - "github.com/stretchr/testify/assert" + driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver" + mock_driver "github.com/awslabs/aws-s3-csi-driver/pkg/driver/mocks" + "github.com/awslabs/aws-s3-csi-driver/pkg/system" + "github.com/golang/mock/gomock" + "k8s.io/mount-utils" ) -func TestUserAgentPrefix(t *testing.T) { +type TestMountLister struct { + Mounts []mount.MountPoint + Err error +} + +func (l *TestMountLister) ListMounts() ([]mount.MountPoint, error) { + return l.Mounts, l.Err +} + +type mounterTestEnv struct { + ctx context.Context + mockCtl *gomock.Controller + mockRunner *mock_driver.MockServiceRunner + mockFs *mock_driver.MockFs + mockMountLister *mock_driver.MockMountLister + mounter *driver.S3Mounter +} + +func initMounterTestEnv(t *testing.T) *mounterTestEnv { + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockRunner := mock_driver.NewMockServiceRunner(mockCtl) + mockFs := mock_driver.NewMockFs(mockCtl) + mockMountLister := mock_driver.NewMockMountLister(mockCtl) + mountpointVersion := "TEST_MP_VERSION-v1.1" + + return &mounterTestEnv{ + ctx: ctx, + mockCtl: mockCtl, + mockRunner: mockRunner, + mockFs: mockFs, + mockMountLister: mockMountLister, + mounter: &driver.S3Mounter{ + Ctx: ctx, + Runner: mockRunner, + Fs: mockFs, + MountLister: mockMountLister, + MpVersion: mountpointVersion, + MountS3Path: driver.MountS3Path(), + }, + } +} + +func TestS3MounterMount(t *testing.T) { + testBucketName := "test-bucket" + testTargetPath := "/mnt/my-mountpoint/bucket/" + testCredentials := &driver.MountCredentials{ + AccessKeyID: "test-access-key", + SecretAccessKey: "test-secret-key", + Region: "test-region", + DefaultRegion: "test-region", + WebTokenPath: "test-web-token-path", + StsEndpoints: "test-sts-endpoint", + AwsRoleArn: "test-aws-role", + } + testCases := []struct { - name string - input []string - expected []string + name string + bucketName string + targetPath string + credentials *driver.MountCredentials + options []string + expectedErr bool + before func(*testing.T, *mounterTestEnv) }{ { - name: "success: add user agent prefix to mount call", - input: []string{"--read-only"}, - expected: []string{"--read-only", userAgentPrefix + "=" + csiDriverPrefix + GetVersion().DriverVersion}, + name: "success: mounts without empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, + }, + { + name: "success: mounts with nil credentials", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, + }, + { + name: "success: replaces user agent prefix", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{"--user-agent-prefix=mycustomuseragent"}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, config *system.ExecConfig) (string, error) { + for _, a := range config.Args { + if strings.Contains(a, "mycustomuseragent") { + t.Fatal("Bad user agent") + } + } + return "success", nil + }) + }, + }, + { + name: "failure: fails on mount failure", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: nil, + options: []string{}, + expectedErr: true, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("fail", errors.New("test failure")) + }, + }, + { + name: "failure: won't mount empty bucket name", + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + expectedErr: true, + }, + { + name: "failure: won't mount empty target", + bucketName: testBucketName, + credentials: testCredentials, + options: []string{}, + expectedErr: true, }, { - name: "success: replacing customer user agent prefix", - input: []string{"--read-only", "--user-agent-prefix testing"}, - expected: []string{"--read-only", userAgentPrefix + "=" + csiDriverPrefix + GetVersion().DriverVersion}, + name: "failure: mounts without empty options", + bucketName: testBucketName, + targetPath: testTargetPath, + credentials: testCredentials, + options: []string{}, + before: func(t *testing.T, env *mounterTestEnv) { + env.mockFs.EXPECT().Stat(gomock.Any()).Return(nil, nil) + env.mockMountLister.EXPECT().ListMounts().Return(nil, nil) + env.mockRunner.EXPECT().StartService(gomock.Any(), gomock.Any()).Return("success", nil) + }, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, addUserAgentToOptions(tc.input), tc.expected) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + env := initMounterTestEnv(t) + if testCase.before != nil { + testCase.before(t, env) + } + err := env.mounter.Mount(testCase.bucketName, testCase.targetPath, + testCase.credentials, testCase.options) + env.mockCtl.Finish() + if err != nil && !testCase.expectedErr { + t.Fatal(err) + } }) } } diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 3cf6342..724a7cb 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -19,34 +19,50 @@ package driver import ( "context" "os" + "path" "strings" + "time" "github.com/container-storage-interface/spec/lib/go/csi" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" + "k8s.io/mount-utils" ) const ( - fstype = "unused" - bucketName = "bucketName" + hostTokenDirEnv = "HOST_TOKEN_DIR" + fstype = "unused" + bucketName = "bucketName" ) var ( nodeCaps = []csi.NodeServiceCapability_RPC_Type{} ) -func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { +// S3NodeServer is the implementation of the csi.NodeServer interface +type S3NodeServer struct { + NodeID string + BaseCredentials *MountCredentials + Mounter Mounter +} + +type Token struct { + Token string `json:"token"` + ExpirationTimestamp time.Time `json:"expirationTimestamp"` +} + +func (ns *S3NodeServer) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { +func (ns *S3NodeServer) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { - klog.V(4).Infof("NodePublishVolume: called with args %+v", req) +func (ns *S3NodeServer) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { + klog.V(4).Infof("NodePublishVolume: req: %+v", req) volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -68,21 +84,15 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu return nil, status.Error(codes.InvalidArgument, "Volume capability not provided") } - if !d.isValidVolumeCapabilities([]*csi.VolumeCapability{volCap}) { + if !ns.isValidVolumeCapabilities([]*csi.VolumeCapability{volCap}) { return nil, status.Error(codes.InvalidArgument, "Volume capability not supported") } mountpointArgs := []string{} - if req.GetReadonly() || volCap.GetAccessMode().GetMode() == csi.VolumeCapability_AccessMode_MULTI_NODE_READER_ONLY { mountpointArgs = append(mountpointArgs, "--read-only") } - klog.V(4).Infof("NodePublishVolume: creating dir %s", target) - if err := d.Mounter.MakeDir(target); err != nil { - return nil, status.Errorf(codes.Internal, "Could not create dir %q: %v", target, err) - } - // get the mount(point) options (yaml list) if capMount := volCap.GetMount(); capMount != nil { mountFlags := capMount.GetMountFlags() @@ -99,19 +109,27 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu mountpointArgs = compileMountOptions(mountpointArgs, mountFlags) } - //Checking if the target directory is already mounted with a volume. - mounted, err := d.isMounted(target) - if err != nil { - return nil, status.Errorf(codes.Internal, "Could not check if %q is mounted: %v", target, err) + klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) + hostTokenDir := os.Getenv(hostTokenDirEnv) + if hostTokenDir == "" { + // set the default in case the env variable isn't found + hostTokenDir = "/var/lib/kubelet/plugins/s3.csi.aws.com/" } - if !mounted { - klog.V(4).Infof("NodePublishVolume: mounting %s at %s with options %v", bucket, target, mountpointArgs) - if err := d.Mounter.Mount(bucket, target, fstype, mountpointArgs); err != nil { - os.Remove(target) - return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) - } - klog.V(4).Infof("NodePublishVolume: %s was mounted", target) + hostTokenPath := path.Join(hostTokenDir, "token") + credentials := &MountCredentials{ + AccessKeyID: os.Getenv(keyIdEnv), + SecretAccessKey: os.Getenv(accessKeyEnv), + Region: os.Getenv(regionEnv), + DefaultRegion: os.Getenv(defaultRegionEnv), + WebTokenPath: hostTokenPath, + StsEndpoints: os.Getenv(stsEndpointsEnv), + AwsRoleArn: os.Getenv(roleArnEnv), + } + if err := ns.Mounter.Mount(bucket, target, credentials, mountpointArgs); err != nil { + os.Remove(target) + return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", bucket, target, err) } + klog.V(4).Infof("NodePublishVolume: %s was mounted", target) return &csi.NodePublishVolumeResponse{}, nil } @@ -140,7 +158,7 @@ func compileMountOptions(currentOptions []string, newOptions []string) []string return allMountOptions.List() } -func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { +func (ns *S3NodeServer) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { klog.V(4).Infof("NodeUnpublishVolume: called with args %+v", req) volumeID := req.GetVolumeId() @@ -152,11 +170,11 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish return nil, status.Error(codes.InvalidArgument, "Target path not provided") } - mounted, err := d.Mounter.IsMountPoint(target) + mounted, err := ns.Mounter.IsMountPoint(target) if err != nil && os.IsNotExist(err) { klog.V(4).Infof("NodeUnpublishVolume: target path %s does not exist, skipping unmount", target) return &csi.NodeUnpublishVolumeResponse{}, nil - } else if err != nil && d.Mounter.IsCorruptedMnt(err) { + } else if err != nil && mount.IsCorruptedMnt(err) { klog.V(4).Infof("NodeUnpublishVolume: target path %s is corrupted: %v, will try to unmount", target, err) mounted = true } else if err != nil { @@ -168,7 +186,7 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish } klog.V(4).Infof("NodeUnpublishVolume: unmounting %s", target) - err = d.Mounter.Unmount(target) + err = ns.Mounter.Unmount(target) if err != nil { return nil, status.Errorf(codes.Internal, "Could not unmount %q: %v", target, err) } @@ -176,15 +194,15 @@ func (d *Driver) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpublish return &csi.NodeUnpublishVolumeResponse{}, nil } -func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { +func (ns *S3NodeServer) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeStatsRequest) (*csi.NodeGetVolumeStatsResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { +func (ns *S3NodeServer) NodeExpandVolume(ctx context.Context, req *csi.NodeExpandVolumeRequest) (*csi.NodeExpandVolumeResponse, error) { return nil, status.Error(codes.Unimplemented, "") } -func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { +func (ns *S3NodeServer) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { klog.V(4).Infof("NodeGetCapabilities: called with args %+v", req) var caps []*csi.NodeServiceCapability for _, cap := range nodeCaps { @@ -200,44 +218,15 @@ func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabi return &csi.NodeGetCapabilitiesResponse{Capabilities: caps}, nil } -func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { +func (ns *S3NodeServer) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { klog.V(4).Infof("NodeGetInfo: called with args %+v", req) return &csi.NodeGetInfoResponse{ - NodeId: d.NodeID, + NodeId: ns.NodeID, }, nil } -// isMounted checks if target is a valid mountpoint -// inexistent target directory is NOT an error -// method will try to unmount the directory if it was detected to be corrupted -func (d *Driver) isMounted(target string) (bool, error) { - notMnt, err := d.Mounter.IsLikelyNotMountPoint(target) - if err != nil && !os.IsNotExist(err) { - _, pathErr := d.Mounter.PathExists(target) - if pathErr != nil && d.Mounter.IsCorruptedMnt(pathErr) { - klog.V(4).Infof("NodePublishVolume: Target path %q is a corrupted mount. Trying to unmount.", target) - if mntErr := d.Mounter.Unmount(target); mntErr != nil { - return false, status.Errorf(codes.Internal, "Unable to unmount the target %q : %v", target, mntErr) - } - return false, nil - } - return false, status.Errorf(codes.Internal, "Could not check if %q is a mount point: %v, %v", target, err, pathErr) - } - - if err != nil && os.IsNotExist(err) { - klog.V(5).Infof("[Debug] NodePublishVolume: Target path %q does not exist", target) - return false, nil - } - - if !notMnt { - klog.V(4).Infof("NodePublishVolume: Target path %q is already mounted", target) - } - - return !notMnt, nil -} - -func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { +func (ns *S3NodeServer) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { hasSupport := func(cap *csi.VolumeCapability) bool { for _, c := range volumeCaps { if c.GetMode() == cap.AccessMode.GetMode() { diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index e02a482..7cf3832 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -15,22 +15,22 @@ import ( type nodeServerTestEnv struct { mockCtl *gomock.Controller mockMounter *mock_driver.MockMounter - driver *driver.Driver + server *driver.S3NodeServer } func initNodeServerTestEnv(t *testing.T) *nodeServerTestEnv { mockCtl := gomock.NewController(t) defer mockCtl.Finish() mockMounter := mock_driver.NewMockMounter(mockCtl) - driver := &driver.Driver{ - Endpoint: "unix://tmp/csi.sock", - NodeID: "test-nodeID", - Mounter: mockMounter, + server := &driver.S3NodeServer{ + NodeID: "test-nodeID", + BaseCredentials: &driver.MountCredentials{}, + Mounter: mockMounter, } return &nodeServerTestEnv{ mockCtl: mockCtl, mockMounter: mockMounter, - driver: driver, + server: server, } } @@ -64,10 +64,8 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(false, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Any()) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Any()) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -94,10 +92,8 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(false, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Eq([]string{"--read-only"})) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only"})) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -127,10 +123,8 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(true, nil) - nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + nodeTestEnv.mockMounter.EXPECT().Mount(gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--bar", "--foo", "--read-only", "--test=123"})) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -160,12 +154,10 @@ func TestNodePublishVolume(t *testing.T) { Readonly: true, } - nodeTestEnv.mockMounter.EXPECT().MakeDir(gomock.Eq(targetPath)).Return(nil) - nodeTestEnv.mockMounter.EXPECT().IsLikelyNotMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Mount( - gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Eq("unused"), + gomock.Eq(bucketName), gomock.Eq(targetPath), gomock.Any(), gomock.Eq([]string{"--read-only", "--test=123"})).Return(nil) - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -184,7 +176,7 @@ func TestNodePublishVolume(t *testing.T) { VolumeContext: map[string]string{"bucketName": bucketName}, } - _, err := nodeTestEnv.driver.NodePublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodePublishVolume(ctx, req) if err == nil { t.Fatalf("NodePublishVolume is failed: %v", err) } @@ -219,29 +211,7 @@ func TestNodeUnpublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) - if err != nil { - t.Fatalf("NodePublishVolume failed: %v", err) - } - - nodeTestEnv.mockCtl.Finish() - }, - }, - { - name: "success: corrupted volume", - testFunc: func(t *testing.T) { - nodeTestEnv := initNodeServerTestEnv(t) - ctx := context.Background() - req := &csi.NodeUnpublishVolumeRequest{ - VolumeId: volumeId, - TargetPath: targetPath, - } - - expectedErr := errors.New("") - nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, expectedErr) - nodeTestEnv.mockMounter.EXPECT().IsCorruptedMnt(expectedErr).Return(true) - nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -260,7 +230,7 @@ func TestNodeUnpublishVolume(t *testing.T) { } nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, nil) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -280,7 +250,7 @@ func TestNodeUnpublishVolume(t *testing.T) { nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(true, nil) nodeTestEnv.mockMounter.EXPECT().Unmount(gomock.Eq(targetPath)).Return(errors.New("")) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err == nil { t.Fatalf("NodePublishVolume must fail") } @@ -300,7 +270,7 @@ func TestNodeUnpublishVolume(t *testing.T) { expectedError := fs.ErrNotExist nodeTestEnv.mockMounter.EXPECT().IsMountPoint(gomock.Eq(targetPath)).Return(false, expectedError) - _, err := nodeTestEnv.driver.NodeUnpublishVolume(ctx, req) + _, err := nodeTestEnv.server.NodeUnpublishVolume(ctx, req) if err != nil { t.Fatalf("NodePublishVolume failed: %v", err) } @@ -314,3 +284,21 @@ func TestNodeUnpublishVolume(t *testing.T) { t.Run(tc.name, tc.testFunc) } } + +func TestNodeGetCapabilities(t *testing.T) { + nodeTestEnv := initNodeServerTestEnv(t) + ctx := context.Background() + req := &csi.NodeGetCapabilitiesRequest{} + + resp, err := nodeTestEnv.server.NodeGetCapabilities(ctx, req) + if err != nil { + t.Fatalf("NodeGetCapabilities failed: %v", err) + } + + capabilities := resp.GetCapabilities() + if len(capabilities) != 0 { + t.Fatalf("NodeGetCapabilities failed: capabilities not empty") + } + + nodeTestEnv.mockCtl.Finish() +} diff --git a/pkg/util/util.go b/pkg/driver/server.go similarity index 55% rename from pkg/util/util.go rename to pkg/driver/server.go index ab71adc..3886a9b 100644 --- a/pkg/util/util.go +++ b/pkg/driver/server.go @@ -1,20 +1,4 @@ -/* -Copyright 2022 The Kubernetes Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package util +package driver import ( "fmt" diff --git a/pkg/system/systemd_test.go b/pkg/system/systemd_test.go index 00605c2..ed39b94 100644 --- a/pkg/system/systemd_test.go +++ b/pkg/system/systemd_test.go @@ -165,11 +165,9 @@ func TestSystemdConnection(t *testing.T) { } }() }) - _, err := conn.ListUnits(ctx) - if err == nil { + if _, err := conn.ListUnits(ctx); err == nil { t.Fatalf("Expected error, got nil") } - }, }, } diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index acd6580..94b5618 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -28,7 +28,6 @@ import ( sanity "github.com/kubernetes-csi/csi-test/pkg/sanity" "github.com/awslabs/aws-s3-csi-driver/pkg/driver" - "github.com/awslabs/aws-s3-csi-driver/pkg/util" ) const ( @@ -66,7 +65,15 @@ func waitDriverIsUp(endpoint string) { } var _ = BeforeSuite(func() { - s3Driver = driver.NewFakeDriver(endpoint) + s3Driver = &driver.Driver{ + Endpoint: endpoint, + NodeID: "fake_id", + NodeServer: &driver.S3NodeServer{ + NodeID: "fake_id", + BaseCredentials: &driver.MountCredentials{}, + Mounter: &driver.FakeMounter{}, + }, + } go func() { Expect(s3Driver.Run()).NotTo(HaveOccurred()) }() @@ -84,7 +91,7 @@ var _ = Describe("Amazon S3 CSI Driver", func() { Address: endpoint, TargetPath: mountPath, StagingPath: stagePath, - TestVolumeSize: 2000 * util.GiB, + TestVolumeSize: 2000 * driver.GiB, IDGen: &sanity.DefaultIDGenerator{}, } sanity.GinkgoTest(config)