Skip to content

Commit

Permalink
[FLINK-2924] [streaming] Use timestamps to store checkpoints so it su…
Browse files Browse the repository at this point in the history
…pports job shutdown/restart
  • Loading branch information
gyfora committed Nov 24, 2015
1 parent 347e6f7 commit c254bda
Show file tree
Hide file tree
Showing 24 changed files with 143 additions and 152 deletions.
Expand Up @@ -137,7 +137,7 @@ void deleteCheckpoint(String jobId, Connection con, long checkpointId, long chec

/**
* Retrieve the latest value from the database for a given key and
* checkpointId.
* timestamp.
*
* @param stateId
* Unique identifier of the kvstate (usually the table name).
Expand All @@ -153,25 +153,24 @@ byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key,
throws SQLException;

/**
* Clean up states between the current and next checkpoint id. Everything
* with larger than current and smaller than next should be removed.
* Clean up states between the checkpoint and recovery timestamp.
*
*/
void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointId,
long nextId) throws SQLException;
void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTimestamp,
long recoveryTimestamp) throws SQLException;

/**
* Insert a list of Key-Value pairs into the database. The suggested
* approach is to use idempotent inserts(updates) as 1 batch operation.
*
*/
void insertBatch(String stateId, DbBackendConfig conf, Connection con, PreparedStatement insertStatement,
long checkpointId, List<Tuple2<byte[], byte[]>> toInsert) throws IOException;
long checkpointTimestamp, List<Tuple2<byte[], byte[]>> toInsert) throws IOException;

/**
* Compact the states between two checkpoint ids by only keeping the most
* Compact the states between two checkpoint timestamp by only keeping the most
* recent.
*/
void compactKvStates(String kvStateId, Connection con, long lowerId, long upperId) throws SQLException;
void compactKvStates(String kvStateId, Connection con, long lowerTs, long upperTs) throws SQLException;

}
Expand Up @@ -93,7 +93,8 @@ public class LazyDbKvState<K, V> implements KvState<K, V, DbStateBackend>, Check
// LRU cache for the key-value states backed by the database
private final StateCache cache;

private long nextCheckpointId;
private long nextTs;
private Map<Long, Long> completedCheckpoints = new HashMap<>();

// ------------------------------------------------------

