Skip to content

Commit

Permalink
Closes ytsaurus#552
Browse files Browse the repository at this point in the history
  • Loading branch information
andrey-vasilyev committed May 11, 2024
1 parent 90172f7 commit 85f3b1c
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -36,22 +38,49 @@ private EntityTableSchemaCreator() {

public static <T> TableSchema create(Class<T> annotatedClass, @Nullable TableSchema schema) {
if (!anyOfAnnotationsPresent(annotatedClass, JavaPersistenceApi.entityAnnotations())) {
throw new IllegalArgumentException("Class must be annotated with @Entity");
throw new IllegalArgumentException(String.format("Class %s must be annotated with @Entity", annotatedClass.getName()));
}

TableSchema.Builder tableSchemaBuilder = TableSchema.builder();
StructType tableSchemaAsStructType = schema != null ?
TiTypeUtil.tableSchemaToStructTiType(schema).asStruct() : null;
for (Field field : getAllDeclaredFields(annotatedClass)) {
var fieldsChain = new LinkedHashMap<Class<?>, String>();
try {
processFieldsRecursively(annotatedClass, tableSchemaBuilder, tableSchemaAsStructType, fieldsChain);
} catch (InfiniteLoopException e) {
var loopChain = fieldsChain.entrySet().stream()
.map(it -> String.format("%s.%s", it.getKey().getName(), it.getValue()))
.collect(Collectors.joining("->"));
throw new IllegalArgumentException(String.format("Entity %s contains a loop in fields hierarchy: %s",
annotatedClass.getName(), loopChain));
}
return tableSchemaBuilder.build();
}

private static <T> void processFieldsRecursively(Class<T> clazz,
TableSchema.Builder tableSchemaBuilder,
StructType tableSchemaAsStructType,
Map<Class<?>, String> fieldsChain) {
if (fieldsChain.containsKey(clazz)) {
throw new InfiniteLoopException();
}
for (Field field : getAllDeclaredFields(clazz)) {
if (isFieldTransient(field, JavaPersistenceApi.transientAnnotations())) {
continue;
}
if (anyOfAnnotationsPresent(field.getType(), JavaPersistenceApi.embeddableAnnotations())) {
fieldsChain.put(clazz, field.getName());
processFieldsRecursively(field.getType(), tableSchemaBuilder, tableSchemaAsStructType, fieldsChain);
fieldsChain.remove(clazz);
continue;
} else if (anyOfAnnotationsPresent(field, JavaPersistenceApi.embeddedAnnotations())) {
throw new IllegalArgumentException(String.format("%s.%s field is annotated with @Embedded, but %s " +
"in not annotated with @Embeddable", clazz.getName(), field.getName(), field.getType().getName()));
}
tableSchemaBuilder.add(
getFieldColumnSchema(field, tableSchemaAsStructType)
);
}

return tableSchemaBuilder.build();
}

private static ColumnSchema getFieldColumnSchema(Field field, @Nullable StructType structTypeInSchema) {
Expand Down Expand Up @@ -173,13 +202,12 @@ private static <T> TiType getClassTiType(Class<T> clazz,
.orElse(null)
);
}
return tiTypeIfSimple.orElseGet(() -> getEntityTiType(
clazz,
Optional.ofNullable(tiTypeInSchema)
.filter(TiType::isStruct)
.map(TiType::asStruct)
.orElse(null)
)
return getEntityTiType(
clazz,
Optional.ofNullable(tiTypeInSchema)
.filter(TiType::isStruct)
.map(TiType::asStruct)
.orElse(null)
);
}

Expand Down Expand Up @@ -304,4 +332,8 @@ public static class PrecisionAndScaleNotSpecifiedException extends RuntimeExcept
public static class MismatchEntityAndTableSchemaDecimalException extends RuntimeException {

}

private static class InfiniteLoopException extends RuntimeException {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ class JavaPersistenceApi {
private static final String COLUMN_PRECISION = "precision";
private static final String COLUMN_SCALE = "scale";
private static final String COLUMN_DEFINITION = "columnDefinition";
private static final String EMBEDDABLE = "Embeddable";
private static final String EMBEDDED = "Embedded";
private static final Set<String> ENTITY_ANNOTATIONS = getAnnotationsFor(ENTITY);
private static final Set<String> TRANSIENT_ANNOTATIONS = getAnnotationsFor(TRANSIENT);
private static final Set<String> COLUMN_ANNOTATIONS = getAnnotationsFor(COLUMN);
private static final Set<String> EMBEDDABLE_ANNOTATIONS = getAnnotationsFor(EMBEDDABLE);
private static final Set<String> EMBEDDED_ANNOTATIONS = getAnnotationsFor(EMBEDDED);

private JavaPersistenceApi() {
}
Expand All @@ -42,6 +46,13 @@ static Set<String> columnAnnotations() {
return COLUMN_ANNOTATIONS;
}

static Set<String> embeddableAnnotations() {
return EMBEDDABLE_ANNOTATIONS;
}

static Set<String> embeddedAnnotations() {
return EMBEDDED_ANNOTATIONS;
}
static boolean isColumnAnnotationPresent(@Nullable Annotation annotation) {
return annotation != null &&
anyMatchWithAnnotation(annotation, JavaPersistenceApi.columnAnnotations());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package tech.ytsaurus.client.rows;

import org.junit.Test;
import tech.ytsaurus.skiff.SkiffSchema;
import tech.ytsaurus.skiff.WireType;

import javax.persistence.Column;
import javax.persistence.Embeddable;
import javax.persistence.Embedded;
import javax.persistence.Entity;
import java.time.ZonedDateTime;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;

public class SchemaOfEmbeddedEntityTest {

@Entity
static class Student {
@Column(nullable = false, name = "student-name")
private String name;
@Embedded
private University university;
}

@Embeddable
static class University {
@Column(nullable = false, name = "university-name")
private String name;
private Address address;
}

@Embeddable
static class Address {
@Column(nullable = false)
private String country;
@Column(nullable = false)
private String city;
private String street;
private transient String fullAddress;
}

@Test
public void testCreateSchema() {
var entitySchema = SchemaConverter.toSkiffSchema(
EntityTableSchemaCreator.create(Student.class, null)
);

SkiffSchema expectedSchema = SkiffSchema.tuple(
List.of(
SkiffSchema.simpleType(WireType.STRING_32).setName("student-name"),
SkiffSchema.simpleType(WireType.STRING_32).setName("university-name"),
SkiffSchema.simpleType(WireType.STRING_32).setName("country"),
SkiffSchema.simpleType(WireType.STRING_32).setName("city"),
SkiffSchema.variant8(List.of(
SkiffSchema.nothing(),
SkiffSchema.simpleType(WireType.STRING_32)
)).setName("street")
));

assertEquals(expectedSchema, entitySchema);
}

@Entity
static class PostV1 {
private long id;
private String title;
@Embedded
private PostDetailsV1 details;
}

static class PostDetailsV1 {
private long id;
private String text;
private ZonedDateTime createdAt;
private String createdBy;
}

@Test
public void testExceptionWhenEmbeddedFieldIsNotEmbeddable() {
assertThrows(IllegalArgumentException.class, () -> EntityTableSchemaCreator.create(PostV1.class, null));
}

@Entity
static class PostV2{
private long id;
private String title;
@Embedded
private PostDetailsV2 details;
}

@Embeddable
static class PostDetailsV2 {
private long id;
private Content content;
private ZonedDateTime createdAt;
private String createdBy;
}

@Embeddable
static class Content {
private String text;
private PostDetailsV2 details;
}

@Test
public void testExceptionWhenEntityHasEmbeddableLoop() {
assertThrows(IllegalArgumentException.class, () -> EntityTableSchemaCreator.create(PostV2.class, null));
}
}

0 comments on commit 85f3b1c

Please sign in to comment.