# ‚öñÔ∏è Load Balancing & API Gateway Design

**Phase 3: System Design - Core Infrastructure Patterns**

**Master scalable API architectures with functional programming patterns**

---

## üèóÔ∏è Distributed API Gateway Implementation

Scalable API routing with functional composition using Cats and Akka HTTP patterns

In [None]:
// Complete API Gateway implementation
import cats.effect._
import cats.syntax.all._
import cats.data._
import org.http4s._
import org.http4s.circe._
import org.http4s.client.Client
import scala.concurrent.duration._

// Domain models
case class ApiRequest(
  endpoint: String,
  method: String,
  payload: Map[String, String],
  headers: Map[String, String]
)

case class ServiceRoute(
  serviceName: String,
  baseUri: Uri,
  endpoints: List[String],
  loadBalanceStrategy: LoadBalanceStrategy
)

// Load balancing strategies
sealed trait LoadBalanceStrategy
case object RoundRobin extends LoadBalanceStrategy
case object LeastConnections extends LoadBalanceStrategy
case object WeightedRandom extends LoadBalanceStrategy

println("=== DISTRIBUTED API GATEWAY ARCHITECTURE ===")
println("Implementing functional API gateway with load balancing")
println("\nKey components:")
println("‚úì Distribution of API traffic across multiple instances")
println("‚úì Service discovery and registration")
println("‚úì Circuit breaker and health checks")
println("‚úì Rate limiting and throttling")
println("‚úì Request routing and authentication")
println()


In [None]:
// Advanced Load Balancer with Cats Effect
class FunctionalLoadBalancer[F[_]: Concurrent: Timer](
  services: Ref[F, Map[String, List[ServiceInstance]]],
  healthChecker: HealthChecker[F],
  metricsCollector: MetricsCollector[F]
) {
  
  private val roundRobinIndex = Ref.of[F, Map[String, Int]](Map.empty).unsafeRunSync()
  
  def selectInstance(serviceName: String, strategy: LoadBalanceStrategy): F[Option[ServiceInstance]] = {
    for {
      serviceInstances <- getHealthyInstances(serviceName)
      instance <- strategy match {
        case RoundRobin => selectRoundRobin(serviceName, serviceInstances)
        case LeastConnections => selectLeastConnections(serviceInstances)
        case WeightedRandom => selectWeightedRandom(serviceInstances)
      }
      _ <- metricsCollector.recordLoadBalance(serviceName, instance.map(_.uri.toString))
    } yield instance
  }
  
  private def getHealthyInstances(serviceName: String): F[List[ServiceInstance]] = {
    services.get.map(_.getOrElse(serviceName, Nil).filter(_.isHealthy))
  }
  
  private def selectRoundRobin(serviceName: String, instances: List[ServiceInstance]): F[Option[ServiceInstance]] = {
    if (instances.isEmpty) Option.empty.pure[F]
    else {
      roundRobinIndex.modify { currentMap =>
        val currentIndex = currentMap.getOrElse(serviceName, 0)
        val nextIndex = (currentIndex + 1) % instances.length
        val updatedMap = currentMap + (serviceName -> nextIndex)
        (updatedMap, Some(instances(currentIndex)))
      }
    }
  }
  
  private def selectLeastConnections(instances: List[ServiceInstance]): F[Option[ServiceInstance]] = {
    instances.minimumByOption(_.activeConnections).pure[F]
  }
  
  private def selectWeightedRandom(instances: List[ServiceInstance]): F[Option[ServiceInstance]] = {
    val totalWeight = instances.map(_.weight).sum
    if (totalWeight == 0) None.pure[F]
    else {
      val randomPoint = scala.util.Random.nextInt(totalWeight)
      instances.find { instance =>
        val cumulativeWeight = instances.takeWhile(_ != instance).map(_.weight).sum + instance.weight
        randomPoint < cumulativeWeight
      }.pure[F]
    }
  }
  
  def registerService(serviceName: String, instance: ServiceInstance): F[Unit] = {
    services.update { currentServices =>
      val currentInstances = currentServices.getOrElse(serviceName, Nil)
      currentServices + (serviceName -> (currentInstances :+ instance))
    }
  }
  
  def deregisterService(serviceName: String, instanceUri: Uri): F[Unit] = {
    services.update { currentServices =>
      val updatedInstances = currentServices.getOrElse(serviceName, Nil)
        .filterNot(_.uri == instanceUri)
      if (updatedInstances.isEmpty) currentServices - serviceName
      else currentServices + (serviceName -> updatedInstances)
    }
  }

  def startHealthMonitoring(): Stream[F, Unit] = {
    Stream.fixedRate[F](30.seconds).evalMap { _ =>
      services.get.flatMap { serviceMap =>
        serviceMap.toList.parUnorderedTraverse { case (serviceName, instances) =>
          instances.parUnorderedTraverse { instance =>
            healthChecker.checkHealth(instance).flatMap { isHealthy =>
              services.update { currentServices =>
                val updatedInstances = currentServices.getOrElse(serviceName, Nil)
                  .map(inst => if (inst.uri == instance.uri) inst.copy(isHealthy = isHealthy) else inst)
                currentServices + (serviceName -> updatedInstances)
              }
            }
          }
        }.void
      }
    }
  }
}

