From 93aac98c3205399359db47e7a84ce01c42bb2dc1 Mon Sep 17 00:00:00 2001 From: Tobias Reiher Date: Fri, 2 Feb 2024 16:16:14 +0100 Subject: [PATCH] Prevent different casings for same entity Ref. eng/recordflux/RecordFlux#563 --- CHANGELOG.md | 4 + examples/specs/tls_handshake.rflx | 267 +++++++++++++++--------------- examples/specs/tls_record.rflx | 56 +++---- rflx/error.py | 4 + rflx/graph.py | 4 +- rflx/model/__init__.py | 1 - rflx/model/message.py | 40 ++++- rflx/model/session.py | 245 ++++++++++++++++----------- rflx/model/type_.py | 43 ++++- tests/conftest.py | 6 +- tests/unit/model/message_test.py | 120 +++++++++++++- tests/unit/model/session_test.py | 117 ++++++++++++- 12 files changed, 635 insertions(+), 272 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f4c04231..80c891872 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Prevent different casings for same entity (eng/recordflux/RecordFlux#563) + ### Fixed - Unexpected errors when using different casings for same entity (AdaCore/RecordFlux#562, eng/recordflux/RecordFlux#1506) diff --git a/examples/specs/tls_handshake.rflx b/examples/specs/tls_handshake.rflx index 49fb6e9ca..4eb48aa57 100644 --- a/examples/specs/tls_handshake.rflx +++ b/examples/specs/tls_handshake.rflx @@ -1,4 +1,4 @@ -with Tls_Common; +with TLS_Common; with Tls_Parameters; with Tls_Extensiontype_Values; @@ -73,7 +73,7 @@ package TLS_Handshake is -- restrictions. -- -- The concrete restrictions for the extension types defined below have been - -- derived semi-automatically from the IANA TLS_ExtensionType_Values + -- derived semi-automatically from the IANA Tls_Extensiontype_Values -- registry. type Data_Length is range 0 .. 2 ** 16 - 1 with Size => 16; @@ -91,7 +91,7 @@ package TLS_Handshake is -- Extensions for the CH (ClientHello) message type CH_Extension_TLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -139,7 +139,7 @@ package TLS_Handshake is type CH_Extension_DTLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -185,7 +185,7 @@ package TLS_Handshake is -- Extensions for the SH (ServerHello) message type SH_Extension_TLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -214,7 +214,7 @@ package TLS_Handshake is type SH_Extension_DTLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -245,7 +245,7 @@ package TLS_Handshake is -- Extensions for the HRR (HelloRetryRequest) message type HRR_Extension_TLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -273,7 +273,7 @@ package TLS_Handshake is type HRR_Extension_DTLS is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -302,7 +302,7 @@ package TLS_Handshake is -- Extensions for the EE (EncryptedExtensions) message type EE_Extension is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -342,7 +342,7 @@ package TLS_Handshake is -- Extensions for the CR (CertificateRequest) message type CR_Extension is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -376,7 +376,7 @@ package TLS_Handshake is -- Extensions for the CT (CertificateEntry) message type CT_Extension is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -405,7 +405,7 @@ package TLS_Handshake is -- Extensions for the NST (NewSessionTicket) message type NST_Extension is message - Tag : TLS_ExtensionType_Values::TLS_ExtensionType_Values + Tag : Tls_Extensiontype_Values::Tls_Extensiontype_Values then Data_Length if -- Message type specific constraints @@ -501,8 +501,8 @@ package TLS_Handshake is -- For compatibility reasons both TLS 1.3 and DTLS 1.3 require -- indicating the version 1.2 here. then Random - if (Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::DTLS_1_2); + if (Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::DTLS_1_2); Random : Opaque with Size => 32 * 8; Legacy_Session_ID_Length : Legacy_Session_ID_Length; @@ -510,32 +510,32 @@ package TLS_Handshake is with Size => Legacy_Session_ID_Length * 8 -- TLS then Cipher_Suites_Length - if Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3 + if Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3 -- DTLS then Legacy_Cookie_Length - if Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3; + if Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3; Legacy_Cookie_Length : Legacy_Cookie_Length; Legacy_Cookie : Opaque with Size => Legacy_Cookie_Length * 8; Cipher_Suites_Length : Cipher_Suites_Length -- TLS then Cipher_Suites_TLS - if Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3 + if Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3 -- DTLS then Cipher_Suites_DTLS - if Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3; + if Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3; -- TLS Cipher_Suites_TLS : Cipher_Suites_TLS with Size => Cipher_Suites_Length * 8 @@ -561,16 +561,16 @@ package TLS_Handshake is Extensions_Length : Client_Hello_Extensions_Length -- TLS then Extensions_TLS - if Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3 + if Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3 -- DTLS then Extensions_DTLS - if Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3; + if Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3; Extensions_TLS : CH_Extensions_TLS with Size => Extensions_Length * 8 then null; @@ -593,8 +593,8 @@ package TLS_Handshake is -- For compatibility reasons both TLS 1.3 and DTLS 1.3 require -- indicating the version 1.2 here. then Random - if (Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::DTLS_1_2); + if (Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::DTLS_1_2); Random : Opaque with Size => 32 * 8; Legacy_Session_ID_Length : Legacy_Session_ID_Length; @@ -602,16 +602,16 @@ package TLS_Handshake is with Size => Legacy_Session_ID_Length * 8 -- TLS then Cipher_Suite_TLS - if Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3 + if Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3 -- DTLS then Cipher_Suite_DTLS - if Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3; + if Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3; -- TLS Cipher_Suite_TLS : Cipher_Suite_TLS then Legacy_Compression_Method; @@ -631,10 +631,10 @@ package TLS_Handshake is 16#E2#, 16#C8#, 16#A8#, 16#33#, 16#9C#]) and -- TLS - (Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3) + (Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3) then Extensions_DTLS with Size => Extensions_Length * 8 if (Random /= [16#CF#, 16#21#, 16#AD#, 16#74#, 16#E5#, 16#9A#, 16#61#, 16#11#, 16#BE#, @@ -643,10 +643,10 @@ package TLS_Handshake is 16#E2#, 16#C8#, 16#A8#, 16#33#, 16#9C#]) and -- DTLS - (Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3) + (Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3) then HRR_Extensions_TLS with Size => Extensions_Length * 8 @@ -656,10 +656,10 @@ package TLS_Handshake is 16#E2#, 16#C8#, 16#A8#, 16#33#, 16#9C#]) and -- TLS - (Legacy_Version = Tls_Common::TLS_1_0 - or Legacy_Version = Tls_Common::TLS_1_1 - or Legacy_Version = Tls_Common::TLS_1_2 - or Legacy_Version = Tls_Common::TLS_1_3) + (Legacy_Version = TLS_Common::TLS_1_0 + or Legacy_Version = TLS_Common::TLS_1_1 + or Legacy_Version = TLS_Common::TLS_1_2 + or Legacy_Version = TLS_Common::TLS_1_3) then HRR_Extensions_DTLS with Size => Extensions_Length * 8 @@ -669,10 +669,10 @@ package TLS_Handshake is 16#E2#, 16#C8#, 16#A8#, 16#33#, 16#9C#]) and -- DTLS - (Legacy_Version /= Tls_Common::TLS_1_0 - and Legacy_Version /= Tls_Common::TLS_1_1 - and Legacy_Version /= Tls_Common::TLS_1_2 - and Legacy_Version /= Tls_Common::TLS_1_3); + (Legacy_Version /= TLS_Common::TLS_1_0 + and Legacy_Version /= TLS_Common::TLS_1_1 + and Legacy_Version /= TLS_Common::TLS_1_2 + and Legacy_Version /= TLS_Common::TLS_1_3); Extensions_TLS : SH_Extensions_TLS then null; Extensions_DTLS : SH_Extensions_DTLS @@ -683,10 +683,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Server_Hello) - if Tag = TLS_Parameters::Server_Hello; + if Tag = Tls_Parameters::Server_Hello; for DTLS_Handshake use (Payload => Server_Hello) - if Tag = TLS_Parameters::Server_Hello; + if Tag = Tls_Parameters::Server_Hello; -- Server Parameters @@ -700,10 +700,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Encrypted_Extensions) - if Tag = TLS_Parameters::Encrypted_Extensions; + if Tag = Tls_Parameters::Encrypted_Extensions; for DTLS_Handshake use (Payload => Encrypted_Extensions) - if Tag = TLS_Parameters::Encrypted_Extensions; + if Tag = Tls_Parameters::Encrypted_Extensions; type Certificate_Request_Context_Length is range 0 .. 2 ** 8 - 1 with Size => 8; type Certificate_Request_Extensions_Length is range 2 .. 2 ** 16 - 1 with Size => 16; @@ -719,10 +719,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Certificate_Request) - if Tag = TLS_Parameters::Certificate_Request; + if Tag = Tls_Parameters::Certificate_Request; for DTLS_Handshake use (Payload => Certificate_Request) - if Tag = TLS_Parameters::Certificate_Request; + if Tag = Tls_Parameters::Certificate_Request; -- Authentication Messages @@ -754,10 +754,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Certificate) - if Tag = TLS_Parameters::Certificate; + if Tag = Tls_Parameters::Certificate; for DTLS_Handshake use (Payload => Certificate) - if Tag = TLS_Parameters::Certificate; + if Tag = Tls_Parameters::Certificate; type Signature_Length is range 0 .. 2 ** 16 - 1 with Size => 16; @@ -770,10 +770,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Certificate_Verify) - if Tag = TLS_Parameters::Certificate_Verify; + if Tag = Tls_Parameters::Certificate_Verify; for DTLS_Handshake use (Payload => Certificate_Verify) - if Tag = TLS_Parameters::Certificate_Verify; + if Tag = Tls_Parameters::Certificate_Verify; type Finished is message @@ -782,20 +782,20 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Finished) - if Tag = TLS_Parameters::Finished; + if Tag = Tls_Parameters::Finished; for DTLS_Handshake use (Payload => Finished) - if Tag = TLS_Parameters::Finished; + if Tag = Tls_Parameters::Finished; -- End of Early Data type End_Of_Early_Data is null message; for TLS_Handshake use (Payload => End_Of_Early_Data) - if Tag = TLS_Parameters::End_Of_Early_Data; + if Tag = Tls_Parameters::End_Of_Early_Data; for DTLS_Handshake use (Payload => End_Of_Early_Data) - if Tag = TLS_Parameters::End_Of_Early_Data; + if Tag = Tls_Parameters::End_Of_Early_Data; -- Post-TLS_Handshake Messages @@ -821,10 +821,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => New_Session_Ticket) - if Tag = TLS_Parameters::New_Session_Ticket; + if Tag = Tls_Parameters::New_Session_Ticket; for DTLS_Handshake use (Payload => New_Session_Ticket) - if Tag = TLS_Parameters::New_Session_Ticket; + if Tag = Tls_Parameters::New_Session_Ticket; type Key_Update_Request is ( Update_Not_Requested => 0, @@ -837,10 +837,10 @@ package TLS_Handshake is end message; for TLS_Handshake use (Payload => Key_Update) - if Tag = TLS_Parameters::Key_Update; + if Tag = Tls_Parameters::Key_Update; for DTLS_Handshake use (Payload => Key_Update) - if Tag = TLS_Parameters::Key_Update; + if Tag = Tls_Parameters::Key_Update; -- Server Name Indication Extension @@ -866,11 +866,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Server_Name_List) - if Tag = TLS_ExtensionType_Values::Server_Name; + if Tag = Tls_Extensiontype_Values::Server_Name; for CH_Extension_DTLS use (Data => Server_Name_List) - if Tag = TLS_ExtensionType_Values::Server_Name; + if Tag = Tls_Extensiontype_Values::Server_Name; for EE_Extension use (Data => Server_Name_List) - if Tag = TLS_ExtensionType_Values::Server_Name; + if Tag = Tls_Extensiontype_Values::Server_Name; -- Max Fragment Length @@ -887,11 +887,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Max_Fragment_Length) - if Tag = TLS_ExtensionType_Values::Max_Fragment_Length; + if Tag = Tls_Extensiontype_Values::Max_Fragment_Length; for CH_Extension_DTLS use (Data => Max_Fragment_Length) - if Tag = TLS_ExtensionType_Values::Max_Fragment_Length; + if Tag = Tls_Extensiontype_Values::Max_Fragment_Length; for EE_Extension use (Data => Max_Fragment_Length) - if Tag = TLS_ExtensionType_Values::Max_Fragment_Length; + if Tag = Tls_Extensiontype_Values::Max_Fragment_Length; -- Supported Versions Extension @@ -911,17 +911,17 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Supported_Versions) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; for CH_Extension_DTLS use (Data => Supported_Versions) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; for SH_Extension_TLS use (Data => Supported_Version) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; for SH_Extension_DTLS use (Data => Supported_Version) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; for HRR_Extension_TLS use (Data => Supported_Version) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; for HRR_Extension_DTLS use (Data => Supported_Version) - if Tag = TLS_ExtensionType_Values::Supported_Versions; + if Tag = Tls_Extensiontype_Values::Supported_Versions; -- Cookie Extension @@ -935,13 +935,13 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Cookie) - if Tag = TLS_ExtensionType_Values::Cookie; + if Tag = Tls_Extensiontype_Values::Cookie; for CH_Extension_DTLS use (Data => Cookie) - if Tag = TLS_ExtensionType_Values::Cookie; + if Tag = Tls_Extensiontype_Values::Cookie; for HRR_Extension_TLS use (Data => Cookie) - if Tag = TLS_ExtensionType_Values::Cookie; + if Tag = Tls_Extensiontype_Values::Cookie; for HRR_Extension_DTLS use (Data => Cookie) - if Tag = TLS_ExtensionType_Values::Cookie; + if Tag = Tls_Extensiontype_Values::Cookie; -- Signature Algorithms Extension @@ -956,11 +956,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Signature_Algorithms) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms; for CH_Extension_DTLS use (Data => Signature_Algorithms) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms; for CR_Extension use (Data => Signature_Algorithms) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms; type Signature_Algorithms_Cert is message @@ -970,11 +970,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Signature_Algorithms_Cert) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms_Cert; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms_Cert; for CH_Extension_DTLS use (Data => Signature_Algorithms_Cert) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms_Cert; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms_Cert; for CR_Extension use (Data => Signature_Algorithms_Cert) - if Tag = TLS_ExtensionType_Values::Signature_Algorithms_Cert; + if Tag = Tls_Extensiontype_Values::Signature_Algorithms_Cert; -- Heartbeat Extension @@ -1005,11 +1005,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Certificate_Authorities) - if Tag = TLS_ExtensionType_Values::Certificate_Authorities; + if Tag = Tls_Extensiontype_Values::Certificate_Authorities; for CH_Extension_DTLS use (Data => Certificate_Authorities) - if Tag = TLS_ExtensionType_Values::Certificate_Authorities; + if Tag = Tls_Extensiontype_Values::Certificate_Authorities; for CR_Extension use (Data => Certificate_Authorities) - if Tag = TLS_ExtensionType_Values::Certificate_Authorities; + if Tag = Tls_Extensiontype_Values::Certificate_Authorities; -- OID Filters Extension @@ -1037,16 +1037,16 @@ package TLS_Handshake is end message; for CR_Extension use (Data => OID_Filters) - if Tag = TLS_ExtensionType_Values::Oid_Filters; + if Tag = Tls_Extensiontype_Values::Oid_Filters; -- Post-TLS_Handshake Client Authentication Extension type Post_Handshake_Auth is null message; for CH_Extension_TLS use (Data => Post_Handshake_Auth) - if Tag = TLS_ExtensionType_Values::Post_Handshake_Auth; + if Tag = Tls_Extensiontype_Values::Post_Handshake_Auth; for CH_Extension_DTLS use (Data => Post_Handshake_Auth) - if Tag = TLS_ExtensionType_Values::Post_Handshake_Auth; + if Tag = Tls_Extensiontype_Values::Post_Handshake_Auth; -- Supported Groups Extension @@ -1062,11 +1062,11 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Supported_Groups) - if Tag = TLS_ExtensionType_Values::Supported_Groups; + if Tag = Tls_Extensiontype_Values::Supported_Groups; for CH_Extension_DTLS use (Data => Supported_Groups) - if Tag = TLS_ExtensionType_Values::Supported_Groups; + if Tag = Tls_Extensiontype_Values::Supported_Groups; for EE_Extension use (Data => Supported_Groups) - if Tag = TLS_ExtensionType_Values::Supported_Groups; + if Tag = Tls_Extensiontype_Values::Supported_Groups; -- Key Share Extension @@ -1099,23 +1099,23 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Key_Share_CH) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; for CH_Extension_DTLS use (Data => Key_Share_CH) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; for SH_Extension_TLS use (Data => Key_Share_SH) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; for SH_Extension_DTLS use (Data => Key_Share_SH) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; for HRR_Extension_TLS use (Data => Key_Share_HRR) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; for HRR_Extension_DTLS use (Data => Key_Share_HRR) - if Tag = TLS_ExtensionType_Values::Key_Share; + if Tag = Tls_Extensiontype_Values::Key_Share; -- Pre-Shared Key Exchange Modes Extension type PSK_Key_Exchange_Modes_Length is range 1 .. 255 with Size => 8; - type Key_Exchange_Modes is sequence of TLS_Parameters::TLS_PskKeyExchangeMode; + type Key_Exchange_Modes is sequence of Tls_Parameters::TLS_PskKeyExchangeMode; type Psk_Key_Exchange_Modes is message @@ -1125,9 +1125,9 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Psk_Key_Exchange_Modes) - if Tag = TLS_ExtensionType_Values::Psk_Key_Exchange_Modes; + if Tag = Tls_Extensiontype_Values::Psk_Key_Exchange_Modes; for CH_Extension_DTLS use (Data => Psk_Key_Exchange_Modes) - if Tag = TLS_ExtensionType_Values::Psk_Key_Exchange_Modes; + if Tag = Tls_Extensiontype_Values::Psk_Key_Exchange_Modes; -- Early Data Indication Extension @@ -1141,13 +1141,12 @@ package TLS_Handshake is type Early_Data is null message; for CH_Extension_TLS use (Data => Early_Data) - if Tag = TLS_ExtensionType_Values::Early_Data; + if Tag = Tls_Extensiontype_Values::Early_Data; for CH_Extension_DTLS use (Data => Early_Data) - if Tag = TLS_ExtensionType_Values::Early_Data; + if Tag = Tls_Extensiontype_Values::Early_Data; for EE_Extension use (Data => Early_Data) - if Tag = TLS_ExtensionType_Values::Early_Data; - for NST_Extension use (Data => Early_Data_Indication) - if Tag = TLS_ExtensionType_Values::Early_Data; + if Tag = Tls_Extensiontype_Values::Early_Data; + for NST_Extension use (Data => Early_Data_Indication); -- Pre-Shared Key Extension @@ -1196,13 +1195,13 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Pre_Shared_Key_CH) - if Tag = TLS_ExtensionType_Values::Pre_Shared_Key; + if Tag = Tls_Extensiontype_Values::Pre_Shared_Key; for CH_Extension_DTLS use (Data => Pre_Shared_Key_CH) - if Tag = TLS_ExtensionType_Values::Pre_Shared_Key; + if Tag = Tls_Extensiontype_Values::Pre_Shared_Key; for SH_Extension_TLS use (Data => Pre_Shared_Key_SH) - if Tag = TLS_ExtensionType_Values::Pre_Shared_Key; + if Tag = Tls_Extensiontype_Values::Pre_Shared_Key; for SH_Extension_DTLS use (Data => Pre_Shared_Key_SH) - if Tag = TLS_ExtensionType_Values::Pre_Shared_Key; + if Tag = Tls_Extensiontype_Values::Pre_Shared_Key; -- Application-Layer Protocol Negotiation Extension @@ -1226,10 +1225,10 @@ package TLS_Handshake is end message; for CH_Extension_TLS use (Data => Protocol_Name_List) - if Tag = TLS_ExtensionType_Values::Application_Layer_Protocol_Negotiation; + if Tag = Tls_Extensiontype_Values::Application_Layer_Protocol_Negotiation; for CH_Extension_DTLS use (Data => Protocol_Name_List) - if Tag = TLS_ExtensionType_Values::Application_Layer_Protocol_Negotiation; + if Tag = Tls_Extensiontype_Values::Application_Layer_Protocol_Negotiation; for EE_Extension use (Data => Protocol_Name_List) - if Tag = TLS_ExtensionType_Values::Application_Layer_Protocol_Negotiation; + if Tag = Tls_Extensiontype_Values::Application_Layer_Protocol_Negotiation; end TLS_Handshake; diff --git a/examples/specs/tls_record.rflx b/examples/specs/tls_record.rflx index 8de660a7b..2f9770b5c 100644 --- a/examples/specs/tls_record.rflx +++ b/examples/specs/tls_record.rflx @@ -1,5 +1,5 @@ with TLS_Alert; -with Tls_Common; +with TLS_Common; with Tls_Handshake; with Tls_Parameters; @@ -62,22 +62,22 @@ package TLS_Record is message Prefix : Plaintext_Prefix; Tag : Plaintext_Content_Type; - Legacy_Record_Version : Tls_Common::Protocol_Version + Legacy_Record_Version : TLS_Common::Protocol_Version -- TLS then Length - if (Legacy_Record_Version = Tls_Common::TLS_1_0 - or Legacy_Record_Version = Tls_Common::TLS_1_1 - or Legacy_Record_Version = Tls_Common::TLS_1_2 - or Legacy_Record_Version = Tls_Common::TLS_1_3) + if (Legacy_Record_Version = TLS_Common::TLS_1_0 + or Legacy_Record_Version = TLS_Common::TLS_1_1 + or Legacy_Record_Version = TLS_Common::TLS_1_2 + or Legacy_Record_Version = TLS_Common::TLS_1_3) and -- The following content types are not defined for TLS (Tag /= Tls12_Cid and Tag /= ACK) -- DTLS then Epoch - if Legacy_Record_Version /= Tls_Common::TLS_1_0 - and Legacy_Record_Version /= Tls_Common::TLS_1_1 - and Legacy_Record_Version /= Tls_Common::TLS_1_2 - and Legacy_Record_Version /= Tls_Common::TLS_1_3; + if Legacy_Record_Version /= TLS_Common::TLS_1_0 + and Legacy_Record_Version /= TLS_Common::TLS_1_1 + and Legacy_Record_Version /= TLS_Common::TLS_1_2 + and Legacy_Record_Version /= TLS_Common::TLS_1_3; Epoch : Plaintext_Epoch; Sequence_Number : Plaintext_Sequence_Number; Length : Plaintext_Length @@ -87,8 +87,8 @@ package TLS_Record is then Encrypted_Record with Size => Length * 8 if Tag = Application_Data - and (Legacy_Record_Version = Tls_Common::TLS_1_2 - or Legacy_Record_Version = Tls_Common::DTLS_1_2); + and (Legacy_Record_Version = TLS_Common::TLS_1_2 + or Legacy_Record_Version = TLS_Common::DTLS_1_2); Fragment : Opaque then null; Encrypted_Record : Opaque; @@ -104,18 +104,18 @@ package TLS_Record is -- tls_handshake.rflx for more information. for TLS_Plaintext use (Fragment => TLS_Handshake::TLS_Handshake) if Tag = Handshake - and (Legacy_Record_Version = Tls_Common::TLS_1_0 - or Legacy_Record_Version = Tls_Common::TLS_1_1 - or Legacy_Record_Version = Tls_Common::TLS_1_2 - or Legacy_Record_Version = Tls_Common::TLS_1_3); + and (Legacy_Record_Version = TLS_Common::TLS_1_0 + or Legacy_Record_Version = TLS_Common::TLS_1_1 + or Legacy_Record_Version = TLS_Common::TLS_1_2 + or Legacy_Record_Version = TLS_Common::TLS_1_3); -- DTLS Handshake for TLS_Plaintext use (Fragment => TLS_Handshake::DTLS_Handshake) if Tag = Handshake - and (Legacy_Record_Version /= Tls_Common::TLS_1_0 - and Legacy_Record_Version /= Tls_Common::TLS_1_1 - and Legacy_Record_Version /= Tls_Common::TLS_1_2 - and Legacy_Record_Version /= Tls_Common::TLS_1_3); + and (Legacy_Record_Version /= TLS_Common::TLS_1_0 + and Legacy_Record_Version /= TLS_Common::TLS_1_1 + and Legacy_Record_Version /= TLS_Common::TLS_1_2 + and Legacy_Record_Version /= TLS_Common::TLS_1_3); for TLS_Plaintext use (Fragment => TLS_Alert::Alert) if Tag = Alert; @@ -209,17 +209,17 @@ package TLS_Record is for TLS_Record use (Plaintext_Rec_Fragment => TLS_Handshake::TLS_Handshake) if Plaintext_Rec_Tag = Handshake - and (Plaintext_Rec_Legacy_Record_Version = Tls_Common::TLS_1_0 - or Plaintext_Rec_Legacy_Record_Version = Tls_Common::TLS_1_1 - or Plaintext_Rec_Legacy_Record_Version = Tls_Common::TLS_1_2 - or Plaintext_Rec_Legacy_Record_Version = Tls_Common::TLS_1_3); + and (Plaintext_Rec_Legacy_Record_Version = TLS_Common::TLS_1_0 + or Plaintext_Rec_Legacy_Record_Version = TLS_Common::TLS_1_1 + or Plaintext_Rec_Legacy_Record_Version = TLS_Common::TLS_1_2 + or Plaintext_Rec_Legacy_Record_Version = TLS_Common::TLS_1_3); for TLS_Record use (Plaintext_Rec_Fragment => TLS_Handshake::DTLS_Handshake) if Plaintext_Rec_Tag = Handshake - and (Plaintext_Rec_Legacy_Record_Version /= Tls_Common::TLS_1_0 - and Plaintext_Rec_Legacy_Record_Version /= Tls_Common::TLS_1_1 - and Plaintext_Rec_Legacy_Record_Version /= Tls_Common::TLS_1_2 - and Plaintext_Rec_Legacy_Record_Version /= Tls_Common::TLS_1_3); + and (Plaintext_Rec_Legacy_Record_Version /= TLS_Common::TLS_1_0 + and Plaintext_Rec_Legacy_Record_Version /= TLS_Common::TLS_1_1 + and Plaintext_Rec_Legacy_Record_Version /= TLS_Common::TLS_1_2 + and Plaintext_Rec_Legacy_Record_Version /= TLS_Common::TLS_1_3); for TLS_Record use (Plaintext_Rec_Fragment => TLS_Alert::Alert) if Plaintext_Rec_Tag = Alert; diff --git a/rflx/error.py b/rflx/error.py index ad691ba4b..caa297263 100644 --- a/rflx/error.py +++ b/rflx/error.py @@ -51,6 +51,10 @@ def start(self) -> tuple[int, int]: def end(self) -> Optional[tuple[int, int]]: return self._end + @property + def short(self) -> Location: + return Location(self.start, Path(self.source.name) if self.source else None, self.end) + def __hash__(self) -> int: return hash(self._start) diff --git a/rflx/graph.py b/rflx/graph.py index bca58079f..2e44502b4 100644 --- a/rflx/graph.py +++ b/rflx/graph.py @@ -11,7 +11,7 @@ from rflx.error import RecordFluxError, Severity, Subsystem from rflx.expression import TRUE, UNDEFINED from rflx.identifier import ID -from rflx.model import FINAL_STATE, AbstractSession, Link, Message +from rflx.model import FINAL_STATE, Link, Message, Session log = logging.getLogger(__name__) @@ -104,7 +104,7 @@ def _edge_label(link: Link) -> str: return result -def create_session_graph(session: AbstractSession, ignore: Optional[Sequence[str]] = None) -> Dot: +def create_session_graph(session: Session, ignore: Optional[Sequence[str]] = None) -> Dot: """ Return pydot graph representation of session. diff --git a/rflx/model/__init__.py b/rflx/model/__init__.py index 07e8cf60f..19765de42 100644 --- a/rflx/model/__init__.py +++ b/rflx/model/__init__.py @@ -20,7 +20,6 @@ from .model import Model as Model, UncheckedModel as UncheckedModel from .session import ( FINAL_STATE as FINAL_STATE, - AbstractSession as AbstractSession, Session as Session, State as State, Transition as Transition, diff --git a/rflx/model/message.py b/rflx/model/message.py index 2239c8aab..ae09537b8 100644 --- a/rflx/model/message.py +++ b/rflx/model/message.py @@ -171,6 +171,8 @@ def __init__( # noqa: PLR0913 if not self.error.errors and not skip_verification: self._verify() + self._check_identifiers(structure, types) + self.error.propagate() def __hash__(self) -> int: @@ -1253,7 +1255,6 @@ def _verify(self) -> None: proofs.check(self.error) self._prove_reachability(valid_paths) - self.error.propagate() def _determine_valid_paths(self) -> set[tuple[Link, ...]]: """Return all paths without contradictions.""" @@ -2118,6 +2119,29 @@ def _compute_first(self, lnk: Link) -> tuple[Field, expr.Expr]: return (lnk.source, expr.Size(lnk.source.affixed_name)) return (root, expr.Add(dist, source_size).simplified()) + def _check_identifiers(self, structure: Sequence[Link], types: Iterable[Field]) -> None: + self.error.extend( + mty.check_identifier_notation( + ( + e + for l in sorted(structure) + for e in ( + expr.Variable(l.source.identifier), + expr.Variable(l.target.identifier), + l.condition, + l.size, + l.first, + ) + ), + itertools.chain( + (f.identifier for f in types), + self._unqualified_enum_literals, + self._qualified_enum_literals, + self._type_names, + ), + ), + ) + class DerivedMessage(Message): def __init__( # noqa: PLR0913 @@ -2213,6 +2237,7 @@ def __init__( # noqa: PLR0913 self._qualified_enum_literals = mty.qualified_enum_literals(self.dependencies) self._type_names = mty.qualified_type_names(self.dependencies) + self._check_identifiers() self._normalize() if not skip_verification: @@ -2344,6 +2369,19 @@ def _verify_condition(self) -> None: ], ) + def _check_identifiers(self) -> None: + self.error.extend( + mty.check_identifier_notation( + [expr.Variable(self.field.identifier), self.condition], + itertools.chain( + (f.identifier for f in self.pdu.fields), + self._unqualified_enum_literals, + self._qualified_enum_literals, + self._type_names, + ), + ), + ) + def __str__(self) -> str: condition = f"\n if {self.condition}" if self.condition != expr.TRUE else "" return f"for {self.pdu.name} use ({self.field.name} => {self.sdu.name}){condition}" diff --git a/rflx/model/session.py b/rflx/model/session.py index 3a6515393..bfc838a75 100644 --- a/rflx/model/session.py +++ b/rflx/model/session.py @@ -2,7 +2,6 @@ import contextlib import itertools -from abc import abstractmethod from collections import defaultdict from collections.abc import Generator, Iterable, Mapping, Sequence from copy import deepcopy @@ -344,8 +343,7 @@ def contains_unsupported_feature(name: ID, action: stmt.Statement) -> bool: substituted(transition.condition, message_decl.type_) -class AbstractSession(TopLevelDeclaration): - @abstractmethod +class Session(TopLevelDeclaration): def __init__( # noqa: PLR0913 self, identifier: StrID, @@ -354,6 +352,7 @@ def __init__( # noqa: PLR0913 parameters: Sequence[decl.FormalDeclaration], types: Sequence[mty.Type], location: Optional[Location] = None, + workers: int = 1, ): super().__init__(identifier, location) @@ -365,6 +364,7 @@ def __init__( # noqa: PLR0913 self.direct_dependencies = {t.identifier: t for t in types} self.types = self.direct_dependencies.copy() self.location = location + self._workers = workers refinements = [t for t in types if isinstance(t, Refinement)] @@ -381,10 +381,15 @@ def __init__( # noqa: PLR0913 self._enum_literals = mty.enum_literals(self.types.values(), self.package) self._type_names = mty.qualified_type_names(self.types.values()) + self._check_identifiers() self._normalize() + self._validate() self.error.propagate() + self._optimize() + self.to_ir() + def __hash__(self) -> int: return hash(self.identifier) @@ -415,6 +420,20 @@ def __str__(self) -> str: def initial_state(self) -> State: return self.states[0] + @lru_cache # noqa: B019 + def to_ir(self) -> ir.Session: + variable_id = id_generator() + return ir.Session( + self.identifier, + [state.to_ir(variable_id) for state in self.states], + [d.to_ir(variable_id) for d in self.declarations.values()], + [p.to_ir() for p in self.parameters.values()], + self.types, + self.location, + variable_id, + self._workers, + ) + def _normalize(self) -> None: # noqa: PLR0912 """ Normalize all expressions of the session. @@ -503,105 +522,77 @@ def normalize_identifiers_local( if state.exception_transition and t.target in states_map: state.exception_transition.target = states_map[state.exception_transition.target] + def _optimize(self) -> None: + for state in self.states: + state.optimize() -def normalize_identifiers( - expression: expr.Expr, - variables: Iterable[ID], - enum_literals: Iterable[ID], - type_names: Iterable[ID], - functions: Iterable[ID], -) -> expr.Expr: - variables_map = {v: v for v in variables} - type_names_map = {t: t for t in type_names} - enum_literals_map = {l: l for l in enum_literals} - functions_map = {f: f for f in functions} - - if isinstance(expression, expr.Variable): - if expression.identifier in type_names_map: - return expr.TypeName( - ID(type_names_map[expression.identifier], location=expression.identifier.location), - expression.type_, - location=expression.location, - ) - if expression.identifier in enum_literals_map: - return expr.Literal( - ID( - enum_literals_map[expression.identifier], - location=expression.identifier.location, + def _check_identifiers(self) -> None: + self.error.extend( + mty.check_identifier_notation( + itertools.chain( + ( + expr.Variable(d.type_identifier) + for d in self.declarations.values() + if isinstance(d, decl.TypeCheckableDeclaration) + ), + ( + d.expression + for d in self.declarations.values() + if isinstance(d, decl.VariableDeclaration) and d.expression + ), ), - expression.type_, - location=expression.location, - ) - if expression.identifier in functions_map: - return expr.Call( - ID(functions_map[expression.identifier], location=expression.identifier.location), - [], - expression.negative, - expression.immutable, - expression.type_, - location=expression.location, - ) - if expression.identifier in variables_map: - return expr.Variable( - ID(variables_map[expression.identifier], location=expression.identifier.location), - expression.negative, - expression.immutable, - expression.type_, - location=expression.location, - ) - - if isinstance(expression, expr.Call) and expression.identifier in functions_map: - return expr.Call( - ID(functions_map[expression.identifier], location=expression.identifier.location), - expression.args, - expression.negative, - expression.immutable, - expression.type_, - expression.argument_types, - location=expression.location, - ) - - return expression - - -class Session(AbstractSession): - def __init__( # noqa: PLR0913 - self, - identifier: StrID, - states: Sequence[State], - declarations: Sequence[decl.BasicDeclaration], - parameters: Sequence[decl.FormalDeclaration], - types: Sequence[mty.Type], - location: Optional[Location] = None, - workers: int = 1, - ): - super().__init__(identifier, states, declarations, parameters, types, location) - self._workers = workers - - self._validate() - - self.error.propagate() - - self._optimize() - self.to_ir() - - @lru_cache # noqa: B019 - def to_ir(self) -> ir.Session: - variable_id = id_generator() - return ir.Session( - self.identifier, - [state.to_ir(variable_id) for state in self.states], - [d.to_ir(variable_id) for d in self.declarations.values()], - [p.to_ir() for p in self.parameters.values()], - self.types, - self.location, - variable_id, - self._workers, + itertools.chain( + self.parameters, + self.declarations, + self._enum_literals, + self._type_names, + ), + ), ) - def _optimize(self) -> None: for state in self.states: - state.optimize() + self.error.extend( + mty.check_identifier_notation( + itertools.chain( + ( + d.expression + for d in state.declarations.values() + if isinstance(d, decl.VariableDeclaration) and d.expression + ), + (a.expression for a in state.actions if isinstance(a, stmt.Assignment)), + ( + e + for a in state.actions + if isinstance(a, stmt.AttributeStatement) + for e in [expr.Variable(a.identifier), *a.parameters] + ), + ( + e + for a in state.actions + if isinstance(a, stmt.Reset) + for e in a.associations.values() + ), + ( + e + for t in state.transitions + for e in [expr.Variable(t.target), t.condition] + ), + ( + expr.Variable(t.target) + for t in [state.exception_transition] + if t is not None + ), + ), + itertools.chain( + self.parameters, + self.declarations, + self._enum_literals, + self._type_names, + state.declarations, + (s.identifier for s in self.states), + ), + ), + ) def _validate_states(self) -> None: if all(s == FINAL_STATE for s in self.states): @@ -1012,6 +1003,66 @@ def _validate(self) -> None: self._validate_usage() +def normalize_identifiers( + expression: expr.Expr, + variables: Iterable[ID], + enum_literals: Iterable[ID], + type_names: Iterable[ID], + functions: Iterable[ID], +) -> expr.Expr: + variables_map = {v: v for v in variables} + type_names_map = {t: t for t in type_names} + enum_literals_map = {l: l for l in enum_literals} + functions_map = {f: f for f in functions} + + if isinstance(expression, expr.Variable): + if expression.identifier in type_names_map: + return expr.TypeName( + ID(type_names_map[expression.identifier], location=expression.identifier.location), + expression.type_, + location=expression.location, + ) + if expression.identifier in enum_literals_map: + return expr.Literal( + ID( + enum_literals_map[expression.identifier], + location=expression.identifier.location, + ), + expression.type_, + location=expression.location, + ) + if expression.identifier in functions_map: + return expr.Call( + ID(functions_map[expression.identifier], location=expression.identifier.location), + [], + expression.negative, + expression.immutable, + expression.type_, + location=expression.location, + ) + if expression.identifier in variables_map: + return expr.Variable( + ID(variables_map[expression.identifier], location=expression.identifier.location), + expression.negative, + expression.immutable, + expression.type_, + location=expression.location, + ) + + if isinstance(expression, expr.Call) and expression.identifier in functions_map: + return expr.Call( + ID(functions_map[expression.identifier], location=expression.identifier.location), + expression.args, + expression.negative, + expression.immutable, + expression.type_, + expression.argument_types, + location=expression.location, + ) + + return expression + + @dataclass class UncheckedSession(UncheckedTopLevelDeclaration): identifier: ID diff --git a/rflx/model/type_.py b/rflx/model/type_.py index ac3bf3478..a87b71564 100644 --- a/rflx/model/type_.py +++ b/rflx/model/type_.py @@ -10,7 +10,7 @@ import rflx.typing_ as rty from rflx import const, expression as expr from rflx.common import indent_next, verbose_repr -from rflx.error import Location, Severity, Subsystem, fail +from rflx.error import Location, RecordFluxError, Severity, Subsystem, fail from rflx.identifier import ID, StrID from . import message @@ -831,3 +831,44 @@ def qualified_type_names(types: abc.Iterable[Type]) -> dict[ID, Type]: return { t.identifier.name if t.package == const.BUILTINS_PACKAGE else t.identifier: t for t in types } + + +def check_identifier_notation( + expressions: abc.Iterable[expr.Expr], + identifiers: abc.Iterable[ID], +) -> RecordFluxError: + id_map = {i: i for i in identifiers} + + def verify_identifier_notation(expression: expr.Expr, error: RecordFluxError) -> expr.Expr: + if ( + isinstance(expression, (expr.Variable, expr.Literal, expr.TypeName, expr.Call)) + and expression.identifier in id_map + and str(expression.identifier) != str(id_map[expression.identifier]) + ): + declaration_location = id_map[expression.identifier].location + error.extend( + [ + ( + f'casing of "{expression.identifier}" differs from casing in the' + f' declaration of "{id_map[expression.identifier]}"' + + (f" at {declaration_location.short}" if declaration_location else ""), + Subsystem.MODEL, + Severity.ERROR, + expression.identifier.location, + ), + ( + f'declaration of "{id_map[expression.identifier]}"', + Subsystem.MODEL, + Severity.INFO, + declaration_location, + ), + ], + ) + return expression + + error = RecordFluxError() + + for e in expressions: + e.substituted(lambda e: verify_identifier_notation(e, error)) + + return error diff --git a/tests/conftest.py b/tests/conftest.py index cf6ddc1e4..35dad2635 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,11 +65,7 @@ def pytest_assertrepr_compare(op: str, left: object, right: object) -> Sequence[ " Actual: " + re.sub(r"\n +", " ", str(left)), " Expected: " + re.sub(r"\n +", " ", str(right)), ] - if ( - isinstance(left, model.AbstractSession) - and isinstance(right, model.AbstractSession) - and op == "==" - ): + if isinstance(left, model.Session) and isinstance(right, model.Session) and op == "==": return [ "Session instances", "repr:", diff --git a/tests/unit/model/message_test.py b/tests/unit/model/message_test.py index 16ca6a1a8..2cf13b1b9 100644 --- a/tests/unit/model/message_test.py +++ b/tests/unit/model/message_test.py @@ -2041,7 +2041,8 @@ def test_invalid_use_of_message_attributes() -> None: ) -def test_identifier_normalization() -> None: +def test_message_identifier_normalization(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(Message, "_check_identifiers", lambda _, _s, _t: None) assert str( Message( "P::M", @@ -2096,6 +2097,79 @@ def test_identifier_normalization() -> None: ) +def test_message_inconsistent_identifier_casing() -> None: + with pytest.raises( + RecordFluxError, + match=( + r"^" + r':1:1: model: error: casing of "b" differs from casing' + r' in the declaration of "B" at :11:11\n' + r':11:11: model: info: declaration of "B"\n' + r':2:2: model: error: casing of "c" differs from casing' + r' in the declaration of "C" at :12:12\n' + r':12:12: model: info: declaration of "C"\n' + r':3:3: model: error: casing of "p::e" differs from cas' + r'ing in the declaration of "P::E" at :13:13\n' + r':13:13: model: info: declaration of "P::E"\n' + r':4:4: model: error: casing of "g" differs from casing' + r' in the declaration of "G" at :14:14\n' + r':14:14: model: info: declaration of "G"\n' + r':5:5: model: error: casing of "f" differs from casing' + r' in the declaration of "F" at :15:15\n' + r':15:15: model: info: declaration of "F"\n' + r':6:6: model: error: casing of "f" differs from casing' + r' in the declaration of "F" at :15:15\n' + r':15:15: model: info: declaration of "F"\n' + r':7:7: model: error: casing of "p::e" differs from casing' + r' in the declaration of "P::E" at :13:13\n' + r':13:13: model: info: declaration of "P::E"' + r"$" + ), + ): + Message( + "P::M", + [ + Link(INITIAL, Field("F")), + Link( + Field("F"), + Field(ID("g", location=Location((4, 4)))), + Equal(Variable(ID("f", location=Location((5, 5)))), Literal("A")), + Add( + Size(ID("f", location=Location((6, 6)))), + Size(ID("p::e", location=Location((7, 7)))), + ), + ), + Link( + Field("F"), + FINAL, + Equal(Variable("F"), Literal(ID("b", location=Location((1, 1))))), + ), + Link( + Field("F"), + FINAL, + And( + Equal(Variable("F"), Literal(ID("c", location=Location((2, 2))))), + Equal(Size(ID("p::e", location=Location((3, 3)))), Number(8)), + ), + ), + Link(Field("G"), FINAL), + ], + { + Field(ID("F", location=Location((15, 15)))): Enumeration( + ID("P::E", location=Location((13, 13))), + [ + ("A", Number(0)), + (ID("B", location=Location((11, 11))), Number(1)), + (ID("C", location=Location((12, 12))), Number(2)), + ], + Number(8), + always_valid=False, + ), + Field(ID("G", location=Location((14, 14)))): OPAQUE, + }, + ) + + def test_no_path_to_final() -> None: structure = [ Link(INITIAL, Field("F1")), @@ -4649,7 +4723,8 @@ def test_message_str() -> None: ) -def test_refinement_identifier_normalization() -> None: +def test_refinement_identifier_normalization(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(Refinement, "_check_identifiers", lambda _: None) assert str( Refinement( "R", @@ -4675,6 +4750,47 @@ def test_refinement_identifier_normalization() -> None: ) +def test_refinement_inconsistent_identifier_casing() -> None: + with pytest.raises( + RecordFluxError, + match=( + r"^" + r':1:1: model: error: casing of "value" differs from casing' + r' in the declaration of "Value"\n' + r'model: info: declaration of "Value"\n' + r':2:2: model: error: casing of "tag" differs from casing' + r' in the declaration of "Tag"\n' + r'model: info: declaration of "Tag"\n' + r':3:3: model: error: casing of "tlv::msg_data" differs from casing' + r' in the declaration of "TLV::Msg_Data"\n' + r'model: info: declaration of "TLV::Msg_Data"\n' + r':4:4: model: error: casing of "length" differs from casing' + r' in the declaration of "Length"\n' + r'model: info: declaration of "Length"\n' + r':5:5: model: error: casing of "tlv::length" differs from casing' + r' in the declaration of "TLV::Length"\n' + r'model: info: declaration of "TLV::Length"' + r"$" + ), + ): + Refinement( + "R", + models.tlv_message(), + Field(ID("value", location=Location((1, 1)))), + models.tlv_message(), + And( + Equal( + Variable(ID("tag", location=Location((2, 2)))), + Variable(ID("tlv::msg_data", location=Location((3, 3)))), + ), + Equal( + Variable(ID("length", location=Location((4, 4)))), + Size(ID("tlv::length", location=Location((5, 5)))), + ), + ), + ) + + def test_refinement_dependencies() -> None: assert models.universal_refinement().direct_dependencies == [ models.universal_message(), diff --git a/tests/unit/model/session_test.py b/tests/unit/model/session_test.py index 7f7f6c906..24ffc0035 100644 --- a/tests/unit/model/session_test.py +++ b/tests/unit/model/session_test.py @@ -108,7 +108,8 @@ def test_str() -> None: ) -def test_identifier_normalization() -> None: +def test_identifier_normalization(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(Session, "_check_identifiers", lambda _: None) assert str( Session( "P::S", @@ -184,6 +185,120 @@ def test_identifier_normalization() -> None: ) +def test_inconsistent_identifier_casing() -> None: + with pytest.raises( + RecordFluxError, + match=( + r"^" + r':1:1: model: error: casing of "tlv::message" differs from casing' + r' in the declaration of "TLV::Message"\n' + r'model: info: declaration of "TLV::Message"\n' + r':2:2: model: error: casing of "x" differs from casing' + r' in the declaration of "X" at :12:12\n' + r':12:12: model: info: declaration of "X"\n' + r':3:3: model: error: casing of "m" differs from casing' + r' in the declaration of "M" at :13:13\n' + r':13:13: model: info: declaration of "M"\n' + r':4:4: model: error: casing of "b" differs from casing' + r' in the declaration of "B" at :14:14\n' + r':14:14: model: info: declaration of "B"\n' + r':5:5: model: error: casing of "y" differs from casing' + r' in the declaration of "Y" at :15:15\n' + r':15:15: model: info: declaration of "Y"\n' + r':6:6: model: error: casing of "z" differs from casing' + r' in the declaration of "Z" at :16:16\n' + r':16:16: model: info: declaration of "Z"\n' + r':7:7: model: error: casing of "g" differs from casing' + r' in the declaration of "G" at :17:17\n' + r':17:17: model: info: declaration of "G"\n' + r':8:8: model: error: casing of "f" differs from casing' + r' in the declaration of "F" at :18:18\n' + r':18:18: model: info: declaration of "F"\n' + r':9:9: model: error: casing of "a" differs from casing' + r' in the declaration of "A" at :19:19\n' + r':19:19: model: info: declaration of "A"' + r"$" + ), + ): + Session( + "P::S", + [ + State( + ID("A", location=Location((19, 19))), + declarations=[], + actions=[ + stmt.Read( + ID("x", location=Location((2, 2))), + expr.Variable(ID("m", location=Location((3, 3)))), + ), + ], + transitions=[ + Transition(ID("b", location=Location((4, 4)))), + ], + ), + State( + ID("B", location=Location((14, 14))), + declarations=[ + decl.VariableDeclaration( + ID("Z", location=Location((16, 16))), + BOOLEAN.identifier, + expr.Variable(ID("y", location=Location((5, 5)))), + ), + ], + actions=[], + transitions=[ + Transition( + "null", + condition=expr.And( + expr.Equal( + expr.Variable(ID("z", location=Location((6, 6)))), + expr.TRUE, + ), + expr.Equal( + expr.Call( + ID("g", location=Location((7, 7))), + [expr.Variable(ID("f", location=Location((8, 8))))], + ), + expr.TRUE, + ), + ), + ), + Transition(ID("a", location=Location((9, 9)))), + ], + ), + ], + [ + decl.VariableDeclaration( + ID("M", location=Location((13, 13))), + ID("tlv::message", location=Location((1, 1))), + ), + decl.VariableDeclaration( + ID("Y", location=Location((15, 15))), + BOOLEAN.identifier, + expr.FALSE, + ), + ], + [ + decl.ChannelDeclaration( + ID("X", location=Location((12, 12))), + readable=True, + writable=True, + ), + decl.FunctionDeclaration( + ID("F", location=Location((18, 18))), + [], + BOOLEAN.identifier, + ), + decl.FunctionDeclaration( + ID("G", location=Location((17, 17))), + [decl.Argument("P", BOOLEAN.identifier)], + BOOLEAN.identifier, + ), + ], + [BOOLEAN, models.tlv_message()], + ) + + def test_invalid_name() -> None: with pytest.raises( RecordFluxError,