From 25e616303edcd76da862a5750b7f0e687c34a65a Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 20 Jun 2016 09:40:31 -0700 Subject: [PATCH 1/2] Make JAXBCoder Thread Safe JAXB Marshaller and Unmarshaller are not thread safe, but coders are required to be. Create context only once, and create Marshaller and Unmarshaller locally on-demand. --- .../org/apache/beam/sdk/coders/JAXBCoder.java | 26 ++++--- .../apache/beam/sdk/coders/JAXBCoderTest.java | 69 ++++++++++++++++++- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java index 6fc8fcf85d98..1c2f37b2bbff 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java @@ -46,8 +46,7 @@ public class JAXBCoder extends AtomicCoder { private final Class jaxbClass; - private transient Marshaller jaxbMarshaller = null; - private transient Unmarshaller jaxbUnmarshaller = null; + private transient JAXBContext jaxbContext; public Class getJAXBClass() { return jaxbClass; @@ -70,10 +69,8 @@ public static JAXBCoder of(Class jaxbClass) { public void encode(T value, OutputStream outStream, Context context) throws CoderException, IOException { try { - if (jaxbMarshaller == null) { - JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); - jaxbMarshaller = jaxbContext.createMarshaller(); - } + JAXBContext jaxbContext = getContext(); + Marshaller jaxbMarshaller = jaxbContext.createMarshaller(); if (!context.isWholeStream) { try { long size = getEncodedElementByteSize(value, Context.OUTER); @@ -95,10 +92,8 @@ public void encode(T value, OutputStream outStream, Context context) @Override public T decode(InputStream inStream, Context context) throws CoderException, IOException { try { - if (jaxbUnmarshaller == null) { - JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); - jaxbUnmarshaller = jaxbContext.createUnmarshaller(); - } + JAXBContext jaxbContext = getContext(); + Unmarshaller jaxbUnmarshaller = jaxbContext.createUnmarshaller(); InputStream stream = inStream; if (!context.isWholeStream) { @@ -113,6 +108,17 @@ public T decode(InputStream inStream, Context context) throws CoderException, IO } } + private final JAXBContext getContext() throws JAXBException { + if (jaxbContext == null) { + synchronized (this) { + if (jaxbContext == null) { + jaxbContext = JAXBContext.newInstance(jaxbClass); + } + } + } + return jaxbContext; + } + @Override public String getEncodingId() { return getJAXBClass().getName(); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java index 1a00417cebad..6b59e525d965 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java @@ -17,12 +17,15 @@ */ package org.apache.beam.sdk.coders; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + import org.apache.beam.sdk.testing.CoderProperties; import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.SerializableUtils; import com.google.common.collect.ImmutableList; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -31,6 +34,11 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.xml.bind.annotation.XmlRootElement; @@ -91,7 +99,15 @@ public void testEncodeDecodeOuter() throws Exception { JAXBCoder coder = JAXBCoder.of(TestType.class); byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); - Assert.assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); + assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); + } + + @Test + public void testEncodeDecodeAfterClone() throws Exception { + JAXBCoder coder = SerializableUtils.clone(JAXBCoder.of(TestType.class)); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); + assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); } @Test @@ -100,10 +116,57 @@ public void testEncodeDecodeNested() throws Exception { TestCoder nesting = new TestCoder(jaxbCoder); byte[] encoded = CoderUtils.encodeToByteArray(nesting, new TestType("abc", 9999)); - Assert.assertEquals( + assertEquals( new TestType("abc", 9999), CoderUtils.decodeFromByteArray(nesting, encoded)); } + @Test + public void testEncodeDecodeMultithreaded() throws Throwable { + final JAXBCoder coder = JAXBCoder.of(TestType.class); + int numThreads = 1000; + + final CountDownLatch ready = new CountDownLatch(numThreads); + final CountDownLatch start = new CountDownLatch(1); + final CountDownLatch done = new CountDownLatch(numThreads); + + final AtomicReference thrown = new AtomicReference<>(); + + Executor executor = Executors.newCachedThreadPool(); + for (int i = 0; i < numThreads; i++) { + final TestType elem = new TestType("abc", i); + final int index = i; + executor.execute( + new Runnable() { + @Override + public void run() { + ready.countDown(); + try { + start.await(); + } catch (InterruptedException e) { + } + + try { + byte[] encoded = CoderUtils.encodeToByteArray(coder, elem); + assertEquals( + new TestType("abc", index), CoderUtils.decodeFromByteArray(coder, encoded)); + } catch (Throwable e) { + thrown.compareAndSet(null, e); + } + done.countDown(); + } + }); + } + ready.await(); + start.countDown(); + + if (!done.await(10L, TimeUnit.SECONDS)) { + fail("Should be able to clone " + numThreads + " elements in 10 seconds"); + } + if (thrown.get() != null) { + throw thrown.get(); + } + } + /** * A coder that surrounds the value with two values, to demonstrate nesting. */ From 7b1e78cff32743d91c6fb6b5fac1e4a71aff12f1 Mon Sep 17 00:00:00 2001 From: Thomas Groh Date: Mon, 20 Jun 2016 10:05:30 -0700 Subject: [PATCH 2/2] fixup! Make JAXBCoder Thread Safe --- .../src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java index 1c2f37b2bbff..f90eb54d53e4 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java @@ -70,6 +70,7 @@ public void encode(T value, OutputStream outStream, Context context) throws CoderException, IOException { try { JAXBContext jaxbContext = getContext(); + // TODO: Consider caching in a ThreadLocal if this impacts performance Marshaller jaxbMarshaller = jaxbContext.createMarshaller(); if (!context.isWholeStream) { try { @@ -93,6 +94,7 @@ public void encode(T value, OutputStream outStream, Context context) public T decode(InputStream inStream, Context context) throws CoderException, IOException { try { JAXBContext jaxbContext = getContext(); + // TODO: Consider caching in a ThreadLocal if this impacts performance Unmarshaller jaxbUnmarshaller = jaxbContext.createUnmarshaller(); InputStream stream = inStream;