diff --git a/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/ConnectionCountManager.java b/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/ConnectionCountManager.java new file mode 100644 index 0000000000..db5df3aa22 --- /dev/null +++ b/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/ConnectionCountManager.java @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2023 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.odc.core.datasource; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import com.oceanbase.odc.core.shared.exception.OverLimitException; + +import lombok.extern.slf4j.Slf4j; + +/** + * Connection count manager for tracking database connections by url+username + */ +@Slf4j +public class ConnectionCountManager { + + private static volatile ConnectionCountManager instance; + private final Map connectionCountMap = new ConcurrentHashMap<>(); + private volatile long maxConnectionCount = -1; + + private ConnectionCountManager() {} + + public static ConnectionCountManager getInstance() { + if (instance == null) { + synchronized (ConnectionCountManager.class) { + if (instance == null) { + instance = new ConnectionCountManager(); + } + } + } + return instance; + } + + /** + * Initialize the max connection count limit + * + * @param maxConnectionCount max connection count, -1 means no limit + */ + public void setMaxConnectionCount(long maxConnectionCount) { + this.maxConnectionCount = maxConnectionCount; + log.info("Set max connection count to {}", maxConnectionCount); + } + + /** + * Get the max connection count limit + * + * @return max connection count, -1 means no limit + */ + public long getMaxConnectionCount() { + return maxConnectionCount; + } + + /** + * Generate key from url and username + * Extracts host:port from JDBC URL and combines with username + * + * @param url database url (JDBC URL format) + * @param username database username + * @return key string in format "host:port:username" + */ + public static String generateKey(String url, String username) { + if (url == null) { + url = ""; + } + if (username == null) { + username = ""; + } + String hostPort = extractHostAndPort(url); + return hostPort + ":" + username; + } + + /** + * Extract host:port from JDBC URL + * Supports formats: + * - jdbc:mysql://host:port/database?params + * - jdbc:oceanbase://host:port/database?params + * - jdbc:postgresql://host:port/database?params + * - jdbc:oracle:thin:@host:port:database + * - jdbc:oracle:thin:@//host:port/database + * + * @param jdbcUrl JDBC URL + * @return host:port string, or original url if parsing fails + */ + private static String extractHostAndPort(String jdbcUrl) { + if (jdbcUrl == null || jdbcUrl.isEmpty()) { + return ""; + } + + // Pattern for MySQL/OceanBase/PostgreSQL: jdbc:type://host:port/... + Pattern mysqlPattern = Pattern.compile("jdbc:(mysql|oceanbase|postgresql)://([^/:]+)(?::([0-9]+))?"); + Matcher mysqlMatcher = mysqlPattern.matcher(jdbcUrl); + if (mysqlMatcher.find()) { + String host = mysqlMatcher.group(2); + String port = mysqlMatcher.group(3); + if (port != null && !port.isEmpty()) { + return host + ":" + port; + } else { + // Use default port based on database type + String dbType = mysqlMatcher.group(1); + int defaultPort = getDefaultPort(dbType); + return host + ":" + defaultPort; + } + } + + // Pattern for Oracle SID format: jdbc:oracle:thin:@host:port:database + Pattern oracleSidPattern = Pattern.compile("jdbc:oracle:thin:@([^:/]+):([0-9]+)"); + Matcher oracleSidMatcher = oracleSidPattern.matcher(jdbcUrl); + if (oracleSidMatcher.find()) { + return oracleSidMatcher.group(1) + ":" + oracleSidMatcher.group(2); + } + + // Pattern for Oracle Service Name format: jdbc:oracle:thin:@//host:port/database + Pattern oracleServicePattern = Pattern.compile("jdbc:oracle:thin:@//([^:/]+):([0-9]+)"); + Matcher oracleServiceMatcher = oracleServicePattern.matcher(jdbcUrl); + if (oracleServiceMatcher.find()) { + return oracleServiceMatcher.group(1) + ":" + oracleServiceMatcher.group(2); + } + + // If no pattern matches, log warning and return original URL + log.warn("Failed to extract host:port from JDBC URL: {}, using original URL as key", jdbcUrl); + return jdbcUrl; + } + + /** + * Get default port for database type + * + * @param dbType database type (mysql, oceanbase, postgresql, etc.) + * @return default port number + */ + private static int getDefaultPort(String dbType) { + if (dbType == null) { + return 3306; // Default to MySQL port + } + switch (dbType.toLowerCase()) { + case "mysql": + return 3306; + case "oceanbase": + return 2883; + case "postgresql": + return 5432; + default: + return 3306; // Default to MySQL port + } + } + + /** + * Increment connection count for the given key + * + * @param key connection key (url+username) + * @return current count after increment + * @throws OverLimitException if connection count exceeds limit + */ + public int incrementConnectionCount(String key) { + if (maxConnectionCount > 0) { + AtomicInteger count = connectionCountMap.computeIfAbsent(key, k -> new AtomicInteger(0)); + int currentCount = count.incrementAndGet(); + log.debug("Increment connection count, key={}, currentCount={}, maxCount={}", key, currentCount, + maxConnectionCount); + if (currentCount > maxConnectionCount) { + count.decrementAndGet(); + String message = String.format("数据库连接数超限, 当前值=%d, 最大值=%d", + currentCount, maxConnectionCount); + throw new IllegalStateException(String.format(message, maxConnectionCount)); + } + return currentCount; + } else { + // No limit, just increment + AtomicInteger count = connectionCountMap.computeIfAbsent(key, k -> new AtomicInteger(0)); + int currentCount = count.incrementAndGet(); + log.debug("Increment connection count (no limit), key={}, currentCount={}", key, currentCount); + return currentCount; + } + } + + /** + * Decrement connection count for the given key + * + * @param key connection key (url+username) + * @return current count after decrement + */ + public int decrementConnectionCount(String key) { + AtomicInteger count = connectionCountMap.get(key); + if (count == null) { + log.warn("Attempt to decrement connection count for non-existent key: {}", key); + return 0; + } + int currentCount = count.decrementAndGet(); + log.debug("Decrement connection count, key={}, currentCount={}", key, currentCount); + if (currentCount <= 0) { + connectionCountMap.remove(key); + log.debug("Removed connection count entry for key: {}", key); + } + return currentCount; + } + + /** + * Get current connection count for the given key + * + * @param key connection key (url+username) + * @return current count + */ + public int getConnectionCount(String key) { + AtomicInteger count = connectionCountMap.get(key); + return count == null ? 0 : count.get(); + } + + /** + * Clear all connection counts (mainly for testing) + */ + public void clear() { + connectionCountMap.clear(); + log.info("Cleared all connection counts"); + } + + /** + * Get all connection counts (mainly for monitoring) + * + * @return snapshot of connection count map + */ + public Map getAllConnectionCounts() { + Map snapshot = new ConcurrentHashMap<>(); + connectionCountMap.forEach((key, count) -> snapshot.put(key, count.get())); + return snapshot; + } + +} diff --git a/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/SingleConnectionDataSource.java b/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/SingleConnectionDataSource.java index a58943a422..8c6d9a859b 100644 --- a/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/SingleConnectionDataSource.java +++ b/server/odc-core/src/main/java/com/oceanbase/odc/core/datasource/SingleConnectionDataSource.java @@ -82,6 +82,8 @@ public class SingleConnectionDataSource extends BaseClassBasedDataSource impleme private ScheduledExecutorService keepAliveScheduler; @Setter private long timeOutMillis = 10 * 1000; + private String connectionKey; + private static final ConnectionCountManager connectionCountManager = ConnectionCountManager.getInstance(); public SingleConnectionDataSource() { this(false, false); @@ -189,6 +191,16 @@ private void closeConnection() { log.error("Failed to close the connection", throwable); } } + // Decrement connection count when connection is closed + if (this.connectionKey != null) { + try { + connectionCountManager.decrementConnectionCount(this.connectionKey); + log.info("Decremented connection count, key={}", this.connectionKey); + } catch (Exception e) { + log.warn("Failed to decrement connection count", e); + } + this.connectionKey = null; + } } private boolean tryLock(Lock lock) { @@ -262,13 +274,26 @@ private synchronized Connection innerCreateConnection() throws SQLException { throw new IllegalStateException("Connection is not null"); } try { + log.info("Incremented connection count, key={}", this.connectionKey); Connection connection = newConnectionFromDriver(getUsername(), getPassword()); + // Generate connection key and check/increment connection count + this.connectionKey = ConnectionCountManager.generateKey(getUrl(), getUsername()); + connectionCountManager.incrementConnectionCount(this.connectionKey); prepareConnection(connection); this.connection = connection; this.lock = new ReentrantLock(); log.info("Established shared JDBC Connection, lock={}", this.lock.hashCode()); return getConnectionProxy(this.connection, this.lock); } catch (Throwable e) { + // If connection creation fails, decrement the count + if (this.connectionKey != null) { + try { + connectionCountManager.decrementConnectionCount(this.connectionKey); + log.info("Decremented connection count due to creation failure, key={}", this.connectionKey); + } catch (Exception ex) { + log.warn("Failed to decrement connection count after creation failure", ex); + } + } publishEvent(new GetConnectionFailedEvent(Optional.ofNullable(connection))); throw new SQLException(e); } diff --git a/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/LimitMetric.java b/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/LimitMetric.java index 35422dce29..0ed328486b 100644 --- a/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/LimitMetric.java +++ b/server/odc-core/src/main/java/com/oceanbase/odc/core/shared/constant/LimitMetric.java @@ -42,6 +42,8 @@ public enum LimitMetric implements Translatable { TRANSACTION_QUERY_LIMIT, SESSION_COUNT, USER_COUNT, + USER_DATASOURCE_SESSION_COUNT, + DATASOURCE_CONNECTION_COUNT, EXPORT_OBJECT_COUNT, TABLE_NAME_LENGTH, WORKSHEET_CHANGE_COUNT, diff --git a/server/odc-core/src/main/resources/i18n/BusinessMessages.properties b/server/odc-core/src/main/resources/i18n/BusinessMessages.properties index 28d7783b1c..0ab32973b0 100644 --- a/server/odc-core/src/main/resources/i18n/BusinessMessages.properties +++ b/server/odc-core/src/main/resources/i18n/BusinessMessages.properties @@ -141,6 +141,8 @@ com.oceanbase.odc.LimitMetric.FILE_COUNT=file count com.oceanbase.odc.LimitMetric.TRANSACTION_QUERY_LIMIT=query limit com.oceanbase.odc.LimitMetric.SESSION_COUNT=session count com.oceanbase.odc.LimitMetric.USER_COUNT=user count +com.oceanbase.odc.LimitMetric.USER_DATASOURCE_SESSION_COUNT=user datasource session count +com.oceanbase.odc.LimitMetric.DATASOURCE_CONNECTION_COUNT=datasource connection count com.oceanbase.odc.LimitMetric.EXPORT_OBJECT_COUNT=export object count com.oceanbase.odc.LimitMetric.TABLE_NAME_LENGTH=table name length com.oceanbase.odc.LimitMetric.WORKSHEET_CHANGE_COUNT=The number of changed worksheets diff --git a/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_CN.properties b/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_CN.properties index 6dc4c3c2fb..50c5f9e796 100644 --- a/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_CN.properties +++ b/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_CN.properties @@ -141,6 +141,8 @@ com.oceanbase.odc.LimitMetric.FILE_COUNT=文件数量 com.oceanbase.odc.LimitMetric.TRANSACTION_QUERY_LIMIT=查询结果集大小 com.oceanbase.odc.LimitMetric.SESSION_COUNT=数据库 SESSION 数量 com.oceanbase.odc.LimitMetric.USER_COUNT=数据库连接用户数 +com.oceanbase.odc.LimitMetric.USER_DATASOURCE_SESSION_COUNT=用户对数据源的连接数 +com.oceanbase.odc.LimitMetric.DATASOURCE_CONNECTION_COUNT=数据源连接数 com.oceanbase.odc.LimitMetric.EXPORT_OBJECT_COUNT=导出对象数量 com.oceanbase.odc.LimitMetric.TABLE_NAME_LENGTH=表名长度 com.oceanbase.odc.LimitMetric.WORKSHEET_CHANGE_COUNT=变更的工作簿数量 diff --git a/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_TW.properties b/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_TW.properties index 69bb6cb787..3947e0bb91 100644 --- a/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_TW.properties +++ b/server/odc-core/src/main/resources/i18n/BusinessMessages_zh_TW.properties @@ -141,6 +141,8 @@ com.oceanbase.odc.LimitMetric.FILE_COUNT=文件數量 com.oceanbase.odc.LimitMetric.TRANSACTION_QUERY_LIMIT=查詢結果集大小 com.oceanbase.odc.LimitMetric.SESSION_COUNT=數據庫 SESSION 數量 com.oceanbase.odc.LimitMetric.USER_COUNT=數據庫連接用戶數 +com.oceanbase.odc.LimitMetric.USER_DATASOURCE_SESSION_COUNT=用戶對數據源的連接數 +com.oceanbase.odc.LimitMetric.DATASOURCE_CONNECTION_COUNT=數據源連接數 com.oceanbase.odc.LimitMetric.EXPORT_OBJECT_COUNT=導出對象數量 com.oceanbase.odc.LimitMetric.TABLE_NAME_LENGTH=表名長度 com.oceanbase.odc.LimitMetric.WORKSHEET_CHANGE_COUNT=变更的工作簿数量 diff --git a/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/ConnectSessionController.java b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/ConnectSessionController.java index 9ce0307e8c..a2719a1906 100644 --- a/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/ConnectSessionController.java +++ b/server/odc-server/src/main/java/com/oceanbase/odc/server/web/controller/v2/ConnectSessionController.java @@ -259,4 +259,21 @@ public ListResponse preview(@PathVariable String sessi SidUtils.getSessionId(sessionId), req.getTableConfigs(), req.isOnlyForPartitionName())); } + /** + * 删除并释放某个用户某个数据源下的所有数据库连接,当 dataSourceId 为空时关闭该用户的所有连接 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID,可为空 + * @return 关闭的会话数量 + */ + @ApiOperation(value = "closeUserDatasourceSessions", notes = "删除并释放指定用户指定数据源下的数据库连接;若未指定数据源则关闭该用户的全部连接") + @RequestMapping(value = {"/users/{userId:[\\d]+}/datasources/sessions", + "/users/{userId:[\\d]+}/datasources/{dataSourceId:[\\d]+}/sessions"}, method = RequestMethod.DELETE) + public SuccessResponse closeUserDatasourceSessions( + @PathVariable Long userId, + @PathVariable(value = "dataSourceId", required = false) Long dataSourceId) { + int closedCount = sessionService.closeUserDatasourceSessions(userId, dataSourceId); + return Responses.success(closedCount); + } + } diff --git a/server/odc-server/src/main/resources/data.sql b/server/odc-server/src/main/resources/data.sql index 60dbef2854..9631947776 100644 --- a/server/odc-server/src/main/resources/data.sql +++ b/server/odc-server/src/main/resources/data.sql @@ -444,6 +444,8 @@ INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('o -- 连接管理 INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('odc.connect.temp.expire-after-inactive-interval-seconds', '86400', '临时连接不活跃之后的保留周期,单位:秒,默认值 86400') ON DUPLICATE KEY UPDATE `id`=`id`; +INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('odc.connect.datasource.max-connection-count', + '-1', '单个数据源(url+username)的最大连接数,默认 -1,表示不限制') ON DUPLICATE KEY UPDATE `id`=`id`; INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('odc.connect.temp.expire-check-interval-millis', '600000', '临时连接配置清理检查周期,单位:毫秒,默认值 600000 表示 10 分钟') ON DUPLICATE KEY UPDATE `id`=`id`; update config_system_configuration set `value`='120000',`description`= @@ -523,6 +525,9 @@ INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('o INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('odc.security.file.upload.safe-suffix-list', '*', '允许上传的文件名扩展名,默认 *,表示允许所有文件扩展名') ON DUPLICATE KEY UPDATE `id`=`id`; +INSERT INTO config_system_configuration(`key`, `value`, `description`) VALUES('odc.session.sql-execute.user-datasource-max-count', + '-1', '用户对单个数据源的最大连接数,默认 -1,表示不限制') ON DUPLICATE KEY UPDATE `id`=`id`; + -- -- cloud object-storage -- diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/connection/model/ConnectProperties.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/connection/model/ConnectProperties.java index 33ad189c30..7397209d60 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/connection/model/ConnectProperties.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/connection/model/ConnectProperties.java @@ -70,6 +70,12 @@ public class ConnectProperties { @Value("${odc.connect.persistent-connection-operations:create,delete,update,read}") private Set persistentConnectionOperations = new HashSet<>(); + /** + * 单个数据源(url+username)的最大连接数,默认 -1,表示不限制 + */ + @Value("${odc.connect.datasource.max-connection-count:-1}") + private long datasourceMaxConnectionCount = -1; + public Set getConnectionSupportedOperations(boolean temp, Set permittedActions) { // temp connection can only be private connection, skip permittedActions heere if (temp) { diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectSessionService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectSessionService.java index eae30d44f2..5802ca1400 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectSessionService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/ConnectSessionService.java @@ -181,6 +181,9 @@ public void init() { this.connectionSessionManager.enableAsyncRefreshSessionManager(); this.connectionSessionManager.addSessionValidator( new SessionValidatorPredicate(sessionProperties.getTimeoutMins(), TimeUnit.MINUTES)); + // Initialize connection count manager + com.oceanbase.odc.core.datasource.ConnectionCountManager.getInstance() + .setMaxConnectionCount(connectProperties.getDatasourceMaxConnectionCount()); log.info("Initialization of the connection session module is complete"); } @@ -283,7 +286,7 @@ public ConnectionSession create(@NotNull CreateSessionReq req) { schemaName = null; dataSourceId = req.getDsId(); } - preCheckSessionLimit(); + preCheckSessionLimit(dataSourceId); ConnectionConfig connection = connectionService.getForConnectionSkipPermissionCheck(dataSourceId); cloudMetadataClient.checkPermission(OBTenant.of(connection.getClusterName(), connection.getTenantName()), connection.getInstanceType(), false, CloudPermissionAction.READONLY); @@ -415,6 +418,68 @@ public Collection listAllSessions() { return this.connectionSessionManager.retrieveAllSessions(); } + /** + * 关闭并释放某个用户某个数据源下的所有数据库连接,当 dataSourceId 为空时关闭该用户的全部连接 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID,可为空 + * @return 关闭的会话数量 + */ + @SkipAuthorize("check permission internally") + public int closeUserDatasourceSessions(@NotNull Long userId, Long dataSourceId) { + PreConditions.notNull(userId, "userId"); + + Collection allSessions = listAllSessions(); + int closedCount = 0; + Set affectedDataSourceIds = new HashSet<>(); + + for (ConnectionSession session : allSessions) { + try { + Long sessionUserId = ConnectionSessionUtil.getUserId(session); + if (sessionUserId == null || !sessionUserId.equals(userId)) { + continue; + } + + Object connectionConfigObj = ConnectionSessionUtil.getConnectionConfig(session); + if (!(connectionConfigObj instanceof ConnectionConfig)) { + continue; + } + ConnectionConfig connectionConfig = (ConnectionConfig) connectionConfigObj; + Long sessionDataSourceId = connectionConfig.id(); + if (dataSourceId != null && !Objects.equals(sessionDataSourceId, dataSourceId)) { + continue; + } + + try { + session.expire(); + closedCount++; + log.info("Closed session for user {} and datasource {}, sessionId={}", + userId, sessionDataSourceId, session.getId()); + if (sessionDataSourceId != null) { + affectedDataSourceIds.add(sessionDataSourceId); + } + } catch (Exception e) { + log.warn("Failed to close session, sessionId={}, userId={}, dataSourceId={}", + session.getId(), userId, sessionDataSourceId, e); + } + } catch (Exception e) { + log.warn("Error processing session, sessionId={}", session.getId(), e); + } + } + + if (dataSourceId != null) { + limitService.clearUserDatasourceSessionCount(userId.toString(), dataSourceId); + log.info("Closed {} sessions and cleared session count limit for user {} and datasource {}", + closedCount, userId, dataSourceId); + } else { + affectedDataSourceIds.forEach(id -> limitService.clearUserDatasourceSessionCount(userId.toString(), id)); + log.info("Closed {} sessions for user {} across datasources {}", closedCount, userId, + affectedDataSourceIds); + } + + return closedCount; + } + /** * Common upload method * @@ -523,7 +588,7 @@ private ConnectionSession getWithCreatorCheck(@NonNull String sessionId) { return session; } - private void preCheckSessionLimit() { + private void preCheckSessionLimit(Long dataSourceId) { long maxCount = sessionProperties.getUserMaxCount(); if (labProperties.isSessionLimitEnabled()) { if (!limitService.allowCreateSession(authenticationFacade.currentUserIdStr())) { @@ -544,6 +609,20 @@ private void preCheckSessionLimit() { throw ex; } } + // 检查用户对数据源的连接数限制 + long userDatasourceMaxCount = sessionProperties.getUserDatasourceMaxCount(); + if (userDatasourceMaxCount > 0 && dataSourceId != null) { + String userId = authenticationFacade.currentUserIdStr(); + try { + int currentCount = limitService.incrementUserDatasourceSessionCount(userId, dataSourceId); + PreConditions.lessThanOrEqualTo("userDatasourceSessionCount", + LimitMetric.USER_DATASOURCE_SESSION_COUNT, + currentCount, userDatasourceMaxCount); + } catch (OverLimitException ex) { + limitService.decrementUserDatasourceSessionCount(userId, dataSourceId); + throw ex; + } + } } private void checkDBPermission(Database database) { diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitListener.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitListener.java index 18e4e1a660..b94b8207e9 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitListener.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitListener.java @@ -18,6 +18,7 @@ import com.oceanbase.odc.core.session.ConnectionSession; import com.oceanbase.odc.core.session.ConnectionSessionUtil; import com.oceanbase.odc.core.session.DefaultSessionEventListener; +import com.oceanbase.odc.service.connection.model.ConnectionConfig; import lombok.NonNull; @@ -36,6 +37,16 @@ public void onExpire(ConnectionSession session) { return; } this.limitService.decrementSessionCount(userId + ""); + + // 减少用户对数据源的会话计数 + Object connectionConfigObj = ConnectionSessionUtil.getConnectionConfig(session); + if (connectionConfigObj instanceof ConnectionConfig) { + ConnectionConfig connectionConfig = (ConnectionConfig) connectionConfigObj; + Long dataSourceId = connectionConfig.id(); + if (dataSourceId != null) { + this.limitService.decrementUserDatasourceSessionCount(userId + "", dataSourceId); + } + } } } diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitService.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitService.java index bff5b6a58d..879254422a 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitService.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionLimitService.java @@ -62,6 +62,11 @@ public class SessionLimitService { private Map userId2SessionCountMap = new ConcurrentHashMap<>(); + /** + * 记录 userId + dataSourceId 的会话数 Key格式: userId:dataSourceId + */ + private Map userId2DataSourceId2SessionCountMap = new ConcurrentHashMap<>(); + public boolean isResourceAvailable() { long userMaxCount = sessionProperties.getUserMaxCount(); return userMaxCount <= 0 || allowCreateSessionUserMap.size() < userMaxCount; @@ -171,6 +176,69 @@ public void decrementSessionCount(String userId) { }); } + /** + * 增加用户对指定数据源的会话计数 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID + * @return 增加后的会话数 + */ + public int incrementUserDatasourceSessionCount(String userId, Long dataSourceId) { + PreConditions.notNull(userId, "userId"); + PreConditions.notNull(dataSourceId, "dataSourceId"); + String key = userId + ":" + dataSourceId; + return this.userId2DataSourceId2SessionCountMap + .computeIfAbsent(key, t -> new AtomicInteger(0)) + .incrementAndGet(); + } + + /** + * 减少用户对指定数据源的会话计数 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID + */ + public void decrementUserDatasourceSessionCount(String userId, Long dataSourceId) { + PreConditions.notNull(userId, "userId"); + PreConditions.notNull(dataSourceId, "dataSourceId"); + String key = userId + ":" + dataSourceId; + userId2DataSourceId2SessionCountMap.computeIfPresent(key, (k, sessionCount) -> { + if (sessionCount.decrementAndGet() < 0) { + log.warn("user datasource session count is less than 0, userId={}, dataSourceId={}", + userId, dataSourceId); + throw new UnexpectedException("user datasource session count is less than 0"); + } + return sessionCount; + }); + } + + /** + * 获取用户对指定数据源的当前会话数 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID + * @return 当前会话数 + */ + public int getUserDatasourceSessionCount(String userId, Long dataSourceId) { + String key = userId + ":" + dataSourceId; + AtomicInteger count = userId2DataSourceId2SessionCountMap.get(key); + return count == null ? 0 : count.get(); + } + + /** + * 清空用户对指定数据源的连接数限制 + * + * @param userId 用户ID + * @param dataSourceId 数据源ID + */ + public void clearUserDatasourceSessionCount(String userId, Long dataSourceId) { + PreConditions.notNull(userId, "userId"); + PreConditions.notNull(dataSourceId, "dataSourceId"); + String key = userId + ":" + dataSourceId; + userId2DataSourceId2SessionCountMap.remove(key); + log.info("Cleared user datasource session count, userId={}, dataSourceId={}", userId, dataSourceId); + } + /** * 用户入队列,记录index * diff --git a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionProperties.java b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionProperties.java index 5d7c019e5e..15262512a7 100644 --- a/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionProperties.java +++ b/server/odc-service/src/main/java/com/oceanbase/odc/service/session/SessionProperties.java @@ -62,6 +62,12 @@ public class SessionProperties { @Value("${odc.session.sql-execute.user-max-count:-1}") private long userMaxCount = -1; + /** + * 用户对单个数据源的最大连接数,默认 -1,表示不限制 + */ + @Value("${odc.session.sql-execute.user-datasource-max-count:-1}") + private long userDatasourceMaxCount = -1; + /** * 单次执行的最大 SQL 语句长度,默认值 -1, <=0 表示不限制 */