diff --git a/backend/backend.go b/backend/backend.go index 6f1d430..18d8655 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -1,7 +1,10 @@ package backend +import "sync" + type Backend struct { - url string + url string + healthy bool } func NewBackend(url string) *Backend { @@ -11,3 +14,24 @@ func NewBackend(url string) *Backend { func (b *Backend) GetUrl() string { return b.url } + +func (b *Backend) IsHealthy() bool { + return b.healthy +} + +type BackendPool struct { + bs []*Backend + mu sync.RWMutex +} + +func NewBackendPool(bs []*Backend) *BackendPool { + return &BackendPool{bs: bs} +} + +func (bp *BackendPool) GetPool() []*Backend { + bp.mu.RLock() + defer bp.mu.RUnlock() + bs := make([]*Backend, len(bp.bs)) + copy(bs, bp.bs) + return bs +} diff --git a/main.go b/main.go index 075a27a..d056418 100644 --- a/main.go +++ b/main.go @@ -2,18 +2,22 @@ package main import ( "fmt" - backend2 "load-balancer/backend" - listener2 "load-balancer/listener" - proxy2 "load-balancer/proxy" - router2 "load-balancer/router" + "load-balancer/backend" + "load-balancer/listener" + "load-balancer/proxy" + "load-balancer/router" + "load-balancer/router/roundrobin" ) func main() { port := 8080 host := "[::1]" - backend := backend2.NewBackend("localhost:80") - router := router2.NewRouter(host+fmt.Sprintf(":%d", port), backend) - proxy := proxy2.NewProxy(router) - listener := listener2.NewListener(proxy) - listener.Listen(int64(port)) + b := backend.NewBackend("localhost:80") + b1 := backend.NewBackend("localhost:8081") + bp := backend.NewBackendPool([]*backend.Backend{b, b1}) + algo := roundrobin.NewRoundRobin(bp) + r := router.NewRouter(host+fmt.Sprintf(":%d", port), algo) + p := proxy.NewProxy(r) + l := listener.NewListener(p) + l.Listen(int64(port)) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 524270e..a2b6131 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -3,12 +3,12 @@ package proxy import ( "fmt" "io" - "load-balancer/router" + "load-balancer/backend" "net" ) type RouterIO interface { - Route(string) router.BackendIO + Route(string) *backend.Backend } type Proxy struct { @@ -21,11 +21,12 @@ func NewProxy(rt RouterIO) *Proxy { func (p *Proxy) Handle(conn net.Conn) error { localAddr := conn.LocalAddr().String() - backend := p.router.Route(localAddr) - if backend == nil { + b := p.router.Route(localAddr) + if b == nil { return fmt.Errorf("no available backend") } - backendConn, err := net.Dial("tcp", backend.GetUrl()) + fmt.Println(b.GetUrl()) + backendConn, err := net.Dial("tcp", b.GetUrl()) if err != nil { return err } diff --git a/router/roundrobin/round_robin.go b/router/roundrobin/round_robin.go new file mode 100644 index 0000000..a4069ba --- /dev/null +++ b/router/roundrobin/round_robin.go @@ -0,0 +1,29 @@ +package roundrobin + +import ( + "load-balancer/backend" + "sync" +) + +type BackendPoolIO interface { + GetPool() []*backend.Backend +} + +type RoundRobin struct { + bp BackendPoolIO + index int + mu sync.Mutex +} + +func NewRoundRobin(bs BackendPoolIO) *RoundRobin { + return &RoundRobin{bp: bs} +} + +func (rr *RoundRobin) GetBackend() *backend.Backend { + rr.mu.Lock() + defer rr.mu.Unlock() + bp := rr.bp.GetPool() + b := bp[rr.index] + rr.index = (rr.index + 1) % len(bp) + return b +} diff --git a/router/roundrobin/round_robin_test.go b/router/roundrobin/round_robin_test.go new file mode 100644 index 0000000..2a9363c --- /dev/null +++ b/router/roundrobin/round_robin_test.go @@ -0,0 +1,63 @@ +package roundrobin + +import ( + "load-balancer/backend" + "sync" + "testing" +) + +type stubPool struct { + backends []*backend.Backend +} + +func (s *stubPool) GetPool() []*backend.Backend { + return s.backends +} + +func newPool(urls ...string) *stubPool { + bs := make([]*backend.Backend, len(urls)) + for i, u := range urls { + bs[i] = backend.NewBackend(u) + } + return &stubPool{backends: bs} +} + +func TestRoundRobinCyclesInOrder(t *testing.T) { + rr := NewRoundRobin(newPool("a", "b", "c")) + + want := []string{"a", "b", "c", "a", "b"} + for _, expected := range want { + got := rr.GetBackend().GetUrl() + if got != expected { + t.Errorf("got %q, want %q", got, expected) + } + } +} + +func TestRoundRobinWrapsAround(t *testing.T) { + pool := newPool("a", "b") + rr := NewRoundRobin(pool) + + for i := 0; i < len(pool.backends); i++ { + rr.GetBackend() + } + + got := rr.GetBackend().GetUrl() + if got != "a" { + t.Errorf("expected wrap-around to first backend, got %q", got) + } +} + +func TestRoundRobinConcurrentSafety(t *testing.T) { + rr := NewRoundRobin(newPool("a", "b", "c")) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rr.GetBackend() + }() + } + wg.Wait() +} diff --git a/router/router.go b/router/router.go index e8a0bf1..e74dcb9 100644 --- a/router/router.go +++ b/router/router.go @@ -1,19 +1,21 @@ package router -type BackendIO interface { - GetUrl() string +import "load-balancer/backend" + +type AlgoIO interface { + GetBackend() *backend.Backend } type Router struct { - router map[string]BackendIO + router map[string]AlgoIO } -func NewRouter(path string, be BackendIO) *Router { - router := make(map[string]BackendIO) +func NewRouter(path string, be AlgoIO) *Router { + router := make(map[string]AlgoIO) router[path] = be return &Router{router: router} } -func (r *Router) Route(path string) BackendIO { - return r.router[path] +func (r *Router) Route(path string) *backend.Backend { + return r.router[path].GetBackend() }