In [1]:
import numpy as np
import torch
import itertools

In [2]:
with open('challenge_files/ciphertext.txt', 'r') as f:
    ciphertext = f.read().splitlines()

ciphertext = np.array([[float(x_) for x_ in x.split()] for x in ciphertext], dtype=np.float32)

In [3]:
n = ciphertext.shape[1] - 1
n_bits = n * 9

state_dict = torch.load('challenge_files/model.pth', weights_only=True)

conv2d_2x2_s1_p1_w = state_dict['conv.weight'].numpy()
conv2d_2x2_s1_p1_b = state_dict['conv.bias'].numpy()
linear_w           = state_dict['linear.weight'].numpy()
linear_b           = state_dict['linear.bias'].numpy()
conv2d_3x3_s3_w    = state_dict['conv1.weight'].numpy()
conv2d_3x3_s3_b    = state_dict['conv1.bias'].numpy()

In [4]:
def invert_conv2d_2x2_s1_p1(y, c_w, c_b):
    w00 = c_w[0, 0, 0, 0]
    w01 = c_w[0, 0, 0, 1]
    w10 = c_w[0, 0, 1, 0]
    w11 = c_w[0, 0, 1, 1]

    h, w = y.shape
    x = np.zeros_like(y)

    for i in range(h):
        for j in range(w):
            val = y[i, j] - c_b[0]

            if i - 1 >= 0 and j - 1 >= 0:
                val -= w00 * x[i - 1, j - 1]
            if i - 1 >= 0:
                val -= w01 * x[i - 1, j]
            if j - 1 >= 0:
                val -= w10 * x[i, j - 1]

            x[i, j] = val / w11

    return x

def invert_linear_48_2304(y, l_w, l_b):
    y_ = y - l_b

    W_pinv = np.linalg.pinv(l_w)
    x = np.dot(y_, W_pinv.T)
    return x

def invert_conv2d_3x3_s3(y, c_w, c_b):
    w = c_w.reshape(3, 3)
    b = c_b[0]

    y_h, y_w = y.shape
    x_h, x_w = y_h * 3, y_w * 3
    x = np.zeros((x_h, x_w), dtype=np.float32)

    # input is 3x3 of 1 or 0 values
    # brute force all possible 3x3 inputs and keep in a cache
    # then we can just look up the value in the cache
    cache_inv = {}
    for bits in itertools.product([0, 1], repeat=9):
        bits = np.array(bits, dtype=np.float32).reshape(3, 3)
        cache_inv[tuple(bits.flatten())] = np.sum(bits * w) + b

    cache = {}
    for k, v in cache_inv.items():
        assert v not in cache # ensure no collisions
        cache[v] = k

    # for each 3x3 block in the output
    # look up the value in the cache
    for i in range(y_h):
        for j in range(y_w):
            y_val = y[i, j]
            x[i*3:i*3+3, j*3:j*3+3] = np.array(cache[y_val]).reshape(3, 3)

    return x


In [5]:
x3 = invert_conv2d_2x2_s1_p1(ciphertext, conv2d_2x2_s1_p1_w, conv2d_2x2_s1_p1_b)
x3 = x3[:-1, :-1]
x3 = x3.reshape((1, n * n))
x2 = np.round(invert_linear_48_2304(x3, linear_w, linear_b))
x1 = invert_conv2d_3x3_s3(x2, conv2d_3x3_s3_w, conv2d_3x3_s3_b)
x  = x1.flatten()

In [6]:
secret_key = bytes([
    int("".join(map(str, x[i:i+9].astype(int))), 2)
    for i in range(0, len(x), 9)
])
print(secret_key)

b'SUCTF{Mi_sika_mosi!Mi_muhe_mita,mita_movo_lata!}'
