forked from drakkan/sftpgo
/
subsystem.go
87 lines (74 loc) · 2.52 KB
/
subsystem.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package sftpd
import (
"io"
"net"
"github.com/pkg/sftp"
"github.com/aKardasz/sftpgo/v2/internal/common"
"github.com/aKardasz/sftpgo/v2/internal/dataprovider"
"github.com/aKardasz/sftpgo/v2/internal/logger"
)
type subsystemChannel struct {
reader io.Reader
writer io.Writer
}
func (s *subsystemChannel) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *subsystemChannel) Write(p []byte) (int, error) {
return s.writer.Write(p)
}
func (s *subsystemChannel) Close() error {
return nil
}
func newSubsystemChannel(reader io.Reader, writer io.Writer) *subsystemChannel {
return &subsystemChannel{
reader: reader,
writer: writer,
}
}
// ServeSubSystemConnection handles a connection as SSH subsystem
func ServeSubSystemConnection(user *dataprovider.User, connectionID string, reader io.Reader, writer io.Writer) error {
err := user.CheckFsRoot(connectionID)
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose)
return err
}
connection := &Connection{
BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolSFTP, "", "", *user),
ClientVersion: "",
RemoteAddr: &net.IPAddr{},
LocalAddr: &net.IPAddr{},
channel: newSubsystemChannel(reader, writer),
}
err = common.Connections.Add(connection)
if err != nil {
errClose := user.CloseFs()
logger.Warn(logSender, connectionID, "unable to add connection: %v close fs error: %v", err, errClose)
return err
}
defer common.Connections.Remove(connection.GetID())
dataprovider.UpdateLastLogin(user)
sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck
server := sftp.NewRequestServer(connection.channel, sftp.Handlers{
FileGet: connection,
FilePut: connection,
FileCmd: connection,
FileList: connection,
}, sftp.WithRSAllocator())
defer server.Close()
return server.Serve()
}