Skip to content

Commit

Permalink
Fix race condition in test
Browse files Browse the repository at this point in the history
  • Loading branch information
mirokuratczyk committed May 12, 2023
1 parent 45f4594 commit b2c25ec
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
16 changes: 9 additions & 7 deletions psiphon/common/transforms/httpNormalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,15 @@ func runHTTPNormalizerTest(tt *httpNormalizerTest, useNormalizer bool) error {
}
}

// Calling Read on an instance of HTTPNormalizer will return io.EOF once a
// passthrough has been activated.
if tt.validateMeekCookie != nil && err == io.EOF {

// wait for passthrough to complete

timeout := time.After(time.Second)

for len(passthroughConn.readBuffer) != 0 || len(conn.readBuffer) != 0 {
for len(passthroughConn.ReadBuffer()) != 0 || len(conn.ReadBuffer()) != 0 {

select {
case <-timeout:
Expand Down Expand Up @@ -149,20 +151,20 @@ func runHTTPNormalizerTest(tt *httpNormalizerTest, useNormalizer bool) error {
return errors.TraceNew("expected to read no bytes")
}

if string(passthroughConn.readBuffer) != "" {
if string(passthroughConn.ReadBuffer()) != "" {
return errors.TraceNew("expected read buffer to be emptied")
}

if string(passthroughConn.writeBuffer) != tt.wantOutput {
return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(passthroughConn.writeBuffer)), len(passthroughConn.writeBuffer))
if string(passthroughConn.WriteBuffer()) != tt.wantOutput {
return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(passthroughConn.WriteBuffer())), len(passthroughConn.WriteBuffer()))
}

if string(conn.readBuffer) != "" {
if string(conn.ReadBuffer()) != "" {
return errors.TraceNew("expected read buffer to be emptied")
}

if string(conn.writeBuffer) != passthroughMessage {
return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(passthroughMessage), len(passthroughMessage), escapeNewlines(string(conn.writeBuffer)), len(conn.writeBuffer))
if string(conn.WriteBuffer()) != passthroughMessage {
return errors.Tracef("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(passthroughMessage), len(passthroughMessage), escapeNewlines(string(conn.WriteBuffer())), len(conn.WriteBuffer()))
}
}

Expand Down
49 changes: 42 additions & 7 deletions psiphon/common/transforms/httpTransformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"net"
"net/http"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -271,8 +272,8 @@ func TestHTTPTransformerHTTPRequest(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error %v", err)
}
if string(conn.writeBuffer) != tt.wantOutput {
t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.writeBuffer)), len(conn.writeBuffer))
if string(conn.WriteBuffer()) != tt.wantOutput {
t.Fatalf("expected \"%s\" of len %d but got \"%s\" of len %d", escapeNewlines(tt.wantOutput), len(tt.wantOutput), escapeNewlines(string(conn.WriteBuffer())), len(conn.WriteBuffer()))
}
} else {
// tt.wantError != nil
Expand Down Expand Up @@ -461,10 +462,16 @@ func escapeNewlines(s string) string {
}

type testConn struct {
// writeBuffer are the accumulated bytes from Write() calls.
writeBuffer []byte
readLock sync.Mutex
// readBuffer are the bytes to return from Read() calls.
readBuffer []byte
// readErrs are returned from Read() calls in order. If empty, then a nil
// error is returned.
readErrs []error

writeLock sync.Mutex
// writeBuffer are the accumulated bytes from Write() calls.
writeBuffer []byte
// writeLimit is the max number of bytes that will be written in a Write()
// call.
writeLimit int
Expand All @@ -475,15 +482,28 @@ type testConn struct {
// writeErrs are returned from Write() calls in order. If empty, then a nil
// error is returned.
writeErrs []error
// readErrs are returned from Read() calls in order. If empty, then a nil
// error is returned.
readErrs []error

net.Conn
}

// ReadBuffer returns a copy of the underlying readBuffer. The length of the
// returned buffer is also the number of bytes remaining to be Read when Conn
// is not set.
func (c *testConn) ReadBuffer() []byte {
c.readLock.Lock()
defer c.readLock.Unlock()

readBufferCopy := make([]byte, len(c.readBuffer))
copy(readBufferCopy, c.readBuffer)

return readBufferCopy
}

func (c *testConn) Read(b []byte) (n int, err error) {

c.readLock.Lock()
defer c.readLock.Unlock()

if len(c.readErrs) > 0 {
err = c.readErrs[0]
c.readErrs = c.readErrs[1:]
Expand All @@ -509,8 +529,23 @@ func (c *testConn) Read(b []byte) (n int, err error) {
return
}

// WriteBuffer returns a copy of the underlying writeBuffer, which is the
// accumulation of all bytes written with Write.
func (c *testConn) WriteBuffer() []byte {
c.readLock.Lock()
defer c.readLock.Unlock()

writeBufferCopy := make([]byte, len(c.writeBuffer))
copy(writeBufferCopy, c.writeBuffer)

return writeBufferCopy
}

func (c *testConn) Write(b []byte) (n int, err error) {

c.writeLock.Lock()
defer c.writeLock.Unlock()

if len(c.writeErrs) > 0 {
err = c.writeErrs[0]
c.writeErrs = c.writeErrs[1:]
Expand Down

0 comments on commit b2c25ec

Please sign in to comment.