Skip to content

Commit

Permalink
fix(GODT-1757): Remove Snapshot locks
Browse files Browse the repository at this point in the history
Prior to this patch the updates to other session states were applied
immediately by going through all the available states in `backend.user`
and applying the changes immediately.

This patch changes the architecture so that all state updates for other
session states are queued via a channel rather than executed
immediately. This removes the need for snapshot lock, snapshot messages
lock, IDLE lock and responder lock as we no longer have two sessions
accessing state from one another.

To ensure the current state is still update, we mark all context with
the stateID (when available) and if the stateID matches one of the
states in the list, the update is applied immediately.

The downside is that every state now has to verify that this update can
be executed locally rather than checking in advance.

Finally this patch also fixes some update code to remove unnecessary
memory duplication.

Note: The `userWrapper` is a temporary type that will be remove in the
next patch.
  • Loading branch information
LBeernaertProton committed Aug 30, 2022
1 parent 5014519 commit 6347efd
Show file tree
Hide file tree
Showing 18 changed files with 696 additions and 516 deletions.
35 changes: 35 additions & 0 deletions internal/backend/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,38 @@ func AsSilent(parent context.Context) context.Context {
func isSilent(ctx context.Context) bool {
return ctx.Value(handleSilentKey) != nil
}

type handleRemoteUpdateCtxType struct{}

var handleRemoteUpdateCtxKey handleRemoteUpdateCtxType

func isRemoteUpdateCtx(ctx context.Context) bool {
return ctx.Value(handleRemoteUpdateCtxKey) != nil
}

func NewRemoteUpdateCtx(ctx context.Context) context.Context {
return context.WithValue(ctx, handleRemoteUpdateCtxKey, struct{}{})
}

type stateContextType struct{}

var stateContextKey stateContextType

func NewStateContext(ctx context.Context, s *State) context.Context {
if s == nil {
return ctx
}

return context.WithValue(ctx, stateContextKey, s.stateID)
}

func getStateIDFromContext(ctx context.Context) (int, bool) {
v := ctx.Value(stateContextKey)
if v == nil {
return 0, false
}

stateID, ok := v.(int)

return stateID, ok
}
85 changes: 27 additions & 58 deletions internal/backend/mailbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@ type Mailbox struct {
mbox *ent.Mailbox

state *State
snap *snapshotWrapper
snap *snapshot

selected bool
readOnly bool
}

func newMailbox(mbox *ent.Mailbox, state *State, wrapper *snapshotWrapper) *Mailbox {
selected := snapshotRead(wrapper, func(s *snapshot) bool {
return s != nil
})

func newMailbox(mbox *ent.Mailbox, state *State, snap *snapshot) *Mailbox {
return &Mailbox{
mbox: mbox,

state: state,

selected: selected,
selected: snap != nil,
readOnly: state.ro,
snap: wrapper,
snap: snap,
}
}

Expand Down Expand Up @@ -66,9 +62,7 @@ func (m *Mailbox) ExpungeIssued() bool {
}

func (m *Mailbox) Count() int {
return snapshotRead(m.snap, func(s *snapshot) int {
return len(s.getAllMessages())
})
return len(m.snap.getAllMessages())
}

func (m *Mailbox) Flags(ctx context.Context) (imap.FlagSet, error) {
Expand Down Expand Up @@ -117,18 +111,14 @@ func (m *Mailbox) Subscribed() bool {
}

func (m *Mailbox) GetMessagesWithFlag(flag string) []int {
return snapshotRead(m.snap, func(s *snapshot) []int {
return xslices.Map(s.getMessagesWithFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
return xslices.Map(m.snap.getMessagesWithFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
}

func (m *Mailbox) GetMessagesWithoutFlag(flag string) []int {
return snapshotRead(m.snap, func(s *snapshot) []int {
return xslices.Map(s.getMessagesWithoutFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
return xslices.Map(m.snap.getMessagesWithoutFlag(flag), func(msg *snapMsg) int {
return msg.Seq
})
}

Expand All @@ -141,11 +131,11 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet
if len(internalID) > 0 {
msgID := imap.InternalMessageID(internalID)

if exists, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (bool, error) {
if exists, err := DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (bool, error) {
return DBHasMessageWithID(ctx, client, msgID)
}); err != nil || !exists {
logrus.WithError(err).Warn("The message has an unknown internal ID")
} else if res, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
} else if res, err := DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
return m.state.actionAddMessagesToMailbox(ctx, tx, []MessageIDPair{NewMessageIDPairWithoutRemote(msgID)}, NewMailboxIDPair(m.mbox))
}); err != nil {
return 0, err
Expand All @@ -154,29 +144,23 @@ func (m *Mailbox) Append(ctx context.Context, literal []byte, flags imap.FlagSet
}
}

snapMBoxID := snapshotRead(m.snap, func(s *snapshot) MailboxIDPair {
return s.mboxID
})

return DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (int, error) {
return m.state.actionCreateMessage(ctx, tx, snapMBoxID, literal, flags, date)
return DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (int, error) {
return m.state.actionCreateMessage(ctx, tx, m.snap.mboxID, literal, flags, date)
})
}

// Copy copies the messages represented by the given sequence set into the mailbox with the given name.
// If the context is a UID context, the sequence set refers to message UIDs.
// If no items are copied the response object will be nil.
func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string) (response.Item, error) {
mbox, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
mbox, err := DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
return DBGetMailboxByName(ctx, client, name)
})
if err != nil {
return nil, ErrNoSuchMailbox
}

messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
messages, err := m.snap.getMessagesInRange(ctx, seq)
if err != nil {
return nil, err
}
Expand All @@ -189,7 +173,7 @@ func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string)
return msg.UID
})

destUIDs, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
destUIDs, err := DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
return m.state.actionAddMessagesToMailbox(ctx, tx, msgIDs, NewMailboxIDPair(mbox))
})
if err != nil {
Expand All @@ -211,20 +195,15 @@ func (m *Mailbox) Copy(ctx context.Context, seq *proto.SequenceSet, name string)
// If the context is a UID context, the sequence set refers to message UIDs.
// If no items are moved the response object will be nil.
func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string) (response.Item, error) {
mbox, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
mbox, err := DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Mailbox, error) {
return DBGetMailboxByName(ctx, client, name)
})

