Skip to content

Commit

Permalink
Packet direction restrictions & array reading safety checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Alemiz112 committed Apr 23, 2024
2 parents 46b4ad3 + 571a7aa commit 82a171a
Show file tree
Hide file tree
Showing 73 changed files with 652 additions and 459 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.cloudburstmc.nbt.NBTOutputStream;
import org.cloudburstmc.nbt.NbtType;
import org.cloudburstmc.nbt.NbtUtils;
import org.cloudburstmc.protocol.bedrock.data.EncodingSettings;
import org.cloudburstmc.protocol.bedrock.data.ExperimentData;
import org.cloudburstmc.protocol.bedrock.data.PlayerAbilityHolder;
import org.cloudburstmc.protocol.bedrock.data.definitions.BlockDefinition;
Expand All @@ -34,7 +35,6 @@
import org.cloudburstmc.protocol.bedrock.data.skin.SerializedSkin;
import org.cloudburstmc.protocol.bedrock.data.structure.StructureSettings;
import org.cloudburstmc.protocol.bedrock.packet.InventoryTransactionPacket;
import org.cloudburstmc.protocol.common.Definition;
import org.cloudburstmc.protocol.common.DefinitionRegistry;
import org.cloudburstmc.protocol.common.NamedDefinition;
import org.cloudburstmc.protocol.common.util.TriConsumer;
Expand Down Expand Up @@ -66,14 +66,24 @@ public abstract class BaseBedrockCodecHelper implements BedrockCodecHelper {
@Setter
protected DefinitionRegistry<BlockDefinition> blockDefinitions;

@Getter
@Setter
protected EncodingSettings encodingSettings = EncodingSettings.DEFAULT;

protected static boolean isAir(ItemDefinition definition) {
return definition == null || "minecraft:air".equals(definition.getIdentifier());
}

@Override
public byte[] readByteArray(ByteBuf buffer) {
return this.readByteArray(buffer, this.encodingSettings.maxByteArraySize());
}

public byte[] readByteArray(ByteBuf buffer, int maxLength) {
int length = VarInts.readUnsignedInt(buffer);
checkArgument(buffer.isReadable(length),
"Tried to read %s bytes but only has %s readable", length, buffer.readableBytes());
checkArgument(maxLength <= 0 || length <= maxLength, "Tried to read %s bytes but maximum is %s", length, maxLength);
byte[] bytes = new byte[length];
buffer.readBytes(bytes);
return bytes;
Expand All @@ -100,6 +110,8 @@ public void writeByteBuf(ByteBuf buffer, ByteBuf toWrite) {

public String readString(ByteBuf buffer) {
int length = VarInts.readUnsignedInt(buffer);
checkArgument(this.encodingSettings.maxStringLength() <= 0 || length <= this.encodingSettings.maxStringLength(),
"Tried to read %s bytes but maximum is %s", length, this.encodingSettings.maxStringLength());
return (String) buffer.readCharSequence(length, StandardCharsets.UTF_8);
}

Expand Down Expand Up @@ -191,9 +203,20 @@ public void writeBlockPosition(ByteBuf buffer, Vector3i blockPosition) {
*/

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader,
BiFunction<ByteBuf, BedrockCodecHelper, T> function) {
public <T> void readArray(ByteBuf buffer, Collection<T> array, BiFunction<ByteBuf, BedrockCodecHelper, T> function) {
this.readArray(buffer, array, function, this.encodingSettings.maxListSize());
}

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, BiFunction<ByteBuf, BedrockCodecHelper, T> function) {
this.readArray(buffer, array, lengthReader, function, this.encodingSettings.maxListSize());
}

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, BiFunction<ByteBuf, BedrockCodecHelper, T> function, int maxLength) {
long length = lengthReader.applyAsLong(buffer);
checkArgument(maxLength <= 0 || length <= maxLength, "Tried to read %s bytes but maximum is %s", length, maxLength);

for (int i = 0; i < length; i++) {
array.add(function.apply(buffer, this));
}
Expand All @@ -209,8 +232,13 @@ public <T> void writeArray(ByteBuf buffer, Collection<T> array, ObjIntConsumer<B

@Override
public <T> T[] readArray(ByteBuf buffer, T[] array, BiFunction<ByteBuf, BedrockCodecHelper, T> function) {
return this.readArray(buffer, array, function, this.encodingSettings.maxListSize());
}

@Override
public <T> T[] readArray(ByteBuf buffer, T[] array, BiFunction<ByteBuf, BedrockCodecHelper, T> function, int maxLength) {
ObjectArrayList<T> list = new ObjectArrayList<>();
readArray(buffer, list, function);
readArray(buffer, list, function, maxLength);
return list.toArray(array);
}

Expand All @@ -227,24 +255,51 @@ public <T> void writeArray(ByteBuf buffer, T[] array, TriConsumer<ByteBuf, Bedro

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, Function<ByteBuf, T> function) {
int length = VarInts.readUnsignedInt(buffer);
this.readArray(buffer, array, function, this.encodingSettings.maxListSize());
}

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, Function<ByteBuf, T> function, int maxLength) {
this.readArray(buffer, array, VarInts::readUnsignedInt, function, maxLength);
}

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, Function<ByteBuf, T> function) {
this.readArray(buffer, array, lengthReader, function, this.encodingSettings.maxListSize());
}

@Override
public <T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, Function<ByteBuf, T> function, int maxLength) {
long length = lengthReader.applyAsLong(buffer);
checkArgument(maxLength <= 0 || length <= maxLength, "Tried to read %s bytes but maximum is %s", length, maxLength);

for (int i = 0; i < length; i++) {
array.add(function.apply(buffer));
}
}

@Override
public <T> void writeArray(ByteBuf buffer, Collection<T> array, BiConsumer<ByteBuf, T> biConsumer) {
VarInts.writeUnsignedInt(buffer, array.size());
this.writeArray(buffer, array, VarInts::writeUnsignedInt, biConsumer);
}

@Override
public <T> void writeArray(ByteBuf buffer, Collection<T> array, ObjIntConsumer<ByteBuf> lengthWriter, BiConsumer<ByteBuf, T> consumer) {
lengthWriter.accept(buffer, array.size());
for (T val : array) {
biConsumer.accept(buffer, val);
consumer.accept(buffer, val);
}
}

@Override
public <T> T[] readArray(ByteBuf buffer, T[] array, Function<ByteBuf, T> function) {
return this.readArray(buffer, array, function, this.encodingSettings.maxListSize());
}

@Override
public <T> T[] readArray(ByteBuf buffer, T[] array, Function<ByteBuf, T> function, int maxLength) {
ObjectArrayList<T> list = new ObjectArrayList<>();
readArray(buffer, list, function);
readArray(buffer, list, function, maxLength);
return list.toArray(array);
}

Expand All @@ -256,10 +311,15 @@ public <T> void writeArray(ByteBuf buffer, T[] array, BiConsumer<ByteBuf, T> biC
}
}

