Skip to content

Commit

Permalink
symbole cleaning in session package
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Apr 17, 2024
1 parent 89635ac commit 2d7c77c
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,14 @@ func (node *Node) IsHelperNode() bool {
// SessionProvider interface implementation

// GetSessionFromID returns the session with the given ID.
func (node *Node) GetSessionFromID(sessionID session.SessionID) (*session.Session, bool) {
func (node *Node) GetSessionFromID(sessionID session.ID) (*session.Session, bool) {
return node.sessions.GetSessionFromID(sessionID)
}

// GetSessionFromContext returns the session by extracting the session id from the
// provided context.
func (node *Node) GetSessionFromContext(ctx context.Context) (*session.Session, bool) {
sessID, has := session.SessionIDFromContext(ctx)
sessID, has := session.IDFromContext(ctx)
if !has {
return nil, false
}
Expand Down
2 changes: 1 addition & 1 deletion services/compute/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ func (s *Service) PutCiphertext(ctx context.Context, ct session.Ciphertext) erro

_, exists := s.sessions.GetSessionFromContext(ctx)
if !exists {
sessid, _ := session.SessionIDFromContext(ctx)
sessid, _ := session.IDFromContext(ctx)
return fmt.Errorf("invalid session id \"%s\"", sessid)
}

Expand Down
8 changes: 4 additions & 4 deletions services/setup/resultbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func newObjStoreResultBackend(os objectstore.ObjectStore) *objStoreResultBackend

func (osrb objStoreResultBackend) Has(ctx context.Context, sig protocol.Signature) (has bool, err error) {

sessid, has := session.SessionIDFromContext(ctx)
sessid, has := session.IDFromContext(ctx)
if !has {
return false, fmt.Errorf("session id not found in context")
}
Expand All @@ -39,7 +39,7 @@ func (osrb objStoreResultBackend) Has(ctx context.Context, sig protocol.Signatur

func (osrb objStoreResultBackend) Put(ctx context.Context, pd protocol.Descriptor, aggShare protocol.Share) error {

sessid, has := session.SessionIDFromContext(ctx)
sessid, has := session.IDFromContext(ctx)
if !has {
return fmt.Errorf("session id not found in context")
}
Expand All @@ -54,7 +54,7 @@ func (osrb objStoreResultBackend) Put(ctx context.Context, pd protocol.Descripto
}

func (osrb objStoreResultBackend) GetShare(ctx context.Context, sig protocol.Signature, share protocol.LattigoShare) (err error) { // TODO: replace by binary unmasharller and remove interface
sessid, has := session.SessionIDFromContext(ctx)
sessid, has := session.IDFromContext(ctx)
if !has {
return fmt.Errorf("session id not found in context")
}
Expand All @@ -64,7 +64,7 @@ func (osrb objStoreResultBackend) GetShare(ctx context.Context, sig protocol.Sig
}

func (osrb objStoreResultBackend) GetProtocolDesc(ctx context.Context, sig protocol.Signature) (pd *protocol.Descriptor, err error) {
sessid, has := session.SessionIDFromContext(ctx)
sessid, has := session.IDFromContext(ctx)
if !has {
return nil, fmt.Errorf("session id not found in context")
}
Expand Down
10 changes: 5 additions & 5 deletions session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var (
)

// NewContext returns a new context derived from ctx with the given session and circuit IDs.
func NewContext(ctx context.Context, sessID SessionID, circID ...CircuitID) context.Context {
func NewContext(ctx context.Context, sessID ID, circID ...CircuitID) context.Context {
ctx = context.WithValue(ctx, CtxSessionID, sessID)
if len(circID) != 0 {
ctx = ContextWithCircuitID(ctx, circID[0])
Expand All @@ -27,7 +27,7 @@ func NewContext(ctx context.Context, sessID SessionID, circID ...CircuitID) cont

// NewBackgroundContext returns a new context derived from context.Background with
// the given session and circuit IDs.
func NewBackgroundContext(sessID SessionID, circID ...CircuitID) context.Context {
func NewBackgroundContext(sessID ID, circID ...CircuitID) context.Context {
return NewContext(context.Background(), sessID, circID...)
}

Expand All @@ -47,9 +47,9 @@ func NodeIDFromContext(ctx context.Context) (NodeID, bool) {
return nid, ok
}

// SessionIDFromContext returns the session ID from the context.
func SessionIDFromContext(ctx context.Context) (SessionID, bool) {
sessID, ok := ctx.Value(CtxSessionID).(SessionID)
// IDFromContext returns the session ID from the context.
func IDFromContext(ctx context.Context) (ID, bool) {
sessID, ok := ctx.Value(CtxSessionID).(ID)
return sessID, ok
}

Expand Down
10 changes: 5 additions & 5 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type FHEParameters interface { // TODO: Lattigo could have a common interface fo
// NodeID is the unique identifier of a node.
type NodeID string

// SessionID is the unique identifier of a session.
type SessionID string
// ID is the unique identifier of a session.
type ID string

// CircuitID is the unique identifier of a running circuit.
type CircuitID string
Expand Down Expand Up @@ -52,7 +52,7 @@ type FHEParamerersLiteralProvider interface {

// Parameters contains data used to initialize a Session.
type Parameters struct {
ID SessionID
ID ID
Nodes []NodeID
FHEParameters FHEParamerersLiteralProvider
Threshold int
Expand Down Expand Up @@ -249,15 +249,15 @@ func (sess *Session) Contains(nodeID NodeID) bool {
return utils.NewSet(sess.Nodes).Contains(nodeID)
}

func (sess *Session) GetSessionFromID(sessionID SessionID) (*Session, bool) {
func (sess *Session) GetSessionFromID(sessionID ID) (*Session, bool) {
if sess.ID == sessionID {
return sess, true
}
return nil, false
}

func (sess *Session) GetSessionFromContext(ctx context.Context) (*Session, bool) {
sessID, has := SessionIDFromContext(ctx)
sessID, has := IDFromContext(ctx)
if !has {
return nil, false
}
Expand Down
8 changes: 4 additions & 4 deletions session/sessionstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ import (
)

type SessionProvider interface {
GetSessionFromID(sessionID SessionID) (*Session, bool)
GetSessionFromID(sessionID ID) (*Session, bool)
GetSessionFromContext(ctx context.Context) (*Session, bool)
}

type SessionStore struct {
lock sync.RWMutex
sessions map[SessionID]*Session
sessions map[ID]*Session
}

func NewSessionStore() *SessionStore {
ss := new(SessionStore)
ss.sessions = make(map[SessionID]*Session)
ss.sessions = make(map[ID]*Session)
return ss
}

Expand All @@ -38,7 +38,7 @@ func (s *SessionStore) NewRLWESession(sessParams Parameters, nodeID NodeID) (ses
return sess, err
}

func (s *SessionStore) GetSessionFromID(id SessionID) (*Session, bool) {
func (s *SessionStore) GetSessionFromID(id ID) (*Session, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
sess, ok := s.sessions[id]
Expand Down
6 changes: 3 additions & 3 deletions transport/centralized/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var (
func getOutgoingContext(ctx context.Context, senderID session.NodeID) context.Context {
md := metadata.New(nil)
md.Append("sender_id", string(senderID))
if sessID, hasSessID := session.SessionIDFromContext(ctx); hasSessID {
if sessID, hasSessID := session.IDFromContext(ctx); hasSessID {
md.Append(string(ctxSessionID), string(sessID))
}
if circID, hasCircID := session.CircuitIDFromContext(ctx); hasCircID {
Expand Down Expand Up @@ -58,8 +58,8 @@ func senderIDFromIncomingContext(ctx context.Context) session.NodeID {
return session.NodeID(valueFromIncomingContext(ctx, "sender_id"))
}

func sessionIDFromIncomingContext(ctx context.Context) session.SessionID {
return session.SessionID(valueFromIncomingContext(ctx, string(ctxSessionID)))
func sessionIDFromIncomingContext(ctx context.Context) session.ID {
return session.ID(valueFromIncomingContext(ctx, string(ctxSessionID)))
}

func circuitIDFromIncomingContext(ctx context.Context) session.CircuitID {
Expand Down

0 comments on commit 2d7c77c

Please sign in to comment.