println("‚úÖ Load Balancer with advanced strategies implemented")
println("  ‚Ä¢ RoundRobin, LeastConnections, WeightedRandom")
println("  ‚Ä¢ Service registration/deregistration")
println("  ‚Ä¢ Health monitoring with circuit breaker")
println()


## üîê Advanced API Gateway with Middleware

Functional composition of concerns: authentication, rate limiting, logging

In [None]:
// Complete API Gateway with CORS, authentication, and monitoring
import cats.data.{Kleisli, OptionT}
import org.http4s.server.AuthMiddleware

// Middleware pipeline using ReaderT/Kleisli
type GatewayMiddleware[F[_], A] = Kleisli[OptionT[F, *], GatewayContext, A]

case class GatewayContext(
  request: Request[F],
  authToken: Option[AuthToken],
  rateLimitStatus: RateLimitStatus,
  startTime: Long = System.currentTimeMillis()
)

case class AuthToken(userId: String, roles: Set[String], expiresAt: Long)
case class RateLimitStatus(remainingRequests: Int, resetTime: Long)

class GatewayMiddlewareStack[F[_]: Concurrent: Timer](
  authService: AuthService[F],
  rateLimiter: RateLimiter[F],
  metrics: MetricsCollector[F],
  corsHandler: CORSHandler[F]
) {
  
  // Request timing middleware
  val timingMiddleware: HttpRoutes[F] => HttpRoutes[F] = routes => {
    Kleisli { req: Request[F] =>
      val startTime = System.currentTimeMillis()
      
      routes.run(req).map { response =>
        val duration = System.currentTimeMillis() - startTime
        println(f"Request ${req.uri} completed in ${duration}ms")
        response
      }.recoverWith { case error =>
        val duration = System.currentTimeMillis() - startTime
        println(f"Request ${req.uri} failed after ${duration}ms: $error")
        MonadError[F, Throwable].raiseError(error)
      }
    }
  }
  
  // Authentication middleware
  val authMiddleware: AuthMiddleware[F, AuthToken] = {
    val authUser: Kleisli[OptionT[F, *], Request[F], AuthToken] = Kleisli { req =>
      val tokenOpt = req.headers.get(header"Authorization")
        .flatMap(h => if (h.value.startsWith("Bearer ")) Some(h.value.drop(7)) else None)
      
      OptionT.liftF {
        tokenOpt match {
          case Some(token) => authService.validateToken(token)
          case None        => AuthError("Missing token").raiseError[F, AuthToken]
        }
      }
    }
    
    authUser
  }
  
  // Rate limiting middleware with sliding window
  val rateLimitMiddleware: HttpRoutes[F] => HttpRoutes[F] = routes => {
    Kleisli { req: Request[F] =>
      for {
        clientKey <- extractClientKey(req)
        rateLimit <- rateLimiter.checkLimit(clientKey)
        response <- if (rateLimit.canProceed) {
          routes.run(req).onError { case _ =>
            rateLimiter.recordFailure(clientKey).as(Unit)
          }
        } else {
          Response[F](Status.TooManyRequests)
            .withEntity(s"Rate limit exceeded. Try after ${rateLimit.resetTime} seconds")
            .pure[F]
        }
      } yield response
    }
  }
  
  // CORS middleware
  val corsMiddleware: HttpRoutes[F] => HttpRoutes[F] = routes => {
    Kleisli { req: Request[F] =>
      val allowHeaders = Set("Accept", "Authorization", "Content-Type", "Origin", "X-Requested-With")
      val allowMethods = Set(Method.GET, Method.POST, Method.PUT, Method.DELETE, Method.OPTIONS)
      
      if (req.method == Method.OPTIONS) {
        Response[F](Status.Ok)
          .putHeaders(
            header"Access-Control-Allow-Origin"("*"),
            header"Access-Control-Allow-Methods"(allowMethods.mkString(", ")),
            header"Access-Control-Allow-Headers"(allowHeaders.mkString(", "))
          )
          .pure[F]
      } else {
        routes.run(req).map { response =>
          response.putHeaders(
            header"Access-Control-Allow-Origin"("*")
          )
        }
      }
    }
  }
  
  // Metrics collection middleware
  val metricsMiddleware: HttpRoutes[F] => HttpRoutes[F] = routes => {
    Kleisli { req: Request[F] =>
      routes.run(req).attempt.flatMap {
        case Right(response) =>
          metrics.recordSuccess(req.uri.toString, response.status.code).as(response)
        case Left(error) =>
          metrics.recordError(req.uri.toString, error.getMessage).as(
            Response[F](Status.InternalServerError)
              .withEntity("Internal server error")
          )
      }
    }
  }
  
  // Compose all middleware into final pipeline
  def apply(routes: HttpRoutes[F]): HttpRoutes[F] = {
    timingMiddleware andThen
    corsMiddleware andThen 
    metricsMiddleware andThen
    rateLimitMiddleware andThen
    authMiddleware andThen
    routes
  }

  private def extractClientKey(req: Request[F]): F[String] = {
    // Extract client key from IP, user agent, or API key
    req.remoteAddr.getOrElse("unknown").pure[F]
  }
}

