-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f38bf5a
commit c8f7288
Showing
4 changed files
with
370 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package client | ||
|
||
import ( | ||
"log" | ||
"net" | ||
"time" | ||
|
||
"github.com/brianlusina/tclient/pkg/bitfield" | ||
"github.com/brianlusina/tclient/pkg/message" | ||
"github.com/brianlusina/tclient/pkg/peers" | ||
) | ||
|
||
// Client is a TCP connection with a Peer | ||
type Client struct { | ||
Conn net.Conn | ||
Choked bool | ||
Bitfield bitfield.Bitfield | ||
peer peers.Peer | ||
infoHash [20]byte | ||
peerID [20]byte | ||
} | ||
|
||
// New connects with a peer, completes a handshake, and receives a handshake returns an error if any of those fail | ||
func New(peer peers.Peer, peerID, infoHash [20]byte) (*Client, error) { | ||
conn, err := net.DialTimeout("tcp", peer.String(), 3*time.Second) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
_, err = completeHandshake(conn, infoHash, peerID) | ||
if err != nil { | ||
if e := conn.Close(); e != nil { | ||
log.Fatalf("Failed to close connection %s", e) | ||
} | ||
return nil, err | ||
} | ||
|
||
bf, err := recvBitfield(conn) | ||
if err != nil { | ||
if e := conn.Close(); e != nil { | ||
log.Fatalf("Failed to close connection %s", e) | ||
} | ||
return nil, err | ||
} | ||
|
||
return &Client{ | ||
Conn: conn, | ||
Choked: true, | ||
Bitfield: bf, | ||
peer: peer, | ||
infoHash: infoHash, | ||
peerID: peerID, | ||
}, nil | ||
} | ||
|
||
// Read reads and consumes a message from the connection | ||
func (c *Client) Read() (*message.Message, error) { | ||
msg, err := message.Read(c.Conn) | ||
return msg, err | ||
} | ||
|
||
// SendRequest sends a request message to the peer | ||
func (c *Client) SendRequest(index, begin, length int) error { | ||
req := message.FormatRequest(index, begin, length) | ||
_, err := c.Conn.Write(req.Serialize()) | ||
return err | ||
} | ||
|
||
// SendInterested sends an Interested message to the peer | ||
func (c *Client) SendInterested() error { | ||
msg := message.Message{ID: message.MsgInterested} | ||
_, err := c.Conn.Write(msg.Serialize()) | ||
return err | ||
} | ||
|
||
// SendNotInterested sends a NotInterested message to the peer | ||
func (c *Client) SendNotInterested() error { | ||
msg := message.Message{ID: message.MsgNotInterested} | ||
_, err := c.Conn.Write(msg.Serialize()) | ||
return err | ||
} | ||
|
||
// SendUnchoke sends an Unchoke message to the peer | ||
func (c *Client) SendUnchoke() error { | ||
msg := message.Message{ID: message.MsgUnchoke} | ||
_, err := c.Conn.Write(msg.Serialize()) | ||
return err | ||
} | ||
|
||
// SendHave sends a Have message to the peer | ||
func (c *Client) SendHave(index int) error { | ||
msg := message.FormatHave(index) | ||
_, err := c.Conn.Write(msg.Serialize()) | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package client | ||
|
||
import ( | ||
"net" | ||
"testing" | ||
|
||
"github.com/brianlusina/tclient/pkg/message" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func createClientAndServer(t *testing.T) (clientConn, serverConn net.Conn) { | ||
ln, err := net.Listen("tcp", "127.0.0.1:0") | ||
require.Nil(t, err) | ||
|
||
// net.Dial does not block, so we need this signalling channel to make sure | ||
// we don't return before serverConn is ready | ||
done := make(chan struct{}) | ||
go func() { | ||
defer ln.Close() | ||
serverConn, err = ln.Accept() | ||
require.Nil(t, err) | ||
done <- struct{}{} | ||
}() | ||
|
||
clientConn, err = net.Dial("tcp", ln.Addr().String()) | ||
<-done | ||
|
||
return clientConn, serverConn | ||
} | ||
|
||
func TestRead(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
|
||
msgBytes := []byte{ | ||
0x00, 0x00, 0x00, 0x05, | ||
4, | ||
0x00, 0x00, 0x05, 0x3c, | ||
} | ||
expected := &message.Message{ | ||
ID: message.MsgHave, | ||
Payload: []byte{0x00, 0x00, 0x05, 0x3c}, | ||
} | ||
_, err := serverConn.Write(msgBytes) | ||
require.Nil(t, err) | ||
|
||
msg, err := client.Read() | ||
assert.Equal(t, expected, msg) | ||
} | ||
|
||
func TestSendRequest(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
err := client.SendRequest(1, 2, 3) | ||
assert.Nil(t, err) | ||
expected := []byte{ | ||
0x00, 0x00, 0x00, 0x0d, | ||
6, | ||
0x00, 0x00, 0x00, 0x01, | ||
0x00, 0x00, 0x00, 0x02, | ||
0x00, 0x00, 0x00, 0x03, | ||
} | ||
buf := make([]byte, len(expected)) | ||
_, err = serverConn.Read(buf) | ||
assert.Nil(t, err) | ||
assert.Equal(t, expected, buf) | ||
} | ||
|
||
func TestSendInterested(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
err := client.SendInterested() | ||
assert.Nil(t, err) | ||
expected := []byte{ | ||
0x00, 0x00, 0x00, 0x01, | ||
2, | ||
} | ||
buf := make([]byte, len(expected)) | ||
_, err = serverConn.Read(buf) | ||
assert.Nil(t, err) | ||
assert.Equal(t, expected, buf) | ||
} | ||
|
||
func TestSendNotInterested(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
err := client.SendNotInterested() | ||
assert.Nil(t, err) | ||
expected := []byte{ | ||
0x00, 0x00, 0x00, 0x01, | ||
3, | ||
} | ||
buf := make([]byte, len(expected)) | ||
_, err = serverConn.Read(buf) | ||
assert.Nil(t, err) | ||
assert.Equal(t, expected, buf) | ||
} | ||
|
||
func TestSendUnchoke(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
err := client.SendUnchoke() | ||
assert.Nil(t, err) | ||
expected := []byte{ | ||
0x00, 0x00, 0x00, 0x01, | ||
1, | ||
} | ||
buf := make([]byte, len(expected)) | ||
_, err = serverConn.Read(buf) | ||
assert.Nil(t, err) | ||
assert.Equal(t, expected, buf) | ||
} | ||
|
||
func TestSendHave(t *testing.T) { | ||
clientConn, serverConn := createClientAndServer(t) | ||
client := Client{Conn: clientConn} | ||
err := client.SendHave(1340) | ||
assert.Nil(t, err) | ||
expected := []byte{ | ||
0x00, 0x00, 0x00, 0x05, | ||
4, | ||
0x00, 0x00, 0x05, 0x3c, | ||
} | ||
buf := make([]byte, len(expected)) | ||
_, err = serverConn.Read(buf) | ||
assert.Nil(t, err) | ||
assert.Equal(t, expected, buf) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package client | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"net" | ||
"time" | ||
|
||
"github.com/brianlusina/tclient/pkg/bitfield" | ||
"github.com/brianlusina/tclient/pkg/handshake" | ||
"github.com/brianlusina/tclient/pkg/message" | ||
) | ||
|
||
func recvBitfield(conn net.Conn) (bitfield.Bitfield, error) { | ||
conn.SetDeadline(time.Now().Add(5 * time.Second)) | ||
defer conn.SetDeadline(time.Time{}) // disable the deadline | ||
|
||
msg, err := message.Read(conn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if msg == nil { | ||
err := fmt.Errorf("Expected bitfield but got %s", msg) | ||
return nil, err | ||
} | ||
|
||
if msg.ID != message.MsgBitfield { | ||
err := fmt.Errorf("Expected bitfield but got id %d", msg.ID) | ||
return nil, err | ||
} | ||
|
||
return msg.Payload, nil | ||
} | ||
|
||
func completeHandshake(conn net.Conn, infohash, peerID [20]byte) (*handshake.Handshake, error) { | ||
conn.SetDeadline(time.Now().Add(3 * time.Second)) | ||
defer conn.SetDeadline(time.Time{}) // disable the deadline | ||
|
||
req := handshake.New(infohash, peerID) | ||
_, err := conn.Write(req.Serialize()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
res, err := handshake.Read(conn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if !bytes.Equal(res.InfoHash[:], infohash[:]) { | ||
return nil, fmt.Errorf("Expected infohash %x but gor %x", res.InfoHash, infohash) | ||
} | ||
|
||
return res, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package client | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/brianlusina/tclient/pkg/bitfield" | ||
"github.com/brianlusina/tclient/pkg/handshake" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestRecvBitfield(t *testing.T) { | ||
tests := map[string]struct { | ||
msg []byte | ||
output bitfield.Bitfield | ||
fails bool | ||
}{ | ||
"successful bitfield": { | ||
msg: []byte{0x00, 0x00, 0x00, 0x06, 5, 1, 2, 3, 4, 5}, | ||
output: bitfield.Bitfield{1, 2, 3, 4, 5}, | ||
fails: false, | ||
}, | ||
"message is not a bitfield": { | ||
msg: []byte{0x00, 0x00, 0x00, 0x06, 99, 1, 2, 3, 4, 5}, | ||
output: nil, | ||
fails: true, | ||
}, | ||
"message is keep-alive": { | ||
msg: []byte{0x00, 0x00, 0x00, 0x00}, | ||
output: nil, | ||
fails: true, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
clientConn, serverConn := createClientAndServer(t) | ||
serverConn.Write(test.msg) | ||
|
||
bf, err := recvBitfield(clientConn) | ||
|
||
if test.fails { | ||
assert.NotNil(t, err) | ||
} else { | ||
assert.Nil(t, err) | ||
assert.Equal(t, bf, test.output) | ||
} | ||
} | ||
} | ||
|
||
func TestCompleteHandshake(t *testing.T) { | ||
tests := map[string]struct { | ||
clientInfohash [20]byte | ||
clientPeerID [20]byte | ||
serverHandshake []byte | ||
output *handshake.Handshake | ||
fails bool | ||
}{ | ||
"successful handshake": { | ||
clientInfohash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, | ||
clientPeerID: [20]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, | ||
serverHandshake: []byte{19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 0, 0, 0, 134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116, 45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, | ||
output: &handshake.Handshake{ | ||
Pstr: "BitTorrent protocol", | ||
InfoHash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, | ||
PeerID: [20]byte{45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, | ||
}, | ||
fails: false, | ||
}, | ||
"wrong infohash": { | ||
clientInfohash: [20]byte{134, 212, 200, 0, 36, 164, 105, 190, 76, 80, 188, 90, 16, 44, 247, 23, 128, 49, 0, 116}, | ||
clientPeerID: [20]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, | ||
serverHandshake: []byte{19, 66, 105, 116, 84, 111, 114, 114, 101, 110, 116, 32, 112, 114, 111, 116, 111, 99, 111, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0xde, 0xe8, 0x6a, 0x7f, 0xa6, 0xf2, 0x86, 0xa9, 0xd7, 0x4c, 0x36, 0x20, 0x14, 0x61, 0x6a, 0x0f, 0xf5, 0xe4, 0x84, 0x3d, 45, 83, 89, 48, 48, 49, 48, 45, 192, 125, 147, 203, 136, 32, 59, 180, 253, 168, 193, 19}, | ||
output: nil, | ||
fails: true, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
clientConn, serverConn := createClientAndServer(t) | ||
serverConn.Write(test.serverHandshake) | ||
|
||
h, err := completeHandshake(clientConn, test.clientInfohash, test.clientPeerID) | ||
|
||
if test.fails { | ||
assert.NotNil(t, err) | ||
} else { | ||
assert.Nil(t, err) | ||
assert.Equal(t, h, test.output) | ||
} | ||
} | ||
} |