@SuppressWarnings("unchecked")
@Override
public <T> T readTag(ByteBuf buffer, Class<T> expected) {
try (NBTInputStream reader = NbtUtils.createNetworkReader(new ByteBufInputStream(buffer))) {
return this.readTag(buffer, expected, this.encodingSettings.maxNetworkNBTSize());
}

@SuppressWarnings("unchecked")
@Override
public <T> T readTag(ByteBuf buffer, Class<T> expected, long maxReadSize) {
try (NBTInputStream reader = NbtUtils.createNetworkReader(new ByteBufInputStream(buffer), maxReadSize)) {
Object tag = reader.readTag();
checkArgument(expected.isInstance(tag), "Expected tag of %s type but received %s",
expected, tag.getClass());
Expand All @@ -278,10 +338,15 @@ public void writeTag(ByteBuf buffer, Object tag) {
}
}

@SuppressWarnings("unchecked")
@Override
public <T> T readTagLE(ByteBuf buffer, Class<T> expected) {
try (NBTInputStream reader = NbtUtils.createReaderLE(new ByteBufInputStream(buffer))) {
return this.readTagLE(buffer, expected, this.encodingSettings.maxNetworkNBTSize());
}

@SuppressWarnings("unchecked")
@Override
public <T> T readTagLE(ByteBuf buffer, Class<T> expected, long maxReadSize) {
try (NBTInputStream reader = NbtUtils.createReaderLE(new ByteBufInputStream(buffer), maxReadSize)) {
Object tag = reader.readTag();
checkArgument(expected.isInstance(tag), "Expected tag of %s type but received %s",
expected, tag.getClass());
Expand All @@ -301,7 +366,12 @@ public void writeTagLE(ByteBuf buffer, Object tag) {

@Override
public <T> T readTagValue(ByteBuf buffer, NbtType<T> type) {
try (NBTInputStream reader = NbtUtils.createNetworkReader(new ByteBufInputStream(buffer))) {
return this.readTagValue(buffer, type, this.encodingSettings.maxNetworkNBTSize());
}

@Override
public <T> T readTagValue(ByteBuf buffer, NbtType<T> type, long maxReadSize) {
try (NBTInputStream reader = NbtUtils.createNetworkReader(new ByteBufInputStream(buffer), maxReadSize)) {
return reader.readValue(type);
} catch (IOException e) {
throw new RuntimeException(e);
Expand Down Expand Up @@ -348,7 +418,7 @@ public boolean readInventoryActions(ByteBuf buffer, List<InventoryActionData> ac
ItemData toItem = helper.readItem(buf);

return new InventoryActionData(source, slot, fromItem, toItem);
});
}, 64); // 64 should be enough
return false;
}

Expand Down Expand Up @@ -448,6 +518,10 @@ protected void writeAnimationData(ByteBuf buffer, AnimationData animation) {
}

protected ImageData readImage(ByteBuf buffer) {
return this.readImage(buffer, ImageData.SKIN_PERSONA_SIZE);
}

protected ImageData readImage(ByteBuf buffer, int maxSize) {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.RequiredArgsConstructor;
import org.checkerframework.checker.index.qual.NonNegative;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.cloudburstmc.protocol.bedrock.data.PacketRecipient;
import org.cloudburstmc.protocol.bedrock.packet.BedrockPacket;
import org.cloudburstmc.protocol.bedrock.packet.UnknownPacket;

Expand Down Expand Up @@ -38,9 +39,19 @@ public static Builder builder() {
return new Builder();
}

@SuppressWarnings({"unchecked", "rawtypes"})
public BedrockPacket tryDecode(BedrockCodecHelper helper, ByteBuf buf, int id) throws PacketSerializeException {
return tryDecode(helper, buf, id, null);
}

@SuppressWarnings({"unchecked", "rawtypes"})
public BedrockPacket tryDecode(BedrockCodecHelper helper, ByteBuf buf, int id, PacketRecipient recipient) throws PacketSerializeException {
BedrockPacketDefinition<? extends BedrockPacket> definition = getPacketDefinition(id);

if (definition != null && recipient != null && definition.getRecipient() != PacketRecipient.BOTH &&
definition.getRecipient() != recipient) {
throw new IllegalArgumentException("Packet " + definition.getFactory().get().getClass().getSimpleName() + " was sent to " + recipient + " instead of " + definition.getRecipient());
}

BedrockPacket packet;
BedrockPacketSerializer<BedrockPacket> serializer;
if (definition == null) {
Expand Down Expand Up @@ -119,13 +130,13 @@ public static class Builder {
private String minecraftVersion = null;
private Supplier<BedrockCodecHelper> helperFactory;

public <T extends BedrockPacket> Builder registerPacket(Supplier<T> factory, BedrockPacketSerializer<T> serializer, @NonNegative int id) {
public <T extends BedrockPacket> Builder registerPacket(Supplier<T> factory, BedrockPacketSerializer<T> serializer, @NonNegative int id, PacketRecipient recipient) {
Class<? extends BedrockPacket> packetClass = factory.get().getClass();

checkArgument(id >= 0, "id cannot be negative");
checkArgument(!packets.containsKey(packetClass), "Packet class already registered");

BedrockPacketDefinition<T> info = new BedrockPacketDefinition<>(id, factory, serializer);
BedrockPacketDefinition<T> info = new BedrockPacketDefinition<>(id, factory, serializer, recipient);

packets.put(packetClass, info);

Expand All @@ -135,7 +146,7 @@ public <T extends BedrockPacket> Builder registerPacket(Supplier<T> factory, Bed
public <T extends BedrockPacket> Builder updateSerializer(Class<T> packetClass, BedrockPacketSerializer<T> serializer) {
BedrockPacketDefinition<T> info = (BedrockPacketDefinition<T>) packets.get(packetClass);
checkArgument(info != null, "Packet does not exist");
BedrockPacketDefinition<T> updatedInfo = new BedrockPacketDefinition<>(info.getId(), info.getFactory(), serializer);
BedrockPacketDefinition<T> updatedInfo = new BedrockPacketDefinition<>(info.getId(), info.getFactory(), serializer, info.getRecipient());

packets.replace(packetClass, info, updatedInfo);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.cloudburstmc.math.vector.Vector3f;
import org.cloudburstmc.math.vector.Vector3i;
import org.cloudburstmc.nbt.NbtType;
import org.cloudburstmc.protocol.bedrock.data.EncodingSettings;
import org.cloudburstmc.protocol.bedrock.data.ExperimentData;
import org.cloudburstmc.protocol.bedrock.data.GameRuleData;
import org.cloudburstmc.protocol.bedrock.data.PlayerAbilityHolder;
Expand All @@ -23,7 +24,6 @@
import org.cloudburstmc.protocol.bedrock.data.skin.SerializedSkin;
import org.cloudburstmc.protocol.bedrock.data.structure.StructureSettings;
import org.cloudburstmc.protocol.bedrock.packet.InventoryTransactionPacket;
import org.cloudburstmc.protocol.common.Definition;
import org.cloudburstmc.protocol.common.DefinitionRegistry;
import org.cloudburstmc.protocol.common.NamedDefinition;
import org.cloudburstmc.protocol.common.util.TriConsumer;
Expand All @@ -48,33 +48,52 @@ public interface BedrockCodecHelper {

DefinitionRegistry<NamedDefinition> getCameraPresetDefinitions();

EncodingSettings getEncodingSettings();

void setEncodingSettings(EncodingSettings settings);

// Array serialization (with helper)

default <T> void readArray(ByteBuf buffer, Collection<T> array, BiFunction<ByteBuf, BedrockCodecHelper, T> function) {
readArray(buffer, array, VarInts::readUnsignedInt, function);
<T> void readArray(ByteBuf buffer, Collection<T> array, BiFunction<ByteBuf, BedrockCodecHelper, T> function);

default <T> void readArray(ByteBuf buffer, Collection<T> array, BiFunction<ByteBuf, BedrockCodecHelper, T> function, int maxLength) {
this.readArray(buffer, array, VarInts::readUnsignedInt, function, maxLength);
}

<T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader,
BiFunction<ByteBuf, BedrockCodecHelper, T> function);
<T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, BiFunction<ByteBuf, BedrockCodecHelper, T> function);

<T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, BiFunction<ByteBuf, BedrockCodecHelper, T> function, int maxLength);

default <T> void writeArray(ByteBuf buffer, Collection<T> array, TriConsumer<ByteBuf, BedrockCodecHelper, T> consumer) {
writeArray(buffer, array, VarInts::writeUnsignedInt, consumer);
this.writeArray(buffer, array, VarInts::writeUnsignedInt, consumer);
}

<T> void writeArray(ByteBuf buffer, Collection<T> array, ObjIntConsumer<ByteBuf> lengthWriter, TriConsumer<ByteBuf, BedrockCodecHelper, T> consumer);

<T> T[] readArray(ByteBuf buffer, T[] array, BiFunction<ByteBuf, BedrockCodecHelper, T> function);

<T> T[] readArray(ByteBuf buffer, T[] array, BiFunction<ByteBuf, BedrockCodecHelper, T> function, int maxLength);

<T> void writeArray(ByteBuf buffer, T[] array, TriConsumer<ByteBuf, BedrockCodecHelper, T> consumer);

// Array serialization (without helper)

<T> void readArray(ByteBuf buffer, Collection<T> array, Function<ByteBuf, T> function);

<T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, Function<ByteBuf, T> function);

<T> void readArray(ByteBuf buffer, Collection<T> array, ToLongFunction<ByteBuf> lengthReader, Function<ByteBuf, T> function, int maxLength);

<T> void readArray(ByteBuf buffer, Collection<T> array, Function<ByteBuf, T> function, int maxLength);

<T> void writeArray(ByteBuf buffer, Collection<T> array, BiConsumer<ByteBuf, T> consumer);

<T> void writeArray(ByteBuf buffer, Collection<T> array, ObjIntConsumer<ByteBuf> lengthWriter, BiConsumer<ByteBuf, T> consumer);

<T> T[] readArray(ByteBuf buffer, T[] array, Function<ByteBuf, T> function);

<T> T[] readArray(ByteBuf buffer, T[] array, Function<ByteBuf, T> function, int maxLength);

<T> void writeArray(ByteBuf buffer, T[] array, BiConsumer<ByteBuf, T> consumer);

// Encoding methods
Expand Down Expand Up @@ -121,6 +140,8 @@ default <T> void writeArray(ByteBuf buffer, Collection<T> array, TriConsumer<Byt

byte[] readByteArray(ByteBuf buffer);

byte[] readByteArray(ByteBuf buffer, int maxLength);

void writeByteArray(ByteBuf buffer, byte[] bytes);

ByteBuf readByteBuf(ByteBuf buffer);
Expand Down Expand Up @@ -155,24 +176,22 @@ default <T> void writeArray(ByteBuf buffer, Collection<T> array, TriConsumer<Byt

void writeBlockPosition(ByteBuf buffer, Vector3i blockPosition);

default Object readTag(ByteBuf buffer) {
return readTag(buffer, Object.class);
}

<T> T readTag(ByteBuf buffer, Class<T> expected);

void writeTag(ByteBuf buffer, Object tag);
<T> T readTag(ByteBuf buffer, Class<T> expected, long maxReadSize);

default Object readTagLE(ByteBuf buffer) {
return readTag(buffer, Object.class);
}
void writeTag(ByteBuf buffer, Object tag);

<T> T readTagLE(ByteBuf buffer, Class<T> expected);

<T> T readTagLE(ByteBuf buffer, Class<T> expected, long maxReadSize);

void writeTagLE(ByteBuf buffer, Object tag);

<T> T readTagValue(ByteBuf buffer, NbtType<T> type);

<T> T readTagValue(ByteBuf buffer, NbtType<T> type, long maxReadSize);

void writeTagValue(ByteBuf buffer, Object tag);

void readItemUse(ByteBuf buffer, InventoryTransactionPacket packet);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.cloudburstmc.protocol.bedrock.codec;

import lombok.Value;
import org.cloudburstmc.protocol.bedrock.data.PacketRecipient;
import org.cloudburstmc.protocol.bedrock.packet.BedrockPacket;

import java.util.function.Supplier;
Expand All @@ -10,4 +11,5 @@ public class BedrockPacketDefinition<T extends BedrockPacket> {
int id;
Supplier<T> factory;
BedrockPacketSerializer<T> serializer;
PacketRecipient recipient;
}
Loading

0 comments on commit 82a171a

Please sign in to comment.