Skip to content

Commit

Permalink
Share a common data structure among the base service manager and indi…
Browse files Browse the repository at this point in the history
…vidual services for session management to avoid race conditions when verifying sessions' timestamps.
  • Loading branch information
kuujo committed Sep 12, 2017
1 parent 52588b0 commit 429fd2a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 69 deletions.
Expand Up @@ -30,7 +30,6 @@
import io.atomix.protocols.raft.service.ServiceId;
import io.atomix.protocols.raft.service.ServiceType;
import io.atomix.protocols.raft.session.RaftSession;
import io.atomix.protocols.raft.session.RaftSessionListener;
import io.atomix.protocols.raft.session.RaftSessions;
import io.atomix.protocols.raft.session.SessionId;
import io.atomix.protocols.raft.session.impl.RaftSessionContext;
Expand Down Expand Up @@ -102,7 +101,7 @@ public DefaultServiceContext(
this.serviceType = checkNotNull(serviceType);
this.service = checkNotNull(service);
this.server = checkNotNull(server);
this.sessions = new DefaultServiceSessions(sessionManager);
this.sessions = new DefaultServiceSessions(serviceId, sessionManager);
this.serviceExecutor = new ThreadPoolContext(threadPool);
this.snapshotExecutor = new ThreadPoolContext(threadPool);
this.threadPool = checkNotNull(threadPool);
Expand Down Expand Up @@ -199,22 +198,12 @@ private void expireSessions(long timestamp) {
for (RaftSessionContext session : sessions.getSessions()) {

// If the current timestamp minus the session timestamp is greater than the session timeout, expire the session.
if (timestamp - session.getTimestamp() > session.timeout()) {

// Remove the session from the sessions list.
sessions.remove(session);

long lastUpdated = session.getTimestamp();
if (lastUpdated > 0 && timestamp - lastUpdated > session.timeout()) {
log.debug("Detected expired session {}", session);

// Expire the session.
session.expire();

log.debug("Closing session {}", session.sessionId());

// Iterate through and invoke session listeners.
for (RaftSessionListener listener : sessions.getListeners()) {
listener.onExpire(session);
}
// Remove the session from the sessions list.
sessions.expireSession(session);
}
}
}
Expand Down Expand Up @@ -263,7 +252,6 @@ private void maybeInstallSnapshot(long index) {
ServiceType serviceType = ServiceType.from(reader.readString());
String serviceName = reader.readString();
int sessionCount = reader.readInt();
sessions.clear();
for (int i = 0; i < sessionCount; i++) {
SessionId sessionId = SessionId.from(reader.readLong());
MemberId node = MemberId.from(reader.readString());
Expand All @@ -286,7 +274,7 @@ private void maybeInstallSnapshot(long index) {
session.setEventIndex(reader.readLong());
session.setLastCompleted(reader.readLong());
session.setLastApplied(snapshot.index());
sessions.add(session);
sessions.openSession(session);
}
service.install(reader);
} catch (Exception e) {
Expand Down Expand Up @@ -392,12 +380,7 @@ public CompletableFuture<Long> openSession(long index, long timestamp, RaftSessi
expireSessions(currentTimestamp);

// Add the session to the sessions list.
sessions.add(session);

// Iterate through and invoke session listeners.
for (RaftSessionListener listener : sessions.getListeners()) {
listener.onOpen(session);
}
sessions.openSession(session);

// Commit the index, causing events to be sent to clients if necessary.
commit();
Expand Down Expand Up @@ -539,15 +522,7 @@ public CompletableFuture<Void> closeSession(long index, long timestamp, RaftSess
expireSessions(currentTimestamp);

// Remove the session from the sessions list.
sessions.remove(session);

// Close the session.
session.close();

// Iterate through and invoke session listeners.
for (RaftSessionListener listener : sessions.getListeners()) {
listener.onClose(session);
}
sessions.closeSession(session);

// Commit the index, causing events to be sent to clients if necessary.
commit();
Expand Down
Expand Up @@ -15,29 +15,25 @@
*/
package io.atomix.protocols.raft.service.impl;

import io.atomix.protocols.raft.service.ServiceId;
import io.atomix.protocols.raft.session.RaftSession;
import io.atomix.protocols.raft.session.RaftSessionListener;
import io.atomix.protocols.raft.session.RaftSessions;
import io.atomix.protocols.raft.session.SessionId;
import io.atomix.protocols.raft.session.impl.RaftSessionContext;
import io.atomix.protocols.raft.session.impl.RaftSessionManager;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* State machine sessions.
*/
class DefaultServiceSessions implements RaftSessions {
private final ServiceId serviceId;
private final RaftSessionManager sessionManager;
private final Map<Long, RaftSessionContext> sessions = new ConcurrentHashMap<>();
private final Set<RaftSessionListener> listeners = new HashSet<>();

DefaultServiceSessions(RaftSessionManager sessionManager) {
public DefaultServiceSessions(ServiceId serviceId, RaftSessionManager sessionManager) {
this.serviceId = serviceId;
this.sessionManager = sessionManager;
}

Expand All @@ -46,27 +42,26 @@ class DefaultServiceSessions implements RaftSessions {
*
* @param session The session to add.
*/
void add(RaftSessionContext session) {
sessions.put(session.sessionId().id(), session);
void openSession(RaftSessionContext session) {
sessionManager.registerSession(session);
}

/**
* Removes a session from the sessions list.
* Expires and removes a session from the sessions list.
*
* @param session The session to remove.
*/
void remove(RaftSessionContext session) {
sessions.remove(session.sessionId().id());
sessionManager.unregisterSession(session.sessionId().id());
void expireSession(RaftSessionContext session) {
sessionManager.expireSession(session.sessionId());
}

/**
* Clears the sessions.
* Closes and removes a session from the sessions list.
*
* @param session The session to remove.
*/
void clear() {
sessions.values().forEach(session -> sessionManager.unregisterSession(session.sessionId().id()));
sessions.clear();
void closeSession(RaftSessionContext session) {
sessionManager.closeSession(session.sessionId());
}

/**
Expand All @@ -75,38 +70,29 @@ void clear() {
* @return The session contexts.
*/
Collection<RaftSessionContext> getSessions() {
return sessions.values();
return sessionManager.getSessions(serviceId);
}

@Override
public RaftSession getSession(long sessionId) {
return sessions.get(sessionId);
return sessionManager.getSession(sessionId);
}

@Override
public RaftSessions addListener(RaftSessionListener listener) {
listeners.add(listener);
sessionManager.addListener(serviceId, listener);
return this;
}

@Override
public RaftSessions removeListener(RaftSessionListener listener) {
listeners.remove(listener);
sessionManager.removeListener(serviceId, listener);
return this;
}

/**
* Returns the session listeners.
*
* @return The session listeners.
*/
Collection<RaftSessionListener> getListeners() {
return listeners;
}

@Override
@SuppressWarnings("unchecked")
public Iterator<RaftSession> iterator() {
return (Iterator) sessions.values().iterator();
return (Iterator) getSessions().iterator();
}
}
Expand Up @@ -15,28 +15,72 @@
*/
package io.atomix.protocols.raft.session.impl;

import io.atomix.protocols.raft.service.ServiceId;
import io.atomix.protocols.raft.session.RaftSessionListener;
import io.atomix.protocols.raft.session.SessionId;

import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.stream.Collectors;

/**
* Session manager.
*/
public class RaftSessionManager {
private final Map<Long, RaftSessionContext> sessions = new ConcurrentHashMap<>();
private final Map<ServiceId, Set<RaftSessionListener>> listeners = new ConcurrentHashMap<>();

/**
* Registers a session.
*/
public void registerSession(RaftSessionContext session) {
sessions.putIfAbsent(session.sessionId().id(), session);
if (sessions.putIfAbsent(session.sessionId().id(), session) == null) {
Set<RaftSessionListener> listeners = this.listeners.get(session.getService().serviceId());
if (listeners != null) {
listeners.forEach(l -> l.onOpen(session));
}
}
}

/**
* Expires a session.
*/
public void expireSession(SessionId sessionId) {
RaftSessionContext session = sessions.remove(sessionId.id());
if (session != null) {
Set<RaftSessionListener> listeners = this.listeners.get(session.getService().serviceId());
if (listeners != null) {
listeners.forEach(l -> l.onExpire(session));
}
session.expire();
}
}

/**
* Closes a session.
*/
public void closeSession(SessionId sessionId) {
RaftSessionContext session = sessions.remove(sessionId.id());
if (session != null) {
Set<RaftSessionListener> listeners = this.listeners.get(session.getService().serviceId());
if (listeners != null) {
listeners.forEach(l -> l.onClose(session));
}
session.close();
}
}

/**
* Unregisters a session.
* Gets a session by session ID.
*
* @param sessionId The session ID.
* @return The session or {@code null} if the session doesn't exist.
*/
public void unregisterSession(long sessionId) {
sessions.remove(sessionId);
public RaftSessionContext getSession(SessionId sessionId) {
return getSession(sessionId.id());
}

/**
Expand All @@ -58,4 +102,42 @@ public Collection<RaftSessionContext> getSessions() {
return sessions.values();
}

/**
* Returns a set of sessions associated with the given service.
*
* @param serviceId the service identifier
* @return a collection of sessions associated with the given service
*/
public Collection<RaftSessionContext> getSessions(ServiceId serviceId) {
return sessions.values().stream()
.filter(session -> session.getService().serviceId().equals(serviceId))
.collect(Collectors.toSet());
}

/**
* Adds a session listener.
*
* @param serviceId the service ID for which to listen to sessions
* @param sessionListener the session listener
*/
public void addListener(ServiceId serviceId, RaftSessionListener sessionListener) {
Set<RaftSessionListener> sessionListeners = listeners.computeIfAbsent(serviceId, k -> new CopyOnWriteArraySet<>());
sessionListeners.add(sessionListener);
}

/**
* Removes a session listener.
*
* @param serviceId the service ID with which the listener is associated
* @param sessionListener the session listener
*/
public void removeListener(ServiceId serviceId, RaftSessionListener sessionListener) {
Set<RaftSessionListener> sessionListeners = listeners.get(serviceId);
if (sessionListeners != null) {
sessionListeners.remove(sessionListener);
if (sessionListeners.isEmpty()) {
listeners.remove(serviceId);
}
}
}
}

0 comments on commit 429fd2a

Please sign in to comment.