Skip to content

Commit

Permalink
KAFKA-12396: added null check for state stores key (#10548)
Browse files Browse the repository at this point in the history
Reviewers: Bruno Cadonna <bruno@confluent.io>, Matthias J. Sax <matthias@confluent.io>
  • Loading branch information
Nathan22177 committed Apr 30, 2021
1 parent 9dbf222 commit e454bec
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 7 deletions.
Expand Up @@ -32,7 +32,6 @@
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.Objects;

public class InMemoryKeyValueStore implements KeyValueStore<Bytes, byte[]> {

Expand Down Expand Up @@ -107,8 +106,6 @@ public void putAll(final List<KeyValue<Bytes, byte[]>> entries) {

@Override
public <PS extends Serializer<P>, P> KeyValueIterator<Bytes, byte[]> prefixScan(final P prefix, final PS prefixKeySerializer) {
Objects.requireNonNull(prefix, "prefix cannot be null");
Objects.requireNonNull(prefixKeySerializer, "prefixKeySerializer cannot be null");

final Bytes from = Bytes.wrap(prefixKeySerializer.serialize(null, prefix));
final Bytes to = Bytes.increment(from);
Expand Down
Expand Up @@ -38,6 +38,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde;
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
Expand Down Expand Up @@ -185,6 +186,7 @@ public boolean setFlushListener(final CacheFlushListener<K, V> listener,

@Override
public V get(final K key) {
Objects.requireNonNull(key, "key cannot be null");
try {
return maybeMeasureLatency(() -> outerValue(wrapped().get(keyBytes(key))), time, getSensor);
} catch (final ProcessorStateException e) {
Expand All @@ -196,6 +198,7 @@ public V get(final K key) {
@Override
public void put(final K key,
final V value) {
Objects.requireNonNull(key, "key cannot be null");
try {
maybeMeasureLatency(() -> wrapped().put(keyBytes(key), serdes.rawValue(value)), time, putSensor);
maybeRecordE2ELatency();
Expand All @@ -208,6 +211,7 @@ public void put(final K key,
@Override
public V putIfAbsent(final K key,
final V value) {
Objects.requireNonNull(key, "key cannot be null");
final V currentValue = maybeMeasureLatency(
() -> outerValue(wrapped().putIfAbsent(keyBytes(key), serdes.rawValue(value))),
time,
Expand All @@ -219,11 +223,13 @@ public V putIfAbsent(final K key,

@Override
public void putAll(final List<KeyValue<K, V>> entries) {
entries.forEach(entry -> Objects.requireNonNull(entry.key, "key cannot be null"));
maybeMeasureLatency(() -> wrapped().putAll(innerEntries(entries)), time, putAllSensor);
}

@Override
public V delete(final K key) {
Objects.requireNonNull(key, "key cannot be null");
try {
return maybeMeasureLatency(() -> outerValue(wrapped().delete(keyBytes(key))), time, deleteSensor);
} catch (final ProcessorStateException e) {
Expand All @@ -234,13 +240,16 @@ public V delete(final K key) {

@Override
public <PS extends Serializer<P>, P> KeyValueIterator<K, V> prefixScan(final P prefix, final PS prefixKeySerializer) {

Objects.requireNonNull(prefix, "key cannot be null");
Objects.requireNonNull(prefixKeySerializer, "prefixKeySerializer cannot be null");
return new MeteredKeyValueIterator(wrapped().prefixScan(prefix, prefixKeySerializer), prefixScanSensor);
}

@Override
public KeyValueIterator<K, V> range(final K from,
final K to) {
Objects.requireNonNull(from, "keyFrom cannot be null");
Objects.requireNonNull(to, "keyTo cannot be null");
return new MeteredKeyValueIterator(
wrapped().range(Bytes.wrap(serdes.rawKey(from)), Bytes.wrap(serdes.rawKey(to))),
rangeSensor
Expand All @@ -250,6 +259,8 @@ public KeyValueIterator<K, V> range(final K from,
@Override
public KeyValueIterator<K, V> reverseRange(final K from,
final K to) {
Objects.requireNonNull(from, "keyFrom cannot be null");
Objects.requireNonNull(to, "keyTo cannot be null");
return new MeteredKeyValueIterator(
wrapped().reverseRange(Bytes.wrap(serdes.rawKey(from)), Bytes.wrap(serdes.rawKey(to))),
rangeSensor
Expand Down
Expand Up @@ -159,6 +159,9 @@ public boolean setFlushListener(final CacheFlushListener<Windowed<K>, V> listene
public void put(final Windowed<K> sessionKey,
final V aggregate) {
Objects.requireNonNull(sessionKey, "sessionKey can't be null");
Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null");
Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't be null");

try {
maybeMeasureLatency(
() -> {
Expand All @@ -178,6 +181,9 @@ public void put(final Windowed<K> sessionKey,
@Override
public void remove(final Windowed<K> sessionKey) {
Objects.requireNonNull(sessionKey, "sessionKey can't be null");
Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null");
Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't be null");

try {
maybeMeasureLatency(
() -> {
Expand Down
Expand Up @@ -36,6 +36,8 @@
import org.apache.kafka.streams.state.WindowStoreIterator;
import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics;

import java.util.Objects;

import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde;
import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;

Expand Down Expand Up @@ -161,6 +163,7 @@ public boolean setFlushListener(final CacheFlushListener<Windowed<K>, V> listene
public void put(final K key,
final V value,
final long windowStartTimestamp) {
Objects.requireNonNull(key, "key cannot be null");
try {
maybeMeasureLatency(
() -> wrapped().put(keyBytes(key), serdes.rawValue(value), windowStartTimestamp),
Expand All @@ -177,6 +180,7 @@ public void put(final K key,
@Override
public V fetch(final K key,
final long timestamp) {
Objects.requireNonNull(key, "key cannot be null");
return maybeMeasureLatency(
() -> {
final byte[] result = wrapped().fetch(keyBytes(key), timestamp);
Expand All @@ -195,6 +199,7 @@ public V fetch(final K key,
public WindowStoreIterator<V> fetch(final K key,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(key, "key cannot be null");
return new MeteredWindowStoreIterator<>(
wrapped().fetch(keyBytes(key), timeFrom, timeTo),
fetchSensor,
Expand All @@ -208,6 +213,7 @@ public WindowStoreIterator<V> fetch(final K key,
public WindowStoreIterator<V> backwardFetch(final K key,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(key, "key cannot be null");
return new MeteredWindowStoreIterator<>(
wrapped().backwardFetch(keyBytes(key), timeFrom, timeTo),
fetchSensor,
Expand All @@ -223,6 +229,8 @@ public KeyValueIterator<Windowed<K>, V> fetch(final K keyFrom,
final K keyTo,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
Objects.requireNonNull(keyTo, "keyTo cannot be null");
return new MeteredWindowedKeyValueIterator<>(
wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo), timeFrom, timeTo),
fetchSensor,
Expand All @@ -236,6 +244,8 @@ public KeyValueIterator<Windowed<K>, V> backwardFetch(final K keyFrom,
final K keyTo,
final long timeFrom,
final long timeTo) {
Objects.requireNonNull(keyFrom, "keyFrom cannot be null");
Objects.requireNonNull(keyTo, "keyTo cannot be null");
return new MeteredWindowedKeyValueIterator<>(
wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo), timeFrom, timeTo),
fetchSensor,
Expand Down
Expand Up @@ -307,9 +307,6 @@ public void putAll(final List<KeyValue<Bytes, byte[]>> entries) {
@Override
public <PS extends Serializer<P>, P> KeyValueIterator<Bytes, byte[]> prefixScan(final P prefix,
final PS prefixKeySerializer) {
Objects.requireNonNull(prefix, "prefix cannot be null");
Objects.requireNonNull(prefixKeySerializer, "prefixKeySerializer cannot be null");

validateStoreOpen();
final Bytes prefixBytes = Bytes.wrap(prefixKeySerializer.serialize(null, prefix));

Expand Down
Expand Up @@ -40,6 +40,7 @@
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;

public class InMemoryKeyValueStoreTest extends AbstractKeyValueStoreTest {

Expand Down Expand Up @@ -224,4 +225,9 @@ public void shouldReturnNoKeys() {
}
assertThat(numberOfKeysReturned, is(0));
}

@Test
public void shouldThrowNullPointerIfPrefixKeySerializerIsNull() {
assertThrows(NullPointerException.class, () -> byteStore.prefixScan("bb", null));
}
}
Expand Up @@ -435,6 +435,57 @@ public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
verify(inner);
}

@Test
public void shouldThrowNullPointerOnGetIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.get(null));
}

@Test
public void shouldThrowNullPointerOnPutIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.put(null, VALUE));
}

@Test
public void shouldThrowNullPointerOnPutIfAbsentIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.putIfAbsent(null, VALUE));
}

@Test
public void shouldThrowNullPointerOnDeleteIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.delete(null));
}

@Test
public void shouldThrowNullPointerOnPutAllIfAnyKeyIsNull() {
assertThrows(NullPointerException.class, () -> metered.putAll(Collections.singletonList(KeyValue.pair(null, VALUE))));
}

@Test
public void shouldThrowNullPointerOnPrefixScanIfPrefixIsNull() {
final StringSerializer stringSerializer = new StringSerializer();
assertThrows(NullPointerException.class, () -> metered.prefixScan(null, stringSerializer));
}

@Test
public void shouldThrowNullPointerOnRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> metered.range(null, "to"));
}

@Test
public void shouldThrowNullPointerOnRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> metered.range("from", null));
}

@Test
public void shouldThrowNullPointerOnReverseRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> metered.reverseRange(null, "to"));
}

@Test
public void shouldThrowNullPointerOnReverseRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> metered.reverseRange("from", null));
}

@Test
public void shouldGetRecordsWithPrefixKey() {
final StringSerializer stringSerializer = new StringSerializer();
Expand Down
Expand Up @@ -472,11 +472,36 @@ public void shouldThrowNullPointerOnRemoveIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.remove(null));
}

@Test
public void shouldThrowNullPointerOnPutIfWrappedKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.put(new Windowed<>(null, new SessionWindow(0, 0)), "a"));
}

@Test
public void shouldThrowNullPointerOnRemoveIfWrappedKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.remove(new Windowed<>(null, new SessionWindow(0, 0))));
}

@Test
public void shouldThrowNullPointerOnPutIfWindowIsNull() {
assertThrows(NullPointerException.class, () -> store.put(new Windowed<>(KEY, null), "a"));
}

@Test
public void shouldThrowNullPointerOnRemoveIfWindowIsNull() {
assertThrows(NullPointerException.class, () -> store.remove(new Windowed<>(KEY, null)));
}

@Test
public void shouldThrowNullPointerOnFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null));
}

@Test
public void shouldThrowNullPointerOnFetchSessionIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetchSession(null, 0, Long.MAX_VALUE));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, "to"));
Expand All @@ -487,6 +512,21 @@ public void shouldThrowNullPointerOnFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch("from", null));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, "to"));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch("from", null));
}

@Test
public void shouldThrowNullPointerOnFindSessionsIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.findSessions(null, 0, 0));
Expand All @@ -502,6 +542,21 @@ public void shouldThrowNullPointerOnFindSessionsRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.findSessions("a", null, 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, "a", 0, 0));
}

@Test
public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFindSessions("a", null, 0, 0));
}

private interface CachedSessionStore extends SessionStore<Bytes, byte[]>, CachedStateStore<byte[], byte[]> { }

@SuppressWarnings("unchecked")
Expand Down
Expand Up @@ -468,6 +468,42 @@ public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
verify(innerStoreMock);
}

@Test
public void shouldThrowNullPointerOnPutIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.put(null, "a", 1L));
}

@Test
public void shouldThrowNullPointerOnFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch(null, "to", 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.fetch("from", null, 0L, 1L));
}


@Test
public void shouldThrowNullPointerOnbackwardFetchRangeIfFromIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch(null, "to", 0L, 1L));
}

@Test
public void shouldThrowNullPointerOnbackwardFetchRangeIfToIsNull() {
assertThrows(NullPointerException.class, () -> store.backwardFetch("from", null, 0L, 1L));
}

private List<MetricName> storeMetrics() {
return metrics.metrics()
.keySet()
Expand Down

0 comments on commit e454bec

Please sign in to comment.