Skip to content

Commit

Permalink
[FLINK-2924] [streaming] Use timestamps to store checkpoints so it su…
Browse files Browse the repository at this point in the history
…pports job shutdown/restart
  • Loading branch information
gyfora committed Nov 24, 2015
1 parent 347e6f7 commit c254bda
Show file tree
Hide file tree
Showing 24 changed files with 143 additions and 152 deletions.
Expand Up @@ -137,7 +137,7 @@ void deleteCheckpoint(String jobId, Connection con, long checkpointId, long chec


/** /**
* Retrieve the latest value from the database for a given key and * Retrieve the latest value from the database for a given key and
* checkpointId. * timestamp.
* *
* @param stateId * @param stateId
* Unique identifier of the kvstate (usually the table name). * Unique identifier of the kvstate (usually the table name).
Expand All @@ -153,25 +153,24 @@ byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key,
throws SQLException; throws SQLException;


/** /**
* Clean up states between the current and next checkpoint id. Everything * Clean up states between the checkpoint and recovery timestamp.
* with larger than current and smaller than next should be removed.
* *
*/ */
void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointId, void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTimestamp,
long nextId) throws SQLException; long recoveryTimestamp) throws SQLException;


/** /**
* Insert a list of Key-Value pairs into the database. The suggested * Insert a list of Key-Value pairs into the database. The suggested
* approach is to use idempotent inserts(updates) as 1 batch operation. * approach is to use idempotent inserts(updates) as 1 batch operation.
* *
*/ */
void insertBatch(String stateId, DbBackendConfig conf, Connection con, PreparedStatement insertStatement, void insertBatch(String stateId, DbBackendConfig conf, Connection con, PreparedStatement insertStatement,
long checkpointId, List<Tuple2<byte[], byte[]>> toInsert) throws IOException; long checkpointTimestamp, List<Tuple2<byte[], byte[]>> toInsert) throws IOException;


/** /**
* Compact the states between two checkpoint ids by only keeping the most * Compact the states between two checkpoint timestamp by only keeping the most
* recent. * recent.
*/ */
void compactKvStates(String kvStateId, Connection con, long lowerId, long upperId) throws SQLException; void compactKvStates(String kvStateId, Connection con, long lowerTs, long upperTs) throws SQLException;


} }
Expand Up @@ -93,7 +93,8 @@ public class LazyDbKvState<K, V> implements KvState<K, V, DbStateBackend>, Check
// LRU cache for the key-value states backed by the database // LRU cache for the key-value states backed by the database
private final StateCache cache; private final StateCache cache;


private long nextCheckpointId; private long nextTs;
private Map<Long, Long> completedCheckpoints = new HashMap<>();


// ------------------------------------------------------ // ------------------------------------------------------


