Skip to content

Commit

Permalink
Added unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
cgivre committed May 15, 2022
1 parent e40ab3d commit 335dc63
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 26 deletions.
Expand Up @@ -93,6 +93,18 @@ public TokenRegistry getTokenRegistry() {
return tokenRegistry;
}

/**
* This method returns the {@link TokenRegistry} for a given user. It is only used for testing user translation
* with OAuth 2.0.
* @param username A {@link String} of the current active user.
* @return A {@link TokenRegistry} for the given user.
*/
@VisibleForTesting
public TokenRegistry getTokenRegistry(String username) {
initializeOauthTokenTable(username);
return tokenRegistry;
}

public PersistentTokenTable getTokenTable() { return tokenRegistry.getTokenTable(getName()); }

@Override
Expand Down
Expand Up @@ -20,13 +20,9 @@

import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.FormBody;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
Expand All @@ -37,25 +33,35 @@
import org.apache.drill.common.logical.CredentialedStoragePluginConfig;
import org.apache.drill.common.logical.OAuthConfig;
import org.apache.drill.common.logical.StoragePluginConfig.AuthMode;
import org.apache.drill.common.logical.security.CredentialsProvider;
import org.apache.drill.common.logical.security.PlainCredentialsProvider;
import org.apache.drill.common.types.TypeProtos.DataMode;
import org.apache.drill.common.types.TypeProtos.MinorType;
import org.apache.drill.common.util.DrillFileUtils;
import org.apache.drill.exec.ExecConstants;
import org.apache.drill.exec.oauth.PersistentTokenTable;
import org.apache.drill.exec.physical.rowSet.RowSet;
import org.apache.drill.exec.rpc.user.security.testing.UserAuthenticatorTestImpl;
import org.apache.drill.exec.physical.rowSet.RowSetBuilder;
import org.apache.drill.exec.record.metadata.SchemaBuilder;
import org.apache.drill.exec.record.metadata.TupleMetadata;
import org.apache.drill.exec.store.StoragePlugin;
import org.apache.drill.exec.store.StoragePluginRegistry;
import org.apache.drill.exec.store.security.oauth.OAuthTokenCredentials;
import org.apache.drill.shaded.guava.com.google.common.base.Charsets;
import org.apache.drill.shaded.guava.com.google.common.io.Files;
import org.apache.drill.test.BaseDirTestWatcher;
import org.apache.drill.test.ClientFixture;
import org.apache.drill.test.ClusterFixtureBuilder;
import org.apache.drill.test.ClusterTest;
import org.apache.drill.test.QueryBuilder.QuerySummary;
import org.apache.drill.test.rowSet.RowSetUtilities;
import org.jetbrains.annotations.NotNull;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -75,6 +81,7 @@
import static org.junit.Assert.fail;

