diff --git a/src/bidirectional_stream.go b/src/bidirectional_stream.go index 5f58994..e2e2aa0 100644 --- a/src/bidirectional_stream.go +++ b/src/bidirectional_stream.go @@ -3,15 +3,18 @@ package main import ( "bufio" "bytes" + "context" "fmt" "io" "log/slog" "net/http" "sync" + "time" "github.com/google/gopacket" "github.com/google/gopacket/tcpassembly" "github.com/google/gopacket/tcpassembly/tcpreader" + "golang.org/x/sync/semaphore" ) type bidirectionalStreamFactory struct { @@ -57,17 +60,22 @@ type bidirectionalStream struct { func (s *bidirectionalStream) run() { defer s.closeCallback() + defer s.clientToServer.Close() + defer s.serverToClient.Close() - wg := sync.WaitGroup{} - wg.Add(2) + sem := semaphore.NewWeighted(2) requestChannel := make(chan *http.Request, 1) responseChannel := make(chan *http.Response, 1) - defer close(requestChannel) - defer close(responseChannel) + err := sem.Acquire(context.Background(), 1) + if err != nil { + slog.Error("Failed to acquire semaphore for clientToServer reader:", "Err", err.Error()) + return + } go func() { - defer wg.Done() + defer sem.Release(1) + defer close(requestChannel) defer func() { if r := recover(); r != nil { slog.Error("Recovered from panic in clientToServer reader:", "Err", r) @@ -89,8 +97,14 @@ func (s *bidirectionalStream) run() { requestChannel <- request }() + err = sem.Acquire(context.Background(), 1) + if err != nil { + slog.Error("Failed to acquire semaphore for serverToClient reader:", "Err", err.Error()) + return + } go func() { - defer wg.Done() + defer sem.Release(1) + defer close(responseChannel) defer func() { if r := recover(); r != nil { slog.Error("Recovered from panic in serverToClient reader:", "Err", r) @@ -110,7 +124,15 @@ func (s *bidirectionalStream) run() { responseChannel <- response }() - wg.Wait() + // Wait for both goroutines to finish with timeout of 2 minutes + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + if err := sem.Acquire(ctx, 2); err != nil { + if err != context.DeadlineExceeded { + slog.Error("Failed to acquire semaphore for both readers:", "Err", err.Error()) + } + return + } var capturedRequest *http.Request var capturedResponse *http.Response diff --git a/src/go.mod b/src/go.mod index 0aa4594..64a55d9 100644 --- a/src/go.mod +++ b/src/go.mod @@ -37,6 +37,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect golang.org/x/net v0.39.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect + golang.org/x/sync v0.14.0 golang.org/x/sys v0.32.0 // indirect golang.org/x/term v0.31.0 // indirect golang.org/x/text v0.24.0 // indirect diff --git a/src/go.sum b/src/go.sum index e0e25f1..86949ce 100644 --- a/src/go.sum +++ b/src/go.sum @@ -117,6 +117,8 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=