diff --git a/pgjdbc/src/main/java/org/postgresql/core/SocketFactoryFactory.java b/pgjdbc/src/main/java/org/postgresql/core/SocketFactoryFactory.java index 09efa75f07..fe56354a21 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/SocketFactoryFactory.java +++ b/pgjdbc/src/main/java/org/postgresql/core/SocketFactoryFactory.java @@ -36,7 +36,7 @@ public static SocketFactory getSocketFactory(Properties info) throws PSQLExcepti return SocketFactory.getDefault(); } try { - return (SocketFactory) ObjectFactory.instantiate(socketFactoryClassName, info, true, + return ObjectFactory.instantiate(SocketFactory.class, socketFactoryClassName, info, true, PGProperty.SOCKET_FACTORY_ARG.get(info)); } catch (Exception e) { throw new PSQLException( @@ -61,7 +61,7 @@ public static SSLSocketFactory getSslSocketFactory(Properties info) throws PSQLE return new LibPQFactory(info); } try { - return (SSLSocketFactory) ObjectFactory.instantiate(classname, info, true, + return ObjectFactory.instantiate(SSLSocketFactory.class, classname, info, true, PGProperty.SSL_FACTORY_ARG.get(info)); } catch (Exception e) { throw new PSQLException( diff --git a/pgjdbc/src/main/java/org/postgresql/core/v3/AuthenticationPluginManager.java b/pgjdbc/src/main/java/org/postgresql/core/v3/AuthenticationPluginManager.java index 938f24632f..bcc8b3a9b6 100644 --- a/pgjdbc/src/main/java/org/postgresql/core/v3/AuthenticationPluginManager.java +++ b/pgjdbc/src/main/java/org/postgresql/core/v3/AuthenticationPluginManager.java @@ -66,11 +66,12 @@ public static T withPassword(AuthenticationRequestType type, Properties info } else { AuthenticationPlugin authPlugin; try { - authPlugin = (AuthenticationPlugin) ObjectFactory.instantiate(authPluginClassName, info, + authPlugin = ObjectFactory.instantiate(AuthenticationPlugin.class, authPluginClassName, info, false, null); } catch (Exception ex) { - LOGGER.log(Level.FINE, "Unable to load Authentication Plugin " + ex.toString()); - throw new PSQLException(ex.getMessage(), PSQLState.UNEXPECTED_ERROR); + String msg = GT.tr("Unable to load Authentication Plugin {0}", authPluginClassName); + LOGGER.log(Level.FINE, msg, ex); + throw new PSQLException(msg, PSQLState.INVALID_PARAMETER_VALUE, ex); } password = authPlugin.getPassword(type); @@ -106,7 +107,8 @@ public static T withEncodedPassword(AuthenticationRequestType type, Properti byte[] encodedPassword = withPassword(type, info, password -> { if (password == null) { throw new PSQLException( - GT.tr("The server requested password-based authentication, but no password was provided."), + GT.tr("The server requested password-based authentication, but no password was provided by plugin {0}", + PGProperty.AUTHENTICATION_PLUGIN_CLASS_NAME.get(info)), PSQLState.CONNECTION_REJECTED); } ByteBuffer buf = StandardCharsets.UTF_8.encode(CharBuffer.wrap(password)); diff --git a/pgjdbc/src/main/java/org/postgresql/ssl/LibPQFactory.java b/pgjdbc/src/main/java/org/postgresql/ssl/LibPQFactory.java index 67e1196b15..4249dcda77 100644 --- a/pgjdbc/src/main/java/org/postgresql/ssl/LibPQFactory.java +++ b/pgjdbc/src/main/java/org/postgresql/ssl/LibPQFactory.java @@ -56,7 +56,7 @@ private CallbackHandler getCallbackHandler( String sslpasswordcallback = PGProperty.SSL_PASSWORD_CALLBACK.get(info); if (sslpasswordcallback != null) { try { - cbh = (CallbackHandler) ObjectFactory.instantiate(sslpasswordcallback, info, false, null); + cbh = ObjectFactory.instantiate(CallbackHandler.class, sslpasswordcallback, info, false, null); } catch (Exception e) { throw new PSQLException( GT.tr("The password callback class provided {0} could not be instantiated.", diff --git a/pgjdbc/src/main/java/org/postgresql/ssl/MakeSSL.java b/pgjdbc/src/main/java/org/postgresql/ssl/MakeSSL.java index bf64673f41..849d107321 100644 --- a/pgjdbc/src/main/java/org/postgresql/ssl/MakeSSL.java +++ b/pgjdbc/src/main/java/org/postgresql/ssl/MakeSSL.java @@ -64,7 +64,7 @@ private static void verifyPeerName(PGStream stream, Properties info, SSLSocket n sslhostnameverifier = "PgjdbcHostnameVerifier"; } else { try { - hvn = (HostnameVerifier) instantiate(sslhostnameverifier, info, false, null); + hvn = instantiate(HostnameVerifier.class, sslhostnameverifier, info, false, null); } catch (Exception e) { throw new PSQLException( GT.tr("The HostnameVerifier class provided {0} could not be instantiated.", diff --git a/pgjdbc/src/main/java/org/postgresql/util/ObjectFactory.java b/pgjdbc/src/main/java/org/postgresql/util/ObjectFactory.java index ef24770e6f..e4fec28fc4 100644 --- a/pgjdbc/src/main/java/org/postgresql/util/ObjectFactory.java +++ b/pgjdbc/src/main/java/org/postgresql/util/ObjectFactory.java @@ -36,14 +36,15 @@ public class ObjectFactory { * @throws IllegalAccessException if something goes wrong * @throws InvocationTargetException if something goes wrong */ - public static Object instantiate(String classname, Properties info, boolean tryString, + public static T instantiate(Class expectedClass, String classname, Properties info, + boolean tryString, @Nullable String stringarg) throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException, InstantiationException, IllegalAccessException, InvocationTargetException { @Nullable Object[] args = {info}; - Constructor ctor = null; - Class cls = Class.forName(classname); + Constructor ctor = null; + Class cls = Class.forName(classname).asSubclass(expectedClass); try { ctor = cls.getConstructor(Properties.class); } catch (NoSuchMethodException ignored) { diff --git a/pgjdbc/src/test/java/org/postgresql/test/util/ObjectFactoryTest.java b/pgjdbc/src/test/java/org/postgresql/test/util/ObjectFactoryTest.java new file mode 100644 index 0000000000..e0a9d1f4b4 --- /dev/null +++ b/pgjdbc/src/test/java/org/postgresql/test/util/ObjectFactoryTest.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022, PostgreSQL Global Development Group + * See the LICENSE file in the project root for more information. + */ + +package org.postgresql.test.util; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +import org.postgresql.PGProperty; +import org.postgresql.jdbc.SslMode; +import org.postgresql.test.TestUtil; +import org.postgresql.util.ObjectFactory; +import org.postgresql.util.PSQLState; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Assertions; +import org.opentest4j.MultipleFailuresError; + +import java.sql.SQLException; +import java.util.Properties; + +import javax.net.SocketFactory; + +public class ObjectFactoryTest { + Properties props = new Properties(); + + static class BadObject { + static boolean wasInstantiated = false; + + BadObject() { + wasInstantiated = true; + throw new RuntimeException("I should not be instantiated"); + } + } + + private void testInvalidInstantiation(PGProperty prop, PSQLState expectedSqlState) { + prop.set(props, BadObject.class.getName()); + + BadObject.wasInstantiated = false; + SQLException ex = assertThrows(SQLException.class, () -> { + TestUtil.openDB(props); + }); + + try { + Assertions.assertAll( + () -> assertFalse(BadObject.wasInstantiated, "ObjectFactory should not have " + + "instantiated bad object for " + prop), + () -> assertEquals(expectedSqlState.getState(), ex.getSQLState(), () -> "#getSQLState()"), + () -> { + assertThrows( + ClassCastException.class, + () -> { + throw ex.getCause(); + }, + () -> "Wrong class specified for " + prop.name() + + " => ClassCastException is expected in SQLException#getCause()" + ); + } + ); + } catch (MultipleFailuresError e) { + // Add the original exception so it is easier to understand the reason for the test to fail + e.addSuppressed(ex); + throw e; + } + } + + @Test + public void testInvalidSocketFactory() { + testInvalidInstantiation(PGProperty.SOCKET_FACTORY, PSQLState.CONNECTION_FAILURE); + } + + @Test + public void testInvalidSSLFactory() { + TestUtil.assumeSslTestsEnabled(); + // We need at least "require" to trigger SslSockerFactory instantiation + PGProperty.SSL_MODE.set(props, SslMode.REQUIRE.value); + testInvalidInstantiation(PGProperty.SSL_FACTORY, PSQLState.CONNECTION_FAILURE); + } + + @Test + public void testInvalidAuthenticationPlugin() { + testInvalidInstantiation(PGProperty.AUTHENTICATION_PLUGIN_CLASS_NAME, + PSQLState.INVALID_PARAMETER_VALUE); + } + + @Test + public void testInvalidSslHostnameVerifier() { + TestUtil.assumeSslTestsEnabled(); + // Hostname verification is done at verify-full level only + PGProperty.SSL_MODE.set(props, SslMode.VERIFY_FULL.value); + PGProperty.SSL_ROOT_CERT.set(props, TestUtil.getSslTestCertPath("goodroot.crt")); + testInvalidInstantiation(PGProperty.SSL_HOSTNAME_VERIFIER, PSQLState.CONNECTION_FAILURE); + } + + @Test + public void testInstantiateInvalidSocketFactory() { + Properties props = new Properties(); + assertThrows(ClassCastException.class, () -> { + ObjectFactory.instantiate(SocketFactory.class, BadObject.class.getName(), props, + false, null); + }); + } +}