diff --git a/rust/benches/car_benchmark.rs b/rust/benches/car_benchmark.rs index c6425ad379..7ca9c2e5c8 100644 --- a/rust/benches/car_benchmark.rs +++ b/rust/benches/car_benchmark.rs @@ -28,7 +28,7 @@ fn encode(state: &mut State) -> SbeResult { let mut acceleration = AccelerationEncoder::default(); let mut extras = OptionalExtras::default(); - car = car.wrap(WriteBuf::new(buffer), message_header::ENCODED_LENGTH); + car = car.wrap(WriteBuf::new(buffer), message_header_codec::ENCODED_LENGTH); car = car.header(0).parent()?; car.code(Model::A); diff --git a/rust/benches/md_benchmark.rs b/rust/benches/md_benchmark.rs index 3b59709a06..b429622f92 100644 --- a/rust/benches/md_benchmark.rs +++ b/rust/benches/md_benchmark.rs @@ -22,7 +22,7 @@ fn encode_md(state: &mut State) -> SbeResult { let mut market_data = MarketDataIncrementalRefreshTradesEncoder::default(); let mut md_inc_grp = MdIncGrpEncoder::default(); - market_data = market_data.wrap(WriteBuf::new(buffer), message_header::ENCODED_LENGTH); + market_data = market_data.wrap(WriteBuf::new(buffer), message_header_codec::ENCODED_LENGTH); market_data = market_data.header(0).parent()?; market_data.transact_time(1234); diff --git a/rust/tests/baseline_test.rs b/rust/tests/baseline_test.rs index 8156b71e53..a19b39a96e 100644 --- a/rust/tests/baseline_test.rs +++ b/rust/tests/baseline_test.rs @@ -33,7 +33,7 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> SbeResult<()> { let buf = ReadBuf::new(buffer); let header = MessageHeaderDecoder::default().wrap(buf, 0); - assert_eq!(car::SBE_TEMPLATE_ID, header.template_id()); + assert_eq!(car_codec::SBE_TEMPLATE_ID, header.template_id()); car = car.header(header); // Car... @@ -167,7 +167,7 @@ fn encode_car_from_scratch() -> SbeResult<(usize, Vec)> { car = car.wrap( WriteBuf::new(buffer.as_mut_slice()), - message_header::ENCODED_LENGTH, + message_header_codec::ENCODED_LENGTH, ); car = car.header(0).parent()?; diff --git a/rust/tests/big_endian_test.rs b/rust/tests/big_endian_test.rs index ae677dce21..40b652c1be 100644 --- a/rust/tests/big_endian_test.rs +++ b/rust/tests/big_endian_test.rs @@ -30,7 +30,7 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> SbeResult<()> { let buf = ReadBuf::new(buffer); let header = MessageHeaderDecoder::default().wrap(buf, 0); - assert_eq!(car::SBE_TEMPLATE_ID, header.template_id()); + assert_eq!(car_codec::SBE_TEMPLATE_ID, header.template_id()); car = car.header(header); // Car... @@ -161,7 +161,7 @@ fn encode_car_from_scratch() -> SbeResult<(usize, Vec)> { car = car.wrap( WriteBuf::new(buffer.as_mut_slice()), - message_header::ENCODED_LENGTH, + message_header_codec::ENCODED_LENGTH, ); car = car.header(0).parent()?; diff --git a/rust/tests/extension_test.rs b/rust/tests/extension_test.rs index 56496d271c..5e993f7701 100644 --- a/rust/tests/extension_test.rs +++ b/rust/tests/extension_test.rs @@ -32,7 +32,7 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> SbeResult<()> { let buf = ReadBuf::new(buffer); let header = MessageHeaderDecoder::default().wrap(buf, 0); - assert_eq!(car::SBE_TEMPLATE_ID, header.template_id()); + assert_eq!(car_codec::SBE_TEMPLATE_ID, header.template_id()); car = car.header(header); // Car... @@ -168,7 +168,7 @@ fn encode_car_from_scratch() -> SbeResult<(usize, Vec)> { car = car.wrap( WriteBuf::new(buffer.as_mut_slice()), - message_header::ENCODED_LENGTH, + message_header_codec::ENCODED_LENGTH, ); car = car.header(0).parent()?; diff --git a/rust/tests/issue_435_test.rs b/rust/tests/issue_435_test.rs index e4e6c1e3c5..3c9b43b0da 100644 --- a/rust/tests/issue_435_test.rs +++ b/rust/tests/issue_435_test.rs @@ -3,7 +3,7 @@ use ::issue_435::*; fn create_encoder(buffer: &mut Vec) -> Issue435Encoder { let issue_435 = Issue435Encoder::default().wrap( WriteBuf::new(buffer.as_mut_slice()), - message_header::ENCODED_LENGTH, + message_header_codec::ENCODED_LENGTH, ); let mut header = issue_435.header(0); header.s(*SetRef::default().set_one(true)); @@ -12,9 +12,9 @@ fn create_encoder(buffer: &mut Vec) -> Issue435Encoder { #[test] fn issue_435_ref_test() -> SbeResult<()> { - assert_eq!(9, message_header::ENCODED_LENGTH); - assert_eq!(1, issue_435::SBE_BLOCK_LENGTH); - assert_eq!(0, issue_435::SBE_SCHEMA_VERSION); + assert_eq!(9, message_header_codec::ENCODED_LENGTH); + assert_eq!(1, issue_435_codec::SBE_BLOCK_LENGTH); + assert_eq!(0, issue_435_codec::SBE_SCHEMA_VERSION); // encode... let mut buffer = vec![0u8; 256]; @@ -24,10 +24,10 @@ fn issue_435_ref_test() -> SbeResult<()> { // decode... let buf = ReadBuf::new(buffer.as_slice()); let header = MessageHeaderDecoder::default().wrap(buf, 0); - assert_eq!(issue_435::SBE_BLOCK_LENGTH, header.block_length()); - assert_eq!(issue_435::SBE_SCHEMA_VERSION, header.version()); - assert_eq!(issue_435::SBE_TEMPLATE_ID, header.template_id()); - assert_eq!(issue_435::SBE_SCHEMA_ID, header.schema_id()); + assert_eq!(issue_435_codec::SBE_BLOCK_LENGTH, header.block_length()); + assert_eq!(issue_435_codec::SBE_SCHEMA_VERSION, header.version()); + assert_eq!(issue_435_codec::SBE_TEMPLATE_ID, header.template_id()); + assert_eq!(issue_435_codec::SBE_SCHEMA_ID, header.schema_id()); assert_eq!(*SetRef::default().set_one(true), header.s()); let decoder = Issue435Decoder::default().header(header); diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/SbeTool.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/SbeTool.java index 565aea2145..ba8fb99784 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/SbeTool.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/SbeTool.java @@ -164,7 +164,7 @@ public class SbeTool * Specifies token that should be appended to keywords to avoid compilation errors. *

* If none is supplied then use of keywords results in an error during schema parsing. The - * underscore character is a good example fo a token to use. + * underscore character is a good example of a token to use. */ public static final String KEYWORD_APPEND_TOKEN = "sbe.keyword.append.token"; diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/LibRsDef.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/LibRsDef.java index aab9f6a592..6799c64a27 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/LibRsDef.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/LibRsDef.java @@ -15,17 +15,12 @@ */ package uk.co.real_logic.sbe.generation.rust; -import org.agrona.generation.OutputManager; -import uk.co.real_logic.sbe.generation.rust.RustGenerator.CodecType; - import java.io.IOException; import java.io.Writer; import java.nio.ByteOrder; +import java.nio.file.Files; import java.util.ArrayList; -import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.LinkedHashSet; -import java.util.Map; import static java.nio.ByteOrder.LITTLE_ENDIAN; import static uk.co.real_logic.sbe.generation.rust.RustGenerator.*; @@ -36,11 +31,7 @@ */ class LibRsDef { - private final LinkedHashMap> modules = new LinkedHashMap<>(); - private final ArrayList enumDefs = new ArrayList<>(); - private final ArrayList bitSetDefs = new ArrayList<>(); - - private final OutputManager outputManager; + private final RustOutputManager outputManager; private final ByteOrder byteOrder; /** @@ -50,28 +41,13 @@ class LibRsDef * @param byteOrder for the Encoding. */ LibRsDef( - final OutputManager outputManager, + final RustOutputManager outputManager, final ByteOrder byteOrder) { this.outputManager = outputManager; this.byteOrder = byteOrder; } - void addMod(final String modName, final CodecType codecType) - { - modules.computeIfAbsent(modName, __ -> new HashSet<>()).add(codecType); - } - - void addEnum(final String enumDef) - { - enumDefs.add(enumDef); - } - - void addBitSet(final String bitSetDef) - { - bitSetDefs.add(bitSetDef); - } - void generate() throws IOException { try (Writer libRs = outputManager.createOutput("lib")) @@ -79,31 +55,28 @@ void generate() throws IOException indent(libRs, 0, "#![forbid(unsafe_code)]\n"); indent(libRs, 0, "#![allow(clippy::upper_case_acronyms)]\n"); indent(libRs, 0, "#![allow(non_camel_case_types)]\n"); - indent(libRs, 0, "use core::{convert::TryInto};\n\n"); + indent(libRs, 0, "use ::core::{convert::TryInto};\n\n"); + + final ArrayList modules = new ArrayList<>(); + Files.walk(outputManager.getSrcDirPath()) + .filter(Files::isRegularFile) + .map(path -> path.getFileName().toString()) + .filter(fileName -> fileName.endsWith(".rs")) + .filter(fileName -> !fileName.equals("lib.rs")) + .map(fileName -> fileName.substring(0, fileName.length() - 3)) + .forEach(modules::add); // add modules - for (final String mod : modules.keySet()) + for (final String mod : modules) { indent(libRs, 0, "pub mod %s;\n", toLowerSnakeCase(mod)); } indent(libRs, 0, "\n"); // add re-export of modules - for (final Map.Entry> entry : modules.entrySet()) + for (final String module : modules) { - final String mod = entry.getKey(); - final HashSet codecTypes = entry.getValue(); - - if (codecTypes.size() == 1) - { - indent(libRs, 0, "pub use %s::%s::*;\n", - toLowerSnakeCase(mod), - toLowerSnakeCase(codecTypes.toArray()[0].toString())); - } - else - { - indent(libRs, 0, "pub use %s::{decoder::*, encoder::*};\n", toLowerSnakeCase(mod)); - } + indent(libRs, 0, "pub use %s::*;\n", toLowerSnakeCase(module)); } indent(libRs, 0, "\n"); @@ -115,20 +88,6 @@ void generate() throws IOException generateReadBuf(libRs, byteOrder); generateWriteBuf(libRs, byteOrder); - - // append generated enums - for (final String code : enumDefs) - { - libRs.append(code); - indent(libRs, 0, "\n"); - } - - // append generated bitSets - for (final String code : bitSetDefs) - { - libRs.append(code); - indent(libRs, 0, "\n"); - } } } diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/MessageCoderDef.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/MessageCoderDef.java index 1a7f7a2315..e885735663 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/MessageCoderDef.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/MessageCoderDef.java @@ -147,7 +147,7 @@ void appendMessageHeaderDecoderFn(final Appendable out) throws IOException indent(out, 3, "let acting_version = header.version();\n\n"); indent(out, 3, "self.wrap(\n"); indent(out, 4, "header.parent().unwrap(),\n"); - indent(out, 4, "message_header::ENCODED_LENGTH,\n"); + indent(out, 4, "message_header_codec::ENCODED_LENGTH,\n"); indent(out, 4, "acting_block_length,\n"); indent(out, 4, "acting_version,\n"); indent(out, 3, ")\n"); diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java index a1f253b8ce..c58a38dc27 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java @@ -126,17 +126,14 @@ public void generate() throws IOException // lib.rs final LibRsDef libRsDef = new LibRsDef(outputManager, ir.byteOrder()); - generateEnums(ir, libRsDef); - generateBitSets(ir, libRsDef); - generateComposites(ir, libRsDef, outputManager); + generateEnums(ir, outputManager); + generateBitSets(ir, outputManager); + generateComposites(ir, outputManager); for (final List tokens : ir.messages()) { final Token msgToken = tokens.get(0); - - libRsDef.addMod(msgToken.name(), Encoder); - libRsDef.addMod(msgToken.name(), Decoder); - + final String codecModName = codecModName(msgToken.name()); final List messageBody = getMessageBody(tokens); int i = 0; @@ -149,9 +146,11 @@ public void generate() throws IOException final List varData = new ArrayList<>(); collectVarData(messageBody, i, varData); - try (Writer out = outputManager.createOutput(msgToken.name())) + try (Writer out = outputManager.createOutput(codecModName)) { indent(out, 0, "use crate::*;\n\n"); + indent(out, 0, "pub use encoder::*;\n"); + indent(out, 0, "pub use decoder::*;\n\n"); final String blockLengthType = blockLengthType(); final String templateIdType = rustTypeName(ir.headerStructure().templateIdType()); final String schemaIdType = rustTypeName(ir.headerStructure().schemaIdType()); @@ -981,32 +980,36 @@ static void generateDecoderVarData( private static void generateBitSets( final Ir ir, - final LibRsDef libRsDef) throws IOException + final RustOutputManager outputManager) throws IOException { for (final List tokens : ir.types()) { if (!tokens.isEmpty() && tokens.get(0).signal() == BEGIN_SET) { - final StringBuilder sb = new StringBuilder(); - generateSingleBitSet(tokens, sb); - libRsDef.addBitSet(sb.toString()); + final Token beginToken = tokens.get(0); + final String bitSetType = formatStructName(beginToken.applicableTypeName()); + + try (Writer out = outputManager.createOutput(bitSetType)) + { + generateSingleBitSet(bitSetType, tokens, out); + } } } } private static void generateSingleBitSet( + final String bitSetType, final List tokens, final Appendable writer) throws IOException { final Token beginToken = tokens.get(0); - final String setType = formatStructName(beginToken.applicableTypeName()); final String rustPrimitiveType = rustTypeName(beginToken.encoding().primitiveType()); indent(writer, 0, "#[derive(Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]\n"); - indent(writer, 0, "pub struct %s(pub %s);\n", setType, rustPrimitiveType); - indent(writer, 0, "impl %s {\n", setType); + indent(writer, 0, "pub struct %s(pub %s);\n", bitSetType, rustPrimitiveType); + indent(writer, 0, "impl %s {\n", bitSetType); indent(writer, 1, "#[inline]\n"); indent(writer, 1, "pub fn new(value: %s) -> Self {\n", rustPrimitiveType); - indent(writer, 2, "%s(value)\n", setType); + indent(writer, 2, "%s(value)\n", bitSetType); indent(writer, 1, "}\n\n"); indent(writer, 1, "#[inline]\n"); @@ -1044,10 +1047,10 @@ private static void generateSingleBitSet( } indent(writer, 0, "}\n"); - indent(writer, 0, "impl core::fmt::Debug for %s {\n", setType); + indent(writer, 0, "impl core::fmt::Debug for %s {\n", bitSetType); indent(writer, 1, "#[inline]\n"); indent(writer, 1, "fn fmt(&self, fmt: &mut core::fmt::Formatter) -> core::fmt::Result {\n"); - indent(writer, 2, "write!(fmt, \"%s[", setType); + indent(writer, 2, "write!(fmt, \"%s[", bitSetType); final StringBuilder builder = new StringBuilder(); final StringBuilder arguments = new StringBuilder(); @@ -1128,7 +1131,7 @@ static void appendImplDecoderTrait( private static void generateEnums( final Ir ir, - final LibRsDef libRsDef) throws IOException + final RustOutputManager outputManager) throws IOException { final Set enumTypeNames = new HashSet<>(); for (final List tokens : ir.types()) @@ -1145,16 +1148,15 @@ private static void generateEnums( } final String typeName = beginToken.applicableTypeName(); - if (enumTypeNames.contains(typeName)) + if (!enumTypeNames.add(typeName)) { continue; } - final StringBuilder sb = new StringBuilder(); - generateEnum(tokens, sb); - enumTypeNames.add(typeName); - - libRsDef.addEnum(sb.toString()); + try (Writer out = outputManager.createOutput(typeName)) + { + generateEnum(tokens, out); + } } } @@ -1219,39 +1221,41 @@ private static void generateEnum( private static void generateComposites( final Ir ir, - final LibRsDef libRsDef, - final OutputManager outputManager) throws IOException + final RustOutputManager outputManager) throws IOException { for (final List tokens : ir.types()) { if (!tokens.isEmpty() && tokens.get(0).signal() == Signal.BEGIN_COMPOSITE) { - generateComposite(tokens, libRsDef, outputManager); + generateComposite(tokens, outputManager); } } } private static void generateComposite( final List tokens, - final LibRsDef libRsDef, - final OutputManager outputManager) throws IOException + final RustOutputManager outputManager) throws IOException { final Token token = tokens.get(0); final String compositeName = token.applicableTypeName(); + final String compositeModName = codecModName(compositeName); - try (Writer out = outputManager.createOutput(compositeName)) + try (Writer out = outputManager.createOutput(compositeModName)) { indent(out, 0, "use crate::*;\n\n"); + indent(out, 0, "pub use encoder::*;\n"); + indent(out, 0, "pub use decoder::*;\n\n"); + final int encodedLength = tokens.get(0).encodedLength(); if (encodedLength > 0) { indent(out, 0, "pub const ENCODED_LENGTH: usize = %d;\n\n", encodedLength); } - generateCompositeEncoder(tokens, libRsDef, compositeName, encoderName(compositeName), out); + generateCompositeEncoder(tokens, encoderName(compositeName), out); indent(out, 0, "\n"); - generateCompositeDecoder(tokens, libRsDef, compositeName, decoderName(compositeName), out); + generateCompositeDecoder(tokens, decoderName(compositeName), out); } } @@ -1331,13 +1335,9 @@ static void appendImplDecoderForComposite( private static void generateCompositeEncoder( final List tokens, - final LibRsDef libRsDef, - final String compositeName, final String encoderName, final Writer out) throws IOException { - libRsDef.addMod(compositeName, Encoder); - indent(out, 0, "pub mod encoder {\n"); indent(out, 1, "use super::*;\n\n"); @@ -1395,16 +1395,12 @@ private static void generateCompositeEncoder( private static void generateCompositeDecoder( final List tokens, - final LibRsDef libRsDef, - final String compositeName, final String decoderName, final Writer out) throws IOException { indent(out, 0, "pub mod decoder {\n"); indent(out, 1, "use super::*;\n\n"); - libRsDef.addMod(compositeName, Decoder); - indent(out, 1, "#[derive(Debug, Default)]\n"); indent(out, 1, "pub struct %s

{\n", decoderName); indent(out, 2, "parent: Option

,\n"); diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustOutputManager.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustOutputManager.java index e1cd8d9941..5fa9c4bf93 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustOutputManager.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustOutputManager.java @@ -21,10 +21,11 @@ import java.io.File; import java.io.IOException; import java.io.Writer; -import java.nio.charset.StandardCharsets; import java.nio.file.Files; +import java.nio.file.Path; import static java.io.File.separatorChar; +import static java.nio.charset.StandardCharsets.UTF_8; import static uk.co.real_logic.sbe.generation.rust.RustUtil.toLowerSnakeCase; /** @@ -36,8 +37,19 @@ public class RustOutputManager implements OutputManager private final File rootDir; private final File srcDir; + static File createDir(final String dirName) + { + final File dir = new File(dirName); + if (!dir.exists() && !dir.mkdirs()) + { + throw new IllegalStateException("Unable to create directory: " + dirName); + } + return dir; + } + /** * Create a new {@link OutputManager} for generating rust source files into a given module. + * * @param baseDirName for the generated source code. * @param packageName for the generated source code relative to the baseDirName. */ @@ -52,12 +64,7 @@ public RustOutputManager(final String baseDirName, final String packageName) rootDir = new File(libDirName); final String srcDirName = libDirName + separatorChar + "src"; - - srcDir = new File(srcDirName); - if (!srcDir.exists() && !srcDir.mkdirs()) - { - throw new IllegalStateException("Unable to create directory: " + srcDirName); - } + srcDir = createDir(srcDirName); } /** @@ -70,15 +77,17 @@ public RustOutputManager(final String baseDirName, final String packageName) * @return a {@link java.io.Writer} to which the source code should be written. * @throws IOException if an issue occurs when creating the file. */ - @Override public Writer createOutput(final String name) throws IOException + @Override + public Writer createOutput(final String name) throws IOException { final String fileName = toLowerSnakeCase(name) + ".rs"; final File targetFile = new File(srcDir, fileName); - return Files.newBufferedWriter(targetFile.toPath(), StandardCharsets.UTF_8); + return Files.newBufferedWriter(targetFile.toPath(), UTF_8); } /** - * + * Creates a new Cargo.toml file + *

* @return a {@link java.io.Writer} to which the crate definition should be written. * @throws IOException if an issue occurs when creating the file. */ @@ -86,7 +95,12 @@ public Writer createCargoToml() throws IOException { final String fileName = "Cargo.toml"; final File targetFile = new File(rootDir, fileName); - return Files.newBufferedWriter(targetFile.toPath(), StandardCharsets.UTF_8); + return Files.newBufferedWriter(targetFile.toPath(), UTF_8); + } + + Path getSrcDirPath() + { + return srcDir.toPath(); } } diff --git a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustUtil.java b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustUtil.java index 9ee5bd204c..19aff8c97d 100644 --- a/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustUtil.java +++ b/sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustUtil.java @@ -122,6 +122,11 @@ static String formatStructName(final String structName) return Generators.toUpperFirstChar(structName); } + static String codecModName(final String prefix) + { + return toLowerSnakeCase(prefix + "Codec"); + } + static String codecName(final String structName, final CodecType codecType) { return formatStructName(structName + codecType.name());