Skip to content

Commit

Permalink
Merge pull request #1749 from cherrylzhao/dev
Browse files Browse the repository at this point in the history
Refactor XA-Transaction & Using ShardingTransactionEngine for proxy.
  • Loading branch information
terrymanu committed Jan 15, 2019
2 parents 9532e67 + 69615c3 commit 7e583c0
Show file tree
Hide file tree
Showing 21 changed files with 156 additions and 158 deletions.
Expand Up @@ -81,9 +81,6 @@ protected AbstractConnectionAdapter(final TransactionType transactionType) {
rootInvokeHook.start();
this.transactionType = transactionType;
shardingTransactionEngine = ShardingTransactionEngineRegistry.getEngine(transactionType);
if (TransactionType.LOCAL != transactionType) {
Preconditions.checkNotNull(shardingTransactionEngine, "Cannot find transaction manager of [%s]", transactionType);
}
}

/**
Expand Down
Expand Up @@ -183,9 +183,7 @@ private List<Connection> createNewConnections(final ConnectionMode connectionMod
}

private List<Connection> getConnectionFromUnderlying(final ConnectionMode connectionMode, final String dataSourceName, final int connectionSize) throws SQLException {
return TransactionType.XA == transactionType
? logicSchema.getBackendDataSource().getConnections(connectionMode, dataSourceName, connectionSize, TransactionType.XA)
: logicSchema.getBackendDataSource().getConnections(connectionMode, dataSourceName, connectionSize);
return logicSchema.getBackendDataSource().getConnections(connectionMode, dataSourceName, connectionSize, transactionType);
}

/**
Expand Down
Expand Up @@ -18,8 +18,6 @@
package io.shardingsphere.shardingproxy.backend.jdbc.connection;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import io.shardingsphere.transaction.api.TransactionType;
import io.shardingsphere.transaction.core.ShardingTransactionEngineRegistry;
import io.shardingsphere.transaction.spi.ShardingTransactionEngine;
import lombok.RequiredArgsConstructor;
Expand All @@ -43,19 +41,19 @@ public void begin() {
connection.getStateHandler().getAndSetStatus(ConnectionStatus.TRANSACTION);
connection.releaseConnections(false);
}
if (!shardingTransactionEngine.isPresent() || !shardingTransactionEngine.get().isInTransaction()) {
if (!shardingTransactionEngine.isPresent()) {
new LocalTransactionManager(connection).begin();
} else if (TransactionType.XA == shardingTransactionEngine.get().getTransactionType()) {
} else {
shardingTransactionEngine.get().begin();
}
}

@Override
public void commit() throws SQLException {
Optional<ShardingTransactionEngine> shardingTransactionEngine = getShardingTransactionEngine(connection);
if (!shardingTransactionEngine.isPresent() || !shardingTransactionEngine.get().isInTransaction()) {
if (!shardingTransactionEngine.isPresent()) {
new LocalTransactionManager(connection).commit();
} else if (TransactionType.XA == shardingTransactionEngine.get().getTransactionType()) {
} else {
shardingTransactionEngine.get().commit();
connection.getStateHandler().getAndSetStatus(ConnectionStatus.TERMINATED);
}
Expand All @@ -64,20 +62,15 @@ public void commit() throws SQLException {
@Override
public void rollback() throws SQLException {
Optional<ShardingTransactionEngine> shardingTransactionEngine = getShardingTransactionEngine(connection);
if (!shardingTransactionEngine.isPresent() || !shardingTransactionEngine.get().isInTransaction()) {
if (!shardingTransactionEngine.isPresent()) {
new LocalTransactionManager(connection).rollback();
} else if (TransactionType.XA == shardingTransactionEngine.get().getTransactionType()) {
} else {
shardingTransactionEngine.get().rollback();
connection.getStateHandler().getAndSetStatus(ConnectionStatus.TERMINATED);
}
}

private Optional<ShardingTransactionEngine> getShardingTransactionEngine(final BackendConnection connection) {
TransactionType transactionType = connection.getTransactionType();
ShardingTransactionEngine result = ShardingTransactionEngineRegistry.getEngine(transactionType);
if (null != transactionType && transactionType != TransactionType.LOCAL) {
Preconditions.checkNotNull(result, String.format("Cannot find transaction manager of [%s]", transactionType));
}
return Optional.fromNullable(result);
return Optional.fromNullable(ShardingTransactionEngineRegistry.getEngine(connection.getTransactionType()));
}
}
Expand Up @@ -18,11 +18,14 @@
package io.shardingsphere.shardingproxy.backend.jdbc.datasource;

import io.shardingsphere.core.constant.ConnectionMode;
import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.core.exception.ShardingException;
import io.shardingsphere.core.util.ReflectiveUtil;
import io.shardingsphere.shardingproxy.backend.BackendDataSource;
import io.shardingsphere.shardingproxy.util.DataSourceParameter;
import io.shardingsphere.transaction.api.TransactionType;
import io.shardingsphere.transaction.core.ShardingTransactionEngineRegistry;
import io.shardingsphere.transaction.spi.ShardingTransactionEngine;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -50,31 +53,25 @@ public final class JDBCBackendDataSource implements BackendDataSource, AutoClose

private Map<String, DataSource> dataSources;

private Map<String, DataSource> xaDataSources;

private JDBCBackendDataSourceFactory hikariDataSourceFactory = JDBCRawBackendDataSourceFactory.getInstance();

private JDBCBackendDataSourceFactory xaDataSourceFactory = JDBCXABackendDataSourceFactory.getInstance();

public JDBCBackendDataSource(final Map<String, DataSourceParameter> dataSourceParameters) {
createDataSourceMap(dataSourceParameters);
}

private void createDataSourceMap(final Map<String, DataSourceParameter> dataSourceParameters) {
Map<String, DataSource> dataSourceMap = new LinkedHashMap<>(dataSourceParameters.size(), 1);
Map<String, DataSource> xaDataSourceMap = new LinkedHashMap<>(dataSourceParameters.size(), 1);
for (Entry<String, DataSourceParameter> entry : dataSourceParameters.entrySet()) {
try {
dataSourceMap.put(entry.getKey(), hikariDataSourceFactory.build(entry.getKey(), entry.getValue()));
xaDataSourceMap.put(entry.getKey(), xaDataSourceFactory.build(entry.getKey(), entry.getValue()));
// CHECKSTYLE:OFF
} catch (final Exception ex) {
// CHECKSTYLE:ON
throw new ShardingException(String.format("Can not build data source, name is `%s`.", entry.getKey()), ex);
}
}
this.dataSources = dataSourceMap;
this.xaDataSources = xaDataSourceMap;
ShardingTransactionEngineRegistry.init(DatabaseType.MySQL, dataSourceMap);
}

/**
Expand Down Expand Up @@ -113,23 +110,23 @@ public List<Connection> getConnections(final ConnectionMode connectionMode, fina
*/
@SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter")
public List<Connection> getConnections(final ConnectionMode connectionMode, final String dataSourceName, final int connectionSize, final TransactionType transactionType) throws SQLException {
DataSource dataSource = TransactionType.XA == transactionType ? xaDataSources.get(dataSourceName) : dataSources.get(dataSourceName);
DataSource dataSource = dataSources.get(dataSourceName);
if (1 == connectionSize) {
return Collections.singletonList(dataSource.getConnection());
return Collections.singletonList(createConnection(transactionType, dataSourceName, dataSource));
}
if (ConnectionMode.CONNECTION_STRICTLY == connectionMode) {
return createConnections(dataSource, connectionSize);
return createConnections(transactionType, dataSourceName, dataSource, connectionSize);
}
synchronized (dataSource) {
return createConnections(dataSource, connectionSize);
return createConnections(transactionType, dataSourceName, dataSource, connectionSize);
}
}

