diff --git a/imap/strong_types.go b/imap/strong_types.go index d8f40899..f35afc83 100644 --- a/imap/strong_types.go +++ b/imap/strong_types.go @@ -4,7 +4,6 @@ import ( "fmt" "strconv" - "github.com/ProtonMail/gluon/internal/utils" "github.com/google/uuid" ) @@ -13,11 +12,11 @@ type MailboxID string type MessageID string func (l MailboxID) ShortID() string { - return utils.ShortID(string(l)) + return ShortID(string(l)) } func (m MessageID) ShortID() string { - return utils.ShortID(string(m)) + return ShortID(string(m)) } type InternalMessageID struct { @@ -31,7 +30,7 @@ func (i InternalMailboxID) ShortID() string { } func (i InternalMessageID) ShortID() string { - return utils.ShortID(i.String()) + return ShortID(i.String()) } func (i InternalMailboxID) String() string { diff --git a/imap/update_mailbox_created.go b/imap/update_mailbox_created.go index b5b7948e..99919e29 100644 --- a/imap/update_mailbox_created.go +++ b/imap/update_mailbox_created.go @@ -3,8 +3,6 @@ package imap import ( "fmt" "strings" - - "github.com/ProtonMail/gluon/internal/utils" ) type MailboxCreated struct { @@ -26,6 +24,6 @@ func (u *MailboxCreated) String() string { return fmt.Sprintf( "MailboxCreated: Mailbox.ID = %v, Mailbox.Name = %v", u.Mailbox.ID.ShortID(), - utils.ShortID(strings.Join(u.Mailbox.Name, "/")), + ShortID(strings.Join(u.Mailbox.Name, "/")), ) } diff --git a/imap/update_mailbox_updated.go b/imap/update_mailbox_updated.go index fdea14cf..76578da5 100644 --- a/imap/update_mailbox_updated.go +++ b/imap/update_mailbox_updated.go @@ -3,8 +3,6 @@ package imap import ( "fmt" "strings" - - "github.com/ProtonMail/gluon/internal/utils" ) type MailboxUpdated struct { @@ -28,6 +26,6 @@ func (u *MailboxUpdated) String() string { return fmt.Sprintf( "MailboxUpdated: MailboxID = %v, MailboxName = %v", u.MailboxID.ShortID(), - utils.ShortID(strings.Join(u.MailboxName, "/")), + ShortID(strings.Join(u.MailboxName, "/")), ) } diff --git a/imap/utils.go b/imap/utils.go new file mode 100644 index 00000000..c6b1b51c --- /dev/null +++ b/imap/utils.go @@ -0,0 +1,12 @@ +package imap + +// ShortID return a string containing a short version of the given ID. Use only for debug display. +func ShortID(id string) string { + const l = 12 + + if len(id) < l { + return id + } + + return id[0:l] + "..." +} diff --git a/internal/backend/state_user_interface_impl.go b/internal/backend/state_user_interface_impl.go index 4a7a4c7f..aebe3dd7 100644 --- a/internal/backend/state_user_interface_impl.go +++ b/internal/backend/state_user_interface_impl.go @@ -2,6 +2,7 @@ package backend import ( "context" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" @@ -89,3 +90,7 @@ func (s *StateUserInterfaceImpl) GetRecoveryMailboxID() ids.MailboxIDPair { RemoteID: ids.GluonInternalRecoveryMailboxRemoteID, } } + +func (s *StateUserInterfaceImpl) GetRecoveredMessageHashesMap() *utils.MessageHashesMap { + return s.u.recoveredMessageHashes +} diff --git a/internal/backend/user.go b/internal/backend/user.go index 811f3278..76604b18 100644 --- a/internal/backend/user.go +++ b/internal/backend/user.go @@ -3,6 +3,7 @@ package backend import ( "context" "fmt" + "github.com/ProtonMail/gluon/internal/utils" "sync" "github.com/ProtonMail/gluon/async" @@ -45,6 +46,8 @@ type user struct { uidValidityGenerator imap.UIDValidityGenerator panicHandler async.PanicHandler + + recoveredMessageHashes *utils.MessageHashesMap } func newUser( @@ -62,6 +65,8 @@ func newUser( return nil, err } + recoveredMessageHashes := utils.NewMessageHashesMap() + // Create recovery mailbox if it does not exist recoveryMBox, err := db.WriteResult(ctx, database, func(ctx context.Context, tx *ent.Tx) (*ent.Mailbox, error) { uidValidity, err := uidValidityGenerator.Generate() @@ -78,7 +83,30 @@ func newUser( Attributes: imap.NewFlagSet(imap.AttrNoInferiors), } - return db.GetOrCreateMailbox(ctx, tx, mbox, delimiter, uidValidity) + recoveryMBox, err := db.GetOrCreateMailbox(ctx, tx, mbox, delimiter, uidValidity) + if err != nil { + return nil, err + } + + // Pre-fill the message hashes map + messages, err := db.GetMailboxMessageIDPairs(ctx, tx.Client(), recoveryMBox.ID) + if err != nil { + return nil, err + } + + for _, m := range messages { + literal, err := st.Get(m.InternalID) + if err != nil { + logrus.WithError(err).Errorf("Failed to load %v for store for recovered message hashes map", m.InternalID) + continue + } + + if _, err := recoveredMessageHashes.Insert(m.InternalID, literal); err != nil { + logrus.WithError(err).Errorf("Failed insert literal for %v into recovered message hashes map", m.InternalID) + } + } + + return recoveryMBox, nil }) if err != nil { return nil, err @@ -104,6 +132,8 @@ func newUser( uidValidityGenerator: uidValidityGenerator, panicHandler: panicHandler, + + recoveredMessageHashes: recoveredMessageHashes, } if err := user.deleteAllMessagesMarkedDeleted(ctx); err != nil { diff --git a/internal/state/actions.go b/internal/state/actions.go index d1fb6973..797d0d19 100644 --- a/internal/state/actions.go +++ b/internal/state/actions.go @@ -201,6 +201,12 @@ func (state *State) actionCreateRecoveredMessage( return err } + alreadyKnown, err := state.user.GetRecoveredMessageHashesMap().Insert(internalID, literal) + if err == nil && alreadyKnown { + // Message is already known to us, so we ignore it. + return nil + } + if err := state.user.GetStore().SetUnchecked(internalID, bytes.NewReader(literal)); err != nil { return fmt.Errorf("failed to store message literal: %w", err) } @@ -436,6 +442,8 @@ func (state *State) actionMoveMessagesOutOfRecoveryMailbox( return nil, err } + state.user.GetRecoveredMessageHashesMap().Erase(oldInternalIDs...) + updates = append(updates, removeUpdates...) } @@ -470,6 +478,8 @@ func (state *State) actionRemoveMessagesFromMailboxUnchecked( if err := state.user.GetRemote().RemoveMessagesFromMailbox(ctx, remoteIDs, mboxID.RemoteID); err != nil { return err } + } else { + state.user.GetRecoveredMessageHashesMap().Erase(internalIDs...) } updates, err := RemoveMessagesFromMailbox(ctx, tx, mboxID.InternalID, internalIDs) diff --git a/internal/state/user_interface.go b/internal/state/user_interface.go index be1e126d..f7174f8b 100644 --- a/internal/state/user_interface.go +++ b/internal/state/user_interface.go @@ -2,6 +2,7 @@ package state import ( "context" + "github.com/ProtonMail/gluon/internal/utils" "github.com/ProtonMail/gluon/imap" "github.com/ProtonMail/gluon/internal/db" @@ -31,4 +32,6 @@ type UserInterface interface { GetRecoveryMailboxID() ids.MailboxIDPair GenerateUIDValidity() (imap.UID, error) + + GetRecoveredMessageHashesMap() *utils.MessageHashesMap } diff --git a/internal/utils/message_hashmap.go b/internal/utils/message_hashmap.go new file mode 100644 index 00000000..c19721f1 --- /dev/null +++ b/internal/utils/message_hashmap.go @@ -0,0 +1,61 @@ +package utils + +import ( + "crypto/sha256" + "encoding/hex" + "github.com/ProtonMail/gluon/imap" + "sync" +) + +// MessageHashesMap tracks the hashes for a literal and it's associated internal IMAP ID. +type MessageHashesMap struct { + lock sync.Mutex + idToHash map[imap.InternalMessageID]string + hashes map[string]struct{} +} + +func NewMessageHashesMap() *MessageHashesMap { + return &MessageHashesMap{ + idToHash: make(map[imap.InternalMessageID]string), + hashes: make(map[string]struct{}), + } +} + +// Insert inserts the hash of the current message literal into the map and return true if an existing value was already +// present. +func (m *MessageHashesMap) Insert(id imap.InternalMessageID, literal []byte) (bool, error) { + hash := sha256.New() + + if _, err := hash.Write(literal); err != nil { + return false, err + } + + literalHash := hash.Sum(nil) + literalHashStr := hex.EncodeToString(literalHash) + + m.lock.Lock() + defer m.lock.Unlock() + + if _, ok := m.hashes[literalHashStr]; ok { + return true, nil + } + + m.idToHash[id] = literalHashStr + m.hashes[literalHashStr] = struct{}{} + + return false, nil +} + +// Erase removes the info associated with a given id. +func (m *MessageHashesMap) Erase(ids ...imap.InternalMessageID) { + m.lock.Lock() + defer m.lock.Unlock() + + for _, id := range ids { + if v, ok := m.idToHash[id]; ok { + delete(m.hashes, v) + } + + delete(m.idToHash, id) + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 85ee1ad4..8e8cd68b 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -20,17 +20,6 @@ func NewRandomMessageID() string { return "msg-" + uuid.NewString() } -// ShortID return a string containing a short version of the given ID. Use only for debug display. -func ShortID(id string) string { - const l = 12 - - if len(id) < l { - return id - } - - return id[0:l] + "..." -} - // ErrCause returns the cause of the error, the inner-most error in the wrapped chain. func ErrCause(err error) error { cause := err diff --git a/tests/recovery_mailbox_test.go b/tests/recovery_mailbox_test.go index 1494dad6..e18490c3 100644 --- a/tests/recovery_mailbox_test.go +++ b/tests/recovery_mailbox_test.go @@ -200,6 +200,44 @@ func TestFailedAppendEndsInRecovery(t *testing.T) { }) } +func TestFailedAppendAreDedupedInRecoveryMailbox(t *testing.T) { + runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withConnectorBuilder(&failAppendLabelConnectorBuilder{})), func(client *client.Client, s *testSession) { + { + status, err := client.Status(ids.GluonRecoveryMailboxName, []goimap.StatusItem{goimap.StatusMessages}) + require.NoError(t, err) + require.Equal(t, uint32(0), status.Messages) + } + + status, err := client.Select("INBOX", false) + require.NoError(t, err) + require.Equal(t, uint32(0), status.Messages) + require.Error(t, doAppendWithClient(client, "INBOX", "To: Foo@bar.com", time.Now())) + require.Error(t, doAppendWithClient(client, "INBOX", "To: Foo@bar.com", time.Now())) + require.Error(t, doAppendWithClient(client, "INBOX", "To: Bar@bar.com", time.Now())) + + { + status, err := client.Status(ids.GluonRecoveryMailboxName, []goimap.StatusItem{goimap.StatusMessages}) + require.NoError(t, err) + require.Equal(t, uint32(2), status.Messages) + } + { + status, err := client.Status("INBOX", []goimap.StatusItem{goimap.StatusMessages}) + require.NoError(t, err) + require.Equal(t, uint32(0), status.Messages) + } + + { + _, err := client.Select(ids.GluonRecoveryMailboxName, false) + require.NoError(t, err) + // Check that no custom headers are appended to the message. + newFetchCommand(t, client).withItems("BODY[]").fetch("1").forSeqNum(1, func(builder *validatorBuilder) { + builder.ignoreFlags() + builder.wantSection("BODY[]", "To: Foo@bar.com") + }).checkAndRequireMessageCount(1) + } + }) +} + func TestRecoveryMBoxCanBeCopiedOutOfDedup(t *testing.T) { runOneToOneTestClientWithAuth(t, defaultServerOptions(t, withConnectorBuilder(&recoveryDedupConnectorConnectorBuilder{})), func(client *client.Client, s *testSession) { // Insert first message, fails.