Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAFKA-12396 added a nullcheck before trying to retrieve a key #10548

Merged
merged 10 commits into from Apr 30, 2021
Expand Up @@ -38,6 +38,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde;
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
Expand Down Expand Up @@ -185,6 +187,7 @@ public boolean setFlushListener(final CacheFlushListener<K, V> listener,

@Override
public V get(final K key) {
Objects.requireNonNull(key, "key cannot be null");
try {
return maybeMeasureLatency(() -> outerValue(wrapped().get(keyBytes(key))), time, getSensor);
} catch (final ProcessorStateException e) {
Expand All @@ -196,6 +199,7 @@ public V get(final K key) {
@Override
public void put(final K key,
final V value) {
Objects.requireNonNull(key, "key cannot be null");
try {
maybeMeasureLatency(() -> wrapped().put(keyBytes(key), serdes.rawValue(value)), time, putSensor);
maybeRecordE2ELatency();
Expand All @@ -208,6 +212,7 @@ public void put(final K key,
@Override
public V putIfAbsent(final K key,
final V value) {
Objects.requireNonNull(key, "key cannot be null");
final V currentValue = maybeMeasureLatency(
() -> outerValue(wrapped().putIfAbsent(keyBytes(key), serdes.rawValue(value))),
time,
Expand All @@ -219,11 +224,19 @@ public V putIfAbsent(final K key,

@Override
public void putAll(final List<KeyValue<K, V>> entries) {
final List<KeyValue<K, V>> possiblyNullKeys = entries
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could simplify this to a one liner?

entries.forEach(entry -> Objects.requireNonNull(entry.key, "key cannot be null"));

.stream()
.filter(entry -> entry.key == null)
.collect(Collectors.toList());
if (!possiblyNullKeys.isEmpty()) {
Objects.requireNonNull(possiblyNullKeys.get(0).key, "key cannot be null");
}
maybeMeasureLatency(() -> wrapped().putAll(innerEntries(entries)), time, putAllSensor);
}

@Override
public V delete(final K key) {
Objects.requireNonNull(key, "key cannot be null");
try {
return maybeMeasureLatency(() -> outerValue(wrapped().delete(keyBytes(key))), time, deleteSensor);
} catch (final ProcessorStateException e) {
Expand All @@ -234,13 +247,15 @@ public V delete(final K key) {

@Override
public <PS extends Serializer<P>, P> KeyValueIterator<K, V> prefixScan(final P prefix, final PS prefixKeySerializer) {

Objects.requireNonNull(prefix, "key cannot be null");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @cadonna the wrapped stores, also check prefixKeySerializer for null -- thus might be good to move both check here.

I think we can also remove both checks in RocksDBStore and InMemoryKeyValueStore -- they seems to be redundant now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also remove both checks in RocksDBStore and InMemoryKeyValueStore -- they seem to be redundant now?

they are both different implementations, aren't they?
image
I don't understand how they will be checked if we only leave it in MeteredKeyValueStore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I wrote a quick test to see if it still throws NPE in InMemoryKeyValueStore without the check - it did, I am confused, but they are, indeed, redundant.
I'll leave the tests in both RocksDBStoreTest and InMemoryKeyValueStore bc why not.

Copy link
Member

@mjsax mjsax Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KafkaStreams runtime always "wraps" any store with a corresponding MeteredXxxStore (cf https://github.com/apache/kafka/blob/trunk/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilder.java#L42-L49) -- those MeteredXxxStores do the transaction from objects to bytes (ie they use the serdes) and also track state store metrics. (Note that stores provided to the runtime always have type <Bytes, byte[]> while they are exposed to Processors as <K,V> types.)

Thus, when you call context.stateStore(...) you always get a MeteredXxxStore object -- of course, those details are hidden behind the interface type.

This architecture allows us to unify code and separate concerns. In fact, it also allows us to add/remove more "layers": we can insert a "caching layer" (cf. https://kafka.apache.org/28/documentation/streams/developer-guide/memory-mgmt.html) and a "change logging layer" (both are inserted by default).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooooh, I see.

return new MeteredKeyValueIterator(wrapped().prefixScan(prefix, prefixKeySerializer), prefixScanSensor);
}

@Override
public KeyValueIterator<K, V> range(final K from,
final K to) {
Objects.requireNonNull(from, "keyFrom cannot be null");
Objects.requireNonNull(to, "keyTo cannot be null");
return new MeteredKeyValueIterator(
wrapped().range(Bytes.wrap(serdes.rawKey(from)), Bytes.wrap(serdes.rawKey(to))),
rangeSensor
Expand All @@ -250,6 +265,8 @@ public KeyValueIterator<K, V> range(final K from,
@Override
public KeyValueIterator<K, V> reverseRange(final K from,
final K to) {
Objects.requireNonNull(from, "keyFrom cannot be null");
Objects.requireNonNull(to, "keyTo cannot be null");
return new MeteredKeyValueIterator(
wrapped().reverseRange(Bytes.wrap(serdes.rawKey(from)), Bytes.wrap(serdes.rawKey(to))),
rangeSensor
Expand Down
Expand Up @@ -159,6 +159,8 @@ public boolean setFlushListener(final CacheFlushListener<Windowed<K>, V> listene
public void put(final Windowed<K> sessionKey,
final V aggregate) {
Objects.requireNonNull(sessionKey, "sessionKey can't be null");
Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null");

try {
maybeMeasureLatency(
() -> {
Expand All @@ -178,6 +180,8 @@ public void put(final Windowed<K> sessionKey,
@Override
public void remove(final Windowed<K> sessionKey) {
Objects.requireNonNull(sessionKey, "sessionKey can't be null");
Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null");

try {
maybeMeasureLatency(
() -> {
Expand Down
Expand Up @@ -36,6 +36,8 @@
import org.apache.kafka.streams.state.WindowStoreIterator;
import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics;

import java.util.Objects;

import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde;
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;

Expand Down Expand Up @@ -161,6 +163,7 @@ public boolean setFlushListener(final CacheFlushListener<Windowed<K>, V> listene
public void put(final K key,
final V value,
final long windowStartTimestamp) {
Objects.requireNonNull(key, "key cannot be null");
try {
maybeMeasureLatency(
() -> wrapped().put(keyBytes(key), serdes.rawValue(value), windowStartTimestamp),
Expand All @@ -177,6 +180,7 @@ public void put(final K key,
@Override
public V fetch(final K key,
final long timestamp) {
Objects.requireNonNull(key, "key cannot be null");
return maybeMeasureLatency(
() -> {
final byte[] result = wrapped().fetch(keyBytes(key), timestamp);
Expand All @@ -195,6 +199,7 @@ public V fetch(final K key,
public WindowStoreIterator<V> fetch(final K key,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(key, "key cannot be null");
return new MeteredWindowStoreIterator<>(
wrapped().fetch(keyBytes(key), timeFrom, timeTo),
fetchSensor,
Expand All @@ -208,6 +213,7 @@ public WindowStoreIterator<V> fetch(final K key,
public WindowStoreIterator<V> backwardFetch(final K key,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(key, "key cannot be null");
return new MeteredWindowStoreIterator<>(
wrapped().backwardFetch(keyBytes(key), timeFrom, timeTo),
fetchSensor,
Expand All @@ -223,6 +229,8 @@ public KeyValueIterator<Windowed<K>, V> fetch(final K keyFrom,
final K keyTo,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
Objects.requireNonNull(keyTo, "keyTo cannot be null");
return new MeteredWindowedKeyValueIterator<>(
wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo), timeFrom, timeTo),
fetchSensor,
Expand All @@ -236,6 +244,8 @@ public KeyValueIterator<Windowed<K>, V> backwardFetch(final K keyFrom,
final K keyTo,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
Objects.requireNonNull(keyTo, "keyTo cannot be null");
return new MeteredWindowedKeyValueIterator<>(
wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo), timeFrom, timeTo),
fetchSensor,
Expand Down
Expand Up @@ -435,6 +435,57 @@ public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
verify(inner);
}

@Test
public void shouldThrowNullPointerOnGetIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.get(null));
}

@Test
public void shouldThrowNullPointerOnPutIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.put(null, VALUE));
}

@Test
public void shouldThrowNullPointerOnPutIfAbsentIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.putIfAbsent(null, VALUE));
}

@Test
public void shouldThrowNullPointerOnDeleteIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.delete(null));
}

@Test
public void shouldThrowNullPointerOnPutAllIfAnyKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.putAll(Collections.singletonList(KeyValue.pair(null, VALUE))));
}

@Test
public void shouldThrowNullPointerOnPrefixScanIfPrefixIsNull() {
final StringSerializer stringSerializer = new StringSerializer();
assertThrows(NullPointerException.class, () -> metered.prefixScan(null, stringSerializer));
}

@Test
public void shouldThrowNullPointerOnRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> metered.range(null, "to"));
}

@Test
public void shouldThrowNullPointerOnRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> metered.range("from", null));
}

@Test
public void shouldThrowNullPointerOnReverseRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> metered.reverseRange(null, "to"));
}

@Test
public void shouldThrowNullPointerOnReverseRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> metered.reverseRange("from", null));
}

@Test
public void shouldGetRecordsWithPrefixKey() {
final StringSerializer stringSerializer = new StringSerializer();
Expand Down
Expand Up @@ -472,11 +472,26 @@ public void shouldThrowNullPointerOnRemoveIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.remove(null));
}

@Test
public void shouldThrowNullPointerOnPutIfWrappedKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.put(new Windowed<>(null, new SessionWindow(0, 0)), "a"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test remind me, that the SessionWindow what is wrapped should not be null either.

}

@Test
public void shouldThrowNullPointerOnRemoveIfWrappedKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.remove(new Windowed<>(null, new SessionWindow(0, 0))));
}

@Test
public void shouldThrowNullPointerOnFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null));
}

@Test
public void shouldThrowNullPointerOnFetchSessionIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetchSession(null, 0, Long.MAX_VALUE));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, "to"));
Expand All @@ -487,6 +502,21 @@ public void shouldThrowNullPointerOnFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch("from", null));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, "to"));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch("from", null));
}

@Test
public void shouldThrowNullPointerOnFindSessionsIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.findSessions(null, 0, 0));
Expand All @@ -502,6 +532,21 @@ public void shouldThrowNullPointerOnFindSessionsRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.findSessions("a", null, 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, "a", 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions("a", null, 0, 0));
}

private interface CachedSessionStore extends SessionStore<Bytes, byte[]>, CachedStateStore<byte[], byte[]> { }

@SuppressWarnings("unchecked")
Expand Down
Expand Up @@ -468,6 +468,42 @@ public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
verify(innerStoreMock);
}

@Test
public void shouldThrowNullPointerOnPutIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.put(null, "a", 1L));
}

@Test
public void shouldThrowNullPointerOnFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, "to", 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch("from", null, 0L, 1L));
}


@Test
public void shouldThrowNullPointerOnbackwardFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, "to", 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnbackwardFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch("from", null, 0L, 1L));
}

private List<MetricName> storeMetrics() {
return metrics.metrics()
.keySet()
Expand Down