Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support user's app quota level limit #311

Merged
merged 13 commits into from
Nov 22, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ public Thread newThread(Runnable r) {
long heartbeatInterval = conf.getLong(RssMRConfig.RSS_HEARTBEAT_INTERVAL,
RssMRConfig.RSS_HEARTBEAT_INTERVAL_DEFAULT_VALUE);
long heartbeatTimeout = conf.getLong(RssMRConfig.RSS_HEARTBEAT_TIMEOUT, heartbeatInterval / 2);
client.registerApplicationInfo(appId, heartbeatTimeout, "user");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this if mr don't need to register application info?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this because applicationManager remove expired app need user. When our Spark does not use the AppQuota checker, it also needs register application info, so it will not have much impact here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this isn't compatible feature. If we use old client to connect new service, something wrong will happen.

Copy link
Contributor Author

@smallzhongfeng smallzhongfeng Nov 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because the old client does not have user information or method of registerApplicationInfo, app can only be updated at refreshApp, which I am compatible with.

scheduledExecutorService.scheduleAtFixedRate(
() -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ public void sendAppHeartbeat(String appId, long timeoutMs) {

}

@Override
public void registerApplicationInfo(String appId, long timeoutMs, String user) {

}

@Override
public void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,11 @@ public void sendAppHeartbeat(String appId, long timeoutMs) {

}

@Override
public void registerApplicationInfo(String appId, long timeoutMs, String user) {

}

@Override
public void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
Expand All @@ -48,6 +49,8 @@ public class DelegationRssShuffleManager implements ShuffleManager {
private final List<CoordinatorClient> coordinatorClients;
private final int accessTimeoutMs;
private final SparkConf sparkConf;
private String user;
private String uuid;

public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws Exception {
this.sparkConf = sparkConf;
Expand All @@ -67,10 +70,20 @@ public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws

private ShuffleManager createShuffleManagerInDriver() throws RssException {
ShuffleManager shuffleManager;

user = "user";
try {
user = UserGroupInformation.getCurrentUser().getShortUserName();
} catch (Exception e) {
LOG.error("Error on getting user from ugi." + e);
}
boolean canAccess = tryAccessCluster();
if (uuid == null || "".equals(uuid)) {
uuid = String.valueOf(System.currentTimeMillis());
}
if (canAccess) {
try {
sparkConf.set("spark.rss.quota.user", user);
sparkConf.set("spark.rss.quota.uuid", uuid);
shuffleManager = new RssShuffleManager(sparkConf, true);
sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
sparkConf.set("spark.shuffle.manager", RssShuffleManager.class.getCanonicalName());
Expand Down Expand Up @@ -113,9 +126,10 @@ private boolean tryAccessCluster() {
try {
canAccess = RetryUtils.retry(() -> {
RssAccessClusterResponse response = coordinatorClient.accessCluster(new RssAccessClusterRequest(
accessId, assignmentTags, accessTimeoutMs, extraProperties));
accessId, assignmentTags, accessTimeoutMs, extraProperties, user));
if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
LOG.warn("Success to access cluster {} using {}", coordinatorClient.getDesc(), accessId);
uuid = response.getUuid();
return true;
} else if (response.getStatusCode() == ResponseStatusCode.ACCESS_DENIED) {
throw new RssException("Request to access cluster " + coordinatorClient.getDesc() + " is denied using "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public class RssShuffleManager implements ShuffleManager {
private final int dataCommitPoolSize;
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
private final String user;
private final String uuid;
private ThreadPoolExecutor threadPoolExecutor;
private EventLoop eventLoop = new EventLoop<AddBlockEvent>("ShuffleDataQueue") {

Expand Down Expand Up @@ -142,7 +144,8 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
throw new IllegalArgumentException("Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be false.");
}
this.sparkConf = sparkConf;

this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
// set & check replica config
this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
Expand Down Expand Up @@ -204,12 +207,12 @@ public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
@Override
public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, ShuffleDependency<K, V, C> dependency) {
// If yarn enable retry ApplicationMaster, appId will be not unique and shuffle data will be incorrect,
// appId + timestamp can avoid such problem,
// appId + uuid can avoid such problem,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need uuid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't get the appId when we try Access, because the appId is generated after the RssManager is created. In order to support push down, we maintain the uuid as a substitute for the appId, and replace the uuid with the appId after the app heartbeat is reported to the coordinator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can put the appId to the AcessInfo when we try to access coordinator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to generate the uuid on the driver?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault. We ignore that I can't get the appId in the construct method. Let me think twice.

// can't get appId in construct because SparkEnv is not created yet,
// appId will be initialized only once in this method which
// will be called many times depend on how many shuffle stage
if ("".equals(appId)) {
appId = SparkEnv.get().conf().getAppId() + "_" + System.currentTimeMillis();
appId = SparkEnv.get().conf().getAppId() + "_" + uuid;
LOG.info("Generate application id used in rss: " + appId);
}

Expand Down Expand Up @@ -260,6 +263,7 @@ public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, int numMaps, Shuff
}

private void startHeartbeat() {
shuffleWriteClient.registerApplicationInfo(appId, heartbeatTimeout, user);
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && !heartbeatStarted) {
heartBeatScheduledExecutorService.scheduleAtFixedRate(
() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
Expand All @@ -48,6 +49,8 @@ public class DelegationRssShuffleManager implements ShuffleManager {
private final List<CoordinatorClient> coordinatorClients;
private final int accessTimeoutMs;
private final SparkConf sparkConf;
private String user;
private String uuid;

public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws Exception {
this.sparkConf = sparkConf;
Expand All @@ -67,10 +70,20 @@ public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws

private ShuffleManager createShuffleManagerInDriver() throws RssException {
ShuffleManager shuffleManager;

user = "user";
try {
user = UserGroupInformation.getCurrentUser().getShortUserName();
} catch (Exception e) {
LOG.error("Error on getting user from ugi." + e);
}
boolean canAccess = tryAccessCluster();
if (uuid == null || "".equals(uuid)) {
uuid = String.valueOf(System.currentTimeMillis());
}
if (canAccess) {
try {
sparkConf.set("spark.rss.quota.user", user);
sparkConf.set("spark.rss.quota.uuid", uuid);
shuffleManager = new RssShuffleManager(sparkConf, true);
sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
sparkConf.set("spark.shuffle.manager", RssShuffleManager.class.getCanonicalName());
Expand Down Expand Up @@ -113,9 +126,10 @@ private boolean tryAccessCluster() {
try {
canAccess = RetryUtils.retry(() -> {
RssAccessClusterResponse response = coordinatorClient.accessCluster(new RssAccessClusterRequest(
accessId, assignmentTags, accessTimeoutMs, extraProperties));
accessId, assignmentTags, accessTimeoutMs, extraProperties, user));
if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
LOG.warn("Success to access cluster {} using {}", coordinatorClient.getDesc(), accessId);
uuid = response.getUuid();
return true;
} else if (response.getStatusCode() == ResponseStatusCode.ACCESS_DENIED) {
throw new RssException("Request to access cluster " + coordinatorClient.getDesc() + " is denied using "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public class RssShuffleManager implements ShuffleManager {
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
private final ShuffleDataDistributionType dataDistributionType;
private String user;
private String uuid;
private final EventLoop eventLoop;
private final EventLoop defaultEventLoop = new EventLoop<AddBlockEvent>("ShuffleDataQueue") {

Expand Down Expand Up @@ -144,7 +146,8 @@ private synchronized void putBlockId(

public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;

this.user = sparkConf.get("spark.rss.quota.user", "user");
this.uuid = sparkConf.get("spark.rss.quota.uuid", Long.toString(System.currentTimeMillis()));
// set & check replica config
this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
this.dataReplicaWrite = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
Expand Down Expand Up @@ -266,7 +269,7 @@ public RssShuffleManager(SparkConf conf, boolean isDriver) {
public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<K, V, C> dependency) {

if (id.get() == null) {
id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + System.currentTimeMillis());
id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid);
}
LOG.info("Generate application id used in rss: " + id.get());

Expand Down Expand Up @@ -662,11 +665,13 @@ public SparkConf getSparkConf() {
}

private synchronized void startHeartbeat() {
shuffleWriteClient.registerApplicationInfo(id.get(), heartbeatTimeout, user);
if (!heartbeatStarted) {
heartBeatScheduledExecutorService.scheduleAtFixedRate(
() -> {
try {
shuffleWriteClient.sendAppHeartbeat(id.get(), heartbeatTimeout);
String appId = id.get();
shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
LOG.info("Finish send heartbeat to coordinator and servers");
} catch (Exception e) {
LOG.warn("Fail to send heartbeat to coordinator and servers", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public interface ShuffleWriteClient {

void sendAppHeartbeat(String appId, long timeoutMs);

void registerApplicationInfo(String appId, long timeoutMs, String user);

void registerShuffle(
ShuffleServerInfo shuffleServerInfo,
String appId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.client.request.RssAppHeartBeatRequest;
import org.apache.uniffle.client.request.RssApplicationInfoRequest;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.request.RssFetchRemoteStorageRequest;
import org.apache.uniffle.client.request.RssFinishShuffleRequest;
Expand All @@ -59,6 +60,7 @@
import org.apache.uniffle.client.response.ClientResponse;
import org.apache.uniffle.client.response.ResponseStatusCode;
import org.apache.uniffle.client.response.RssAppHeartBeatResponse;
import org.apache.uniffle.client.response.RssApplicationInfoResponse;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.client.response.RssFetchRemoteStorageResponse;
import org.apache.uniffle.client.response.RssFinishShuffleResponse;
Expand Down Expand Up @@ -549,6 +551,37 @@ public Roaring64NavigableMap getShuffleResultForMultiPart(String clientType,
return blockIdBitmap;
}

@Override
public void registerApplicationInfo(String appId, long timeoutMs, String user) {
RssApplicationInfoRequest request = new RssApplicationInfoRequest(appId, timeoutMs, user);
List<Callable<Void>> callableList = Lists.newArrayList();
coordinatorClients.forEach(coordinatorClient -> {
callableList.add(() -> {
try {
RssApplicationInfoResponse response = coordinatorClient.sendApplicationInfo(request);
if (response.getStatusCode() != ResponseStatusCode.SUCCESS) {
LOG.error("Failed to send applicationInfo to " + coordinatorClient.getDesc());
} else {
LOG.info("Successfully send applicationInfo to " + coordinatorClient.getDesc());
}
} catch (Exception e) {
LOG.warn("Error happened when send applicationInfo to " + coordinatorClient.getDesc(), e);
}
return null;
});
});
try {
List<Future<Void>> futures = heartBeatExecutorService.invokeAll(callableList, timeoutMs, TimeUnit.MILLISECONDS);
for (Future<Void> future : futures) {
if (!future.isDone()) {
future.cancel(true);
}
}
} catch (InterruptedException ie) {
LOG.warn("register application is interrupted", ie);
}
}

@Override
public void sendAppHeartbeat(String appId, long timeoutMs) {
RssAppHeartBeatRequest request = new RssAppHeartBeatRequest(appId, timeoutMs);
Expand All @@ -571,7 +604,7 @@ public void sendAppHeartbeat(String appId, long timeoutMs) {
}
);

coordinatorClients.stream().forEach(coordinatorClient -> {
coordinatorClients.forEach(coordinatorClient -> {
callableList.add(() -> {
try {
RssAppHeartBeatResponse response = coordinatorClient.sendAppHeartBeat(request);
Expand Down Expand Up @@ -683,7 +716,6 @@ public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo shuffleServe
return ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, shuffleServerInfo);
}

@VisibleForTesting
void addShuffleServer(String appId, int shuffleId, ShuffleServerInfo serverInfo) {
Map<Integer, Set<ShuffleServerInfo>> appServerMap = shuffleServerInfoMap.get(appId);
if (appServerMap == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ public static <T> List<T> loadExtensions(

List<T> extensions = Lists.newArrayList();
for (String name : classes) {
name = name.trim();
try {
Class<?> klass = Class.forName(name);
if (!extClass.isAssignableFrom(klass)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ public class AccessCheckResult {

private final boolean success;
private final String msg;
private final String uuid;

public AccessCheckResult(boolean success, String msg, String uuid) {
this.success = success;
this.msg = msg;
this.uuid = uuid;
}

public AccessCheckResult(boolean success, String msg) {
this.success = success;
this.msg = msg;
this.uuid = "";
}

public boolean isSuccess() {
Expand All @@ -34,4 +42,8 @@ public boolean isSuccess() {
public String getMsg() {
return msg;
}

public String getUuid() {
return uuid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,25 @@
import java.util.Map;
import java.util.Set;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;

public class AccessInfo {
private final String accessId;
private final Set<String> tags;
private final Map<String, String> extraProperties;
private final String user;

public AccessInfo(String accessId, Set<String> tags, Map<String, String> extraProperties) {
public AccessInfo(String accessId, Set<String> tags, Map<String, String> extraProperties, String user) {
this.accessId = accessId;
this.tags = tags;
this.extraProperties = extraProperties == null ? Collections.emptyMap() : extraProperties;
this.user = user;
}

@VisibleForTesting
public AccessInfo(String accessId) {
this(accessId, Sets.newHashSet(), Collections.emptyMap());
this(accessId, Sets.newHashSet(), Collections.emptyMap(), "");
}

public String getAccessId() {
Expand All @@ -50,10 +54,15 @@ public Map<String, String> getExtraProperties() {
return extraProperties;
}

public String getUser() {
return user;
}

@Override
public String toString() {
return "AccessInfo{"
+ "accessId='" + accessId + '\''
+ ", user= " + user
+ ", tags=" + tags
+ ", extraProperties=" + extraProperties
+ '}';
Expand Down
Loading