Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Collection;
import java.util.Map;

Expand Down Expand Up @@ -92,8 +94,16 @@ protected Object newArray(Object old, int size, Schema schema) {
((Collection<?>) old).clear();
return old;
}

if (collectionClass.isAssignableFrom(ArrayList.class))
return new ArrayList<>();

if (collectionClass.isAssignableFrom(HashSet.class))
return new HashSet<>();

if (collectionClass.isAssignableFrom(HashMap.class))
return new HashMap<>();

return SpecificData.newInstance(collectionClass, schema);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@

package org.apache.avro.reflect;

import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Map;

import org.apache.avro.io.Decoder;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
import org.junit.Test;
import org.junit.jupiter.api.Test;

public class TestReflectDatumReader {

Expand Down Expand Up @@ -78,6 +82,49 @@ public void testRead_PojoWithArray() throws IOException {
assertEquals(pojoWithArray, deserialized);
}

@Test
public void testRead_PojoWithSet() throws IOException {
PojoWithSet pojoWithSet = new PojoWithSet();
pojoWithSet.setId(42);
Set<Integer> relatedIds = new HashSet<>();
relatedIds.add(1);
relatedIds.add(2);
relatedIds.add(3);
pojoWithSet.setRelatedIds(relatedIds);

byte[] serializedBytes = serializeWithReflectDatumWriter(pojoWithSet, PojoWithSet.class);

Decoder decoder = DecoderFactory.get().binaryDecoder(serializedBytes, null);
ReflectDatumReader<PojoWithSet> reflectDatumReader = new ReflectDatumReader<>(PojoWithSet.class);

PojoWithSet deserialized = new PojoWithSet();
reflectDatumReader.read(deserialized, decoder);

assertEquals(pojoWithSet, deserialized);

}

@Test
public void testRead_PojoWithMap() throws IOException {
PojoWithMap pojoWithMap = new PojoWithMap();
pojoWithMap.setId(42);
Map<Integer, Integer> relatedIds = new HashMap<>();
relatedIds.put(1, 11);
relatedIds.put(2, 22);
relatedIds.put(3, 33);
pojoWithMap.setRelatedIds(relatedIds);

byte[] serializedBytes = serializeWithReflectDatumWriter(pojoWithMap, PojoWithMap.class);

Decoder decoder = DecoderFactory.get().binaryDecoder(serializedBytes, null);
ReflectDatumReader<PojoWithMap> reflectDatumReader = new ReflectDatumReader<>(PojoWithMap.class);

PojoWithMap deserialized = new PojoWithMap();
reflectDatumReader.read(deserialized, decoder);

assertEquals(pojoWithMap, deserialized);
}

public static class PojoWithList {
private int id;
private List<Integer> relatedIds;
Expand Down Expand Up @@ -167,6 +214,99 @@ public boolean equals(Object obj) {
return false;
return Arrays.equals(relatedIds, other.relatedIds);
}
}

public static class PojoWithSet {
private int id;
private Set<Integer> relatedIds;

public int getId() {
return id;
}

public void setId(int id) {
this.id = id;
}

public Set<Integer> getRelatedIds() {
return relatedIds;
}

public void setRelatedIds(Set<Integer> relatedIds) {
this.relatedIds = relatedIds;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + id;
result = prime * result + ((relatedIds == null) ? 0 : relatedIds.hashCode());
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
PojoWithSet other = (PojoWithSet) obj;
if (id != other.id)
return false;
if (relatedIds == null) {
return other.relatedIds == null;
} else
return relatedIds.equals(other.relatedIds);
}
}

public static class PojoWithMap {
private int id;
private Map<Integer, Integer> relatedIds;

public int getId() {
return id;
}

public void setId(int id) {
this.id = id;
}

public Map<Integer, Integer> getRelatedIds() {
return relatedIds;
}

public void setRelatedIds(Map<Integer, Integer> relatedIds) {
this.relatedIds = relatedIds;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + id;
result = prime * result + ((relatedIds == null) ? 0 : relatedIds.hashCode());
return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
PojoWithMap other = (PojoWithMap) obj;
if (id != other.id)
return false;
if (relatedIds == null) {
return other.relatedIds == null;
} else
return relatedIds.equals(other.relatedIds);
}
}
}