Expand All @@ -110,7 +111,7 @@ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons,
* Initialize the {@link LazyDbKvState} from a snapshot.
*/
public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons, final DbBackendConfig conf,
TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer, V defaultValue, long nextId)
TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer, V defaultValue, long nextTs)
throws IOException {

this.kvStateId = kvStateId;
Expand All @@ -128,7 +129,7 @@ public LazyDbKvState(String kvStateId, boolean compact, ShardedConnection cons,
this.numSqlRetries = conf.getMaxNumberOfSqlRetries();
this.sqlRetrySleep = conf.getSleepBetweenSqlRetries();

this.nextCheckpointId = nextId;
this.nextTs = nextTs;

this.cache = new StateCache(conf.getKvCacheSize(), conf.getNumElementsToEvict());

Expand Down Expand Up @@ -175,18 +176,23 @@ public V value() throws IOException {
@Override
public DbKvStateSnapshot<K, V> snapshot(long checkpointId, long timestamp) throws IOException {

// We insert the modified elements to the database with the current id
// then clear the modified states
// Validate timing assumptions
if (timestamp <= nextTs) {
throw new RuntimeException("Checkpoint timestamp is smaller than previous ts + 1, "
+ "this should not happen.");
}

// We insert the modified elements to the database with the current
// timestamp then clear the modified states
for (Entry<K, Optional<V>> state : cache.modified.entrySet()) {
batchInsert.add(state, checkpointId);
batchInsert.add(state, timestamp);
}
batchInsert.flush(checkpointId);
batchInsert.flush(timestamp);
cache.modified.clear();

// We increase the next checkpoint id
nextCheckpointId = checkpointId + 1;

return new DbKvStateSnapshot<K, V>(kvStateId, checkpointId);
nextTs = timestamp + 1;
completedCheckpoints.put(checkpointId, timestamp);
return new DbKvStateSnapshot<K, V>(kvStateId, timestamp);
}

/**
Expand Down Expand Up @@ -230,11 +236,16 @@ public Void call() throws Exception {

@Override
public void notifyCheckpointComplete(long checkpointId) {
Long ts = completedCheckpoints.remove(checkpointId);
if (ts == null) {
LOG.warn("Complete notification for missing checkpoint: " + checkpointId);
ts = 0L;
}
// If compaction is turned on we compact on the first subtask
if (compactEvery > 0 && compact && checkpointId % compactEvery == 0) {
try {
for (Connection c : connections.connections()) {
dbAdapter.compactKvStates(kvStateId, c, 0, checkpointId);
dbAdapter.compactKvStates(kvStateId, c, 0, ts);
}
if (LOG.isDebugEnabled()) {
LOG.debug("State succesfully compacted for {}.", kvStateId);
Expand Down Expand Up @@ -294,17 +305,25 @@ private static class DbKvStateSnapshot<K, V> implements KvStateSnapshot<K, V, Db
private static final long serialVersionUID = 1L;

private final String kvStateId;
private final long checkpointId;
private final long checkpointTimestamp;

public DbKvStateSnapshot(String kvStateId, long checkpointId) {
this.checkpointId = checkpointId;
public DbKvStateSnapshot(String kvStateId, long checkpointTimestamp) {
this.checkpointTimestamp = checkpointTimestamp;
this.kvStateId = kvStateId;
}

@Override
public LazyDbKvState<K, V> restoreState(final DbStateBackend stateBackend,
final TypeSerializer<K> keySerializer, final TypeSerializer<V> valueSerializer, final V defaultValue,
ClassLoader classLoader, final long nextId) throws IOException {
ClassLoader classLoader, final long recoveryTimestamp) throws IOException {

// Validate timing assumptions
if (recoveryTimestamp <= checkpointTimestamp) {
throw new RuntimeException(
"Recovery timestamp is smaller or equal to checkpoint timestamp. "
+ "This might happen if the job was started with a new JobManager "
+ "and the clocks got really out of sync.");
}

// First we clean up the states written by partially failed
// snapshots
Expand All @@ -314,7 +333,7 @@ public Void call() throws Exception {
// We need to perform cleanup on all shards to be safe here
for (Connection c : stateBackend.getConnections().connections()) {
stateBackend.getConfiguration().getDbAdapter().cleanupFailedCheckpoints(kvStateId,
c, checkpointId, nextId);
c, checkpointTimestamp, recoveryTimestamp);
}

return null;
Expand All @@ -327,10 +346,10 @@ public Void call() throws Exception {
// Restore the KvState
LazyDbKvState<K, V> restored = new LazyDbKvState<K, V>(kvStateId, cleanup,
stateBackend.getConnections(), stateBackend.getConfiguration(), keySerializer, valueSerializer,
defaultValue, nextId);
defaultValue, recoveryTimestamp);

if (LOG.isDebugEnabled()) {
LOG.debug("KV state({},{}) restored.", kvStateId, nextId);
LOG.debug("KV state({},{}) restored.", kvStateId, recoveryTimestamp);
}

return restored;
Expand Down Expand Up @@ -410,7 +429,7 @@ public V call() throws Exception {
// We lookup using the adapter and serialize/deserialize
// with the TypeSerializers
byte[] serializedVal = dbAdapter.lookupKey(kvStateId,
selectStatements.getForKey(key), serializedKey, nextCheckpointId);
selectStatements.getForKey(key), serializedKey, nextTs);

return serializedVal != null
? InstantiationUtil.deserializeFromByteArray(valueSerializer, serializedVal) : null;
Expand Down Expand Up @@ -443,13 +462,13 @@ private void evictIfFull() {

// We only need to write to the database if modified
if (modified.remove(next.getKey()) != null) {
batchInsert.add(next, nextCheckpointId);
batchInsert.add(next, nextTs);
}

entryIterator.remove();
}

batchInsert.flush(nextCheckpointId);
batchInsert.flush(nextTs);

} catch (IOException e) {
// We need to re-throw this exception to conform to the map
Expand Down Expand Up @@ -492,7 +511,7 @@ public BatchInserter(int numShards) {
}
}

public void add(Entry<K, Optional<V>> next, long checkpointId) throws IOException {
public void add(Entry<K, Optional<V>> next, long timestamp) throws IOException {

K key = next.getKey();
V value = next.getValue().orNull();
Expand All @@ -512,19 +531,19 @@ public void add(Entry<K, Optional<V>> next, long checkpointId) throws IOExceptio
dbAdapter.insertBatch(kvStateId, conf,
connections.getForIndex(shardIndex),
insertStatements.getForIndex(shardIndex),
checkpointId, insertPartition);
timestamp, insertPartition);

insertPartition.clear();
}
}

public void flush(long checkpointId) throws IOException {
public void flush(long timestamp) throws IOException {
// We flush all non-empty partitions
for (int i = 0; i < inserts.length; i++) {
List<Tuple2<byte[], byte[]>> insertPartition = inserts[i];
if (!insertPartition.isEmpty()) {
dbAdapter.insertBatch(kvStateId, conf, connections.getForIndex(i),
insertStatements.getForIndex(i), checkpointId, insertPartition);
insertStatements.getForIndex(i), timestamp, insertPartition);
insertPartition.clear();
}
}
Expand Down
Expand Up @@ -120,18 +120,18 @@ public void createKVStateTable(String stateId, Connection con) throws SQLExcepti
smt.executeUpdate(
"CREATE TABLE IF NOT EXISTS kvstate_" + stateId
+ " ("
+ "id bigint, "
+ "timestamp bigint, "
+ "k varbinary(256), "
+ "v blob, "
+ "PRIMARY KEY (k, id) "
+ "PRIMARY KEY (k, timestamp) "
+ ")");
}
}

@Override
public String prepareKVCheckpointInsert(String stateId) throws SQLException {
validateStateId(stateId);
return "INSERT INTO kvstate_" + stateId + " (id, k, v) VALUES (?,?,?) "
return "INSERT INTO kvstate_" + stateId + " (timestamp, k, v) VALUES (?,?,?) "
+ "ON DUPLICATE KEY UPDATE v=? ";
}

Expand All @@ -141,15 +141,13 @@ public String prepareKeyLookup(String stateId) throws SQLException {
return "SELECT v"
+ " FROM kvstate_" + stateId
+ " WHERE k = ?"
+ " AND id <= ?"
+ " ORDER BY id DESC LIMIT 1";
+ " ORDER BY timestamp DESC LIMIT 1";
}

@Override
public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupId)
public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[] key, long lookupTs)
throws SQLException {
lookupStatement.setBytes(1, key);
lookupStatement.setLong(2, lookupId);

ResultSet res = lookupStatement.executeQuery();

Expand All @@ -161,13 +159,13 @@ public byte[] lookupKey(String stateId, PreparedStatement lookupStatement, byte[
}

@Override
public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointId,
long nextId) throws SQLException {
public void cleanupFailedCheckpoints(String stateId, Connection con, long checkpointTs,
long recoveryTs) throws SQLException {
validateStateId(stateId);
try (Statement smt = con.createStatement()) {
smt.executeUpdate("DELETE FROM kvstate_" + stateId
+ " WHERE id > " + checkpointId
+ " AND id < " + nextId);
+ " WHERE timestamp > " + checkpointTs
+ " AND timestamp < " + recoveryTs);
}
}

Expand All @@ -180,12 +178,12 @@ public void compactKvStates(String stateId, Connection con, long lowerId, long u
smt.executeUpdate("DELETE state.* FROM kvstate_" + stateId + " AS state"
+ " JOIN"
+ " ("
+ " SELECT MAX(id) AS maxts, k FROM kvstate_" + stateId
+ " WHERE id BETWEEN " + lowerId + " AND " + upperId
+ " SELECT MAX(timestamp) AS maxts, k FROM kvstate_" + stateId
+ " WHERE timestamp BETWEEN " + lowerId + " AND " + upperId
+ " GROUP BY k"
+ " ) m"
+ " ON state.k = m.k"
+ " AND state.id >= " + lowerId);
+ " AND state.timestamp >= " + lowerId);
}
}

Expand All @@ -201,13 +199,13 @@ protected static void validateStateId(String name) {

@Override
public void insertBatch(final String stateId, final DbBackendConfig conf,
final Connection con, final PreparedStatement insertStatement, final long checkpointId,
final Connection con, final PreparedStatement insertStatement, final long checkpointTs,
final List<Tuple2<byte[], byte[]>> toInsert) throws IOException {

SQLRetrier.retry(new Callable<Void>() {
public Void call() throws Exception {
for (Tuple2<byte[], byte[]> kv : toInsert) {
setKvInsertParams(stateId, insertStatement, checkpointId, kv.f0, kv.f1);
setKvInsertParams(stateId, insertStatement, checkpointTs, kv.f0, kv.f1);
insertStatement.addBatch();
}
insertStatement.executeBatch();
Expand All @@ -222,9 +220,9 @@ public Void call() throws Exception {
}, conf.getMaxNumberOfSqlRetries(), conf.getSleepBetweenSqlRetries());
}

private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointId,
private void setKvInsertParams(String stateId, PreparedStatement insertStatement, long checkpointTs,
byte[] key, byte[] value) throws SQLException {
insertStatement.setLong(1, checkpointId);
insertStatement.setLong(1, checkpointTs);
insertStatement.setBytes(2, key);
if (value != null) {
insertStatement.setBytes(3, value);
Expand Down
Expand Up @@ -196,19 +196,19 @@ public void testKeyValueState() throws Exception {
kv.setCurrentKey(3);
kv.update("u3");

assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));

kv.notifyCheckpointComplete(682375462378L);

// draw another snapshot
KvStateSnapshot<Integer, String, DbStateBackend> snapshot2 = kv.snapshot(682375462379L,
200);
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462379L));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 200));
kv.notifyCheckpointComplete(682375462379L);
// Compaction should be performed
assertFalse(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462378L));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 682375462379L));
assertFalse(containsKey(backend.getConnections().getFirst(), tableName, 1, 100));
assertTrue(containsKey(backend.getConnections().getFirst(), tableName, 1, 200));

// validate the original state
assertEquals(3, kv.size());
Expand Down Expand Up @@ -426,7 +426,7 @@ private static boolean isTableEmpty(Connection con, String tableName) throws SQL
private static boolean containsKey(Connection con, String tableName, int key, long ts)
throws SQLException, IOException {
try (PreparedStatement smt = con
.prepareStatement("select * from " + tableName + " where k=? and id=?")) {
.prepareStatement("select * from " + tableName + " where k=? and timestamp=?")) {
smt.setBytes(1, InstantiationUtil.serializeToByteArray(IntSerializer.INSTANCE, key));
smt.setLong(2, ts);
return smt.executeQuery().next();
Expand Down

0 comments on commit c254bda

Please sign in to comment.