Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support VPC Lattice as an event source #845

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions aws-serverless-java-container-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
<version>1.2.3</version>
</dependency>

<dependency>
<groupId>org.jetbrains</groupId>
<artifactId>annotations</artifactId>
<version>24.0.1</version>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>jakarta.servlet</groupId>
<artifactId>jakarta.servlet-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.amazonaws.serverless.proxy;

import com.amazonaws.serverless.proxy.internal.jaxrs.AwsVpcLatticeV2SecurityContext;
import com.amazonaws.serverless.proxy.model.VPCLatticeV2RequestEvent;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.ws.rs.core.SecurityContext;

public class AwsVPCLatticeV2SecurityContextWriter implements SecurityContextWriter<VPCLatticeV2RequestEvent>{
@Override
public SecurityContext writeSecurityContext(VPCLatticeV2RequestEvent event, Context lambdaContext) {
return new AwsVpcLatticeV2SecurityContext(lambdaContext, event);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public abstract class RequestReader<RequestType, ContainerRequestType> {
*/
public static final String API_GATEWAY_CONTEXT_PROPERTY = "com.amazonaws.apigateway.request.context";

/**
* The key for the <strong>VPC Lattice V2 context</strong> property in the PropertiesDelegate object
*/
public static final String VPC_LATTICE_V2_CONTEXT_PROPERTY = "com.amazonaws.vpclattice.request.context";

/**
* The key for the <strong>API Gateway stage variables</strong> property in the PropertiesDelegate object
*/
Expand All @@ -55,6 +60,11 @@ public abstract class RequestReader<RequestType, ContainerRequestType> {
*/
public static final String API_GATEWAY_EVENT_PROPERTY = "com.amazonaws.apigateway.request";

/**
* The key to store the entire VPC Lattice V2 event
*/
public static final String VPC_LATTICE_V2_EVENT_PROPERTY = "com.amazonaws.vpclattice.request";

/**
* The key for the <strong>AWS Lambda context</strong> property in the PropertiesDelegate object
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.amazonaws.serverless.proxy.internal.jaxrs;

import com.amazonaws.serverless.proxy.model.VPCLatticeV2RequestEvent;
import com.amazonaws.services.lambda.runtime.Context;
import jakarta.ws.rs.core.SecurityContext;

import java.security.Principal;
import java.util.Objects;

/**
* default implementation of the <code>SecurityContext</code> object. This class supports 1 VPC Lattice authentication type:
* AWS_IAM.
*/
public class AwsVpcLatticeV2SecurityContext implements SecurityContext {

static final String AUTH_SCHEME_AWS_IAM = "AWS_IAM";


private final VPCLatticeV2RequestEvent event;

public AwsVpcLatticeV2SecurityContext(Context lambdaContext, VPCLatticeV2RequestEvent event) {
this.event = event;
}

//-------------------------------------------------------------
// Implementation - SecurityContext
//-------------------------------------------------------------
@Override
public Principal getUserPrincipal() {
if (Objects.equals(getAuthenticationScheme(), AUTH_SCHEME_AWS_IAM)) {
return () -> getEvent().getRequestContext().getIdentity().getPrincipal();
}
return null;
}

private VPCLatticeV2RequestEvent getEvent() {
return event;
}


@Override
public boolean isUserInRole(String role) {
return role.equals(event.getRequestContext().getIdentity().getPrincipal());
}

@Override
public boolean isSecure() {
return getAuthenticationScheme() != null;
}

@Override
public String getAuthenticationScheme() {
if (Objects.equals(getEvent().getRequestContext().getIdentity().getType(), AUTH_SCHEME_AWS_IAM)) {
return AUTH_SCHEME_AWS_IAM;
} else {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class AwsHttpApiV2ProxyHttpServletRequest extends AwsHttpServletRequest {
private MultiValuedTreeMap<String, String> queryString;
private Headers headers;
private ContainerConfig config;
private SecurityContext securityContext;
private AwsAsyncContext asyncContext;

/**
Expand All @@ -57,10 +56,9 @@ public class AwsHttpApiV2ProxyHttpServletRequest extends AwsHttpServletRequest {
* @param lambdaContext The Lambda function context. This object is used for utility methods such as log
*/
public AwsHttpApiV2ProxyHttpServletRequest(HttpApiV2ProxyRequest req, Context lambdaContext, SecurityContext sc, ContainerConfig cfg) {
super(lambdaContext);
super(lambdaContext, sc);
request = req;
config = cfg;
securityContext = sc;
queryString = parseRawQueryString(request.getRawQueryString());
headers = headersMapToMultiValue(request.getHeaders());
}
Expand All @@ -69,12 +67,6 @@ public HttpApiV2ProxyRequest getRequest() {
return request;
}

@Override
public String getAuthType() {
// TODO
return null;
}

@Override
public Cookie[] getCookies() {
Cookie[] rhc;
Expand Down Expand Up @@ -108,56 +100,27 @@ public Cookie[] getCookies() {

@Override
public long getDateHeader(String s) {
if (headers == null) {
return -1L;
}
String dateString = headers.getFirst(s);
if (dateString == null) {
return -1L;
}
try {
return Instant.from(ZonedDateTime.parse(dateString, dateFormatter)).toEpochMilli();
} catch (DateTimeParseException e) {
log.warn("Invalid date header in request: " + SecurityUtils.crlf(dateString));
return -1L;
}
return getDateHeader(s, headers);
}

@Override
public String getHeader(String s) {
if (headers == null) {
return null;
}
return headers.getFirst(s);
return getHeader(s, headers);
}

@Override
public Enumeration<String> getHeaders(String s) {
if (headers == null || !headers.containsKey(s)) {
return Collections.emptyEnumeration();
}
return Collections.enumeration(headers.get(s));
return getHeaders(s, headers);
}

@Override
public Enumeration<String> getHeaderNames() {
if (headers == null) {
return Collections.emptyEnumeration();
}
return Collections.enumeration(headers.keySet());
return getHeaderNames(headers);
}

@Override
public int getIntHeader(String s) {
if (headers == null) {
return -1;
}
String headerValue = headers.getFirst(s);
if (headerValue == null || "".equals(headerValue)) {
return -1;
}

return Integer.parseInt(headerValue);
return getIntHeader(s, headers);
}

@Override
Expand Down Expand Up @@ -187,28 +150,6 @@ public String getQueryString() {
return request.getRawQueryString();
}

@Override
public String getRemoteUser() {
if (securityContext == null || securityContext.getUserPrincipal() == null) {
return null;
}
return securityContext.getUserPrincipal().getName();
}

@Override
public boolean isUserInRole(String s) {
// TODO: Not supported
return false;
}

@Override
public Principal getUserPrincipal() {
if (securityContext == null) {
return null;
}
return securityContext.getUserPrincipal();
}

@Override
public String getRequestURI() {
return cleanUri(getContextPath()) + cleanUri(request.getRawPath());
Expand All @@ -219,27 +160,6 @@ public StringBuffer getRequestURL() {
return generateRequestURL(request.getRawPath());
}


@Override
public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException {
throw new UnsupportedOperationException();
}

@Override
public void login(String s, String s1) throws ServletException {
throw new UnsupportedOperationException();
}

@Override
public void logout() throws ServletException {
throw new UnsupportedOperationException();
}

@Override
public <T extends HttpUpgradeHandler> T upgrade(Class<T> aClass) throws IOException, ServletException {
throw new UnsupportedOperationException();
}

@Override
public String getCharacterEncoding() {
if (headers == null) {
Expand All @@ -250,30 +170,17 @@ public String getCharacterEncoding() {

@Override
public void setCharacterEncoding(String s) throws UnsupportedEncodingException {
if (headers == null || !headers.containsKey(HttpHeaders.CONTENT_TYPE)) {
log.debug("Called set character encoding to " + SecurityUtils.crlf(s) + " on a request without a content type. Character encoding will not be set");
return;
}
String currentContentType = headers.getFirst(HttpHeaders.CONTENT_TYPE);
headers.putSingle(HttpHeaders.CONTENT_TYPE, appendCharacterEncoding(currentContentType, s));
setCharacterEncoding(s, headers);
}

@Override
public int getContentLength() {
String headerValue = headers.getFirst(HttpHeaders.CONTENT_LENGTH);
if (headerValue == null) {
return -1;
}
return Integer.parseInt(headerValue);
return getContentLength(headers);
}

@Override
public long getContentLengthLong() {
String headerValue = headers.getFirst(HttpHeaders.CONTENT_LENGTH);
if (headerValue == null) {
return -1;
}
return Long.parseLong(headerValue);
return getContentLengthLong(headers);
}

@Override
Expand All @@ -286,17 +193,7 @@ public String getContentType() {

@Override
public String getParameter(String s) {
String queryStringParameter = getFirstQueryParamValue(queryString, s, config.isQueryStringCaseSensitive());
if (queryStringParameter != null) {
return queryStringParameter;
}

String[] bodyParams = getFormBodyParameterCaseInsensitive(s);
if (bodyParams.length == 0) {
return null;
} else {
return bodyParams[0];
}
return getParameter(queryString, s, config.isQueryStringCaseSensitive());
}

@Override
Expand All @@ -315,7 +212,7 @@ public String[] getParameterValues(String s) {

values.addAll(Arrays.asList(getFormBodyParameterCaseInsensitive(s)));

if (values.size() == 0) {
if (values.isEmpty()) {
return null;
} else {
return values.toArray(new String[0]);
Expand Down Expand Up @@ -409,16 +306,6 @@ public Enumeration<Locale> getLocales() {
return Collections.enumeration(locales);
}

@Override
public boolean isSecure() {
return securityContext.isSecure();
}

@Override
public RequestDispatcher getRequestDispatcher(String s) {
return getServletContext().getRequestDispatcher(s);
}

@Override
public int getRemotePort() {
return 0;
Expand Down Expand Up @@ -456,6 +343,8 @@ public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse se
return asyncContext;
}



@Override
public AsyncContext getAsyncContext() {
if (asyncContext == null) {
Expand All @@ -475,11 +364,6 @@ public String getProtocolRequestId() {
return "";
}

@Override
public ServletConnection getServletConnection() {
return null;
}

private MultiValuedTreeMap<String, String> parseRawQueryString(String qs) {
if (qs == null || "".equals(qs.trim())) {
return new MultiValuedTreeMap<>();
Expand All @@ -505,7 +389,7 @@ private MultiValuedTreeMap<String, String> parseRawQueryString(String qs) {
return qsMap;
}

private Headers headersMapToMultiValue(Map<String, String> headers) {
protected static Headers headersMapToMultiValue(Map<String, String> headers) {
if (headers == null || headers.size() == 0) {
return new Headers();
}
Expand Down