Skip to content

Commit

Permalink
Merge pull request #1048 from dota17/UpdatePMDE
Browse files Browse the repository at this point in the history
  • Loading branch information
marci4 committed Jul 28, 2020
2 parents b7e2c94 + 7022fca commit e38d9a9
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 32 deletions.
2 changes: 2 additions & 0 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ non-blocking event-driven model (similar to the
Implemented WebSocket protocol versions are:

* [RFC 6455](http://tools.ietf.org/html/rfc6455)
* [RFC 7692](http://tools.ietf.org/html/rfc7692)

[Here](https://github.com/TooTallNate/Java-WebSocket/wiki/Drafts) some more details about protocol versions/drafts.
[PerMessageDeflateExample](https://github.com/TooTallNate/Java-WebSocket/wiki/PerMessageDeflateExample) enable the extension with reference to both a server and client example.


## Getting Started
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import org.java_websocket.extensions.CompressionExtension;
import org.java_websocket.extensions.ExtensionRequestData;
import org.java_websocket.extensions.IExtension;
import org.java_websocket.framing.*;
import org.java_websocket.framing.BinaryFrame;
import org.java_websocket.framing.CloseFrame;
import org.java_websocket.framing.ContinuousFrame;
import org.java_websocket.framing.DataFrame;
import org.java_websocket.framing.Framedata;
import org.java_websocket.framing.FramedataImpl1;
import org.java_websocket.framing.TextFrame;

import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
Expand All @@ -16,6 +22,12 @@
import java.util.zip.Deflater;
import java.util.zip.Inflater;

/**
* PerMessage Deflate Extension (<a href="https://tools.ietf.org/html/rfc7692#section-7">7&#46; The "permessage-deflate" Extension</a> in
* <a href="https://tools.ietf.org/html/rfc7692">RFC 7692</a>).
*
* @see <a href="https://tools.ietf.org/html/rfc7692#section-7">7&#46; The "permessage-deflate" Extension in RFC 7692</a>
*/
public class PerMessageDeflateExtension extends CompressionExtension {

// Name of the extension as registered by IETF https://tools.ietf.org/html/rfc7692#section-9.
Expand All @@ -28,7 +40,7 @@ public class PerMessageDeflateExtension extends CompressionExtension {
private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
private static final int serverMaxWindowBits = 1 << 15;
private static final int clientMaxWindowBits = 1 << 15;
private static final byte[] TAIL_BYTES = {0x00, 0x00, (byte)0xFF, (byte)0xFF};
private static final byte[] TAIL_BYTES = { (byte)0x00, (byte)0x00, (byte)0xFF, (byte)0xFF };
private static final int BUFFER_SIZE = 1 << 10;

private boolean serverNoContextTakeover = true;
Expand All @@ -37,9 +49,60 @@ public class PerMessageDeflateExtension extends CompressionExtension {
// For WebSocketServers, this variable holds the extension parameters that the peer client has requested.
// For WebSocketClients, this variable holds the extension parameters that client himself has requested.
private Map<String, String> requestedParameters = new LinkedHashMap<String, String>();

private Inflater inflater = new Inflater(true);
private Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);

public Inflater getInflater() {
return inflater;
}

public void setInflater(Inflater inflater) {
this.inflater = inflater;
}

public Deflater getDeflater() {
return deflater;
}

public void setDeflater(Deflater deflater) {
this.deflater = deflater;
}

/**
*
* @return serverNoContextTakeover
*/
public boolean isServerNoContextTakeover()
{
return serverNoContextTakeover;
}

/**
*
* @param serverNoContextTakeover
*/
public void setServerNoContextTakeover(boolean serverNoContextTakeover) {
this.serverNoContextTakeover = serverNoContextTakeover;
}

/**
*
* @return clientNoContextTakeover
*/
public boolean isClientNoContextTakeover()
{
return clientNoContextTakeover;
}

/**
*
* @param clientNoContextTakeover
*/
public void setClientNoContextTakeover(boolean clientNoContextTakeover) {
this.clientNoContextTakeover = clientNoContextTakeover;
}

/*
An endpoint uses the following algorithm to decompress a message.
1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the
Expand All @@ -50,11 +113,11 @@ public class PerMessageDeflateExtension extends CompressionExtension {
@Override
public void decodeFrame(Framedata inputFrame) throws InvalidDataException {
// Only DataFrames can be decompressed.
if(!(inputFrame instanceof DataFrame))
if (!(inputFrame instanceof DataFrame))
return;

// RSV1 bit must be set only for the first frame.
if(inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1())
if (inputFrame.getOpcode() == Opcode.CONTINUOUS && inputFrame.isRSV1())
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, "RSV1 bit can only be set for the first frame.");

// Decompressed output buffer.
Expand All @@ -70,47 +133,53 @@ We can check the getRemaining() method to see whether the data we supplied has b
And if not, we just reset the inflater and decompress again.
Note that this behavior doesn't occur if the message is "first compressed and then fragmented".
*/
if(inflater.getRemaining() > 0){
if (inflater.getRemaining() > 0) {
inflater = new Inflater(true);
decompress(inputFrame.getPayloadData().array(), output);
}

if(inputFrame.isFin()) {
if (inputFrame.isFin()) {
decompress(TAIL_BYTES, output);
// If context takeover is disabled, inflater can be reset.
if(clientNoContextTakeover)
if (clientNoContextTakeover)
inflater = new Inflater(true);
}
} catch (DataFormatException e) {
throw new InvalidDataException(CloseFrame.POLICY_VALIDATION, e.getMessage());
}

// RSV1 bit must be cleared after decoding, so that other extensions don't throw an exception.
if(inputFrame.isRSV1())
if (inputFrame.isRSV1())
((DataFrame) inputFrame).setRSV1(false);

// Set frames payload to the new decompressed data.
((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(output.toByteArray(), 0, output.size()));
}

private void decompress(byte[] data, ByteArrayOutputStream outputBuffer) throws DataFormatException{
/**
*
* @param data the bytes of data
* @param outputBuffer the output stream
* @throws DataFormatException
*/
private void decompress(byte[] data, ByteArrayOutputStream outputBuffer) throws DataFormatException {
inflater.setInput(data);
byte[] buffer = new byte[BUFFER_SIZE];

int bytesInflated;
while((bytesInflated = inflater.inflate(buffer)) > 0){
while ((bytesInflated = inflater.inflate(buffer)) > 0) {
outputBuffer.write(buffer, 0, bytesInflated);
}
}

@Override
public void encodeFrame(Framedata inputFrame) {
// Only DataFrames can be decompressed.
if(!(inputFrame instanceof DataFrame))
if (!(inputFrame instanceof DataFrame))
return;

// Only the first frame's RSV1 must be set.
if(!(inputFrame instanceof ContinuousFrame))
if (!(inputFrame instanceof ContinuousFrame))
((DataFrame) inputFrame).setRSV1(true);

deflater.setInput(inputFrame.getPayloadData().array());
Expand All @@ -119,7 +188,7 @@ public void encodeFrame(Framedata inputFrame) {
// Temporary buffer to hold compressed output.
byte[] buffer = new byte[1024];
int bytesCompressed;
while((bytesCompressed = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH)) > 0) {
while ((bytesCompressed = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH)) > 0) {
output.write(buffer, 0, bytesCompressed);
}

Expand All @@ -132,11 +201,11 @@ public void encodeFrame(Framedata inputFrame) {
To simulate removal, we just pass 4 bytes less to the new payload
if the frame is final and outputBytes ends with 0x00 0x00 0xff 0xff.
*/
if(inputFrame.isFin()) {
if(endsWithTail(outputBytes))
if (inputFrame.isFin()) {
if (endsWithTail(outputBytes))
outputLength -= TAIL_BYTES.length;

if(serverNoContextTakeover) {
if (serverNoContextTakeover) {
deflater.end();
deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
}
Expand All @@ -146,13 +215,18 @@ public void encodeFrame(Framedata inputFrame) {
((FramedataImpl1) inputFrame).setPayload(ByteBuffer.wrap(outputBytes, 0, outputLength));
}

private boolean endsWithTail(byte[] data){
if(data.length < 4)
/**
*
* @param data the bytes of data
* @return true if the data is OK
*/
private boolean endsWithTail(byte[] data) {
if (data.length < 4)
return false;

int length = data.length;
for(int i = 0; i < TAIL_BYTES.length; i++){
if(TAIL_BYTES[i] != data[length - TAIL_BYTES.length + i])
for (int i = 0; i < TAIL_BYTES.length; i++) {
if (TAIL_BYTES[i] != data[length - TAIL_BYTES.length + i])
return false;
}

Expand All @@ -162,15 +236,15 @@ private boolean endsWithTail(byte[] data){
@Override
public boolean acceptProvidedExtensionAsServer(String inputExtension) {
String[] requestedExtensions = inputExtension.split(",");
for(String extension : requestedExtensions) {
for (String extension : requestedExtensions) {
ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension);
if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
if (!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
continue;

// Holds parameters that peer client has sent.
Map<String, String> headers = extensionData.getExtensionParameters();
requestedParameters.putAll(headers);
if(requestedParameters.containsKey(CLIENT_NO_CONTEXT_TAKEOVER))
if (requestedParameters.containsKey(CLIENT_NO_CONTEXT_TAKEOVER))
clientNoContextTakeover = true;

return true;
Expand All @@ -182,9 +256,9 @@ public boolean acceptProvidedExtensionAsServer(String inputExtension) {
@Override
public boolean acceptProvidedExtensionAsClient(String inputExtension) {
String[] requestedExtensions = inputExtension.split(",");
for(String extension : requestedExtensions) {
for (String extension : requestedExtensions) {
ExtensionRequestData extensionData = ExtensionRequestData.parseExtensionRequest(extension);
if(!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
if (!EXTENSION_REGISTERED_NAME.equalsIgnoreCase(extensionData.getExtensionName()))
continue;

// Holds parameters that are sent by the server, as a response to our initial extension request.
Expand Down Expand Up @@ -222,9 +296,9 @@ public IExtension copyInstance() {
*/
@Override
public void isFrameValid(Framedata inputFrame) throws InvalidDataException {
if((inputFrame instanceof TextFrame || inputFrame instanceof BinaryFrame) && !inputFrame.isRSV1())
if ((inputFrame instanceof TextFrame || inputFrame instanceof BinaryFrame) && !inputFrame.isRSV1())
throw new InvalidFrameException("RSV1 bit must be set for DataFrames.");
if((inputFrame instanceof ContinuousFrame) && (inputFrame.isRSV1() || inputFrame.isRSV2() || inputFrame.isRSV3()))
if ((inputFrame instanceof ContinuousFrame) && (inputFrame.isRSV1() || inputFrame.isRSV2() || inputFrame.isRSV3()))
throw new InvalidFrameException( "bad rsv RSV1: " + inputFrame.isRSV1() + " RSV2: " + inputFrame.isRSV2() + " RSV3: " + inputFrame.isRSV3() );
super.isFrameValid(inputFrame);
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/java_websocket/framing/FramedataImpl1.java
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void setFin(boolean fin) {
/**
* Set the rsv1 of this frame to the provided boolean
*
* @param rsv1 true if fin has to be set
* @param rsv1 true if rsv1 has to be set
*/
public void setRSV1(boolean rsv1) {
this.rsv1 = rsv1;
Expand All @@ -192,7 +192,7 @@ public void setRSV1(boolean rsv1) {
/**
* Set the rsv2 of this frame to the provided boolean
*
* @param rsv2 true if fin has to be set
* @param rsv2 true if rsv2 has to be set
*/
public void setRSV2(boolean rsv2) {
this.rsv2 = rsv2;
Expand All @@ -201,7 +201,7 @@ public void setRSV2(boolean rsv2) {
/**
* Set the rsv3 of this frame to the provided boolean
*
* @param rsv3 true if fin has to be set
* @param rsv3 true if rsv3 has to be set
*/
public void setRSV3(boolean rsv3) {
this.rsv3 = rsv3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
import org.junit.Test;

import java.nio.ByteBuffer;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class PerMessageDeflateExtensionTest {

Expand Down Expand Up @@ -113,8 +119,67 @@ public void testGetProvidedExtensionAsServer() {
}

@Test
public void testToString() throws Exception {
public void testToString() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertEquals( "PerMessageDeflateExtension", deflateExtension.toString() );
}

@Test
public void testIsServerNoContextTakeover() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertTrue(deflateExtension.isServerNoContextTakeover());
}

@Test
public void testSetServerNoContextTakeover() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setServerNoContextTakeover(false);
assertFalse(deflateExtension.isServerNoContextTakeover());
}

@Test
public void testIsClientNoContextTakeover() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertFalse(deflateExtension.isClientNoContextTakeover());
}

@Test
public void testSetClientNoContextTakeover() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setClientNoContextTakeover(true);
assertTrue(deflateExtension.isClientNoContextTakeover());
}

@Test
public void testCopyInstance() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
IExtension newDeflateExtension = deflateExtension.copyInstance();
assertEquals(deflateExtension.toString(), newDeflateExtension.toString());
}

@Test
public void testGetInflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertEquals(deflateExtension.getInflater().getRemaining(), new Inflater(true).getRemaining());
}

@Test
public void testSetInflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setInflater(new Inflater(false));
assertEquals(deflateExtension.getInflater().getRemaining(), new Inflater(false).getRemaining());
}

@Test
public void testGetDeflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertEquals(deflateExtension.getDeflater().finished(), new Deflater(Deflater.DEFAULT_COMPRESSION, true).finished());
}

@Test
public void testSetDeflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setDeflater(new Deflater(Deflater.DEFAULT_COMPRESSION, false));
assertEquals(deflateExtension.getDeflater().finished(),new Deflater(Deflater.DEFAULT_COMPRESSION, false).finished());
}
}

0 comments on commit e38d9a9

Please sign in to comment.