Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/bidirectional_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ package main
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
"time"

"github.com/google/gopacket"
"github.com/google/gopacket/tcpassembly"
"github.com/google/gopacket/tcpassembly/tcpreader"
"golang.org/x/sync/semaphore"
)

type bidirectionalStreamFactory struct {
Expand Down Expand Up @@ -57,17 +60,22 @@ type bidirectionalStream struct {

func (s *bidirectionalStream) run() {
defer s.closeCallback()
defer s.clientToServer.Close()
defer s.serverToClient.Close()

wg := sync.WaitGroup{}
wg.Add(2)
sem := semaphore.NewWeighted(2)

requestChannel := make(chan *http.Request, 1)
responseChannel := make(chan *http.Response, 1)
defer close(requestChannel)
defer close(responseChannel)

err := sem.Acquire(context.Background(), 1)
if err != nil {
slog.Error("Failed to acquire semaphore for clientToServer reader:", "Err", err.Error())
return
}
go func() {
defer wg.Done()
defer sem.Release(1)
defer close(requestChannel)
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered from panic in clientToServer reader:", "Err", r)
Expand All @@ -89,8 +97,14 @@ func (s *bidirectionalStream) run() {
requestChannel <- request
}()

err = sem.Acquire(context.Background(), 1)
if err != nil {
slog.Error("Failed to acquire semaphore for serverToClient reader:", "Err", err.Error())
return
}
go func() {
defer wg.Done()
defer sem.Release(1)
defer close(responseChannel)
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered from panic in serverToClient reader:", "Err", r)
Expand All @@ -110,7 +124,15 @@ func (s *bidirectionalStream) run() {
responseChannel <- response
}()

wg.Wait()
// Wait for both goroutines to finish with timeout of 2 minutes
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
if err := sem.Acquire(ctx, 2); err != nil {
if err != context.DeadlineExceeded {
slog.Error("Failed to acquire semaphore for both readers:", "Err", err.Error())
}
return
}

var capturedRequest *http.Request
var capturedResponse *http.Response
Expand Down
1 change: 1 addition & 0 deletions src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ require (
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/net v0.39.0 // indirect
golang.org/x/oauth2 v0.27.0 // indirect
golang.org/x/sync v0.14.0
golang.org/x/sys v0.32.0 // indirect
golang.org/x/term v0.31.0 // indirect
golang.org/x/text v0.24.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions src/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down