Skip to content
Permalink
Browse files
ARTEMIS-3770 refactor MQTT handling of client ID
It would be useful for security manager implementations to be able to
alter the client ID of MQTT connections.

This commit supports this functionality by moving the code which handles
the client ID *ahead* of the authentication code. There it sets the
client ID on the connection and thereafter any component (e.g. security
managers) which needs to inspect or modify it can do so on the
connection.

This commit also refactors the MQTT connection class to extend the
abstract connection class. This greatly simplifies the MQTT connection
class and will make it easier to maintain in the future.
  • Loading branch information
jbertram committed May 9, 2022
1 parent e420eb4 commit 446ff61542f47f50c2299d8ef1cae8fe2b98a5ad
Showing 7 changed files with 287 additions and 273 deletions.
@@ -161,20 +161,16 @@ public boolean removeCloseListener(final CloseListener listener) {

@Override
public List<CloseListener> removeCloseListeners() {
List<CloseListener> ret = new ArrayList<>(closeListeners);

List<CloseListener> deletedCloseListeners = new ArrayList<>(closeListeners);
closeListeners.clear();

return ret;
return deletedCloseListeners;
}

@Override
public List<FailureListener> removeFailureListeners() {
List<FailureListener> ret = getFailureListeners();

List<FailureListener> deletedFailureListeners = getFailureListeners();
failureListeners.clear();

return ret;
return deletedFailureListeners;
}

@Override
@@ -19,153 +19,43 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.activemq.artemis.api.core.ActiveMQBuffer;
import org.apache.activemq.artemis.api.core.ActiveMQException;
import org.apache.activemq.artemis.api.core.SimpleString;
import org.apache.activemq.artemis.core.remoting.CloseListener;
import org.apache.activemq.artemis.core.remoting.FailureListener;
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection;
import org.apache.activemq.artemis.spi.core.protocol.AbstractRemotingConnection;
import org.apache.activemq.artemis.spi.core.remoting.Connection;
import org.apache.activemq.artemis.spi.core.remoting.ReadyListener;

import javax.security.auth.Subject;

public class MQTTConnection implements RemotingConnection {

private final Connection transportConnection;

private final long creationTime;

private AtomicBoolean dataReceived;
public class MQTTConnection extends AbstractRemotingConnection {

private boolean destroyed;

private boolean connected;

private String clientID;

private final List<FailureListener> failureListeners = new CopyOnWriteArrayList<>();

private final List<CloseListener> closeListeners = new CopyOnWriteArrayList<>();

private Subject subject;

private int receiveMaximum = -1;

private String protocolVersion;

private boolean clientIdAssignedByBroker = false;

public MQTTConnection(Connection transportConnection) throws Exception {
this.transportConnection = transportConnection;
this.creationTime = System.currentTimeMillis();
this.dataReceived = new AtomicBoolean();
super(transportConnection, null);
this.destroyed = false;
transportConnection.setProtocolConnection(this);
}


@Override
public void scheduledFlush() {
flush();
}

@Override
public boolean isWritable(ReadyListener callback) {
return transportConnection.isWritable(callback) && transportConnection.isOpen();
}

@Override
public Object getID() {
return transportConnection.getID();
}

@Override
public long getCreationTime() {
return creationTime;
}

@Override
public String getRemoteAddress() {
return transportConnection.getRemoteAddress();
}

@Override
public void addFailureListener(FailureListener listener) {
failureListeners.add(listener);
}

@Override
public boolean removeFailureListener(FailureListener listener) {
return failureListeners.remove(listener);
}

@Override
public void addCloseListener(CloseListener listener) {
closeListeners.add(listener);
}

@Override
public boolean removeCloseListener(CloseListener listener) {
return closeListeners.remove(listener);
}

@Override
public List<CloseListener> removeCloseListeners() {
List<CloseListener> deletedCloseListeners = copyCloseListeners();
closeListeners.clear();
return deletedCloseListeners;
}

@Override
public void setCloseListeners(List<CloseListener> listeners) {
closeListeners.clear();
closeListeners.addAll(listeners);
}

@Override
public List<FailureListener> getFailureListeners() {
return failureListeners;
}

@Override
public List<FailureListener> removeFailureListeners() {
List<FailureListener> deletedFailureListeners = copyFailureListeners();
failureListeners.clear();
return deletedFailureListeners;
}

@Override
public void setFailureListeners(List<FailureListener> listeners) {
failureListeners.clear();
failureListeners.addAll(listeners);
}

@Override
public ActiveMQBuffer createTransportBuffer(int size) {
return transportConnection.createTransportBuffer(size);
}

@Override
public void fail(ActiveMQException me) {
List<FailureListener> copy = copyFailureListeners();
List<FailureListener> copy = new ArrayList<>(failureListeners);
for (FailureListener listener : copy) {
listener.connectionFailed(me, false);
}
transportConnection.close();
}

private List<FailureListener> copyFailureListeners() {
return new ArrayList<>(failureListeners);
}

private List<CloseListener> copyCloseListeners() {
return new ArrayList<>(closeListeners);
}

@Override
public void fail(ActiveMQException me, String scaleDownTargetNodeID) {
synchronized (failureListeners) {
@@ -198,11 +88,6 @@ public void destroy() {
disconnect(false);
}

@Override
public Connection getTransportConnection() {
return transportConnection;
}

@Override
public boolean isClient() {
return false;
@@ -224,12 +109,7 @@ public void disconnect(String scaleDownNodeID, boolean criticalError) {
}

protected void dataReceived() {
dataReceived.set(true);
}

@Override
public boolean checkDataReceived() {
return dataReceived.compareAndSet(true, false);
dataReceived = true;
}

@Override
@@ -254,31 +134,11 @@ public void killMessage(SimpleString nodeID) {
//unsupported
}

@Override
public boolean isSupportReconnect() {
return false;
}

@Override
public boolean isSupportsFlowControl() {
return false;
}

@Override
public void setAuditSubject(Subject subject) {
this.subject = subject;
}

@Override
public Subject getAuditSubject() {
return subject;
}

@Override
public Subject getSubject() {
return null;
}

/**
* Returns the name of the protocol for this Remoting Connection
*
@@ -289,26 +149,6 @@ public String getProtocolName() {
return MQTTProtocolManagerFactory.MQTT_PROTOCOL_NAME + (protocolVersion != null ? protocolVersion : "");
}

/**
* Sets the client ID associated with this connection
*
* @param cID
*/
@Override
public void setClientID(String cID) {
this.clientID = cID;
}

/**
* Returns the Client ID associated with this connection
*
* @return
*/
@Override
public String getClientID() {
return clientID;
}

@Override
public String getTransportLocalAddress() {
return getTransportConnection().getLocalAddress();
@@ -325,4 +165,12 @@ public void setReceiveMaximum(int maxReceive) {
public void setProtocolVersion(String protocolVersion) {
this.protocolVersion = protocolVersion;
}

public void setClientIdAssignedByBroker(boolean clientIdAssignedByBroker) {
this.clientIdAssignedByBroker = clientIdAssignedByBroker;
}

public boolean isClientIdAssignedByBroker() {
return clientIdAssignedByBroker;
}
}
@@ -17,15 +17,12 @@

package org.apache.activemq.artemis.core.protocol.mqtt;

import java.util.UUID;
import java.util.List;

import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttProperties;
import io.netty.handler.codec.mqtt.MqttVersion;
import io.netty.util.CharsetUtil;
import org.apache.activemq.artemis.api.core.Pair;
import org.apache.activemq.artemis.api.core.client.ActiveMQClient;
import org.apache.activemq.artemis.core.server.ActiveMQServer;
import org.apache.activemq.artemis.core.server.ServerSession;
@@ -58,7 +55,7 @@ public MQTTConnectionManager(MQTTSession session) {
session.getConnection().addFailureListener(failureListener);
}

void connect(MqttConnectMessage connect, String validatedUser) throws Exception {
void connect(MqttConnectMessage connect, String validatedUser, String username, String password) throws Exception {
if (session.getVersion() == MQTTVersion.MQTT_5) {
session.getConnection().setProtocolVersion(Byte.toString(MqttVersion.MQTT_5.protocolLevel()));
String authenticationMethod = MQTTUtil.getProperty(String.class, connect.variableHeader().properties(), AUTHENTICATION_METHOD);
@@ -70,32 +67,14 @@ void connect(MqttConnectMessage connect, String validatedUser) throws Exception
}
}

String password = connect.payload().passwordInBytes() == null ? null : new String( connect.payload().passwordInBytes(), CharsetUtil.UTF_8);
String username = connect.payload().userName();

// the Netty codec uses "CleanSession" for both 3.1.1 "clean session" and 5 "clean start" which have slightly different semantics
boolean cleanStart = connect.variableHeader().isCleanSession();

Pair<String, Boolean> clientIdValidation = validateClientId(connect.payload().clientIdentifier(), cleanStart);
if (clientIdValidation == null) {
// this represents an invalid client ID for MQTT 5 clients
session.getProtocolHandler().sendConnack(MQTTReasonCodes.CLIENT_IDENTIFIER_NOT_VALID);
disconnect(true);
return;
} else if (clientIdValidation.getA() == null) {
// this represents an invalid client ID for MQTT 3.x clients
session.getProtocolHandler().sendConnack(MQTTReasonCodes.IDENTIFIER_REJECTED_3);
disconnect(true);
return;
}
String clientId = clientIdValidation.getA();
boolean assignedClientId = clientIdValidation.getB();

String clientId = session.getConnection().getClientID();
boolean sessionPresent = session.getProtocolManager().getSessionStates().containsKey(clientId);
MQTTSessionState sessionState = getSessionState(clientId);
synchronized (sessionState) {
session.setSessionState(sessionState);
session.getConnection().setClientID(clientId);
sessionState.setFailed(false);
ServerSessionImpl serverSession = createServerSession(username, password, validatedUser);
serverSession.start();
@@ -143,7 +122,7 @@ void connect(MqttConnectMessage connect, String validatedUser) throws Exception
sessionState.setClientMaxPacketSize(MQTTUtil.getProperty(Integer.class, connect.variableHeader().properties(), MAXIMUM_PACKET_SIZE, 0));
sessionState.setClientTopicAliasMaximum(MQTTUtil.getProperty(Integer.class, connect.variableHeader().properties(), TOPIC_ALIAS_MAXIMUM));

connackProperties = getConnackProperties(clientId, assignedClientId);
connackProperties = getConnackProperties();
} else {
connackProperties = MqttProperties.NO_PROPERTIES;
}
@@ -155,11 +134,11 @@ void connect(MqttConnectMessage connect, String validatedUser) throws Exception
}
}

private MqttProperties getConnackProperties(String clientId, boolean assignedClientId) {
private MqttProperties getConnackProperties() {
MqttProperties connackProperties = new MqttProperties();

if (assignedClientId) {
connackProperties.add(new MqttProperties.StringProperty(ASSIGNED_CLIENT_IDENTIFIER.value(), clientId));
if (this.session.getConnection().isClientIdAssignedByBroker()) {
connackProperties.add(new MqttProperties.StringProperty(ASSIGNED_CLIENT_IDENTIFIER.value(), this.session.getConnection().getClientID()));
}

if (this.session.getProtocolManager().getTopicAliasMaximum() != -1) {
@@ -227,30 +206,4 @@ void disconnect(boolean failure) {
private synchronized MQTTSessionState getSessionState(String clientId) {
return session.getProtocolManager().getSessionState(clientId);
}

private Pair<String, Boolean> validateClientId(String clientId, boolean cleanSession) {
Boolean assigned = Boolean.FALSE;
if (clientId == null || clientId.isEmpty()) {
// [MQTT-3.1.3-7] [MQTT-3.1.3-6] If client does not specify a client ID and clean session is set to 1 create it.
if (cleanSession) {
assigned = Boolean.TRUE;
clientId = UUID.randomUUID().toString();
} else {
// [MQTT-3.1.3-8] Return ID rejected and disconnect if clean session = false and client id is null
return null;
}
} else {
MQTTConnection connection = session.getProtocolManager().addConnectedClient(clientId, session.getConnection());

if (connection != null) {
MQTTSession existingSession = session.getProtocolManager().getSessionState(clientId).getSession();
if (session.getVersion() == MQTTVersion.MQTT_5) {
existingSession.getProtocolHandler().sendDisconnect(MQTTReasonCodes.SESSION_TAKEN_OVER);
}
// [MQTT-3.1.4-2] If the client ID represents a client already connected to the server then the server MUST disconnect the existing client
existingSession.getConnectionManager().disconnect(false);
}
}
return new Pair<>(clientId, assigned);
}
}

0 comments on commit 446ff61

Please sign in to comment.