Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Common/Cpp/Concurrency/Mutex.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@ namespace PokemonAutomation{

#endif


namespace PokemonAutomation{
template <typename LockType>
using LockGuard = std::lock_guard<LockType>;
}


#endif
24 changes: 18 additions & 6 deletions Common/Cpp/StreamConnections/PollingStreamConnections.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class UnreliableStreamConnectionPolling : public UnreliableStreamSender{

class ReliableStreamConnectionPolling{
public:
virtual void lock() noexcept = 0;
virtual void unlock() noexcept = 0;

//
// These 3 functions are not thread/IRQ safe with anything else.
// If needed call these under the lock above.
//

// Enqueue the specified data into the uncommitted stream.
// On success, returns true.
// On fail, return false and aborts all uncommitted sends.
Expand All @@ -38,12 +46,16 @@ class ReliableStreamConnectionPolling{
// Commits all uncommitted sends to stream.
virtual void commit_uncommitted_reliable_sends() noexcept = 0;


public:
virtual bool reliable_send_all_or_nothing(const void* data, size_t bytes) noexcept{
if (!enqueue_uncommitted_reliable_sends(data, bytes)){
return false;
lock();
bool success = enqueue_uncommitted_reliable_sends(data, bytes);
if (success){
commit_uncommitted_reliable_sends();
}
commit_uncommitted_reliable_sends();
return true;
unlock();
return success;
}


Expand All @@ -53,10 +65,10 @@ class ReliableStreamConnectionPolling{
virtual bool reset_flag_set() const{ return false; }
virtual void clear_reset_flag(){}

virtual bool run_send_events(const WallDuration& timeout){
virtual bool run_send_events(const WallDuration& timeout) noexcept{
return false;
}
virtual bool run_recv_events(const WallDuration& timeout){
virtual bool run_recv_events(const WallDuration& timeout) noexcept{
return false;
}
};
Expand Down
16 changes: 3 additions & 13 deletions Common/PABotBase2/PABotBase2CC_MessageDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ std::string tostr(const PacketHeader* header){
case PABB2_CONNECTION_OPCODE_RET_PACKET_SIZE:
str += "PABB2_CONNECTION_OPCODE_RET_PACKET_SIZE: seqnum = ";
str += std::to_string(header->seqnum);
str += ", bytes = " + std::to_string(((const PacketHeader_u16*)header)->data);
str += ", bytes = " + std::to_string(((const PacketHeader_u32*)header)->data);
return str;

case PABB2_CONNECTION_OPCODE_ASK_BUFFER_SLOTS:
Expand All @@ -61,7 +61,7 @@ std::string tostr(const PacketHeader* header){
case PABB2_CONNECTION_OPCODE_RET_BUFFER_SLOTS:
str += "PABB2_CONNECTION_OPCODE_RET_BUFFER_SLOTS: seqnum = ";
str += std::to_string(header->seqnum);
str += ", slots = " + std::to_string(((const PacketHeader_u8*)header)->data);
str += ", slots = " + std::to_string(((const PacketHeader_u32*)header)->data);
return str;

case PABB2_CONNECTION_OPCODE_ASK_BUFFER_BYTES:
Expand All @@ -71,7 +71,7 @@ std::string tostr(const PacketHeader* header){
case PABB2_CONNECTION_OPCODE_RET_BUFFER_BYTES:
str += "PABB2_CONNECTION_OPCODE_RET_BUFFER_BYTES: seqnum = ";
str += std::to_string(header->seqnum);
str += ", bytes = " + std::to_string(((const PacketHeader_u16*)header)->data);
str += ", bytes = " + std::to_string(((const PacketHeader_u32*)header)->data);
return str;

case PABB2_CONNECTION_OPCODE_ASK_STREAM_DATA:
Expand Down Expand Up @@ -112,16 +112,6 @@ std::string tostr(const PacketHeader* header){
str += "PABB2_CONNECTION_OPCODE_INFO: seqnum = ";
str += std::to_string(header->seqnum);
return str;
case PABB2_CONNECTION_OPCODE_INFO_U8:
str += "PABB2_CONNECTION_OPCODE_INFO_U8: seqnum = ";
str += std::to_string(header->seqnum);
str += ", data = " + std::to_string(((const PacketHeader_u8*)header)->data);
return str;
case PABB2_CONNECTION_OPCODE_INFO_U16:
str += "PABB2_CONNECTION_OPCODE_INFO_U16: seqnum = ";
str += std::to_string(header->seqnum);
str += ", data = " + std::to_string(((const PacketHeader_u16*)header)->data);
return str;
case PABB2_CONNECTION_OPCODE_INFO_H32:
str += "PABB2_CONNECTION_OPCODE_INFO_H32: seqnum = ";
str += std::to_string(header->seqnum);
Expand Down
2 changes: 1 addition & 1 deletion Common/PABotBase2/PABotBase2_MessageProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace PABotBase2{



#define PABB2_MESSAGE_PROTOCOL_VERSION 2026050905
#define PABB2_MESSAGE_PROTOCOL_VERSION 2026052100


struct PABB_PACK MessageHeader{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,10 @@ void ReliableStreamConnection::on_cancellable_cancel(
// Send Path
//

bool ReliableStreamConnection::reset(bool random_session_id, WallDuration timeout){
bool ReliableStreamConnection::reset(WallDuration timeout){
{
std::lock_guard<Mutex> lg(m_lock);
if (!random_session_id){
m_reliable_sender.reset(0xffffffff);
}else if (m_reliable_sender.session_id() == 0xffffffff){
if (m_reliable_sender.session_id() == 0xffffffff){
m_reliable_sender.reset(random_u32());
}else{
m_reliable_sender.reset(m_reliable_sender.session_id() + 1);
Expand All @@ -215,11 +213,7 @@ bool ReliableStreamConnection::reset(bool random_session_id, WallDuration timeou
m_parser.reset();
m_stream_coalescer.reset();
throw_if_cancelled();
if (random_session_id){
m_reliable_sender.send_reset();
}else{
m_reliable_sender.send_packet(PABB2_CONNECTION_OPCODE_ASK_RESET, 0, nullptr);
}
m_reliable_sender.send_reset();
}
m_cv.notify_all();
return wait_for_pending(timeout);
Expand Down Expand Up @@ -267,10 +261,10 @@ void ReliableStreamConnection::send_ack(uint8_t seqnum, uint8_t opcode){
pabb_crc32_write_to_message(m_reliable_sender.session_id(), &packet, sizeof(packet));
unreliable_send(&packet, sizeof(packet));
}
void ReliableStreamConnection::send_ack_u16(uint8_t seqnum, uint8_t opcode, uint16_t data){
void ReliableStreamConnection::send_ack_u32(uint8_t seqnum, uint8_t opcode, uint32_t data){
// Must call inside lock.
struct{
PacketHeader_u16 header;
PacketHeader_u32 header;
uint8_t crc[sizeof(uint32_t)];
} packet;
packet.header.magic_number = PABB2_CONNECTION_MAGIC_NUMBER;
Expand Down Expand Up @@ -414,8 +408,6 @@ void ReliableStreamConnection::on_packet(const PacketHeader* packet){
case PABB2_CONNECTION_OPCODE_INFO_STREAM_SEND_FULL:
case PABB2_CONNECTION_OPCODE_INFO_STREAM_RECV_FULL:
case PABB2_CONNECTION_OPCODE_INFO:
case PABB2_CONNECTION_OPCODE_INFO_U8:
case PABB2_CONNECTION_OPCODE_INFO_U16:
case PABB2_CONNECTION_OPCODE_INFO_H32:
case PABB2_CONNECTION_OPCODE_INFO_U32:
case PABB2_CONNECTION_OPCODE_INFO_I32:
Expand All @@ -440,13 +432,13 @@ void ReliableStreamConnection::on_packet(const PacketHeader* packet){
}
void ReliableStreamConnection::process_UNKNOWN_OPCODE(const PacketHeader* packet){
std::lock_guard<Mutex> lg(m_lock);
if (packet->packet_bytes < sizeof(PacketHeader_u8) + sizeof(uint32_t)){
if (packet->packet_bytes < sizeof(PacketHeader_u32) + sizeof(uint32_t)){
m_error = "Unknown opcode packet is too small: " + std::to_string(packet->packet_bytes);
m_logger.log("[RSC]: " + m_error, COLOR_RED);
return;
}

const PacketHeader_u8* message = (const PacketHeader_u8*)packet;
const PacketHeader_u32* message = (const PacketHeader_u32*)packet;
m_logger.log(
"[RSC]: PABB2_CONNECTION_OPCODE_INVALID_OPCODE: Device reported an invalid opcode: " +
std::to_string(message->data),
Expand Down Expand Up @@ -493,14 +485,14 @@ void ReliableStreamConnection::process_RET_VERSION(const PacketHeader* packet){
m_cv.notify_all();
}
void ReliableStreamConnection::process_RET_PACKET_SIZE(const PacketHeader* packet){
if (packet->packet_bytes < sizeof(PacketHeader_u16) + sizeof(uint32_t)){
if (packet->packet_bytes < sizeof(PacketHeader_u32) + sizeof(uint32_t)){
m_logger.log(
"[RSC]: Packet size response is too small: " + std::to_string(packet->packet_bytes),
COLOR_RED
);
return;
}
const PacketHeader_u16* message = (const PacketHeader_u16*)packet;
const PacketHeader_u32* message = (const PacketHeader_u32*)packet;
m_logger.log(
"[RSC]: Setting Packet Size: " + std::to_string(message->data) + " bytes",
COLOR_BLUE
Expand All @@ -513,14 +505,14 @@ void ReliableStreamConnection::process_RET_PACKET_SIZE(const PacketHeader* packe
m_cv.notify_all();
}
void ReliableStreamConnection::process_RET_BUFFER_SLOTS(const PacketHeader* packet){
if (packet->packet_bytes < sizeof(PacketHeader_u8) + sizeof(uint32_t)){
if (packet->packet_bytes < sizeof(PacketHeader_u32) + sizeof(uint32_t)){
m_logger.log(
"[RSC]: Buffer slot response is too small: " + std::to_string(packet->packet_bytes),
COLOR_RED
);
return;
}
const PacketHeader_u8* message = (const PacketHeader_u8*)packet;
const PacketHeader_u32* message = (const PacketHeader_u32*)packet;
{
std::lock_guard<Mutex> lg(m_lock);
m_reliable_sender.remove(packet->seqnum);
Expand All @@ -533,14 +525,14 @@ void ReliableStreamConnection::process_RET_BUFFER_SLOTS(const PacketHeader* pack
m_cv.notify_all();
}
void ReliableStreamConnection::process_RET_BUFFER_BYTES(const PacketHeader* packet){
if (packet->packet_bytes < sizeof(PacketHeader_u16) + sizeof(uint32_t)){
if (packet->packet_bytes < sizeof(PacketHeader_u32) + sizeof(uint32_t)){
m_logger.log(
"[RSC]: Buffer slot response is too small: " + std::to_string(packet->packet_bytes),
COLOR_RED
);
return;
}
const PacketHeader_u16* message = (const PacketHeader_u16*)packet;
const PacketHeader_u32* message = (const PacketHeader_u32*)packet;
{
std::lock_guard<Mutex> lg(m_lock);
m_reliable_sender.remove(packet->seqnum);
Expand Down Expand Up @@ -570,7 +562,7 @@ void ReliableStreamConnection::process_ASK_STREAM_DATA(const PacketHeader* packe
// cout << "Calling: send_ack_u16()" << endl;
{
std::lock_guard<Mutex> lg(m_lock);
send_ack_u16(
send_ack_u32(
packet->seqnum,
PABB2_CONNECTION_OPCODE_RET_STREAM_DATA,
m_stream_coalescer.free_bytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ReliableStreamConnection final
}
virtual bool cancel(std::exception_ptr exception) noexcept override;

bool reset(bool random_session_id, WallDuration timeout = WallDuration::max());
bool reset(WallDuration timeout = WallDuration::max());

bool remote_protocol_is_compatible() const{
return m_remote_protocol_compatible;
Expand Down Expand Up @@ -89,7 +89,7 @@ class ReliableStreamConnection final
// Send

void send_ack(uint8_t seqnum, uint8_t opcode);
void send_ack_u16(uint8_t seqnum, uint8_t opcode, uint16_t data);
void send_ack_u32(uint8_t seqnum, uint8_t opcode, uint32_t data);

void retransmit_thread();

Expand Down
Loading
Loading