println("‚úÖ Advanced middleware stack implemented")
println("  ‚Ä¢ Authentication with JWT tokens")
println("  ‚Ä¢ Rate limiting with sliding windows")
println("  ‚Ä¢ CORS handling across domains")
println("  ‚Ä¢ Request/response timing and metrics")
println()


## üöÄ Microservices Communication Patterns

Synchronous and asynchronous communication with resilience patterns

In [None]:
// Circuit Breaker for resilient service communication
import scala.concurrent.duration._
import cats.effect.{Deferred, Ref}

sealed trait CircuitState
case object CircuitClosed extends CircuitState      // Normal operation
case object CircuitOpen extends CircuitState        // Failure - reject calls  
case object CircuitHalfOpen extends CircuitState    // Testing recovery

case class CircuitConfig(
  failureThreshold: Int = 5,
  successThreshold: Int = 3, 
  timeoutMs: Long = 5000L,
  retryTimeoutSeconds: Int = 60
)

case class CircuitStats(
  calls: Long = 0,
  failures: Long = 0,
  successes: Long = 0,
  consecutiveFailures: Long = 0,
  consecutiveSuccesses: Long = 0,
  lastFailure: Option[Long] = None,
  lastSuccess: Option[Long] = None
)

class CircuitBreaker[F[_]: Concurrent: Timer](
  config: CircuitConfig,
  serviceName: String
) {
  
  private val circuitState = Ref.of[F, CircuitState](CircuitClosed).unsafeRunSync()
  private val stats = Ref.of[F, CircuitStats](CircuitStats()).unsafeRunSync()
  private val openUntilTime = Ref.of[F, Long](0L).unsafeRunSync()

  def execute[A](operation: F[A]): F[Either[CircuitOpen, A]] = {
    circuitState.get.flatMap {
      case CircuitOpen => 
        openUntilTime.get.flatMap { openTime =>
          val now = System.currentTimeMillis()
          if (now >= openTime) {
            // Try to close circuit (move to half-open)
            circuitState.set(CircuitHalfOpen) *>
            attemptOperation(operation)
          } else {
            CircuitOpen(serviceName).asLeft[A].pure[F]
          }
        }
        
      case CircuitHalfOpen =>
        attemptOperation(operation)
        
      case CircuitClosed =>
        attemptOperation(operation)
    }
  }
  
  private def attemptOperation[A](operation: F[A]): F[Either[CircuitOpen, A]] = {
    stats.update(s => s.copy(calls = s.calls + 1)) *>
    
    Concurrent.timeout(operation, config.timeoutMs.millis).attempt.flatMap {
      case Right(result) => 
        recordSuccess *> result.asRight[CircuitOpen].pure[F]
        
      case Left(error) => 
        recordFailure(error) *> CircuitOpen(serviceName).asLeft[A].pure[F]
    }
  }
  
  private def recordSuccess: F[Unit] = {
    for {
      _ <- stats.update(s => s.copy(
        successes = s.successes + 1,
        consecutiveSuccesses = s.consecutiveSuccesses + 1,
        consecutiveFailures = 0,
        lastSuccess = Some(System.currentTimeMillis())
      ))
      currentSuccesses <- stats.get.map(_.consecutiveSuccesses)
      _ <- if (currentSuccesses >= config.successThreshold) {
        circuitState.set(CircuitClosed) *>
        stats.update(_.copy(consecutiveSuccesses = 0))
      } else unit
    } yield ()
  }
  
  private def recordFailure(error: Throwable): F[Unit] = {
    for {
      updatedStats <- stats.updateAndGet(s => s.copy(
        failures = s.failures + 1,
        consecutiveFailures = s.consecutiveFailures + 1,
        consecutiveSuccesses = 0,
        lastFailure = Some(System.currentTimeMillis())
      ))
      _ <- if (updatedStats.consecutiveFailures >= config.failureThreshold) {
        circuitState.set(CircuitOpen) *>
        val openUntil = System.currentTimeMillis() + (config.retryTimeoutSeconds * 1000L)
        openUntilTime.set(openUntil)
      } else if (circuitState.get.map(_ == CircuitHalfOpen).unsafeRunSync()) {
        circuitState.set(CircuitOpen)
      } else unit
    } yield ()
  }

  def getState: F[CircuitState] = circuitState.get
  def getStats: F[CircuitStats] = stats.get
  
  def reset(): F[Unit] = {
    circuitState.set(CircuitClosed) *>
    stats.set(CircuitStats()) *>
    openUntilTime.set(0L)
  }
}

