Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.CheckpointStreamFactory.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.RunnableWithException;

Expand All @@ -40,10 +42,12 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;

import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.UUID.randomUUID;
import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
Expand All @@ -58,8 +62,8 @@ class ChannelStateCheckpointWriter {
private final DataOutputStream dataStream;
private final CheckpointStateOutputStream checkpointStream;
private final ChannelStateWriteResult result;
private final Map<InputChannelInfo, List<Long>> inputChannelOffsets = new HashMap<>();
private final Map<ResultSubpartitionInfo, List<Long>> resultSubpartitionOffsets = new HashMap<>();
private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>();
private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets = new HashMap<>();
private final ChannelStateSerializer serializer;
private final long checkpointId;
private boolean allInputsReceived = false;
Expand Down Expand Up @@ -112,17 +116,19 @@ void writeOutput(ResultSubpartitionInfo info, Buffer... flinkBuffers) throws Exc
write(resultSubpartitionOffsets, info, flinkBuffers, !allOutputsReceived);
}

private <K> void write(Map<K, List<Long>> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
private <K> void write(Map<K, StateContentMetaInfo> offsets, K key, Buffer[] flinkBuffers, boolean precondition) throws Exception {
try {
if (result.isDone()) {
return;
}
runWithChecks(() -> {
checkState(precondition);
offsets
.computeIfAbsent(key, unused -> new ArrayList<>())
.add(checkpointStream.getPos());
long offset = checkpointStream.getPos();
serializer.writeData(dataStream, flinkBuffers);
long size = checkpointStream.getPos() - offset;
offsets
.computeIfAbsent(key, unused -> new StateContentMetaInfo())
.withDataAdded(offset, size);
});
} finally {
for (Buffer flinkBuffer : flinkBuffers) {
Expand Down Expand Up @@ -159,14 +165,8 @@ private void finishWriteAndResult() throws IOException {
}
dataStream.flush();
StreamStateHandle underlying = checkpointStream.closeAndGetHandle();
complete(
result.inputChannelStateHandles,
inputChannelOffsets,
(chan, offsets) -> new InputChannelStateHandle(chan, underlying, offsets));
complete(
result.resultSubpartitionStateHandles,
resultSubpartitionOffsets,
(chan, offsets) -> new ResultSubpartitionStateHandle(chan, underlying, offsets));
complete(underlying, result.inputChannelStateHandles, inputChannelOffsets, HandleFactory.INPUT_CHANNEL);
complete(underlying, result.resultSubpartitionStateHandles, resultSubpartitionOffsets, HandleFactory.RESULT_SUBPARTITION);
}

private void doComplete(boolean precondition, RunnableWithException complete, RunnableWithException... callbacks) throws Exception {
Expand All @@ -180,17 +180,38 @@ private void doComplete(boolean precondition, RunnableWithException complete, Ru
}

private <I, H extends AbstractChannelStateHandle<I>> void complete(
StreamStateHandle underlying,
CompletableFuture<Collection<H>> future,
Map<I, List<Long>> offsets,
BiFunction<I, List<Long>, H> buildHandle) {
Map<I, StateContentMetaInfo> offsets,
HandleFactory<I, H> handleFactory) throws IOException {
final Collection<H> handles = new ArrayList<>();
for (Map.Entry<I, List<Long>> e : offsets.entrySet()) {
handles.add(buildHandle.apply(e.getKey(), e.getValue()));
for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
}
future.complete(handles);
LOG.debug("channel state write completed, checkpointId: {}, handles: {}", checkpointId, handles);
}

private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
HandleFactory<I, H> handleFactory,
StreamStateHandle underlying,
I channelInfo,
StateContentMetaInfo contentMetaInfo) throws IOException {
Optional<byte[]> bytes = underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and removing this method: https://issues.apache.org/jira/browse/FLINK-17972
if (bytes.isPresent()) {
StreamStateHandle extracted = new ByteStreamStateHandle(
randomUUID().toString(),
serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
return handleFactory.create(
channelInfo,
extracted,
singletonList(serializer.getHeaderLength()),
extracted.getStateSize());
} else {
return handleFactory.create(channelInfo, underlying, contentMetaInfo.getOffsets(), contentMetaInfo.getSize());
}
}

private void runWithChecks(RunnableWithException r) throws Exception {
try {
checkState(!result.isDone(), "result is already completed", result);
Expand All @@ -206,4 +227,11 @@ public void fail(Throwable e) throws Exception {
checkpointStream.close();
}

private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
H create(I info, StreamStateHandle underlying, List<Long> offsets, long size);

HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL = InputChannelStateHandle::new;

HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION = ResultSubpartitionStateHandle::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public ChannelStateReaderImpl(TaskStateSnapshot snapshot) {
this(snapshot, new ChannelStateSerializerImpl());
}

ChannelStateReaderImpl(TaskStateSnapshot snapshot, ChannelStateDeserializer serializer) {
ChannelStateReaderImpl(TaskStateSnapshot snapshot, ChannelStateSerializer serializer) {
RefCountingFSDataInputStreamFactory streamFactory = new RefCountingFSDataInputStreamFactory(serializer);
final HashMap<InputChannelInfo, ChannelStateStreamReader> inputChannelHandleReadersTmp = new HashMap<>();
final HashMap<ResultSubpartitionInfo, ChannelStateStreamReader> resultSubpartitionHandleReadersTmp = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@

import javax.annotation.concurrent.NotThreadSafe;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;

import static java.lang.Math.addExact;
import static java.lang.Math.min;
Expand All @@ -39,15 +42,16 @@ interface ChannelStateSerializer {
void writeHeader(DataOutputStream dataStream) throws IOException;

void writeData(DataOutputStream stream, Buffer... flinkBuffers) throws IOException;
}

interface ChannelStateDeserializer {

void readHeader(InputStream stream) throws IOException;

int readLength(InputStream stream) throws IOException;

int readData(InputStream stream, ChannelStateByteBuffer buffer, int bytes) throws IOException;

byte[] extractAndMerge(byte[] bytes, List<Long> offsets) throws IOException;

long getHeaderLength();
}

/**
Expand Down Expand Up @@ -128,7 +132,7 @@ public int writeBytes(InputStream input, int bytesToRead) throws IOException {
}
}

class ChannelStateSerializerImpl implements ChannelStateSerializer, ChannelStateDeserializer {
class ChannelStateSerializerImpl implements ChannelStateSerializer {
private static final int SERIALIZATION_VERSION = 0;

@Override
Expand Down Expand Up @@ -174,4 +178,35 @@ public int readData(InputStream stream, ChannelStateByteBuffer buffer, int bytes
private static int readInt(InputStream stream) throws IOException {
return new DataInputStream(stream).readInt();
}

@Override
public byte[] extractAndMerge(byte[] bytes, List<Long> offsets) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
DataOutputStream dataOutputStream = new DataOutputStream(out);
byte[] merged = extractByOffsets(bytes, offsets);
writeHeader(dataOutputStream);
dataOutputStream.writeInt(merged.length);
dataOutputStream.write(merged, 0, merged.length);
dataOutputStream.close();
return out.toByteArray();
}

private byte[] extractByOffsets(byte[] data, List<Long> offsets) throws IOException {
DataInputStream lengthReadingStream = new DataInputStream(new ByteArrayInputStream(data, 0, data.length));
ByteArrayOutputStream out = new ByteArrayOutputStream();
long prevOffset = 0;
for (long offset : offsets) {
lengthReadingStream.skipBytes((int) (offset - prevOffset));
int dataWithLengthOffset = (int) offset + Integer.BYTES;
out.write(data, dataWithLengthOffset, lengthReadingStream.readInt());
prevOffset = dataWithLengthOffset;
}
return out.toByteArray();
}

@Override
public long getHeaderLength() {
return Integer.BYTES;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
class ChannelStateStreamReader implements Closeable {

private final RefCountingFSDataInputStream stream;
private final ChannelStateDeserializer serializer;
private final ChannelStateSerializer serializer;
private final Queue<Long> offsets;
private int remainingBytes = -1;
private boolean closed = false;
Expand All @@ -54,7 +54,7 @@ class ChannelStateStreamReader implements Closeable {
this(streamFactory.getOrCreate(handle), handle.getOffsets(), streamFactory.getSerializer());
}

private ChannelStateStreamReader(RefCountingFSDataInputStream stream, List<Long> offsets, ChannelStateDeserializer serializer) {
private ChannelStateStreamReader(RefCountingFSDataInputStream stream, List<Long> offsets, ChannelStateSerializer serializer) {
this.stream = stream;
this.stream.incRef();
this.serializer = serializer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ private enum State {NEW, OPENED, CLOSED}

private final SupplierWithException<FSDataInputStream, IOException> streamSupplier;
private FSDataInputStream stream;
private final ChannelStateDeserializer serializer;
private final ChannelStateSerializer serializer;
private int refCount = 0;
private State state = State.NEW;

private RefCountingFSDataInputStream(
SupplierWithException<FSDataInputStream, IOException> streamSupplier,
ChannelStateDeserializer serializer) {
ChannelStateSerializer serializer) {
this.streamSupplier = checkNotNull(streamSupplier);
this.serializer = checkNotNull(serializer);
}
Expand Down Expand Up @@ -105,9 +105,9 @@ private void checkNotClosed() {
@NotThreadSafe
static class RefCountingFSDataInputStreamFactory {
private final Map<StreamStateHandle, RefCountingFSDataInputStream> streams = new HashMap<>(); // not clearing: expecting short life
private final ChannelStateDeserializer serializer;
private final ChannelStateSerializer serializer;

RefCountingFSDataInputStreamFactory(ChannelStateDeserializer serializer) {
RefCountingFSDataInputStreamFactory(ChannelStateSerializer serializer) {
this.serializer = checkNotNull(serializer);
}

Expand All @@ -121,7 +121,7 @@ <T> RefCountingFSDataInputStream getOrCreate(AbstractChannelStateHandle<T> handl
return stream;
}

ChannelStateDeserializer getSerializer() {
ChannelStateSerializer getSerializer() {
return serializer;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
Expand Down Expand Up @@ -51,7 +52,7 @@ ResultSubpartitionStateHandle deserializeResultSubpartitionStateHandle(

return deserializeChannelStateHandle(
is -> new ResultSubpartitionInfo(is.readInt(), is.readInt()),
(streamStateHandle, longs, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, longs),
(streamStateHandle, contentMetaInfo, info) -> new ResultSubpartitionStateHandle(info, streamStateHandle, contentMetaInfo),
dis,
context);
}
Expand All @@ -69,7 +70,7 @@ InputChannelStateHandle deserializeInputChannelStateHandle(

return deserializeChannelStateHandle(
is -> new InputChannelInfo(is.readInt(), is.readInt()),
(streamStateHandle, longs, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, longs),
(streamStateHandle, contentMetaInfo, inputChannelInfo) -> new InputChannelStateHandle(inputChannelInfo, streamStateHandle, contentMetaInfo),
dis,
context);
}
Expand All @@ -83,12 +84,13 @@ private static <I> void serializeChannelStateHandle(
for (long offset : handle.getOffsets()) {
dos.writeLong(offset);
}
dos.writeLong(handle.getStateSize());
serializeStreamStateHandle(handle.getDelegate(), dos);
}

private static <Info, Handle extends AbstractChannelStateHandle<Info>> Handle deserializeChannelStateHandle(
FunctionWithException<DataInputStream, Info, IOException> infoReader,
TriFunctionWithException<StreamStateHandle, List<Long>, Info, Handle, IOException> handleBuilder,
TriFunctionWithException<StreamStateHandle, StateContentMetaInfo, Info, Handle, IOException> handleBuilder,
DataInputStream dis,
MetadataV2V3SerializerBase.DeserializationContext context) throws IOException {

Expand All @@ -98,6 +100,7 @@ private static <Info, Handle extends AbstractChannelStateHandle<Info>> Handle de
for (int i = 0; i < offsetsSize; i++) {
offsets.add(dis.readLong());
}
return handleBuilder.apply(deserializeStreamStateHandle(dis, context), offsets, info);
final long size = dis.readLong();
return handleBuilder.apply(deserializeStreamStateHandle(dis, context), new StateContentMetaInfo(offsets, size), info);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.flink.annotation.Internal;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

Expand All @@ -39,11 +41,13 @@ public abstract class AbstractChannelStateHandle<Info> implements StateObject {
* Start offsets in a {@link org.apache.flink.core.fs.FSDataInputStream stream} {@link StreamStateHandle#openInputStream obtained} from {@link #delegate}.
*/
private final List<Long> offsets;
private final long size;

AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info) {
AbstractChannelStateHandle(StreamStateHandle delegate, List<Long> offsets, Info info, long size) {
this.info = checkNotNull(info);
this.delegate = checkNotNull(delegate);
this.offsets = checkNotNull(offsets);
this.size = size;
}

@Override
Expand All @@ -53,7 +57,7 @@ public void discardState() throws Exception {

@Override
public long getStateSize() {
return delegate.getStateSize();
return size; // can not rely on delegate.getStateSize because it can be shared
}

public List<Long> getOffsets() {
Expand Down Expand Up @@ -84,4 +88,35 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(info, delegate, offsets);
}

/**
* Describes the underlying content.
*/
public static class StateContentMetaInfo {
private final List<Long> offsets;
private long size = 0;

public StateContentMetaInfo() {
this(new ArrayList<>(), 0);
}

public StateContentMetaInfo(List<Long> offsets, long size) {
this.offsets = offsets;
this.size = size;
}

public void withDataAdded(long offset, long size) {
this.offsets.add(offset);
this.size += size;
}

public List<Long> getOffsets() {
return Collections.unmodifiableList(offsets);
}

public long getSize() {
return size;
}
}

}
Loading