diff --git a/SingularityClusterCoordinator/src/main/java/com/hubspot/singularity/proxy/ProxyResource.java b/SingularityClusterCoordinator/src/main/java/com/hubspot/singularity/proxy/ProxyResource.java index 4fe143485d..0ed04e84b1 100644 --- a/SingularityClusterCoordinator/src/main/java/com/hubspot/singularity/proxy/ProxyResource.java +++ b/SingularityClusterCoordinator/src/main/java/com/hubspot/singularity/proxy/ProxyResource.java @@ -1,7 +1,7 @@ package com.hubspot.singularity.proxy; +import java.util.ArrayList; import java.util.Enumeration; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -70,8 +70,8 @@ public Response getMergedListResult(HttpServletRequest request) { } public Response getMergedListResult(HttpServletRequest request, T body) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); // TODO - parallelize List combined = Lists.newArrayList(); @@ -111,8 +111,8 @@ public Response routeByRequestId(HttpServletRequest request, String requestId) { } public Response routeByRequestId(HttpServletRequest request, String requestId, T body) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenterForRequest(requestId); @@ -123,8 +123,8 @@ public Response routeByRequestId(HttpServletRequest request, String requestI * Route a request to a particular dataCenter using the request group Id to locate the correct Singularity cluster */ Response routeByRequestGroupId(HttpServletRequest request, String requestGroupId) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenterForRequestGroup(requestGroupId); @@ -135,8 +135,8 @@ Response routeByRequestGroupId(HttpServletRequest request, String requestGroupId * Route a request to a particular dataCenter using the slaveId/hostname to locate the correct Singularity cluster */ Response routeBySlaveId(HttpServletRequest request, String slaveId) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenterForSlaveId(slaveId); @@ -144,8 +144,8 @@ Response routeBySlaveId(HttpServletRequest request, String slaveId) { } Response routeByHostname(HttpServletRequest request, String hostname) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenterForSlaveHostname(hostname); @@ -156,8 +156,8 @@ Response routeByHostname(HttpServletRequest request, String hostname) { * Route a request to a particular dataCenter using the rack ID to locate the correct Singularity cluster */ Response routeByRackId(HttpServletRequest request, String rackId) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenterForRackId(rackId); @@ -168,8 +168,8 @@ Response routeByRackId(HttpServletRequest request, String rackId) { * Route a request to a particular dataCenter by name, failing if it is not present */ Response routeByDataCenter(HttpServletRequest request, String dataCenterName, T body) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = getDataCenter(dataCenterName); @@ -184,8 +184,8 @@ Response routeToDefaultDataCenter(HttpServletRequest request) { } Response routeToDefaultDataCenter(HttpServletRequest request, T body) { - Map headers = getHeaders(request); - Map params = getParams(request); + List headers = getHeaders(request); + List params = getParams(request); DataCenter dataCenter = configuration.getDataCenters().get(0); @@ -226,7 +226,7 @@ private DataCenter getDataCenter(String name) { /* * Generic methods for proxying requests */ - private HttpResponse proxyAndGetResponse(DataCenter dc, HttpServletRequest request, T body, Map headers, Map params) { + private HttpResponse proxyAndGetResponse(DataCenter dc, HttpServletRequest request, T body, List headers, List params) { String fullPath = request.getContextPath() + request.getPathInfo(); String url = String.format("%s://%s%s", dc.getScheme(), getHost(dc), fullPath.replace(contextPath, dc.getContextPath())); @@ -243,8 +243,8 @@ private HttpResponse proxyAndGetResponse(DataCenter dc, HttpServletRequest r LOG.error("Could not write body from object {}", body); throw new WebApplicationException(jpe, 500); } - headers.forEach(requestBuilder::addHeader); - params.forEach((k, v) -> requestBuilder.setQueryParam(k).to(v)); + headers.forEach((h) -> requestBuilder.addHeader(h.getKey(), h.getValue())); + params.forEach((h) -> requestBuilder.setQueryParam(h.getKey()).to(h.getValue())); try { return httpClient.execute(requestBuilder.build()).get(); @@ -253,7 +253,7 @@ private HttpResponse proxyAndGetResponse(DataCenter dc, HttpServletRequest r } } - private T proxyAndGetResponseAs(DataCenter dc, HttpServletRequest request, Q body, TypeReference clazz, Map headers, Map params) { + private T proxyAndGetResponseAs(DataCenter dc, HttpServletRequest request, Q body, TypeReference clazz, List headers, List params) { HttpResponse response = proxyAndGetResponse(dc, request, body, headers, params); if (response.getStatusCode() > 399) { throw new WebApplicationException(response.getAsString(), response.getStatusCode()); @@ -269,29 +269,27 @@ private T proxyAndGetResponseAs(DataCenter dc, HttpServletRequest request } } - private Map getHeaders(HttpServletRequest request) { - Map headers = new HashMap<>(); + private List getHeaders(HttpServletRequest request) { + List headers = new ArrayList<>(); Enumeration headerNames = request.getHeaderNames(); if (headerNames != null) { while (headerNames.hasMoreElements()) { String headerName = headerNames.nextElement(); - headers.put(headerName, request.getHeader(headerName)); + headers.add(new Param(headerName, request.getHeader(headerName))); } } LOG.trace("Found headers: {}", headers); return headers; } - private Map getParams(HttpServletRequest request) { - Map params = new HashMap<>(); - Enumeration parameterNames = request.getParameterNames(); - if (parameterNames != null) { - while (parameterNames.hasMoreElements()) { - String parameterName = parameterNames.nextElement(); - params.put(parameterName, request.getParameter(parameterName)); + private List getParams(HttpServletRequest request) { + List params = new ArrayList<>(); + for (Map.Entry entry : request.getParameterMap().entrySet()) { + for (String value : entry.getValue()) { + params.add(new Param(entry.getKey(), value)); } } - LOG.trace("Found query params: {}", params); + LOG.trace("Found params {}", params); return params; } @@ -301,4 +299,22 @@ private Response toResponse(HttpResponse original) { original.getHeaders().forEach((h) -> builder.header(h.getName(), h.getValue())); return builder.build(); } + + private class Param { + private final String key; + private final String value; + + Param(String key, String value) { + this.key = key; + this.value = value; + } + + String getKey() { + return key; + } + + String getValue() { + return value; + } + } }