Skip to content

Commit

Permalink
feat(plc4py): Work on the templates
Browse files Browse the repository at this point in the history
  • Loading branch information
hutcheb committed May 21, 2023
1 parent a60e492 commit ad7fbdb
Show file tree
Hide file tree
Showing 54 changed files with 360 additions and 515 deletions.
Expand Up @@ -212,6 +212,11 @@ public String getLanguageTypeNameForTypeReference(TypeReference typeReference, S
}
}

public String getReservedValue(ReservedField reservedField) {
final String languageTypeName = getLanguageTypeNameForTypeReference(reservedField.getType());
return languageTypeName + "(" + reservedField.getReferenceValue() + ")";
}

public String getFieldOptions(TypedField field, List<Argument> parserArguments) {
StringBuilder sb = new StringBuilder();
final Optional<Term> encodingOptional = field.getEncoding();
Expand Down Expand Up @@ -276,36 +281,33 @@ public String getDataReaderCall(SimpleTypeReference simpleTypeReference) {
final int sizeInBits = simpleTypeReference.getSizeInBits();
switch (simpleTypeReference.getBaseType()) {
case BIT:
return "read_boolean(read_buffer)";
return "read_boolean";
case BYTE:
return "read_byte(read_buffer, " + sizeInBits + ")";
return "read_byte";
case UINT:
if (sizeInBits <= 4) return "read_unsigned_byte(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 8) return "read_unsigned_short(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 16) return "read_unsigned_int(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 32) return "read_unsigned_long(read_buffer, " + sizeInBits + ")";
return "read_unsigned_big_integer(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 4) return "read_unsigned_byte";
if (sizeInBits <= 8) return "read_unsigned_short";
if (sizeInBits <= 16) return "read_unsigned_int";
return "read_unsigned_long";
case INT:
if (sizeInBits <= 8) return "read_signed_byte(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 16) return "read_signed_short(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 32) return "read_signed_int(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 64) return "read_signed_long(read_buffer, " + sizeInBits + ")";
return "read_signed_big_integer(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 8) return "read_signed_byte";
if (sizeInBits <= 16) return "read_signed_short";
if (sizeInBits <= 32) return "read_signed_int";
return "read_signed_long";
case FLOAT:
if (sizeInBits <= 32) return "read_float(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 64) return "read_double(read_buffer, " + sizeInBits + ")";
return "read_big_decimal(read_buffer, " + sizeInBits + ")";
if (sizeInBits <= 32) return "read_float";
return "read_double";
case STRING:
return "read_string(read_buffer, " + sizeInBits + ")";
return "read_string";
case VSTRING:
VstringTypeReference vstringTypeReference = (VstringTypeReference) simpleTypeReference;
return "read_string(read_buffer, " + toParseExpression(null, INT_TYPE_REFERENCE, vstringTypeReference.getLengthExpression(), null) + ")";
return "read_string";
case TIME:
return "read_time(read_buffer)";
return "read_time";
case DATE:
return "read_date(read_buffer)";
return "read_date";
case DATETIME:
return "read_date_time(read_buffer)";
return "read_date_time";
default:
throw new UnsupportedOperationException("Unsupported type " + simpleTypeReference.getBaseType());
}
Expand Down Expand Up @@ -674,31 +676,6 @@ public String getWriteBufferWriteMethodCall(String logicalName, SimpleTypeRefere
}
}

public String getReservedValue(ReservedField reservedField) {
final String languageTypeName = getLanguageTypeNameForTypeReference(reservedField.getType());
switch (languageTypeName) {
case "*big.Int":
emitRequiredImport("math/big");
return "big.NewInt(" + reservedField.getReferenceValue() + ")";
case "*big.Float":
emitRequiredImport("math/big");
return "*big.Float(" + reservedField.getReferenceValue() + ")";
default:
return languageTypeName + "(" + reservedField.getReferenceValue() + ")";
}
}

public String toTypeSafeCompare(ReservedField reservedField) {
final String languageTypeName = getLanguageTypeNameForTypeReference(reservedField.getType());
switch (languageTypeName) {
case "*big.Int":
case "*big.Float":
emitRequiredImport("math/big");
return "reserved.Cmp(" + getReservedValue(reservedField) + ") != 0";
default:
return "reserved != " + getReservedValue(reservedField);
}
}

public String toParseExpression(Field field, TypeReference resultType, Term term, List<Argument> parserArguments) {
Tracer tracer = Tracer.start("toParseExpression");
Expand Down Expand Up @@ -815,8 +792,8 @@ private String toBinaryTermExpression(Field field, TypeReference fieldType, Bina
switch (operation) {
case "^":
tracer = tracer.dive("^");
emitRequiredImport("math");
return tracer + "Math.pow(" +
emitRequiredImport("from math import pow");
return tracer + "pow(" +
castExpressionForTypeReference + "(" + toExpression(field, fieldType, a, parserArguments, serializerArguments, serialize, false) + "), " +
castExpressionForTypeReference + "(" + toExpression(field, fieldType, b, parserArguments, serializerArguments, serialize, false) + "))";
// If we start casting for comparisons, equals or non equals, really messy things happen.
Expand Down Expand Up @@ -853,9 +830,9 @@ private String toBinaryTermExpression(Field field, TypeReference fieldType, Bina
toExpression(field, fieldType, b, parserArguments, serializerArguments, serialize, false);
}
return tracer +
castExpressionForTypeReference + "(" + toExpression(field, fieldType, a, parserArguments, serializerArguments, serialize, false) + ") " +
toExpression(field, fieldType, a, parserArguments, serializerArguments, serialize, false) +
operation + " " +
castExpressionForTypeReference + "(" + toExpression(field, fieldType, b, parserArguments, serializerArguments, serialize, false) + ")";
toExpression(field, fieldType, b, parserArguments, serializerArguments, serialize, false);
}
}

Expand Down Expand Up @@ -884,7 +861,8 @@ private String toLiteralTermExpression(Field field, TypeReference fieldType, Ter
return tracer + "None";
} else if (term instanceof BooleanLiteral) {
tracer = tracer.dive("boolean literal instanceOf");
return tracer + getCastExpressionForTypeReference(fieldType) + "(" + ((BooleanLiteral) term).getValue() + ")";
String bool = Boolean.toString(((BooleanLiteral) term).getValue());
return tracer + bool.substring(0,1).toUpperCase() + bool.substring(1);
} else if (term instanceof NumericLiteral) {
tracer = tracer.dive("numeric literal instanceOf");
if (getCastExpressionForTypeReference(fieldType).equals("string")) {
Expand Down Expand Up @@ -1130,8 +1108,8 @@ private String toCeilVariableExpression(Field field, VariableLiteral variableLit
.stream().findFirst().orElseThrow(IllegalStateException::new);
// The Ceil function expects 64 bit floating point values.
TypeReference tr = new DefaultFloatTypeReference(SimpleTypeReference.SimpleBaseType.FLOAT, 64);
emitRequiredImport("math");
return tracer + "math.Ceil(" + toExpression(field, tr, va, parserArguments, serializerArguments, serialize, suppressPointerAccess) + ")";
emitRequiredImport("from math import ceil");
return tracer + "ceil(" + toExpression(field, tr, va, parserArguments, serializerArguments, serialize, suppressPointerAccess) + ")";
}

private String toArraySizeInBytesVariableExpression(Field field, TypeReference typeReference, VariableLiteral variableLiteral, List<Argument> parserArguments, List<Argument> serializerArguments, boolean suppressPointerAccess, Tracer tracer) {
Expand All @@ -1144,7 +1122,7 @@ private String toArraySizeInBytesVariableExpression(Field field, TypeReference t
.asVariableLiteral()
.orElseThrow(() -> new RuntimeException("ARRAY_SIZE_IN_BYTES needs a variable literal"));
// "io" and "m" are always available in every parser.
boolean isSerializerArg = "readBuffer".equals(va.getName()) || "writeBuffer".equals(va.getName()) || "m".equals(va.getName()) || "element".equals(va.getName());
boolean isSerializerArg = "read_buffer".equals(va.getName()) || "write_buffer".equals(va.getName()) || "self".equals(va.getName()) || "element".equals(va.getName());
if (!isSerializerArg && serializerArguments != null) {
for (Argument serializerArgument : serializerArguments) {
if (serializerArgument.getName().equals(va.getName())) {
Expand All @@ -1159,7 +1137,8 @@ private String toArraySizeInBytesVariableExpression(Field field, TypeReference t
} else {
sb.append(toVariableExpression(field, typeReference, va, parserArguments, serializerArguments, true, suppressPointerAccess));
}
return tracer + getCastExpressionForTypeReference(typeReference) + "(" + va.getName() + "ArraySizeInBytes(" + sb + "))";
emitRequiredImport("from sys import getsizeof");
return tracer + getCastExpressionForTypeReference(typeReference) + "(getsizeof(" + sb + "))";
}

private String toCountVariableExpression(Field field, TypeReference typeReference, VariableLiteral variableLiteral, List<Argument> parserArguments, List<Argument> serializerArguments, boolean serialize, boolean suppressPointerAccess, Tracer tracer) {
Expand Down Expand Up @@ -1764,26 +1743,6 @@ public String capitalize(String str) {
return extractedTrace + StringUtils.capitalize(cleanedString);
}

public String getEndiannessOptions(boolean read, boolean separatorPrefix) {
return getEndiannessOptions(read, separatorPrefix, Collections.emptyList());
}

public String getEndiannessOptions(boolean read, boolean separatorPrefix, List<Argument> parserArguments) {
Optional<Term> byteOrder = thisType.getAttribute("byteOrder");
if (byteOrder.isPresent()) {
emitRequiredImport("encoding/binary");
if(read) {
return (separatorPrefix ? ", " : "") + "utils.WithByteOrderForReadBufferByteBased(" +
toParseExpression(null, new DefaultByteOrderTypeReference(), byteOrder.orElseThrow(), parserArguments) +
")";
} else {
return (separatorPrefix ? ", " : "") + "utils.WithByteOrderForByteBasedBuffer(" +
toSerializationExpression(null, new DefaultByteOrderTypeReference(), byteOrder.orElseThrow(), parserArguments) +
")";
}
}
return "";
}

/**
* Converts a camel-case string to snake-case.
Expand Down
Expand Up @@ -220,7 +220,7 @@ class ${type.name}<#if type.isDiscriminatedParentTypeDefinition()></#if>(<#if ty

# Array Field (${arrayField.name})
<#if arrayField.type.elementTypeReference.isByteBased()>
write_buffer.write_byte_array(self.${helper.camelCaseToSnakeCase(namedField.name)}, 8, logical_name="${namedField.name}")
write_buffer.write_byte_array(self.${helper.camelCaseToSnakeCase(namedField.name)}, logical_name="${namedField.name}")
<#elseif arrayField.type.elementTypeReference.isSimpleTypeReference()>
write_buffer.write_simple_array(self.${helper.camelCaseToSnakeCase(namedField.name)}, ${helper.getDataWriterCall(arrayField.type.elementTypeReference, namedField.name)}, logical_name="${namedField.name}")
<#else>
Expand Down Expand Up @@ -369,12 +369,12 @@ class ${type.name}<#if type.isDiscriminatedParentTypeDefinition()></#if>(<#if ty
if self.${helper.camelCaseToSnakeCase(arrayField.name)} is not None:
<#if arrayElementTypeReference.isSimpleTypeReference()>
<#assign simpleTypeReference = arrayElementTypeReference.asSimpleTypeReference().orElseThrow()>
length_in_bits += ${simpleTypeReference.sizeInBits} * self.${helper.camelCaseToSnakeCase(arrayField.name)}.<#if arrayElementTypeReference.isByteBased()>length<#else>size()</#if>
length_in_bits += ${simpleTypeReference.sizeInBits} * len(self.${helper.camelCaseToSnakeCase(arrayField.name)})
<#elseif arrayField.isCountArrayField()>
i: int = 0
<#assign nonSimpleTypeReference = arrayElementTypeReference.asNonSimpleTypeReference().orElseThrow()>
for element in self.${helper.camelCaseToSnakeCase(arrayField.name)}:
last: bool = ++i >= self.${helper.camelCaseToSnakeCase(arrayField.name)}.size()
last: bool = ++i >= len(self.${helper.camelCaseToSnakeCase(arrayField.name)})
length_in_bits += element.get_length_in_bits()

<#else>
Expand Down
6 changes: 3 additions & 3 deletions sandbox/plc4py/plc4py/protocols/modbus/readwrite/DataItem.py
Expand Up @@ -99,15 +99,15 @@ def static_parse(
if EvaluationHelper.equals(data_type, ModbusDataType.get_byte()): # List
# Array field (value)
# Count array
if c_int32(numberOfValues) * c_int32(c_int32(8)) > Integer.MAX_VALUE:
if numberOfValues * c_int32(8) > Integer.MAX_VALUE:
raise ParseException(
"Array count of "
+ (c_int32(numberOfValues) * c_int32(c_int32(8)))
+ (numberOfValues * c_int32(8))
+ " exceeds the maximum allowed count of "
+ Integer.MAX_VALUE
)

item_count: int = int(c_int32(numberOfValues) * c_int32(c_int32(8)))
item_count: int = int(numberOfValues * c_int32(8))
value: List[PlcValue] = []
for cur_item in range(item_count):
value.append(PlcBOOL(c_bool(read_buffer.readBit(""))))
Expand Down
Expand Up @@ -88,9 +88,7 @@ def static_parse_builder(
cur_pos: int = 0

address: c_uint8 = read_simple_field(
"address",
read_unsigned_short(read_buffer, 8),
WithOption.WithByteOrder(get_bi_g__endian()),
"address", read_unsigned_short, WithOption.WithByteOrder(get_bi_g__endian())
)

pdu: ModbusPDU = read_simple_field(
Expand All @@ -103,7 +101,7 @@ def static_parse_builder(

crc: c_uint8 = read_checksum_field(
"crc",
read_unsigned_short(read_buffer, 8),
read_unsigned_short,
(c_uint8)(ascii_lrc_check(address, pdu)),
WithOption.WithByteOrder(get_bi_g__endian()),
)
Expand Down
Expand Up @@ -66,7 +66,7 @@ def static_parse_context(read_buffer: ReadBuffer):

modbus_tcp_default_port: c_uint16 = read_const_field(
"modbusTcpDefaultPort",
read_unsigned_int(read_buffer, 16),
read_unsigned_int,
ModbusConstants.MODBUSTCPDEFAULTPORT,
)

Expand Down
Expand Up @@ -47,7 +47,7 @@ def serialize(self, write_buffer: WriteBuffer):
write_buffer.write_unsigned_byte(object_length, logical_name="objectLength")

# Array Field (data)
write_buffer.write_byte_array(self.data, 8, logical_name="data")
write_buffer.write_byte_array(self.data, logical_name="data")

write_buffer.pop_context("ModbusDeviceInformationObject")

Expand All @@ -66,7 +66,7 @@ def get_length_in_bits(self) -> int:

# Array field
if self.data is not None:
length_in_bits += 8 * self.data.length
length_in_bits += 8 * len(self.data)

return length_in_bits

Expand All @@ -79,12 +79,10 @@ def static_parse_context(read_buffer: ReadBuffer):
start_pos: int = read_buffer.get_pos()
cur_pos: int = 0

object_id: c_uint8 = read_simple_field(
"objectId", read_unsigned_short(read_buffer, 8)
)
object_id: c_uint8 = read_simple_field("objectId", read_unsigned_short)

object_length: c_uint8 = read_implicit_field(
"objectLength", read_unsigned_short(read_buffer, 8)
"objectLength", read_unsigned_short
)

data: List[c_byte] = read_buffer.read_byte_array("data", int(objectLength))
Expand Down

0 comments on commit ad7fbdb

Please sign in to comment.