From bd63678fae95ec218b7f2b758c02c5c208c9fbbf Mon Sep 17 00:00:00 2001 From: Sachin Goyal Date: Wed, 4 Mar 2015 16:15:57 -0800 Subject: [PATCH] AVRO-1568: Allow Java polymorphism in Avro for third-party code The fix simply adds two APIs to ReflectData: 1) setSchema (Class clazz, Schema s); 2) setSchema (Field field, Schema s); With these two APIs, clients can create UNION schemas for any class/field and set them accordingly. With the UNION schema, avro can easily handle derived objects' presence on base-class fields. --- .../org/apache/avro/reflect/ReflectData.java | 21 + .../reflect/TestPolymorphicSetSchema.java | 413 ++++++++++++++++++ 2 files changed, 434 insertions(+) create mode 100644 lang/java/avro/src/test/java/org/apache/avro/reflect/TestPolymorphicSetSchema.java diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java index e72b1e72d1a..66d6829ef87 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java @@ -60,6 +60,15 @@ /** Utilities to use existing Java classes and interfaces via reflection. */ public class ReflectData extends SpecificData { + + /** Lookup to store user specified schemas for classes */ + private static final Map, Schema> CLASS_SCHEMAS = + new HashMap, Schema>(); + + /** Lookup to store user specified schemas for fields */ + private static final Map FIELD_SCHEMAS = + new HashMap(); + /** {@link ReflectData} implementation that permits null field values. The * schema generated for each field is a union of its declared type and * null. */ @@ -117,6 +126,14 @@ public DatumWriter createDatumWriter(Schema schema) { return new ReflectDatumWriter(schema, this); } + public Schema setSchema (Class clazz, Schema s) { + return CLASS_SCHEMAS.put (clazz, s); + } + + public Schema setSchema (Field field, Schema s) { + return FIELD_SCHEMAS.put (field, s); + } + @Override public void setField(Object record, String name, int position, Object o) { setField(record, name, position, o, null); @@ -432,6 +449,8 @@ protected Schema createSchema(Type type, Map names) { AvroSchema explicit = c.getAnnotation(AvroSchema.class); if (explicit != null) // explicit schema return Schema.parse(explicit.value()); + if (CLASS_SCHEMAS.containsKey(c)) // Set explicitly by setSchema() + return CLASS_SCHEMAS.get(c); if (CharSequence.class.isAssignableFrom(c)) // String return Schema.create(Schema.Type.STRING); if (ByteBuffer.class.isAssignableFrom(c)) // bytes @@ -609,6 +628,8 @@ protected Schema createFieldSchema(Field field, Map names) { AvroSchema explicit = field.getAnnotation(AvroSchema.class); if (explicit != null) // explicit schema return Schema.parse(explicit.value()); + if (FIELD_SCHEMAS.containsKey(field)) // Set explicitly by setSchema() + return FIELD_SCHEMAS.get(field); Schema schema = createSchema(field.getGenericType(), names); if (field.isAnnotationPresent(Stringable.class)) { // Stringable diff --git a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestPolymorphicSetSchema.java b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestPolymorphicSetSchema.java new file mode 100644 index 00000000000..e8e3b98edd8 --- /dev/null +++ b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestPolymorphicSetSchema.java @@ -0,0 +1,413 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.avro.reflect; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +import java.lang.reflect.Field; + +import static org.junit.Assert.*; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileReader; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.file.SeekableByteArrayInput; +import org.apache.avro.generic.GenericArray; +import org.apache.avro.generic.GenericDatumReader; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.reflect.ReflectDatumReader; +import org.apache.avro.reflect.ReflectDatumWriter; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.util.Utf8; +import org.junit.Test; + +/** + * Test serialization and de-serialization for Java polymorphism + * For example, when a field of type base-class holds a derived-class object. + */ +public class TestPolymorphicSetSchema { + + private boolean applyFieldSchema = true; + + @Test + public void testFieldAndClassLevelSchema() throws Exception { + applyFieldSchema = true; + testPolymorphicFields("testFieldLevelSchema"); + applyFieldSchema = false; + testPolymorphicFields("testClassLevelSchema"); + } + + public void testPolymorphicFields(String testType) throws Exception { + + Showroom entityObj1 = buildShowroom(); + Showroom entityObj2 = buildShowroom(); + + Showroom [] entityObjs = {entityObj1, entityObj2}; + byte[] bytes = testSerialization(testType, entityObj1, entityObj2); + List records = + (List) testGenericDatumRead(testType, bytes, entityObjs); + + GenericRecord record = records.get(0); + Object trending = record.get("trending"); + assertTrue ("Unable to read 'trending' list", trending instanceof GenericArray); + + GenericArray arrayTrending = ((GenericArray)trending); + + // Test chair specific properties + Object chair = arrayTrending.get(0); + assertTrue (chair instanceof GenericRecord); + Object color = ((GenericRecord)chair).get("color"); + Object height = ((GenericRecord)chair).get("height"); + Object percentDiscount = ((GenericRecord)chair).get("percentDiscount"); + assertNotNull (color); + assertNull (height); + assertNull (percentDiscount); + + // Test table specific properties + Object table = arrayTrending.get(1); + assertTrue (table instanceof GenericRecord); + color = ((GenericRecord)table).get("color"); + height = ((GenericRecord)table).get("height"); + percentDiscount = ((GenericRecord)table).get("percentDiscount"); + assertNull (color); + assertNotNull (height); + assertNull (percentDiscount); + + // Test swivel-chair specific properties + Object swivelChair = arrayTrending.get(2); + assertTrue (swivelChair instanceof GenericRecord); + color = ((GenericRecord)swivelChair).get("color"); + height = ((GenericRecord)swivelChair).get("height"); + percentDiscount = ((GenericRecord)swivelChair).get("percentDiscount"); + assertNotNull (color); // because swivel is also a chair + assertNull (height); + assertNotNull (percentDiscount); + + // Test reflect data + List records2 = + (List) testReflectDatumRead(testType, bytes, entityObjs); + Showroom showroom = records2.get(0); + log ("Read: " + showroom); + List trendingItems = showroom.getTrending(); + assertNotNull (trendingItems); + assertEquals (3, trendingItems.size()); + + Item item1 = trendingItems.get(0); + Item item2 = trendingItems.get(1); + Item item3 = trendingItems.get(2); + assertTrue (item1 instanceof Chair); + assertTrue (item2 instanceof Table); + assertTrue (item3 instanceof SwivelChair); + + assertTrue (showroom.getMostSelling() instanceof Chair); + assertTrue (showroom.getLeastSelling() instanceof SwivelChair); + + // Test json encoder/decoder + byte[] jsonBytes = testJsonEncoder (testType, entityObj1); + assertNotNull ("Unable to serialize using jsonEncoder", jsonBytes); + GenericRecord jsonRecord = testJsonDecoder(testType, jsonBytes, entityObj1); + assertEquals ("JSON decoder output not same as Binary Decoder", record, jsonRecord); + } + + private ReflectData getReflectData () { + if (applyFieldSchema) + return getReflectDataWithFieldLevelSchema(); + else + return getReflectDataWithClassLevelSchema(); + } + + private ReflectData getReflectDataWithClassLevelSchema () { + ReflectData rdata = ReflectData.AllowNull.get(); + + // Get schemas for all hierarchies + Schema chairSchema = rdata.getSchema(Chair.class); + Schema tableSchema = rdata.getSchema(Table.class); + Schema swivelSchema = rdata.getSchema(SwivelChair.class); + + // Since the list can contain any type of derived classes, + // we create a union for all of the possible types here. + // And then we create an array of the union and make it nullable + List unionTypes = new ArrayList (2); + unionTypes.add(chairSchema); + unionTypes.add(tableSchema); + unionTypes.add(swivelSchema); + Schema unionSchema = Schema.createUnion(unionTypes); + + try { + rdata.setSchema (Item.class, unionSchema); + } catch (Exception e) { + throw new RuntimeException (e); + } + return rdata; + } + + private ReflectData getReflectDataWithFieldLevelSchema () { + ReflectData rdata = ReflectData.AllowNull.get(); + + // Get schemas for all hierarchies + Schema chairSchema = rdata.getSchema(Chair.class); + Schema tableSchema = rdata.getSchema(Table.class); + Schema swivelSchema = rdata.getSchema(SwivelChair.class); + + // Since the list can contain any type of derived classes, + // we create a union for all of the possible types here. + // And then we create an array of the union and make it nullable + List unionTypes = new ArrayList (2); + unionTypes.add(chairSchema); + unionTypes.add(tableSchema); + unionTypes.add(swivelSchema); + Schema unionSchema = Schema.createUnion(unionTypes); + Schema listSchema = Schema.createArray (unionSchema); + Schema nullableListSchema = rdata.makeNullable (listSchema); + + try { + // Get maverick fields + Field mostSelling = Showroom.class.getDeclaredField("mostSelling"); + Field leastSelling = Showroom.class.getDeclaredField("leastSelling"); + Field trending = Showroom.class.getDeclaredField("trending"); + + // Set the schema for each of the fields + rdata.setSchema (mostSelling, chairSchema); + rdata.setSchema (leastSelling, swivelSchema); + rdata.setSchema (trending, nullableListSchema); + } catch (Exception e) { + throw new RuntimeException (e); + } + return rdata; + } + + /** + * Test serialization of non-string map-key POJOs + */ + public byte[] testSerialization(String testType, T ... entityObjs) throws Exception { + + log ("---- Beginning " + testType + " ----"); + T entityObj1 = entityObjs[0]; + ReflectData rdata = getReflectData(); + + Schema schema = rdata.getSchema(entityObj1.getClass()); + assertNotNull("Unable to get schema for " + testType, schema); + log (schema.toString(true)); + + ReflectDatumWriter datumWriter = + new ReflectDatumWriter (entityObj1.getClass(), rdata); + DataFileWriter fileWriter = new DataFileWriter (datumWriter); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + fileWriter.create(schema, baos); + for (T entityObj : entityObjs) { + fileWriter.append(entityObj); + } + fileWriter.close(); + + byte[] bytes = baos.toByteArray(); + return bytes; + } + + /** + * Test that non-string map-keys are readable through GenericDatumReader + * This methoud should read as array of {key, value} and not as a map + */ + private List testGenericDatumRead + (String testType, byte[] bytes, T ... entityObjs) throws IOException { + + GenericDatumReader datumReader = + new GenericDatumReader (); + SeekableByteArrayInput avroInputStream = new SeekableByteArrayInput(bytes); + DataFileReader fileReader = + new DataFileReader(avroInputStream, datumReader); + + Schema schema = fileReader.getSchema(); + assertNotNull("Unable to get schema for " + testType, schema); + GenericRecord record = null; + List records = new ArrayList (); + while (fileReader.hasNext()) { + records.add (fileReader.next(record)); + } + return records; + } + + /** + * Test that non-string map-keys are readable through ReflectDatumReader + * This methoud should form the original map and should not return any + * array of {key, value} as done by {@link #testGenericDatumRead()} + */ + private List testReflectDatumRead + (String testType, byte[] bytes, T ... entityObjs) throws IOException { + + ReflectDatumReader datumReader = new ReflectDatumReader (); + SeekableByteArrayInput avroInputStream = new SeekableByteArrayInput(bytes); + DataFileReader fileReader = new DataFileReader(avroInputStream, datumReader); + + Schema schema = fileReader.getSchema(); + T record = null; + List records = new ArrayList (); + while (fileReader.hasNext()) { + records.add (fileReader.next(record)); + } + return records; + } + + private byte[] testJsonEncoder + (String testType, T entityObj) throws IOException { + + ReflectData rdata = getReflectData(); + + Schema schema = rdata.getSchema(entityObj.getClass()); + ByteArrayOutputStream os = new ByteArrayOutputStream(); + Encoder encoder = EncoderFactory.get().jsonEncoder(schema, os); + ReflectDatumWriter datumWriter = new ReflectDatumWriter(schema, rdata); + datumWriter.write(entityObj, encoder); + encoder.flush(); + + byte[] bytes = os.toByteArray(); + System.out.println ("JSON encoder output:\n" + new String(bytes)); + return bytes; + } + + private GenericRecord testJsonDecoder + (String testType, byte[] bytes, T entityObj) throws IOException { + + ReflectData rdata = getReflectData(); + + Schema schema = rdata.getSchema(entityObj.getClass()); + GenericDatumReader datumReader = + new GenericDatumReader(schema); + + Decoder decoder = DecoderFactory.get().jsonDecoder(schema, new String(bytes)); + GenericRecord r = datumReader.read(null, decoder); + return r; + } + + /** + * Create a POJO having polymorphic fields + */ + private Showroom buildShowroom () { + Showroom sr = new Showroom (); + + Item mostSelling = new Chair(); + mostSelling.setPrice (20); + Item leastSelling = new SwivelChair(); + leastSelling.setPrice (30); + + List trending = new ArrayList(); + trending.add (new Chair()); + trending.add (new Table()); + trending.add (new SwivelChair()); + + sr.setMostSelling(mostSelling); + sr.setLeastSelling(leastSelling); + sr.setTrending(trending); + + return sr; + } + + private void log (String msg) { + System.out.println (msg); + } + + private static class Item { + String id = "500"; + int price = 10; + public String getId() { + return id; + } + public void setId(String id) { + this.id = id; + } + public int getPrice() { + return price; + } + public void setPrice(int price) { + this.price = price; + } + } + + private static class Chair extends Item { + String color = "blue"; + + public String getColor() { + return color; + } + public void setColor(String color) { + this.color = color; + } + } + + private static class Table extends Item { + int height = 10; + boolean hasDrawers = true; + + public int getHeight() { + return height; + } + public void setHeight(int height) { + this.height = height; + } + public boolean isHasDrawers() { + return hasDrawers; + } + public void setHasDrawers(boolean hasDrawers) { + this.hasDrawers = hasDrawers; + } + } + + private static class SwivelChair extends Chair { + int percentDiscount = 10; + + public int getPercentDiscount() { + return percentDiscount; + } + public void setPercentDiscount(int percentDiscount) { + this.percentDiscount = percentDiscount; + } + } + + private static class Showroom { + Item mostSelling; + Item leastSelling; + List trending; + public Item getMostSelling() { + return mostSelling; + } + public void setMostSelling(Item mostSelling) { + this.mostSelling = mostSelling; + } + public Item getLeastSelling() { + return leastSelling; + } + public void setLeastSelling(Item leastSelling) { + this.leastSelling = leastSelling; + } + public List getTrending() { + return trending; + } + public void setTrending(List trending) { + this.trending = trending; + } + } +}