/
packet.go
143 lines (117 loc) · 3.66 KB
/
packet.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// SPDX-FileCopyrightText: 2020 SAP SE
// SPDX-FileCopyrightText: 2021 SAP SE
// SPDX-FileCopyrightText: 2022 SAP SE
// SPDX-FileCopyrightText: 2023 SAP SE
//
// SPDX-License-Identifier: Apache-2.0
package tds
import (
"context"
"errors"
"fmt"
"io"
"time"
)
var (
ErrEOFAfterZeroRead = errors.New("received io.EOF after reading 0 bytes")
)
// Packet represents a single packet in a message.
type Packet struct {
Header PacketHeader
Data []byte
}
// NewPacket creates a packet-struct.
func NewPacket(packetSize int) *Packet {
packet := &Packet{}
packet.Header = NewPacketHeader(packetSize)
packet.Data = make([]byte, packetSize-PacketHeaderSize)
return packet
}
// Bytes returns a byte slice of a packet.
func (packet Packet) Bytes() ([]byte, error) {
bs := make([]byte, int(packet.Header.Length))
if _, err := packet.Header.Read(bs[:PacketHeaderSize]); err != nil {
return nil, fmt.Errorf("error reading header into byte slice: %w", err)
}
copy(bs[PacketHeaderSize:], packet.Data)
return bs, nil
}
// ReadFrom reads the packet-data and returns the amount of read bytes.
func (packet *Packet) ReadFrom(ctx context.Context, reader io.Reader, timeout time.Duration) (int64, error) {
var totalBytes int64
packet.Header = PacketHeader{}
n, err := packet.Header.ReadFrom(reader)
if err != nil {
return n, fmt.Errorf("failed to read header: %w", err)
}
totalBytes += n
packet.Data = make([]byte, packet.Header.Length-PacketHeaderSize)
// The timeout will be refreshed (replaced) on every successful
// read. This is done so the timeout only triggers if there was
// actually no data read from the server to prevent failures when
// the PDU is split over multiple responses and the responses
// themselves are arriving slowly due to the network (e.g. erroneous
// scheduling, packet inspection, overloaded firewalls, etc.pp.).
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
if err := ctx.Err(); err != nil {
return totalBytes, err
}
m, err := reader.Read(packet.Data[totalBytes-n:])
totalBytes += int64(m)
if m > 0 {
timeoutCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
if err != nil {
if errors.Is(err, io.EOF) {
// Check if the timeout was exceeded _and_ if the last
// read returned 0 bytes. So the timeout may be
// exceeded but isn't triggered until the packet also
// can't read any more data.
if err := timeoutCtx.Err(); err != nil && m == 0 {
return totalBytes, ErrEOFAfterZeroRead
}
// The PDU is split over multiple responses
if totalBytes != int64(packet.Header.Length) {
continue
}
if packet.Header.MsgType == TDS_BUF_CLOSE {
return totalBytes, err
}
}
return totalBytes, fmt.Errorf("error reading body: %w", err)
}
if totalBytes == int64(packet.Header.Length) {
// Read the expected amount of bytes
break
}
}
return totalBytes, nil
}
// WriteTo writes a byte slice and returns the amount of written bytes.
func (packet Packet) WriteTo(writer io.Writer) (int64, error) {
bs, err := packet.Bytes()
if err != nil {
return 0, fmt.Errorf("error compiling packet bytes: %w", err)
}
n, err := writer.Write(bs)
return int64(n), err
}
func (packet Packet) String() string {
strHeaderStatus := deBitmaskString(int(packet.Header.Status), int(TDS_BUFSTAT_SYMENCRYPT),
func(i int) string { return PacketHeaderStatus(i).String() },
"no status",
)
return fmt.Sprintf(
"Type: %s, Status: %s, Length: %d, Channel: %d, PacketNr: %d, Window: %d, DataLen: %d",
packet.Header.MsgType,
strHeaderStatus,
packet.Header.Length,
packet.Header.Channel,
packet.Header.PacketNr,
packet.Header.Window,
len(packet.Data),
)
}