case class CircuitOpen(serviceName: String) extends Exception(s"Circuit breaker open for service: $serviceName")

println("‚úÖ Circuit Breaker with state machine implemented")
println("  ‚Ä¢ Closed ‚Üí Open ‚Üí Half-Open ‚Üí Closed transitions")
println("  ‚Ä¢ Consecutive failure/success tracking")
println("  ‚Ä¢ Configurable thresholds and timeouts")
println()


## üåä Async Communication with Event Streams

Event-driven microservices using FS2 streams and Kafka patterns

In [None]:
// Event-driven async communication using fs2 streams
import fs2.{Pipe, Stream, Pull}
import cats.effect.concurrent.Ref
import scala.concurrent.duration._

// Domain events for async communication
sealed trait ServiceEvent
case class UserCreated(userId: String, email: String, timestamp: Long) extends ServiceEvent
case class OrderPlaced(orderId: String, userId: String, total: BigDecimal) extends ServiceEvent
case class PaymentProcessed(paymentId: String, orderId: String, status: String) extends ServiceEvent
case class InventoryUpdated(productId: String, quantityChange: Int) extends ServiceEvent

case class EventEnvelope(
  event: ServiceEvent,
  offset: Long,
  partition: Int,
  timestamp: Long
)

// Stream processing pipeline
class EventProcessor[F[_]: Concurrent: Timer](
  eventStream: Stream[F, EventEnvelope],
  userService: UserService[F],
  orderService: OrderService[F],
  paymentService: PaymentService[F],
  inventoryService: InventoryService[F]
) {
  
  // Event routing based on type
  val eventRouter: Pipe[F, EventEnvelope, ServiceEvent] = _.map(_.event).flatMap {
    case userEvent: UserCreated => userEvent
    case orderEvent: OrderPlaced => orderEvent  
    case paymentEvent: PaymentProcessed => paymentEvent
    case inventoryEvent: InventoryUpdated => inventoryEvent
  }
  
  // User aggregate processor
  val userProcessor: Pipe[F, ServiceEvent, Unit] = _.flatMap {
    case UserCreated(userId, email, _) => 
      Stream.eval(userService.createUser(User(userId, email)))
    case OrderPlaced(_, userId, _) =>
      Stream.eval(userService.updateOrderHistory(userId))
    case _ => Stream.empty
  }
  
  // Order aggregate processor  
  val orderProcessor: Pipe[F, ServiceEvent, Unit] = _.flatMap {
    case OrderPlaced(orderId, userId, total) =>
      Stream.eval(orderService.createOrder(Order(orderId, userId, total)))
    case PaymentProcessed(_, orderId, "COMPLETED") =>
      Stream.eval(orderService.markOrderPaid(orderId))
    case PaymentProcessed(_, orderId, "FAILED") =>
      Stream.eval(orderService.cancelOrder(orderId))
    case _ => Stream.empty
  }
  
  // Sagas for distributed transactions
  val orderSagaProcessor: Pipe[F, ServiceEvent, Unit] = {
    def sagaCoordinator: Pipe[F, ServiceEvent, Unit] = { events =>
      Pull.eval(Ref.of[F, Map[String, SagaState]](Map.empty)).flatMap { sagaStates =>
        events.pull.uncons1.flatMap {
          case Some((event, remainder)) => 
            event match {
              case OrderPlaced(orderId, _, _) =>
                val initialState = SagaState(orderId, WaitingForPayment)
                Pull.eval(sagaStates.update(_ + (orderId -> initialState))) >>
                Pull.done
              
              case PaymentProcessed(_, orderId, status) =>
                Pull.eval {
                  sagaStates.get.flatMap { states =>
                    states.get(orderId).fold(unit[F]) { sagaState =>
                      status match {
                        case "COMPLETED" =>
                          sagaStates.update(_ + (orderId -> sagaState.copy(phase = Completed))) *>
                          inventoryService.reserveForOrder(orderId) // Compensatable
                        case "FAILED" =>
                          sagaStates.update(_ - orderId) *>
                          orderService.cancelOrder(orderId) // Compensation
                        case _ => unit[F]
                      }
                    }
                  }
                } >> Pull.done
              
              case _ => Pull.done
            }
          case None => Pull.done
        }
      }.stream
    }
    
    sagaCoordinator
  }
  
  // Compose all processors
  val processingPipeline: Stream[F, Unit] = {
    eventStream
      .through(eventRouter)
      .observe(userProcessor)
      .observe(orderProcessor) 
      .through(orderSagaProcessor)
      .drain // Convert to Unit stream
  }
  
  def start(): Stream[F, Unit] = {
    Stream.eval(println("üöÄ Starting event processing pipeline")) *>
    processingPipeline.handleErrorWith { error =>
      // Apply retry/backoff strategy
      Stream.eval(println(s"‚ùå Processing error: $error")) *>
      Stream.sleep_(2.seconds) *>
      processingPipeline
    }
  }
}