if err != nil {
return nil, ErrNoSuchMailbox
}

var snapMBoxID MailboxIDPair

messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
snapMBoxID = s.mboxID
return s.getMessagesInRange(ctx, seq)
})
messages, err := m.snap.getMessagesInRange(ctx, seq)
if err != nil {
return nil, err
}
Expand All @@ -237,8 +216,8 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
return msg.UID
})

destUIDs, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
return m.state.actionMoveMessages(ctx, tx, msgIDs, snapMBoxID, NewMailboxIDPair(mbox))
destUIDs, err := DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]int, error) {
return m.state.actionMoveMessages(ctx, tx, msgIDs, m.snap.mboxID, NewMailboxIDPair(mbox))
})
if err != nil {
return nil, err
Expand All @@ -256,9 +235,7 @@ func (m *Mailbox) Move(ctx context.Context, seq *proto.SequenceSet, name string)
}

func (m *Mailbox) Store(ctx context.Context, seq *proto.SequenceSet, operation proto.Operation, flags imap.FlagSet) error {
messages, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
messages, err := m.snap.getMessagesInRange(ctx, seq)
if err != nil {
return err
}
Expand All @@ -267,7 +244,7 @@ func (m *Mailbox) Store(ctx context.Context, seq *proto.SequenceSet, operation p
return msg.ID
})

return m.state.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
switch operation {
case proto.Operation_Add:
if _, err := m.state.actionAddMessageFlags(ctx, tx, msgIDs, flags); err != nil {
Expand All @@ -293,18 +270,14 @@ func (m *Mailbox) Expunge(ctx context.Context, seq *proto.SequenceSet) error {
var msg []*snapMsg

if seq != nil {
snapMsgs, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
snapMsgs, err := m.snap.getMessagesInRange(ctx, seq)
if err != nil {
return err
}

msg = snapMsgs
} else {
msg = snapshotRead(m.snap, func(s *snapshot) []*snapMsg {
return s.getAllMessages()
})
msg = m.snap.getAllMessages()
}

return m.expunge(ctx, msg)
Expand All @@ -319,17 +292,13 @@ func (m *Mailbox) expunge(ctx context.Context, messages []*snapMsg) error {
return msg.ID
})

mboxID := snapshotRead(m.snap, func(s *snapshot) MailboxIDPair {
return s.mboxID
})

return m.state.db.Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, mboxID)
return m.state.db().Write(ctx, func(ctx context.Context, tx *ent.Tx) error {
return m.state.actionRemoveMessagesFromMailbox(ctx, tx, msgIDs, m.snap.mboxID)
})
}

func (m *Mailbox) Flush(ctx context.Context, permitExpunge bool) ([]response.Response, error) {
return DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) ([]response.Response, error) {
return DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) ([]response.Response, error) {
return m.state.flushResponses(ctx, tx, permitExpunge)
})
}
Expand Down
8 changes: 3 additions & 5 deletions internal/backend/mailbox_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ import (
)

