Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Dec 1, 2023
1 parent d2980b7 commit 1726063
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.TaskResources;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
* Built-in toolkit for managing native memory allocations. To use the facility, one should import
Expand All @@ -44,12 +46,12 @@ private CHNativeMemoryAllocators() {}
private static CHNativeMemoryAllocatorManager createNativeMemoryAllocatorManager(
String name,
TaskMemoryManager taskMemoryManager,
Spiller spiller,
List<Spiller> spillers,
SimpleMemoryUsageRecorder usage) {

CHManagedCHReservationListener rl =
new CHManagedCHReservationListener(
MemoryTargets.newConsumer(taskMemoryManager, name, spiller, Collections.emptyMap()),
MemoryTargets.newConsumer(taskMemoryManager, name, spillers, Collections.emptyMap()),
usage);
return new CHNativeMemoryAllocatorManagerImpl(CHNativeMemoryAllocator.createListenable(rl));
}
Expand All @@ -65,7 +67,7 @@ public static CHNativeMemoryAllocator contextInstance() {
createNativeMemoryAllocatorManager(
"ContextInstance",
TaskResources.getLocalTaskContext().taskMemoryManager(),
Spiller.NO_OP,
Collections.emptyList(),
TaskResources.getSharedUsage());
TaskResources.addResource(id, manager);
}
Expand All @@ -76,7 +78,7 @@ public static CHNativeMemoryAllocator contextInstanceForUT() {
return CHNativeMemoryAllocator.getDefaultForUT();
}