public class TestUserTranslationInHttpPlugin extends ClusterTest {
private static final Logger logger = LoggerFactory.getLogger(TestUserTranslationInHttpPlugin.class);

private static final int MOCK_SERVER_PORT = 47775;
private static final int TIMEOUT = 30;
Expand All @@ -86,6 +93,7 @@ public class TestUserTranslationInHttpPlugin extends ClusterTest {
.cookieJar(new TestCookieJar())
.build();
private static String TEST_JSON_RESPONSE_WITH_DATATYPES;
private static String ACCESS_TOKEN_RESPONSE;
private static int portNumber;


Expand All @@ -100,11 +108,11 @@ public void cleanup() throws Exception {
@BeforeClass
public static void setup() throws Exception {
TEST_JSON_RESPONSE_WITH_DATATYPES = Files.asCharSource(DrillFileUtils.getResourceAsFile("/data/response2.json"), Charsets.UTF_8).read();
ACCESS_TOKEN_RESPONSE = Files.asCharSource(DrillFileUtils.getResourceAsFile("/data/oauth_access_token_response.json"), Charsets.UTF_8).read();

ClusterFixtureBuilder builder = new ClusterFixtureBuilder(dirTestWatcher)
.configProperty(ExecConstants.HTTP_ENABLE, true)
.configProperty(ExecConstants.HTTP_PORT_HUNT, true)
.configProperty(ExecConstants.USER_AUTHENTICATOR_IMPL, UserAuthenticatorTestImpl.TYPE)
.configProperty(ExecConstants.IMPERSONATION_ENABLED, true);

startCluster(builder);
Expand All @@ -123,6 +131,15 @@ public static void setup() throws Exception {
.callbackURL(makeUrl("http://localhost:%d") + "/update_oauth2_authtoken")
.build();

Map<String, String> oauthCreds = new HashMap<>();
oauthCreds.put("clientID", "12345");
oauthCreds.put("clientSecret", "54321");;
oauthCreds.put(OAuthTokenCredentials.TOKEN_URI, "http://localhost:" + MOCK_SERVER_PORT + "/get_access_token");

CredentialsProvider oauthCredentialProvider = new PlainCredentialsProvider(oauthCreds);



Map<String, HttpApiConfig> configs = new HashMap<>();
configs.put("sharedEndpoint", testEndpoint);

Expand All @@ -132,11 +149,16 @@ public static void setup() throws Exception {

PlainCredentialsProvider credentialsProvider = new PlainCredentialsProvider(TEST_USER_2, credentials);

HttpStoragePluginConfig mockStorageConfigWithWorkspace = new HttpStoragePluginConfig(false, configs, 2, null, null, "", 80, "", "", "", oAuthConfig, credentialsProvider,
HttpStoragePluginConfig mockStorageConfigWithWorkspace = new HttpStoragePluginConfig(false, configs, 2, null, null, "", 80, "", "", "", null, credentialsProvider,
AuthMode.USER_TRANSLATION.name());
mockStorageConfigWithWorkspace.setEnabled(true);

HttpStoragePluginConfig mockOAuthPlugin = new HttpStoragePluginConfig(false, configs, 2, null, null, "", 80, "", "", "", oAuthConfig, oauthCredentialProvider,
AuthMode.USER_TRANSLATION.name());
mockOAuthPlugin.setEnabled(true);

cluster.defineStoragePlugin("local", mockStorageConfigWithWorkspace);
cluster.defineStoragePlugin("oauth", mockOAuthPlugin);
}

@Test
Expand Down Expand Up @@ -203,31 +225,57 @@ public void testQueryWithMissingCredentials() throws Exception {
}
}

private boolean makeLoginRequest(String username, String password) throws IOException {
String loginURL = "http://localhost:" + portNumber + "/j_security_check";

RequestBody formBody = new FormBody.Builder()
.add("j_username", username)
.add("j_password", password)
@Test
public void testQueryWithOAuth() throws Exception {
ClientFixture client = cluster
.clientBuilder()
.property(DrillProperties.USER, TEST_USER_2)
.property(DrillProperties.PASSWORD, TEST_USER_2_PASSWORD)
.build();

Request request = new Request.Builder()
.url(loginURL)
.post(formBody)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.addHeader("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8")
.build();
try (MockWebServer server = startServer()) {
// Get the token table for test user 2, which should be empty
PersistentTokenTable tokenTable = ((HttpStoragePlugin) cluster.storageRegistry()
.getPlugin("oauth"))
.getTokenRegistry(TEST_USER_2)
.getTokenTable("oauth");

Response response = httpClient.newCall(request).execute();
return response.code() == 200;
}
// Add the access tokens for user 2
tokenTable.setAccessToken("you_have_access_2");
tokenTable.setRefreshToken("refresh_me_2");

@Test
public void testOAuthWithUserTranslation() throws Exception {
makeLoginRequest(TEST_USER_2, TEST_USER_2_PASSWORD);
assertEquals("you_have_access_2", tokenTable.getAccessToken());
assertEquals("refresh_me_2", tokenTable.getRefreshToken());

// Now execute a query and get query results.
server.enqueue(new MockResponse()
.setResponseCode(200)
.setBody(TEST_JSON_RESPONSE_WITH_DATATYPES));

// Now get credentials for this user
String sql = "SELECT * FROM oauth.sharedEndpoint";
RowSet results = queryBuilder().sql(sql).rowSet();

TupleMetadata expectedSchema = new SchemaBuilder()
.add("col_1", MinorType.FLOAT8, DataMode.OPTIONAL)
.add("col_2", MinorType.BIGINT, DataMode.OPTIONAL)
.add("col_3", MinorType.VARCHAR, DataMode.OPTIONAL)
.build();

RowSet expected = new RowSetBuilder(client.allocator(), expectedSchema)
.addRow(1.0, 2, "3.0")
.addRow(4.0, 5, "6.0")
.build();

RowSetUtilities.verify(expected, results);

// Verify the correct tokens were passed
RecordedRequest recordedRequest = server.takeRequest();
String authToken = recordedRequest.getHeader("Authorization");
assertEquals("you_have_access_2", authToken);
} catch (Exception e) {
logger.debug(e.getMessage());
fail();
}
}

@Test
Expand Down

0 comments on commit 335dc63

Please sign in to comment.