private List<Connection> createConnections(final DataSource dataSource, final int connectionSize) throws SQLException {
private List<Connection> createConnections(final TransactionType transactionType, final String dataSourceName, final DataSource dataSource, final int connectionSize) throws SQLException {
List<Connection> result = new ArrayList<>(connectionSize);
for (int i = 0; i < connectionSize; i++) {
try {
result.add(dataSource.getConnection());
result.add(createConnection(transactionType, dataSourceName, dataSource));
} catch (final SQLException ex) {
for (Connection each : result) {
each.close();
Expand All @@ -140,14 +137,21 @@ private List<Connection> createConnections(final DataSource dataSource, final in
return result;
}

private Connection createConnection(final TransactionType transactionType, final String dataSourceName, final DataSource dataSource) throws SQLException {
ShardingTransactionEngine shardingTransactionEngine = ShardingTransactionEngineRegistry.getEngine(transactionType);
return isInShardingTransaction(shardingTransactionEngine) ? shardingTransactionEngine.getConnection(dataSourceName) : dataSource.getConnection();
}

private boolean isInShardingTransaction(final ShardingTransactionEngine shardingTransactionEngine) {
return null != shardingTransactionEngine && shardingTransactionEngine.isInTransaction();
}

@Override
public void close() {
public void close() throws Exception {
if (null != dataSources) {
closeDataSource(dataSources);
}
if (null != xaDataSources) {
closeDataSource(xaDataSources);
}
ShardingTransactionEngineRegistry.close();
}

private void closeDataSource(final Map<String, DataSource> dataSourceMap) {
Expand Down
Expand Up @@ -74,9 +74,10 @@ protected final Map<String, String> getDataSourceURLs(final Map<String, DataSour
* Renew data source configuration.
*
* @param dataSourceChangedEvent data source changed event.
* @throws Exception exception
*/
@Subscribe
public final synchronized void renew(final DataSourceChangedEvent dataSourceChangedEvent) {
public final synchronized void renew(final DataSourceChangedEvent dataSourceChangedEvent) throws Exception {
if (!name.equals(dataSourceChangedEvent.getShardingSchemaName())) {
return;
}
Expand Down
Expand Up @@ -70,7 +70,7 @@ public void setUp() {
@Test
public void assertGetConnectionCacheIsEmpty() throws SQLException {
backendConnection.getStateHandler().getAndSetStatus(ConnectionStatus.TRANSACTION);
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(MockConnectionUtil.mockNewConnections(2));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(2));
List<Connection> actualConnections = backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 2);
assertThat(actualConnections.size(), is(2));
assertThat(backendConnection.getConnectionSize(), is(2));
Expand All @@ -91,7 +91,7 @@ public void assertGetConnectionSizeLessThanCache() throws SQLException {
public void assertGetConnectionSizeGreaterThanCache() throws SQLException {
backendConnection.getStateHandler().getAndSetStatus(ConnectionStatus.TRANSACTION);
MockConnectionUtil.setCachedConnections(backendConnection, "ds1", 10);
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(MockConnectionUtil.mockNewConnections(2));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(2));
List<Connection> actualConnections = backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 12);
assertThat(actualConnections.size(), is(12));
assertThat(backendConnection.getConnectionSize(), is(12));
Expand All @@ -101,7 +101,7 @@ public void assertGetConnectionSizeGreaterThanCache() throws SQLException {
@Test
public void assertGetConnectionWithMethodInvocation() throws SQLException {
backendConnection.getStateHandler().getAndSetStatus(ConnectionStatus.TRANSACTION);
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(MockConnectionUtil.mockNewConnections(2));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(2));
setMethodInvocation();
List<Connection> actualConnections = backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 2);
verify(backendConnection.getMethodInvocations().iterator().next(), times(2)).invoke(any());
Expand All @@ -123,7 +123,7 @@ private void setMethodInvocation() {
@SneakyThrows
public void assertMultiThreadGetConnection() {
MockConnectionUtil.setCachedConnections(backendConnection, "ds1", 10);
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(MockConnectionUtil.mockNewConnections(2));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(2));
Thread thread1 = new Thread(new Runnable() {
@Override
public void run() {
Expand Down Expand Up @@ -156,7 +156,7 @@ public void assertAutoCloseConnectionWithoutTransaction() throws SQLException {
BackendConnection actual;
try (BackendConnection backendConnection = new BackendConnection(TransactionType.LOCAL)) {
backendConnection.setCurrentSchema("schema_0");
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(12))).thenReturn(MockConnectionUtil.mockNewConnections(12));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(12), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(12));
backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 12);
assertThat(backendConnection.getStateHandler().getStatus(), is(ConnectionStatus.RUNNING));
mockResultSetAndStatement(backendConnection);
Expand All @@ -175,7 +175,7 @@ public void assertAutoCloseConnectionWithTransaction() throws SQLException {
try (BackendConnection backendConnection = new BackendConnection(TransactionType.LOCAL)) {
backendConnection.setCurrentSchema("schema_0");
MockConnectionUtil.setCachedConnections(backendConnection, "ds1", 10);
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(MockConnectionUtil.mockNewConnections(2));
when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2), eq(TransactionType.LOCAL))).thenReturn(MockConnectionUtil.mockNewConnections(2));
backendConnection.getStateHandler().getAndSetStatus(ConnectionStatus.TRANSACTION);
backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 12);
mockResultSetAndStatement(backendConnection);
Expand Down
Expand Up @@ -21,8 +21,8 @@
import io.shardingsphere.core.constant.DatabaseType;
import io.shardingsphere.transaction.api.TransactionType;
import io.shardingsphere.transaction.spi.ShardingTransactionEngine;
import io.shardingsphere.transaction.xa.jta.connection.ShardingXAConnection;
import io.shardingsphere.transaction.xa.jta.datasource.ShardingXADataSource;
import io.shardingsphere.transaction.xa.jta.connection.SingleXAConnection;
import io.shardingsphere.transaction.xa.jta.datasource.SingleXADataSource;
import io.shardingsphere.transaction.xa.manager.XATransactionManagerLoader;
import io.shardingsphere.transaction.xa.spi.XATransactionManager;
import lombok.SneakyThrows;
Expand All @@ -41,7 +41,7 @@
*/
public final class XAShardingTransactionEngine implements ShardingTransactionEngine {

private final Map<String, ShardingXADataSource> cachedShardingXADataSourceMap = new HashMap<>();
private final Map<String, SingleXADataSource> cachedSingleXADataSourceMap = new HashMap<>();

private final XATransactionManager xaTransactionManager = XATransactionManagerLoader.getInstance().getTransactionManager();

Expand All @@ -53,9 +53,9 @@ public void init(final DatabaseType databaseType, final Map<String, DataSource>
continue;
}
String resourceName = entry.getKey();
ShardingXADataSource shardingXADataSource = new ShardingXADataSource(databaseType, resourceName, entry.getValue());
cachedShardingXADataSourceMap.put(resourceName, shardingXADataSource);
xaTransactionManager.registerRecoveryResource(resourceName, shardingXADataSource.getXaDataSource());
SingleXADataSource singleXADataSource = new SingleXADataSource(databaseType, resourceName, entry.getValue());
cachedSingleXADataSourceMap.put(resourceName, singleXADataSource);
xaTransactionManager.registerRecoveryResource(resourceName, singleXADataSource.getXaDataSource());
}
xaTransactionManager.init();
}
Expand All @@ -74,9 +74,9 @@ public boolean isInTransaction() {
@SneakyThrows
@Override
public Connection getConnection(final String dataSourceName) {
ShardingXAConnection shardingXAConnection = cachedShardingXADataSourceMap.get(dataSourceName).getXAConnection();
xaTransactionManager.enlistResource(shardingXAConnection.getXAResource());
return shardingXAConnection.getConnection();
SingleXAConnection singleXAConnection = cachedSingleXADataSourceMap.get(dataSourceName).getXAConnection();
xaTransactionManager.enlistResource(singleXAConnection.getXAResource());
return singleXAConnection.getConnection();
}

@Override
Expand All @@ -96,10 +96,10 @@ public void rollback() {

@Override
public void close() throws Exception {
for (ShardingXADataSource each : cachedShardingXADataSourceMap.values()) {
for (SingleXADataSource each : cachedSingleXADataSourceMap.values()) {
xaTransactionManager.removeRecoveryResource(each.getResourceName(), each.getXaDataSource());
}
cachedShardingXADataSourceMap.clear();
cachedSingleXADataSourceMap.clear();
xaTransactionManager.close();
}
}
Expand Up @@ -17,7 +17,7 @@

package io.shardingsphere.transaction.xa.jta.connection;

import io.shardingsphere.transaction.xa.jta.resource.ShardingXAResource;
import io.shardingsphere.transaction.xa.jta.resource.SingleXAResource;
import lombok.RequiredArgsConstructor;

import javax.sql.ConnectionEventListener;
Expand All @@ -28,20 +28,20 @@
import java.sql.SQLException;

/**
* Sharding XA Connection.
* Single XA Connection.
*
* @author zhaojun
*/
@RequiredArgsConstructor
public final class ShardingXAConnection implements XAConnection {
public final class SingleXAConnection implements XAConnection {

private final String resourceName;

private final XAConnection delegate;

@Override
public XAResource getXAResource() throws SQLException {
return new ShardingXAResource(resourceName, delegate.getXAResource());
return new SingleXAResource(resourceName, delegate.getXAResource());
}

@Override
Expand Down

0 comments on commit 7e583c0

Please sign in to comment.