// Saga states for distributed transactions
sealed trait SagaPhase
case object WaitingForPayment extends SagaPhase
case object PaymentReceived extends SagaPhase  
case object InventoryReserved extends SagaPhase
case object Completed extends SagaPhase

case class SagaState(orderId: String, phase: SagaPhase)

println("‚úÖ Event-driven architecture implemented")
println("  ‚Ä¢ Async communication with fs2 streams")
println("  ‚Ä¢ Event sourcing patterns")
println("  ‚Ä¢ Saga coordination for transactions")
println("  ‚Ä¢ CQRS with separate read/write models")
println()


## üìä Performance & Metrics

**Measuring and optimizing API Gateway performance**

In [None]:
// Comprehensive metrics and monitoring
case class APIMetrics(
  requestCount: Long = 0,
  errorCount: Long = 0,
  averageResponseTime: Double = 0.0,
  p95ResponseTime: Double = 0.0,
  p99ResponseTime: Double = 0.0,
  throughput: Double = 0.0, // requests per second
  activeConnections: Long = 0,
  queueSize: Int = 0
)

class MetricsAggregator[F[_]: Concurrent](
  metrics: Ref[F, Map[String, APIMetrics]],
  requestTimes: Ref[F, List[Long]]
) {

  def recordRequest(serviceName: String, responseTime: Long): F[Unit] = {
    val requestTimeMs = responseTime.toDouble / 1e6 // Convert nanos to millis
    
    for {
      _ <- updateMetrics(serviceName, requestTimeMs)
      _ <- requestTimes.update(_ :+ requestTimeMs.toLong)
      _ <- maintainRequestTimesWindow
    } yield ()
  }

  def recordError(serviceName: String): F[Unit] = {
    metrics.update { current =>
      val currentMetrics = current.getOrElse(serviceName, APIMetrics())
      current + (serviceName -> currentMetrics.copy(errorCount = currentMetrics.errorCount + 1))
    }
  }

  def getMetrics(serviceName: String): F[APIMetrics] = {
    metrics.get.map(_.getOrElse(serviceName, APIMetrics()))
  }

  def getAllMetrics: F[Map[String, APIMetrics]] = metrics.get

  def healthCheck: F[ServiceHealth] = {
    metrics.get.map { metricsMap =>
      val totalRequests = metricsMap.values.map(_.requestCount).sum
      val totalErrors = metricsMap.values.map(_.errorCount).sum
      
      if (totalRequests > 0) {
        val errorRate = totalErrors.toDouble / totalRequests
        if (errorRate < 0.05) ServiceHealthy else ServiceDegraded
      } else {
        ServiceHealthy
      }
    }
  }

  private def updateMetrics(serviceName: String, responseTime: Double): F[Unit] = {
    metrics.update { current =>
      val currentMetrics = current.getOrElse(serviceName, APIMetrics())
      val newRequestCount = currentMetrics.requestCount + 1
      val newAvgTime = ((currentMetrics.averageResponseTime * currentMetrics.requestCount) + responseTime) / newRequestCount
      
      current + (serviceName -> currentMetrics.copy(
        requestCount = newRequestCount,
        averageResponseTime = newAvgTime,
        throughput = calculateThroughput(newRequestCount)
      ))
    }
  }
  
  private def calculatePercentile(times: List[Long], percentile: Double): Long = {
    if (times.isEmpty) 0L
    else {
      val sorted = times.sorted
      val index = ((sorted.length - 1) * percentile).toInt
      sorted(index)
    }
  }

  private def maintainRequestTimesWindow: F[Unit] = {
    // Keep last 1000 request times for percentile calculations
    requestTimes.update { times =>
      if (times.length > 1000) times.takeRight(1000) else times
    }
  }
  
  private def calculateThroughput(requestCount: Long): Double = {
    // Simplified throughput calculation - requests per second
    // In real implementation, use sliding time windows
    requestCount.toDouble / System.currentTimeMillis() * 1000
  }
}

sealed trait ServiceHealth
case object ServiceHealthy extends ServiceHealth
case object ServiceDegraded extends ServiceHealth

println("‚úÖ Comprehensive metrics system implemented")
println("  ‚Ä¢ Response time percentiles (P95, P99)")
println("  ‚Ä¢ Throughput and request rates")
println("  ‚Ä¢ Error rate monitoring")
println("  ‚Ä¢ Health check integration")
println()
