# Tom's Data Onion

## Setup

Retrieve—but cache—the HTML page and grab the initial layer from the `<pre></pre>` tags.

In [1]:
import html
import pathlib


def start_toms_data_onion():
    path = pathlib.Path().resolve() / "toms-data-onion.html"

    if not path.exists():
        request = urllib.request.Request("https://www.tomdalling.com/toms-data-onion/")
        request.add_header(
            "User-Agent",
            "toms-data-onion",
        )
        with urllib.request.urlopen(request) as r:
            body = r.read()
            with open(path, "wb") as o:
                o.write(body)

    with open(path, "r") as i:
        body = i.read()

    body = html.unescape(body)

    start = body.index("<pre>") + 5
    end = body.index("</pre>")

    with open("layer-0.txt", "w") as o:
        o.write(body[start:end])


toms_data_onion = start_toms_data_onion()

We're going to be doing this a lot, so define a function to retrieve the Ascii85-encoded block from each layer:

In [2]:
def get_payload(path: str) -> str:
    with open(pathlib.Path().resolve() / path) as i:
        txt = i.read()

    start = txt.index("<~")
    end = txt.index("~>", start) + 2

    return "".join(txt[start:end].splitlines())

## Layer 0/6: ASCII85
   
It turns out that Python's standard library—via the `base64` module—already has support for decoded Ascii85, even specifically the Adobe-flavoured version.

In [3]:
import base64


layer0 = get_payload("layer-0.txt")
decoded = base64.a85decode(layer0.encode("utf-8"), adobe=True)

with open("layer-1.txt", "wb") as o:
    o.write(decoded)

## Layer 1/6: Bitwise Operations

In [4]:
layer1 = get_payload("layer-1.txt")
decoded = base64.a85decode(layer1.encode("utf-8"), adobe=True)

As the instructions say, we need to _"Flip every second bit"_. To do so, we can XOR each byte with `01010101`:

In [5]:
xord = [byte ^ int("01010101", 2) for byte in decoded]

Rotation is more complicated:

- shift everything one bit.
- move the bit that was previously in the least significant position to the most significant.

That translates to, for each _byte_:

- `byte >> 1`
- `(byte & 1) << 7`

In [6]:
rotated = [(byte >> 1) | ((byte & 1) << 7) for byte in xord]

That will give us the next layer.

In [7]:
with open("layer-2.txt", "wb") as o:
    o.write(bytearray(rotated))

## Layer 2/6: Parity Bit

In [8]:
layer2 = get_payload("layer-2.txt")
decoded = base64.a85decode(layer2.encode("utf-8"), adobe=True)

The _parity_ bit here, using a method similar to the previous layer, will be `byte & 1`; the remaining 7 bits of data will be `byte >> 1`.

I'm unsure of a simple way to count occurrences of binary digits in Python. However, we can easily convert each `byte` to its binary representation via the [`bin()`](https://docs.python.org/3.8/library/functions.html#bin) function—we can then make use of the [`count()`](https://docs.python.org/3.8/library/stdtypes.html#str.count) function on the resulting `str`.

Given the above and the specification for verifying the parity bit, a validation function will be:

In [9]:
def is_valid_byte(byte: int):
    parity = byte & 1
    first_seven = byte >> 1
    ones = bin(first_seven).count("1")

    if ones % 2 == 0 and parity == 0:
        return True
    elif ones % 2 != 0 and parity == 1:
        return True

    return False

We can then apply that function to the decoded input and, as per the instructions, divide up the input into batches of 8 bytes:

In [10]:
valid = list(filter(is_valid_byte, decoded))
chunks = [valid[i : i + 8] for i in range(0, len(valid), 8)]


We start with a batch of 8 bytes:

```txt
00000000
11111111
00000000
11111111
00000000
11111111
00000000
11111111
```

Each byte is bit-shifted one place to the right:

```txt
 0000000
 1111111
 0000000
 1111111
 0000000
 1111111
 0000000
 1111111
```