public static CHNativeMemoryAllocator createSpillable(String name, Spiller spiller) {
public static CHNativeMemoryAllocator createSpillable(String name, Spiller... spillers) {
if (!TaskResources.inSparkTask()) {
throw new IllegalStateException("spiller must be used in a Spark task");
}
Expand All @@ -85,7 +87,7 @@ public static CHNativeMemoryAllocator createSpillable(String name, Spiller spill
createNativeMemoryAllocatorManager(
name,
TaskResources.getLocalTaskContext().taskMemoryManager(),
spiller,
Arrays.asList(spillers),
TaskResources.getSharedUsage());
TaskResources.addAnonymousResource(manager);
// force add memory consumer to task memory manager, will release by inactivate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.shuffle
import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.clickhouse.CHBackendSettings
import io.glutenproject.memory.alloc.CHNativeMemoryAllocators
import io.glutenproject.memory.memtarget.{MemoryTarget, Spiller}
import io.glutenproject.memory.memtarget.{MemoryTarget, Spiller, Spillers}
import io.glutenproject.vectorized._

import org.apache.spark.SparkEnv
Expand All @@ -29,6 +29,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkDirectoryUtil, Utils}

import java.io.IOException
import java.util
import java.util.{Locale, UUID}

class CHColumnarShuffleWriter[K, V](
Expand Down Expand Up @@ -121,6 +122,8 @@ class CHColumnarShuffleWriter[K, V](
logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}

override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.glutenproject.memory.alloc.CHNativeMemoryAllocator;
import io.glutenproject.memory.alloc.CHNativeMemoryAllocatorManagerImpl;
import io.glutenproject.memory.memtarget.MemoryTargets;
import io.glutenproject.memory.memtarget.Spiller;

import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.package$;
Expand Down Expand Up @@ -53,7 +52,7 @@ public void initMemoryManager() {
listener =
new CHManagedCHReservationListener(
MemoryTargets.newConsumer(
taskMemoryManager, "test", Spiller.NO_OP, Collections.emptyMap()),
taskMemoryManager, "test", Collections.emptyList(), Collections.emptyMap()),
new SimpleMemoryUsageRecorder());

manager = new CHNativeMemoryAllocatorManagerImpl(new CHNativeMemoryAllocator(-1L, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import io.glutenproject.backendsapi.clickhouse.CHBackendSettings
import io.glutenproject.memory.alloc.CHNativeMemoryAllocators
import io.glutenproject.memory.memtarget.MemoryTarget
import io.glutenproject.memory.memtarget.Spiller
import io.glutenproject.memory.memtarget.Spillers
import io.glutenproject.vectorized._

import org.apache.spark._
Expand All @@ -32,6 +33,7 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf

import java.io.IOException
import java.util
import java.util.Locale

class CHCelebornHashBasedColumnarShuffleWriter[K, V](
Expand Down Expand Up @@ -89,6 +91,8 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V](
logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}

override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.columnarbatch.ColumnarBatches
import io.glutenproject.memory.memtarget.MemoryTarget
import io.glutenproject.memory.memtarget.Spiller
import io.glutenproject.memory.memtarget.Spillers
import io.glutenproject.memory.nmm.NativeMemoryManagers
import io.glutenproject.vectorized._

Expand All @@ -34,6 +35,7 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf

import java.io.IOException
import java.util

class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
handle: CelebornShuffleHandle[K, V, V],
Expand Down Expand Up @@ -99,6 +101,9 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
pushed
}

override def applicablePhases(): util.Set[Spiller.Phase] =
Spillers.PHASE_SET_SPILL_ONLY
}
)
.getNativeInstanceHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import org.apache.spark.memory.TaskMemoryManager;

import java.util.List;
import java.util.Map;

public final class MemoryTargets {
Expand All @@ -45,14 +46,14 @@ public static MemoryTarget overAcquire(
public static MemoryTarget newConsumer(
TaskMemoryManager tmm,
String name,
Spiller spiller,
List<Spiller> spillers,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory;
if (GlutenConfig.getConf().memoryIsolation()) {
factory = TreeMemoryConsumers.isolated();
} else {
factory = TreeMemoryConsumers.shared();
}
return factory.newConsumer(tmm, name, spiller, virtualChildren);
return factory.newConsumer(tmm, name, spillers, virtualChildren);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
*/
package io.glutenproject.memory.memtarget;

import java.util.Set;

public interface Spiller {
Spiller NO_OP =
new Spiller() {
@Override
public long spill(MemoryTarget self, long size) {
return 0L;
}
};

long spill(MemoryTarget self, long size);

Set<Phase> applicablePhases();

enum Phase {
SHRINK,
SPILL
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@
*/
package io.glutenproject.memory.memtarget;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

public final class Spillers {
private Spillers() {
// enclose factory ctor
}

// calls the spillers one by one within the order
public static Spiller withOrder(Spiller... spillers) {
return (self, size) -> {
long remaining = size;
for (int i = 0; i < spillers.length && remaining > 0; i++) {
Spiller spiller = spillers[i];
remaining -= spiller.spill(self, remaining);
}
return size - remaining;
};
}
public static final Set<Spiller.Phase> PHASE_SET_ALL =
Collections.unmodifiableSet(
new HashSet<>(Arrays.asList(Spiller.Phase.SHRINK, Spiller.Phase.SPILL)));

public static final Set<Spiller.Phase> PHASE_SET_SHRINK_ONLY =
Collections.singleton(Spiller.Phase.SHRINK);

public static final Set<Spiller.Phase> PHASE_SET_SPILL_ONLY =
Collections.singleton(Spiller.Phase.SPILL);

public static Spiller withMinSpillSize(Spiller spiller, long minSize) {
return new WithMinSpillSize(spiller, minSize);
Expand All @@ -53,5 +56,10 @@ private WithMinSpillSize(Spiller delegated, long minSize) {
public long spill(MemoryTarget self, long size) {
return delegated.spill(self, Math.max(size, minSize));
}

@Override
public Set<Phase> applicablePhases() {
return delegated.applicablePhases();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.glutenproject.memory.MemoryUsageStatsBuilder;
import io.glutenproject.memory.memtarget.spark.TreeMemoryConsumer;

import java.util.List;
import java.util.Map;

/** An abstract for both {@link TreeMemoryConsumer} and it's non-consumer children nodes. */
Expand All @@ -28,12 +29,12 @@ public interface TreeMemoryTarget extends MemoryTarget, KnownNameAndStats {
TreeMemoryTarget newChild(
String name,
long capacity,
Spiller spiller,
List<Spiller> spillers,
Map<String, MemoryUsageStatsBuilder> virtualChildren);

Map<String, TreeMemoryTarget> children();

TreeMemoryTarget parent();

Spiller getNodeSpiller();
List<Spiller> getNodeSpillers();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@
import com.google.common.base.Preconditions;
import org.apache.spark.util.Utils;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class TreeMemoryTargets {
public static final List<Spiller.Phase> SPILL_PHASES =
Arrays.asList(Spiller.Phase.SHRINK, Spiller.Phase.SPILL);

private TreeMemoryTargets() {
// enclose factory ctor
}
Expand All @@ -39,12 +45,26 @@ public static TreeMemoryTarget newChild(
TreeMemoryTarget parent,
String name,
long capacity,
Spiller spiller,
List<Spiller> spillers,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
return new Node(parent, name, capacity, spiller, virtualChildren);
return new Node(parent, name, capacity, spillers, virtualChildren);
}

public static long spillTree(TreeMemoryTarget node, final long bytes) {
long remainingBytes = bytes;
for (Spiller.Phase phase : SPILL_PHASES) {
// First shrink, then if no good, spill.
if (remainingBytes <= 0) {
break;
}
remainingBytes -=
spillTree(node, remainingBytes, spiller -> spiller.applicablePhases().contains(phase));
}
return bytes - remainingBytes;
}

private static long spillTree(
TreeMemoryTarget node, final long bytes, Predicate<Spiller> spillerFilter) {
// sort children by used bytes, descending
Queue<TreeMemoryTarget> q =
new PriorityQueue<>(
Expand All @@ -63,8 +83,13 @@ public static long spillTree(TreeMemoryTarget node, final long bytes) {

if (remainingBytes > 0) {
// if still doesn't fit, spill self
final long spilled = node.getNodeSpiller().spill(node, remainingBytes);
remainingBytes -= spilled;
final List<Spiller> applicableSpillers =
node.getNodeSpillers().stream().filter(spillerFilter).collect(Collectors.toList());
for (int i = 0; i < applicableSpillers.size() && remainingBytes > 0; i++) {
final Spiller spiller = applicableSpillers.get(i);
long spilled = spiller.spill(node, remainingBytes);
remainingBytes -= spilled;
}
}

return bytes - remainingBytes;
Expand All @@ -76,15 +101,15 @@ public static class Node implements TreeMemoryTarget, KnownNameAndStats {
private final TreeMemoryTarget parent;
private final String name;
private final long capacity;
private final Spiller spiller;
private final List<Spiller> spillers;
private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
private final SimpleMemoryUsageRecorder selfRecorder = new SimpleMemoryUsageRecorder();

private Node(
TreeMemoryTarget parent,
String name,
long capacity,
Spiller spiller,
List<Spiller> spillers,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
this.parent = parent;
this.capacity = capacity;
Expand All @@ -94,7 +119,7 @@ private Node(
} else {
this.name = String.format("%s, %s", uniqueName, Utils.bytesToString(capacity));
}
this.spiller = spiller;
this.spillers = Collections.unmodifiableList(spillers);
this.virtualChildren = virtualChildren;
}

Expand All @@ -114,8 +139,8 @@ private long borrow0(long size) {
return granted;
}

public Spiller getNodeSpiller() {
return spiller;
public List<Spiller> getNodeSpillers() {
return spillers;
}

private boolean ensureFreeCapacity(long bytesNeeded) {
Expand Down Expand Up @@ -183,9 +208,9 @@ public MemoryUsageStats stats() {
public TreeMemoryTarget newChild(
String name,
long capacity,
Spiller spiller,
List<Spiller> spillers,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child = new Node(this, name, capacity, spiller, virtualChildren);
final Node child = new Node(this, name, capacity, spillers, virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " + child.name());
}
Expand Down
Loading

0 comments on commit 1726063

Please sign in to comment.