func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes []*proto.FetchAttribute, ch chan response.Response) error {
msg, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInRange(ctx, seq)
})
msg, err := m.snap.getMessagesInRange(ctx, seq)
if err != nil {
return err
}
Expand All @@ -42,7 +40,7 @@ func (m *Mailbox) fetchItems(ctx context.Context, msg *snapMsg, attributes []*pr
setSeen bool
)

message, err := DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) (*ent.Message, error) {
message, err := DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) (*ent.Message, error) {
return DBGetMessage(ctx, client, msg.ID.InternalID)
})
if err != nil {
Expand Down Expand Up @@ -91,7 +89,7 @@ func (m *Mailbox) fetchItems(ctx context.Context, msg *snapMsg, attributes []*pr
}

if setSeen {
newFlags, err := DBWriteResult(ctx, m.state.db, func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]imap.FlagSet, error) {
newFlags, err := DBWriteResult(ctx, m.state.db(), func(ctx context.Context, tx *ent.Tx) (map[imap.InternalMessageID]imap.FlagSet, error) {
return m.state.actionAddMessageFlags(ctx, tx, []MessageIDPair{msg.ID}, imap.NewFlagSet(imap.FlagSeen))
})
if err != nil {
Expand Down
22 changes: 8 additions & 14 deletions internal/backend/mailbox_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import (
)

func (m *Mailbox) Search(ctx context.Context, keys []*proto.SearchKey, decoder *encoding.Decoder) ([]int, error) {
snapMessages := snapshotRead(m.snap, func(s *snapshot) []*snapMsg {
return s.getAllMessages()
})
snapMessages := m.snap.getAllMessages()

messages, err := doSearch(ctx, m, snapMessages, keys, decoder)
if err != nil {
Expand Down Expand Up @@ -203,7 +201,7 @@ func (m *Mailbox) matchSearchKeyBefore(ctx context.Context, candidates []*snapMs
return nil, err
}

return DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return filter(candidates, func(message *snapMsg) (bool, error) {
msg, err := DBGetMessage(ctx, client, message.ID.InternalID)
if err != nil {
Expand Down Expand Up @@ -324,7 +322,7 @@ func (m *Mailbox) matchSearchKeyKeyword(ctx context.Context, candidates []*snapM
}

func (m *Mailbox) matchSearchKeyLarger(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return filter(candidates, func(message *snapMsg) (bool, error) {
msg, err := DBGetMessage(ctx, client, message.ID.InternalID)
if err != nil {
Expand Down Expand Up @@ -365,7 +363,7 @@ func (m *Mailbox) matchSearchKeyOn(ctx context.Context, candidates []*snapMsg, k
return nil, err
}

return DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return filter(candidates, func(message *snapMsg) (bool, error) {
msg, err := DBGetMessage(ctx, client, message.ID.InternalID)
if err != nil {
Expand Down Expand Up @@ -499,7 +497,7 @@ func (m *Mailbox) matchSearchKeySince(ctx context.Context, candidates []*snapMsg
return nil, err
}

return DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return filter(candidates, func(message *snapMsg) (bool, error) {
msg, err := DBGetMessage(ctx, client, message.ID.InternalID)
if err != nil {
Expand All @@ -514,7 +512,7 @@ func (m *Mailbox) matchSearchKeySince(ctx context.Context, candidates []*snapMsg
}

func (m *Mailbox) matchSearchKeySmaller(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db, func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return DBReadResult(ctx, m.state.db(), func(ctx context.Context, client *ent.Client) ([]*snapMsg, error) {
return filter(candidates, func(message *snapMsg) (bool, error) {
msg, err := DBGetMessage(ctx, client, message.ID.InternalID)
if err != nil {
Expand Down Expand Up @@ -585,9 +583,7 @@ func (m *Mailbox) matchSearchKeyTo(ctx context.Context, candidates []*snapMsg, k
}

func (m *Mailbox) matchSearchKeyUID(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
left, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInUIDRange(key.GetSequenceSet())
})
left, err := m.snap.getMessagesInUIDRange(key.GetSequenceSet())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -634,9 +630,7 @@ func (m *Mailbox) matchSearchKeyUnseen(ctx context.Context, candidates []*snapMs
}

func (m *Mailbox) matchSearchKeySeqSet(ctx context.Context, candidates []*snapMsg, key *proto.SearchKey) ([]*snapMsg, error) {
left, err := snapshotReadErr(m.snap, func(s *snapshot) ([]*snapMsg, error) {
return s.getMessagesInSeqRange(key.GetSequenceSet())
})
left, err := m.snap.getMessagesInSeqRange(key.GetSequenceSet())
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 6347efd

Please sign in to comment.