The trick then is combining each byte into a single, 56-bit integer (we need to combine them as there's no way to write 7 bits of a byte to our output.)

Combining them would look something like this:

```txt
         49     42     35     28     21     14     7
0 0000000
1        1111111
2               0000000
3                      1111111
4                             0000000
5                                    1111111
6                                           0000000
7                                                  1111111
```

That is: the first byte must be shifted 49 bits to the left, the second 42, the third 35, etc. The pattern here translates to: `7 * (7 - offset)` (where `offset` is the bytes position in the batch.)

In [11]:
with open("layer-3.txt", "wb") as o:
    for chunk in chunks:
        result = 0
        for offset in range(len(chunk)):
            result |= (chunk[offset] >> 1) << (7 * (7 - offset))
        o.write(result.to_bytes(7, byteorder="big"))

## Layer 3/6: XOR Encryption

In [12]:
layer3 = get_payload("layer-3.txt")
decoded = base64.a85decode(layer3.encode("utf-8"), adobe=True)

To start with, there are two pieces of decrypted data that we already know: the start of the next layer and the payload indicator therein:

In [13]:
known_start = b"==[ Layer 4/6: "
known_text = b"==[ Payload ]==============================================="

As indicated in the instructions, we know that for each byte of the encrypted payload `encrypted_byte ^ key_byte == decrypted_byte`. Similarly, `decrypted_byte ^ encrypted_byte == key_byte`.

Given the above known-output, we can determine the start of the key (and extend it to 32 bytes):

In [14]:
key = bytearray(d ^ p for (d, p) in zip(decoded, known_start))
key.extend(bytearray(32 - len(known_start)))

Armed with our partial key, we can iterate through each 32-byte section of the data (each of which would be decrypted by the full, 32-byte key) and decrypt the starting bytes of it with our partial key.

If, at any point, we decrypt a section which appears in the known piece of text we know to be present in the output (`known_text`, above), we have the key!

In [15]:
for i in range(64, len(decoded) - 32, 32):
    partial = bytearray(decoded[i + j] ^ key[j] for j in range(len(known_start)))

    if (index := known_text.find(partial)) == -1:
        continue

    key = bytearray(decoded[i + k] ^ known_text[index + k] for k in range(32))
    break

We can cycle the key, as per the instructions, to the length of the input:

In [16]:
cycled_key = key * ((len(decoded) // 32) + 1)

with open("layer-4.txt", "wb") as o:
    o.write(bytearray(c ^ k for (c, k) in zip(decoded, cycled_key)))

## Layer 4/6: Network Traffic

In [17]:
layer4 = get_payload("layer-4.txt")
decoded = base64.a85decode(layer4.encode("utf-8"), adobe=True)

I _really_ wanted to stick to the standard library for all this. Parsing IP/UDP packets? How often do you get to do that?

This is what I came up with for parsing each packet using Python's [`struct`](https://docs.python.org/3/library/struct.html#format-characters) module, mostly thanks to Wikipedia's pages on [_IPv4_](https://en.wikipedia.org/wiki/IPv4#Header) and [_UDP_](https://en.wikipedia.org/wiki/User_Datagram_Protocol#UDP_datagram_structure):

In [18]:
class Ipv4Header:
    def __init__(self, data: bytes):
        self._raw_bytes = data
        (
            _ihl_version,
            _dscp_ecn,
            self.length,
            self.identification,
            _flags_offset,
            self.ttl,
            self.protocol,
            self.checksum,
            self.source,
            self.dest,
        ) = struct.unpack(">BBHHHBBHII", data)
        self.version = _ihl_version >> 4
        self.ihl = _ihl_version & int("00001111", 2)
        self.dscp = _dscp_ecn >> 2
        self.ecn = _dscp_ecn & 0x3
        self.flags = _flags_offset >> 13
        self.offset = _flags_offset & int("0001111111111111", 2)


class UdpHeader:
    def __init__(self, data: bytes):
        self._raw_bytes = data
        (
            self.source_port,
            self.dest_port,
            self.length,
            self.checksum,
        ) = struct.unpack(">HHHH", data)

However, when validating the header checksums, I bailed:

> _"The checksum field is the 16 bit one's complement of the one's complement sum of all 16 bit words in the header. For purposes of computing the checksum, the value of the checksum field is zero."_

After several attempts, I found myself looking at the implementation in [`scapy`](https://scapy.net/)…at which point I realised I was staring at a library intended for this very purpose.

In [19]:
from scapy.all import IP, UDP

We can easily turn our input into a stream of packets, as per the instructions, and parse those 20 and 8 byte sections into an IP and UDP packet respectively:

In [24]:
import io


stream = io.BytesIO(decoded)

ip = IP(stream.read(20))
udp = UDP(stream.read(8))

The resulting headers will give us the source (`IP.src`) and destination (`IP.dst`) addresses, along with the destination port (`UDP.dport`).

Validation of the checksum, however, is somewhat complicated.

Reviewing the [`scapy.packet.Packet.show2()`](https://github.com/secdev/scapy/blob/a86ad5c/scapy/packet.py#L1449) function shows how `scapy` does calculates the checksum for a packet: by first removing the `chksum` field, then building a new packet. Leveraging this, we can infer that comparing the `chksum` of the original packet to a version derived using the aforementioned method should tell us if a packet is valid:

In [28]:
import copy
from typing import Union

from scapy.compat import raw


def is_valid(packet: Union[IP, UDP]) -> bool:
    _copy = copy.deepcopy(packet)
    del _copy.chksum
    return _copy.__class__(raw(_copy)).chksum == packet.chksum

That said, it doesn't _quite_ work as described.

The following is the method I used to write out the final, valid packet payloads.

Firstly, create the empty file (as I'll be appending to the output, I want to make sure the file is empty on subsequent runs):

In [22]:
with open("layer-5.txt", "w") as o:
    ...

Then simply loop over the input stream, reading `IP` and `UDP` packets and checking the criteria for validity.

Except my `is_valid()` function _never_ returns `True` for `UDP` instances. In fact, omitting validity-checks for `UDP` packets entirely seems to give a valid result:

In [29]:
stream.seek(0)

while True:
    ip = IP(stream.read(20))
    if not ip.len:
        break

    udp = UDP(stream.read(8))
    data = stream.read(ip.len - 28)
    
    if ip.src == "10.1.1.10":
        if ip.dst == "10.1.1.200" and udp.dport == 42069:
            if is_valid(ip): # and is_valid(udp):
                with open("layer-5.txt", "ab") as o:
                    o.write(data)