Expand All @@ -110,7 +111,7 @@ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons,
* Initialize the {@link LazyDbKvState} from a snapshot. * Initialize the {@link LazyDbKvState} from a snapshot.
*/ */
public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, final DbBackendConfig conf, public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, final DbBackendConfig conf,
TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer, V defaultValue, long nextId) TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer, V defaultValue, long nextTs)
throws IOException { throws IOException {


this.kvStateId = kvStateId; this.kvStateId = kvStateId;
Expand All @@ -128,7 +129,7 @@ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons,
this.numSqlRetries = conf.getMaxNumberOfSqlRetries(); this.numSqlRetries = conf.getMaxNumberOfSqlRetries();
this.sqlRetrySleep = conf.getSleepBetweenSqlRetries(); this.sqlRetrySleep = conf.getSleepBetweenSqlRetries();


this.nextCheckpointId = nextId; this.nextTs = nextTs;


this.cache = new StateCache(conf.getKvCacheSize(), conf.getNumElementsToEvict()); this.cache = new StateCache(conf.getKvCacheSize(), conf.getNumElementsToEvict());


Expand Down Expand Up @@ -175,18 +176,23 @@ public V value() throws IOException {
@Override @Override
public DbKvStateSnapshot<K, V> snapshot(long checkpointId, long timestamp) throws IOException { public DbKvStateSnapshot<K, V> snapshot(long checkpointId, long timestamp) throws IOException {


// We insert the modified elements to the database with the current id // Validate timing assumptions
// then clear the modified states if (timestamp <= nextTs) {
throw new RuntimeException("Checkpoint timestamp is smaller than previous ts + 1, "
+ "this should not happen.");
}

// We insert the modified elements to the database with the current
// timestamp then clear the modified states
for (Entry<K, Optional<V>> state : cache.modified.entrySet()) { for (Entry<K, Optional<V>> state : cache.modified.entrySet()) {
batchInsert.add(state, checkpointId); batchInsert.add(state, timestamp);
} }
batchInsert.flush(checkpointId); batchInsert.flush(timestamp);
cache.modified.clear(); cache.modified.clear();


// We increase the next checkpoint id nextTs = timestamp + 1;
nextCheckpointId = checkpointId + 1; completedCheckpoints.put(checkpointId, timestamp);

return new DbKvStateSnapshot<K, V>(kvStateId, timestamp);
return new DbKvStateSnapshot<K, V>(kvStateId, checkpointId);
} }


/** /**
Expand Down Expand Up @@ -230,11 +236,16 @@ public Void call() throws Exception {


@Override @Override
public void notifyCheckpointComplete(long checkpointId) { public void notifyCheckpointComplete(long checkpointId) {
Long ts = completedCheckpoints.remove(checkpointId);
if (ts == null) {
LOG.warn("Complete notification for missing checkpoint: " + checkpointId);
ts = 0L;
}
// If compaction is turned on we compact on the first subtask // If compaction is turned on we compact on the first subtask
if (compactEvery > 0 && compact && checkpointId % compactEvery == 0) { if (compactEvery > 0 && compact && checkpointId % compactEvery == 0) {
try { try {
for (Connection c : connections.connections()) { for (Connection c : connections.connections()) {
dbAdapter.compactKvStates(kvStateId, c, 0, checkpointId); dbAdapter.compactKvStates(kvStateId, c, 0, ts);
} }
if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("State succesfully compacted for {}.", kvStateId); LOG.debug("State succesfully compacted for {}.", kvStateId);
Expand Down Expand Up @@ -294,17 +305,25 @@ private static class DbKvStateSnapshot<K, V> implements KvStateSnapshot<K, V, Db
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;


private final String kvStateId; private final String kvStateId;
private final long checkpointId; private final long checkpointTimestamp;


public DbKvStateSnapshot(String kvStateId, long checkpointId) { public DbKvStateSnapshot(String kvStateId, long checkpointTimestamp) {
this.checkpointId = checkpointId; this.checkpointTimestamp = checkpointTimestamp;
this.kvStateId = kvStateId; this.kvStateId = kvStateId;
} }


@Override @Override
public LazyDbKvState<K, V> restoreState(final DbStateBackend stateBackend, public LazyDbKvState<K, V> restoreState(final DbStateBackend stateBackend,
final TypeSerializer<K> keySerializer, final TypeSerializer<V> valueSerializer, final V defaultValue, final TypeSerializer<K> keySerializer, final TypeSerializer<V> valueSerializer, final V defaultValue,
ClassLoader classLoader, final long nextId) throws IOException { ClassLoader classLoader, final long recoveryTimestamp) throws IOException {

// Validate timing assumptions
if (recoveryTimestamp <= checkpointTimestamp) {
throw new RuntimeException(
"Recovery timestamp is smaller or equal to checkpoint timestamp. "
+ "This might happen if the job was started with a new JobManager "
+ "and the clocks got really out of sync.");
}


// First we clean up the states written by partially failed // First we clean up the states written by partially failed
// snapshots // snapshots
Expand All @@ -314,7 +333,7 @@ public Void call() throws Exception {
// We need to perform cleanup on all shards to be safe here // We need to perform cleanup on all shards to be safe here
for (Connection c : stateBackend.getConnections().connections()) { for (Connection c : stateBackend.getConnections().connections()) {
stateBackend.getConfiguration().getDbAdapter().cleanupFailedCheckpoints(kvStateId, stateBackend.getConfiguration().getDbAdapter().cleanupFailedCheckpoints(kvStateId,
c, checkpointId, nextId); c, checkpointTimestamp, recoveryTimestamp);
} }


return null; return null;
Expand All @@ -327,10 +346,10 @@ public Void call() throws Exception {
// Restore the KvState // Restore the KvState
LazyDbKvState<K, V> restored = new LazyDbKvState<K, V>(kvStateId, cleanup, LazyDbKvState<K, V> restored = new LazyDbKvState<K, V>(kvStateId, cleanup,
stateBackend.getConnections(), stateBackend.getConfiguration(), keySerializer, valueSerializer, stateBackend.getConnections(), stateBackend.getConfiguration(), keySerializer, valueSerializer,
defaultValue, nextId); defaultValue, recoveryTimestamp);


if (LOG.isDebugEnabled()) { if (LOG.isDebugEnabled()) {
LOG.debug("KV state({},{}) restored.", kvStateId, nextId); LOG.debug("KV state({},{}) restored.", kvStateId, recoveryTimestamp);
} }


return restored; return restored;
Expand Down Expand Up @@ -410,7 +429,7 @@ public V call() throws Exception {
// We lookup using the adapter and serialize/deserialize // We lookup using the adapter and serialize/deserialize
// with the TypeSerializers // with the TypeSerializers
byte[] serializedVal = dbAdapter.lookupKey(kvStateId, byte[] serializedVal = dbAdapter.lookupKey(kvStateId,
selectStatements.getForKey(key), serializedKey, nextCheckpointId); selectStatements.getForKey(key), serializedKey, nextTs);


return serializedVal != null return serializedVal != null
? InstantiationUtil.deserializeFromByteArray(valueSerializer, serializedVal) : null; ? InstantiationUtil.deserializeFromByteArray(valueSerializer, serializedVal) : null;
Expand Down Expand Up @@ -443,13 +462,13 @@ private void evictIfFull() {


// We only need to write to the database if modified // We only need to write to the database if modified
if (modified.remove(next.getKey()) != null) { if (modified.remove(next.getKey()) != null) {
batchInsert.add(next, nextCheckpointId); batchInsert.add(next, nextTs);
} }


entryIterator.remove(); entryIterator.remove();
} }


batchInsert.flush(nextCheckpointId); batchInsert.flush(nextTs);


} catch (IOException e) { } catch (IOException e) {
// We need to re-throw this exception to conform to the map // We need to re-throw this exception to conform to the map
Expand Down Expand Up @@ -492,7 +511,7 @@ public BatchInserter(int numShards) {
} }
} }


public void add(Entry<K, Optional<V>> next, long checkpointId) throws IOException { public void add(Entry<K, Optional<V>> next, long timestamp) throws IOException {


K key = next.getKey(); K key = next.getKey();
V value = next.getValue().orNull(); V value = next.getValue().orNull();
Expand All @@ -512,19 +531,19 @@ public void add(Entry<K, Optional<V>> next, long checkpointId) throws IOExceptio
dbAdapter.insertBatch(kvStateId, conf, dbAdapter.insertBatch(kvStateId, conf,
connections.getForIndex(shardIndex), connections.getForIndex(shardIndex),
insertStatements.getForIndex(shardIndex), insertStatements.getForIndex(shardIndex),
checkpointId, insertPartition); timestamp, insertPartition);


insertPartition.clear(); insertPartition.clear();
} }
} }


public void flush(long checkpointId) throws IOException { public void flush(long timestamp) throws IOException {
// We flush all non-empty partitions // We flush all non-empty partitions
for (int i = 0; i < inserts.length; i++) { for (int i = 0; i < inserts.length; i++) {
List<Tuple2<byte[], byte[]>> insertPartition = inserts[i]; List<Tuple2<byte[], byte[]>> insertPartition = inserts[i];
if (!insertPartition.isEmpty()) { if (!insertPartition.isEmpty()) {
dbAdapter.insertBatch(kvStateId, conf, connections.getForIndex(i), dbAdapter.insertBatch(kvStateId, conf, connections.getForIndex(i),
insertStatements.getForIndex(i), checkpointId, insertPartition); insertStatements.getForIndex(i), timestamp, insertPartition);
insertPartition.clear(); insertPartition.clear();
} }
} }
Expand Down
Expand Up @@ -120,18 +120,18 @@ public void createKVStateTable(String stateId, Connection con) throws SQLExcepti
smt.executeUpdate( smt.executeUpdate(
"CREATE TABLE IF NOT EXISTS kvstate_" + stateId "CREATE TABLE IF NOT EXISTS kvstate_" + stateId
+ " (" + " ("
+ "id bigint, " + "timestamp bigint, "
+ "k varbinary(256), " + "k varbinary(256), "
+ "v blob, " + "v blob, "
+ "PRIMARY KEY (k, id) " + "PRIMARY KEY (k, timestamp) "
+ ")"); + ")");
} }
} }


@Override @Override
public String prepareKVCheckpointInsert(String stateId) throws SQLException { public String prepareKVCheckpointInsert(String stateId) throws SQLException {
validateStateId(stateId); validateStateId(stateId);
return "INSERT INTO kvstate_" + stateId + " (id, k, v) VALUES (?,?,?) " return "INSERT INTO kvstate_" + stateId + " (timestamp, k, v) VALUES (?,?,?) "
+ "ON DUPLICATE KEY UPDATE v=? "; + "ON DUPLICATE KEY UPDATE v=? ";
} }


Expand All @@ -141,15 +141,13 @@ public String prepareKeyLookup(String stateId) throws SQLException {
return "SELECT v" return "SELECT v"
+ " FROM kvstate_" + stateId + " FROM kvstate_" + stateId
+ " WHERE k = ?" + " WHERE k = ?"
+ " AND id <= ?" + " ORDER BY timestamp DESC LIMIT 1";
+ " ORDER BY id DESC LIMIT 1";
} }


@Override @Override
public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupId) public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupTs)
throws SQLException { throws SQLException {
lookupStatement.setBytes(1, key); lookupStatement.setBytes(1, key);
lookupStatement.setLong(2, lookupId);


ResultSet res = lookupStatement.executeQuery(); ResultSet res = lookupStatement.executeQuery();


Expand All @@ -161,13 +159,13 @@ public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[
} }


@Override @Override
public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointId, public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTs,
long nextId) throws SQLException { long recoveryTs) throws SQLException {
validateStateId(stateId); validateStateId(stateId);
try (Statement smt = con.createStatement()) { try (Statement smt = con.createStatement()) {
smt.executeUpdate("DELETE FROM kvstate_" + stateId smt.executeUpdate("DELETE FROM kvstate_" + stateId
+ " WHERE id > " + checkpointId + " WHERE timestamp > " + checkpointTs
+ " AND id < " + nextId); + " AND timestamp < " + recoveryTs);
} }
} }


Expand All @@ -180,12 +178,12 @@ public void compactKvStates(String stateId, Connection con, long lowerId, long u
smt.executeUpdate("DELETE state.* FROM kvstate_" + stateId + " AS state" smt.executeUpdate("DELETE state.* FROM kvstate_" + stateId + " AS state"
+ " JOIN" + " JOIN"
+ " (" + " ("
+ " SELECT MAX(id) AS maxts, k FROM kvstate_" + stateId + " SELECT MAX(timestamp) AS maxts, k FROM kvstate_" + stateId
+ " WHERE id BETWEEN " + lowerId + " AND " + upperId + " WHERE timestamp BETWEEN " + lowerId + " AND " + upperId
+ " GROUP BY k" + " GROUP BY k"
+ " ) m" + " ) m"
+ " ON state.k = m.k" + " ON state.k = m.k"
+ " AND state.id >= " + lowerId); + " AND state.timestamp >= " + lowerId);
} }
} }


Expand All @@ -201,13 +199,13 @@ protected static void validateStateId(String name) {


@Override @Override
public void insertBatch(final String stateId, final DbBackendConfig conf, public void insertBatch(final String stateId, final DbBackendConfig conf,
final Connection con, final PreparedStatement insertStatement, final long checkpointId, final Connection con, final PreparedStatement insertStatement, final long checkpointTs,
final List<Tuple2<byte[], byte[]>> toInsert) throws IOException { final List<Tuple2<byte[], byte[]>> toInsert) throws IOException {


SQLRetrier.retry(new Callable<Void>() { SQLRetrier.retry(new Callable<Void>() {
public Void call() throws Exception { public Void call() throws Exception {
for (Tuple2<byte[], byte[]> kv : toInsert) { for (Tuple2<byte[], byte[]> kv : toInsert) {
setKvInsertParams(stateId, insertStatement, checkpointId, kv.f0, kv.f1); setKvInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1);
insertStatement.addBatch(); insertStatement.addBatch();
} }
insertStatement.executeBatch(); insertStatement.executeBatch();
Expand All @@ -222,9 +220,9 @@ public Void call() throws Exception {
}, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries()); }, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries());
} }


private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointId, private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointTs,
byte[] key, byte[] value) throws SQLException { byte[] key, byte[] value) throws SQLException {
insertStatement.setLong(1, checkpointId); insertStatement.setLong(1, checkpointTs);
insertStatement.setBytes(2, key); insertStatement.setBytes(2, key);
if (value != null) { if (value != null) {
insertStatement.setBytes(3, value); insertStatement.setBytes(3, value);
Expand Down
Expand Up @@ -196,19 +196,19 @@ public void testKeyValueState() throws Exception {
kv.setCurrentKey(3); kv.setCurrentKey(3);
kv.update("u3"); kv.update("u3");


assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L)); assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));


kv.notifyCheckpointComplete(682375462378L); kv.notifyCheckpointComplete(682375462378L);


// draw another snapshot // draw another snapshot
KvStateSnapshot<Integer, String, DbStateBackend> snapshot2 = kv.snapshot(682375462379L, KvStateSnapshot<Integer, String, DbStateBackend> snapshot2 = kv.snapshot(682375462379L,
200); 200);
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L)); assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462379L)); assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 200));
kv.notifyCheckpointComplete(682375462379L); kv.notifyCheckpointComplete(682375462379L);
// Compaction should be performed // Compaction should be performed
assertFalse(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L)); assertFalse(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462379L)); assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 200));


// validate the original state // validate the original state
assertEquals(3, kv.size()); assertEquals(3, kv.size());
Expand Down Expand Up @@ -426,7 +426,7 @@ private static boolean isTableEmpty(Connection con, String tableName) throws SQL
private static boolean containsKey(Connection con, String tableName, int key, long ts) private static boolean containsKey(Connection con, String tableName, int key, long ts)
throws SQLException, IOException { throws SQLException, IOException {
try (PreparedStatement smt = con try (PreparedStatement smt = con
.prepareStatement("select * from " + tableName + " where k=? and id=?")) { .prepareStatement("select * from " + tableName + " where k=? and timestamp=?")) {
smt.setBytes(1, InstantiationUtil.serializeToByteArray(IntSerializer.INSTANCE, key)); smt.setBytes(1, InstantiationUtil.serializeToByteArray(IntSerializer.INSTANCE, key));
smt.setLong(2, ts); smt.setLong(2, ts);
return smt.executeQuery().next(); return smt.executeQuery().next();
Expand Down

0 comments on commit c254bda

Please sign in to comment.