Skip to content

Commit

Permalink
[grid] Review code and null checks for registration secret
Browse files Browse the repository at this point in the history
  • Loading branch information
shs96c committed Oct 20, 2020
1 parent b845a3c commit 004be30
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 11 deletions.
Expand Up @@ -52,7 +52,7 @@ class AddNode implements HttpHandler {
this.distributor = Require.nonNull("Distributor", distributor);
this.json = Require.nonNull("Json converter", json);
this.httpFactory = Require.nonNull("HTTP Factory", httpFactory);
this.registrationSecret = registrationSecret;
this.registrationSecret = Require.nonNull("Registration secret", registrationSecret);
}

@Override
Expand Down
Expand Up @@ -135,6 +135,8 @@ protected Distributor(
this.slotSelector = Require.nonNull("Slot selector", slotSelector);
this.sessions = Require.nonNull("Session map", sessions);

Require.nonNull("Registration secret", registrationSecret);

RequiresSecretFilter requiresSecret = new RequiresSecretFilter(registrationSecret);

Json json = new Json();
Expand Down
Expand Up @@ -62,6 +62,7 @@ public class GridModel {

public GridModel(EventBus events, Secret registrationSecret) {
this.events = Require.nonNull("Event bus", events);
Require.nonNull("Registration secret", registrationSecret);

events.addListener(NodeDrainStarted.listener(nodeId -> setAvailability(nodeId, DRAINING)));
events.addListener(NodeDrainComplete.listener(this::remove));
Expand Down Expand Up @@ -101,9 +102,10 @@ public GridModel add(NodeStatus node) {
}

public GridModel refresh(Secret registrationSecret, NodeStatus status) {
Require.nonNull("Registration secret", registrationSecret);
Require.nonNull("Node status", status);

Secret statusSecret = status.getRegistrationSecret() == null ? null : new Secret(status.getRegistrationSecret());
Secret statusSecret = new Secret(status.getRegistrationSecret());
if (!Secret.matches(registrationSecret, statusSecret)) {
LOG.severe(String.format("Node at %s failed to send correct registration secret. Node NOT refreshed.", status.getUri()));
events.fire(new NodeRejectedEvent(status.getUri()));
Expand Down Expand Up @@ -261,7 +263,7 @@ private NodeStatus rewrite(NodeStatus status, Availability availability) {
status.getMaxSessionCount(),
status.getSlots(),
availability,
status.getRegistrationSecret() == null ? null : new Secret(status.getRegistrationSecret()));
new Secret(status.getRegistrationSecret()));
}

private void release(SessionId id) {
Expand Down Expand Up @@ -365,7 +367,7 @@ private void amend(Availability availability, NodeStatus status, Slot slot) {
status.getMaxSessionCount(),
newSlots,
status.getAvailability(),
status.getRegistrationSecret() == null ? null : new Secret(status.getRegistrationSecret())));
new Secret(status.getRegistrationSecret())));
}

private static class AvailabilityAndNode {
Expand Down
Expand Up @@ -131,9 +131,10 @@ public boolean isReady() {
}

private void register(Secret registrationSecret, NodeStatus status) {
Require.nonNull("Registration secret", registrationSecret);
Require.nonNull("Node", status);

Secret nodeSecret = status.getRegistrationSecret() == null ? null : new Secret(status.getRegistrationSecret());
Secret nodeSecret = new Secret(status.getRegistrationSecret());
if (!Secret.matches(registrationSecret, nodeSecret)) {
LOG.severe(String.format("Node at %s failed to send correct registration secret. Node NOT registered.", status.getUri()));
bus.fire(new NodeRejectedEvent(status.getUri()));
Expand Down
1 change: 1 addition & 0 deletions java/server/src/org/openqa/selenium/grid/node/Node.java
Expand Up @@ -110,6 +110,7 @@ protected Node(Tracer tracer, NodeId id, URI uri, Secret registrationSecret) {
this.tracer = Require.nonNull("Tracer", tracer);
this.id = Require.nonNull("Node id", id);
this.uri = Require.nonNull("URI", uri);
Require.nonNull("Registration secret", registrationSecret);

RequiresSecretFilter requiresSecret = new RequiresSecretFilter(registrationSecret);

Expand Down
Expand Up @@ -108,7 +108,7 @@ private OneShotNode(
WebDriverInfo driverInfo) {
super(tracer, id, uri, registrationSecret);

this.registrationSecret = registrationSecret;
this.registrationSecret = Require.nonNull("Registration secret", registrationSecret);
this.events = Require.nonNull("Event bus", events);
this.gridUri = Require.nonNull("Public Grid URI", gridUri);
this.stereotype = ImmutableCapabilities.copyOf(Require.nonNull("Stereotype", stereotype));
Expand Down
Expand Up @@ -84,11 +84,11 @@ public RemoteNode(
this.externalUri = Require.nonNull("External URI", externalUri);
this.capabilities = ImmutableSet.copyOf(capabilities);

this.client = Require
.nonNull("HTTP client factory", clientFactory).createClient(fromUri(externalUri));
this.client = Require.nonNull("HTTP client factory", clientFactory).createClient(fromUri(externalUri));

this.healthCheck = new RemoteCheck();

Require.nonNull("Registration secret", registrationSecret);
this.addSecret = new AddSecretFilter(registrationSecret);
}

Expand Down
Expand Up @@ -17,6 +17,7 @@

package org.openqa.selenium.grid.security;

import org.openqa.selenium.internal.Require;
import org.openqa.selenium.remote.http.Filter;
import org.openqa.selenium.remote.http.HttpHandler;

Expand All @@ -26,13 +27,13 @@ public class AddSecretFilter implements Filter {
private final Secret secret;

public AddSecretFilter(Secret secret) {
this.secret = secret;
this.secret = Require.nonNull("Secret", secret);
}

@Override
public HttpHandler apply(HttpHandler httpHandler) {
return req -> {
if (secret != null && req.getHeader(HEADER_NAME) == null) {
if (req.getHeader(HEADER_NAME) == null) {
req.addHeader(HEADER_NAME, secret.encode());
}

Expand Down
Expand Up @@ -37,7 +37,7 @@ public class RequiresSecretFilter implements Filter {
private final Secret secret;

public RequiresSecretFilter(Secret secret) {
this.secret = secret;
this.secret = Require.nonNull("Secret", secret);
}

@Override
Expand Down

0 comments on commit 004be30

Please sign in to comment.