Skip to content

Commit

Permalink
Aesthetic clean-up / format OAuth code (#7782)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopheDuong committed Nov 9, 2021
1 parent 3d4f730 commit 21d6dd9
Show file tree
Hide file tree
Showing 37 changed files with 690 additions and 697 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
public class ElasticsearchAirbyteMessageConsumerFactory {

private static final Logger log = LoggerFactory.getLogger(ElasticsearchAirbyteMessageConsumerFactory.class);
private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4 ; // 256mib
private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4; // 256mib
private static final ObjectMapper mapper = new ObjectMapper();

private static AtomicLong recordsWritten = new AtomicLong(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ static void insertRawRecordsInSingleQuery(final String insertQueryComponent,
// string. Thus there will be two loops below.
// 1) Loop over records to build the full string.
// 2) Loop over the records and bind the appropriate values to the string.
// We also partition the query to run on 10k records at a time, since some DBs set a max limit on how many records can be inserted at once
// TODO(sherif) this should use a smarter, destination-aware partitioning scheme instead of 10k by default
for (List<AirbyteRecordMessage> partition : Iterables.partition(records, 10_000)){
// We also partition the query to run on 10k records at a time, since some DBs set a max limit on
// how many records can be inserted at once
// TODO(sherif) this should use a smarter, destination-aware partitioning scheme instead of 10k by
// default
for (List<AirbyteRecordMessage> partition : Iterables.partition(records, 10_000)) {
final StringBuilder sql = new StringBuilder(insertQueryComponent);
partition.forEach(r -> sql.append(recordQueryComponent));
final String s = sql.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class CopyConsumerFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(CopyConsumerFactory.class);

private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4 ; // 256 mib
private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4; // 256 mib

public static <T> AirbyteMessageConsumer create(final Consumer<AirbyteMessage> outputRecordCollector,
final JdbcDatabase database,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class MeiliSearchDestination extends BaseConnector implements Destination

private static final Logger LOGGER = LoggerFactory.getLogger(MeiliSearchDestination.class);

private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4 ; //256mib
private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4; // 256mib
private static final DateTimeFormatter FORMATTER = DateTimeFormatter.ofPattern("uuuu-MM-dd'T'HH:mm:ss.SSSSSSSSS");

public static final String AB_PK_COLUMN = "_ab_pk";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.airbyte.commons.io.IOs;
import io.airbyte.commons.json.Jsons;
import io.airbyte.commons.resources.MoreResources;
import io.airbyte.commons.string.Strings;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.JavaBaseConstants;
import io.airbyte.integrations.destination.ExtendedNameTransformer;
import io.airbyte.integrations.standardtest.destination.DataArgumentsProvider;
Expand All @@ -22,19 +20,14 @@
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Scanner;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ArgumentsSource;

Expand Down Expand Up @@ -170,11 +163,12 @@ public void testSyncWithBillionRecords(final String messagesFilename, final Stri
runSyncAndVerifyStateOutput(config, largeNumberRecords, configuredCatalog, false);
}


private <T> T parseConfig(final String path, Class<T> clazz) throws IOException {
return Jsons.deserialize(MoreResources.readResource(path), clazz);
}
private JsonNode parseConfig(final String path) throws IOException {

private JsonNode parseConfig(final String path) throws IOException {
return Jsons.deserialize(MoreResources.readResource(path));
}

}
244 changes: 244 additions & 0 deletions airbyte-oauth/src/main/java/io/airbyte/oauth/BaseOAuth2Flow.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.oauth;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.persistence.ConfigNotFoundException;
import io.airbyte.config.persistence.ConfigRepository;
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URI;
import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.commons.lang3.RandomStringUtils;

/**
* Abstract Class factoring common behavior for oAuth 2.0 flow implementations
*/
public abstract class BaseOAuth2Flow extends BaseOAuthFlow {

/**
* Simple enum of content type strings and their respective encoding functions used for POSTing the
* access token request
*/
public enum TOKEN_REQUEST_CONTENT_TYPE {

URL_ENCODED("application/x-www-form-urlencoded", BaseOAuth2Flow::toUrlEncodedString),
JSON("application/json", BaseOAuth2Flow::toJson);

String contentType;
Function<Map<String, String>, String> converter;

TOKEN_REQUEST_CONTENT_TYPE(final String contentType, final Function<Map<String, String>, String> converter) {
this.contentType = contentType;
this.converter = converter;
}

}

protected final HttpClient httpClient;
private final TOKEN_REQUEST_CONTENT_TYPE tokenReqContentType;
private final Supplier<String> stateSupplier;

public BaseOAuth2Flow(final ConfigRepository configRepository, final HttpClient httpClient) {
this(configRepository, httpClient, BaseOAuth2Flow::generateRandomState);
}

public BaseOAuth2Flow(final ConfigRepository configRepository, final HttpClient httpClient, final Supplier<String> stateSupplier) {
this(configRepository, httpClient, stateSupplier, TOKEN_REQUEST_CONTENT_TYPE.URL_ENCODED);
}

public BaseOAuth2Flow(final ConfigRepository configRepository,
final HttpClient httpClient,
final Supplier<String> stateSupplier,
final TOKEN_REQUEST_CONTENT_TYPE tokenReqContentType) {
super(configRepository);
this.httpClient = httpClient;
this.stateSupplier = stateSupplier;
this.tokenReqContentType = tokenReqContentType;
}

@Override
public String getSourceConsentUrl(final UUID workspaceId, final UUID sourceDefinitionId, final String redirectUrl)
throws IOException, ConfigNotFoundException {
final JsonNode oAuthParamConfig = getSourceOAuthParamConfig(workspaceId, sourceDefinitionId);
return formatConsentUrl(sourceDefinitionId, getClientIdUnsafe(oAuthParamConfig), redirectUrl);
}

@Override
public String getDestinationConsentUrl(final UUID workspaceId, final UUID destinationDefinitionId, final String redirectUrl)
throws IOException, ConfigNotFoundException {
final JsonNode oAuthParamConfig = getDestinationOAuthParamConfig(workspaceId, destinationDefinitionId);
return formatConsentUrl(destinationDefinitionId, getClientIdUnsafe(oAuthParamConfig), redirectUrl);
}

/**
* Depending on the OAuth flow implementation, the URL to grant user's consent may differ,
* especially in the query parameters to be provided. This function should generate such consent URL
* accordingly.
*/
protected abstract String formatConsentUrl(UUID definitionId, String clientId, String redirectUrl) throws IOException;

private static String generateRandomState() {
return RandomStringUtils.randomAlphanumeric(7);
}

/**
* Generate a string to use as state in the OAuth process.
*/
protected String getState() {
return stateSupplier.get();
}

@Override
public Map<String, Object> completeSourceOAuth(final UUID workspaceId,
final UUID sourceDefinitionId,
final Map<String, Object> queryParams,
final String redirectUrl)
throws IOException, ConfigNotFoundException {
final JsonNode oAuthParamConfig = getSourceOAuthParamConfig(workspaceId, sourceDefinitionId);
return formatOAuthOutput(
oAuthParamConfig,
completeOAuthFlow(
getClientIdUnsafe(oAuthParamConfig),
getClientSecretUnsafe(oAuthParamConfig),
extractCodeParameter(queryParams),
redirectUrl,
oAuthParamConfig),
getDefaultOAuthOutputPath());
}

@Override
public Map<String, Object> completeDestinationOAuth(final UUID workspaceId,
final UUID destinationDefinitionId,
final Map<String, Object> queryParams,
final String redirectUrl)
throws IOException, ConfigNotFoundException {
final JsonNode oAuthParamConfig = getDestinationOAuthParamConfig(workspaceId, destinationDefinitionId);
return formatOAuthOutput(
oAuthParamConfig,
completeOAuthFlow(
getClientIdUnsafe(oAuthParamConfig),
getClientSecretUnsafe(oAuthParamConfig),
extractCodeParameter(queryParams),
redirectUrl,
oAuthParamConfig),
getDefaultOAuthOutputPath());
}

protected Map<String, Object> completeOAuthFlow(final String clientId,
final String clientSecret,
final String authCode,
final String redirectUrl,
final JsonNode oAuthParamConfig)
throws IOException {
final var accessTokenUrl = getAccessTokenUrl();
final HttpRequest request = HttpRequest.newBuilder()
.POST(HttpRequest.BodyPublishers
.ofString(tokenReqContentType.converter.apply(getAccessTokenQueryParameters(clientId, clientSecret, authCode, redirectUrl))))
.uri(URI.create(accessTokenUrl))
.header("Content-Type", tokenReqContentType.contentType)
.header("Accept", "application/json")
.build();
// TODO: Handle error response to report better messages
try {
final HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
return extractOAuthOutput(Jsons.deserialize(response.body()), accessTokenUrl);
} catch (final InterruptedException e) {
throw new IOException("Failed to complete OAuth flow", e);
}
}

/**
* Query parameters to provide the access token url with.
*/
protected Map<String, String> getAccessTokenQueryParameters(final String clientId,
final String clientSecret,
final String authCode,
final String redirectUrl) {
return ImmutableMap.<String, String>builder()
// required
.put("client_id", clientId)
.put("redirect_uri", redirectUrl)
.put("client_secret", clientSecret)
.put("code", authCode)
.build();
}

/**
* Once the user is redirected after getting their consent, the API should redirect them to a
* specific redirection URL along with query parameters. This function should parse and extract the
* code from these query parameters in order to continue the OAuth Flow.
*/
protected String extractCodeParameter(final Map<String, Object> queryParams) throws IOException {
if (queryParams.containsKey("code")) {
return (String) queryParams.get("code");
} else {
throw new IOException("Undefined 'code' from consent redirected url.");
}
}

/**
* Returns the URL where to retrieve the access token from.
*/
protected abstract String getAccessTokenUrl();

/**
* Extract all OAuth outputs from distant API response and store them in a flat map.
*/
protected Map<String, Object> extractOAuthOutput(final JsonNode data, final String accessTokenUrl) throws IOException {
final Map<String, Object> result = new HashMap<>();
if (data.has("refresh_token")) {
result.put("refresh_token", data.get("refresh_token").asText());
} else {
throw new IOException(String.format("Missing 'refresh_token' in query params from %s", accessTokenUrl));
}
return result;
}

@Override
protected List<String> getDefaultOAuthOutputPath() {
return List.of("credentials");
}

private static String urlEncode(final String s) {
try {
return URLEncoder.encode(s, StandardCharsets.UTF_8);
} catch (final Exception e) {
throw new RuntimeException(e);
}
}

private static String toUrlEncodedString(final Map<String, String> body) {
final StringBuilder result = new StringBuilder();
for (final var entry : body.entrySet()) {
if (result.length() > 0) {
result.append("&");
}
result.append(entry.getKey()).append("=").append(urlEncode(entry.getValue()));
}
return result.toString();
}

protected static String toJson(final Map<String, String> body) {
final Gson gson = new Gson();
final Type gsonType = new TypeToken<Map<String, String>>() {}.getType();
return gson.toJson(body, gsonType);
}

}

0 comments on commit 21d6dd9

Please sign in to comment.