In [1]:
from itertools import product
from inspect import signature
    
def hrep_to_ineq(hrep, args):
    """
    Returns a list of inequality constraints given
    `hrep` -> a H-representation and
    `args` -> a list of MILP vars
    """
    assert len(args) == len(hrep[0]) - 1, \
        "Number of arguments does not match dimensionality of `hrep`"
    return [
        sum(a*c for c,a in zip([*l][1:], args)) + l.b() >= 0 
        for l in hrep
    ]

def get_boolexpr_hrep(boolfunc):
    """
    Get H-Representation of points representing
    (a0,a1,...,an, boolfunc(a0,a1,...,an))
    
    boolfunc is a boolean function of n-args with a boolean output
    """
    nargs = len(signature(boolfunc).parameters)
    space = []
    for nb in product([0,1], repeat=nargs):
        space.append((*nb, boolfunc(*nb)))
        
    p = Polyhedron(vertices=space)
    hrep = p.Hrepresentation()
    return simplify_binary_hrep(space, hrep)
    
def simplify_binary_hrep(space, hrep):
    
    """
    Computes a small subset of inequalities from `hrep`
    which integral points in [0,1]^n is exactly `space`
    """

    npts_space = len(space)
    dim = len(space[0])
    all_eq = [*hrep]
    new_hrep = []
    all_pts = set(product([0,1], repeat=dim)) - set(space)
    eq_pts = [set(pts for pts in all_pts if not hrep_to_ineq([l], pts)[0]) for l in all_eq]
    while len(all_eq) > 0:
        npts = [len(i) for i in eq_pts]
        midx = npts.index(max(npts))
        if npts[midx] == 0: break
        new_hrep.append(all_eq.pop(midx))
        pts = eq_pts.pop(midx)
        eq_pts = [i - pts for i in eq_pts]
    return (*new_hrep,)
    
def gethrep_sbox(sbox):
    # Big endian (inp, out)
    nbits = len(sbox).bit_length() - 1
    assert 1<<nbits == len(sbox)
    space = [(*pt, *map(int, format(sbox[i], f"0{nbits}b")),) for i,pt in enumerate(product([0,1], repeat=nbits))]
    p = Polyhedron(vertices=space)
    hrep = p.Hrepresentation()
    return simplify_binary_hrep(space, hrep)

In [2]:
# Concrete execution

SBOX = [0xc, 0x5, 0x6, 0xb, 0x9, 0x0, 0xa, 0xd, 0x3, 0xe, 0xf, 0x8, 0x4, 0x7, 0x1, 0x2]
PERM = [0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
        4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
        8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59,
        12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63]
DEPERM = [PERM.index(i) for i in range(64)]
DESBOX = [SBOX.index(i) for i in range(16)]

def sbox_bit(x):
    # big endian
    return [*map(int, format(SBOX[int("".join(map(str, x)), 2)], "04b"))]

def desbox_bit(x):
    # big endian
    return [*map(int, format(DESBOX[int("".join(map(str, x)), 2)], "04b"))]

def update_keyreg(k):
    nk = [k[(i+61)%len(k)] for i in range(len(k))]
    nk[:4] = sbox_bit(nk[:4])
    return nk

def keyscheduler(k, nrounds):
    for _ in range(nrounds):
        yield k[:64]
        k = update_keyreg(k)
        
def add_present(pt, k):
    return [x^^y for x,y in zip(pt,k)]
        
def sub_present(pt):
    pt = pt.copy()
    for i in range(64//4):
        pt[i*4:i*4+4] = sbox_bit(pt[i*4:i*4+4])
    return pt

def desub_present(pt):
    pt = pt.copy()
    for i in range(64//4):
        pt[i*4:i*4+4] = desbox_bit(pt[i*4:i*4+4])
    return pt

def perm_present(pt):
    return [pt[PERM[i]] for i in range(64)]

def deperm_present(pt):
    return [pt[DEPERM[i]] for i in range(64)]
        
def present(pt, k, nrounds):
    ct = pt
    for k in keyscheduler(k, nrounds):
        ct = add_present(ct, k)
        ct = sub_present(ct)
        ct = perm_present(ct)
    return ct

def depresent(ct, k, nrounds):
    pt = ct
    for k in [*keyscheduler(k, nrounds)][::-1]:
        pt = deperm_present(pt)
        pt = desub_present(pt)
        pt = add_present(pt, k)
    return pt

def bytes_to_bits(pt):
    return [*map(int, "".join(format(c, "08b") for c in pt))]

def bits_to_bytes(pt):
    return bytes([int("".join(map(str, pt[8*i:8*i+8])), 2) for i in range(len(pt)//8)])

import os
key = bytes_to_bits(os.urandom(10))
ct = present(bytes_to_bits(b"0"*16), key, 30)
bits_to_bytes(depresent(ct, key, 30))

b'00000000'

In [3]:
# MILP execution

SBOX = [0xc, 0x5, 0x6, 0xb, 0x9, 0x0, 0xa, 0xd, 0x3, 0xe, 0xf, 0x8, 0x4, 0x7, 0x1, 0x2]
PERM = [0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
        4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
        8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59,
        12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63]

SBOX_HREP = gethrep_sbox(SBOX)
XOR_HREP = get_boolexpr_hrep(lambda x,y: x^^y)

def update_keyreg_milp(solver, vargen, k):
    nk = [k[(i+61)%len(k)] for i in range(len(k))]
    outbit = [vargen.gen() for _ in range(4)]
    inbit = nk[:4]
    ineqs = hrep_to_ineq(SBOX_HREP, [*inbit, *outbit])
    for i in ineqs:
        solver.add_constraint(i)
    nk[:4] = outbit
    return nk

def keyscheduler_milp(solver, vargen, k, nrounds):
    for _ in range(nrounds):
        yield k[:64]
        k = update_keyreg_milp(solver, vargen, k)
        
def add_present_milp(solver, vargen, pt, k):
    out = [vargen.gen() for _ in range(64)]
    for x,y,z in zip(pt, k, out):
        ineqs = hrep_to_ineq(XOR_HREP, [x,y,z])
        for i in ineqs:
            solver.add_constraint(i)
    return out
        
def sub_present_milp(solver, vargen, pt):
    out = [vargen.gen() for _ in range(64)]
    for i in range(64//4):
        outbit = out[i*4:i*4+4]
        inbit = pt[i*4:i*4+4]
        ineqs = hrep_to_ineq(SBOX_HREP, [*inbit, *outbit])
        for i in ineqs:
            solver.add_constraint(i)
    return out

def perm_present_milp(pt):
    return [pt[PERM[i]] for i in range(64)]
        
def present_milp(solver, vargen, pt, k, nrounds):
    ct = pt
    for k0 in keyscheduler_milp(solver, vargen, k, nrounds):
        ct = add_present_milp(solver, vargen, ct, k0)
        ct = sub_present_milp(solver, vargen, ct)
        ct = perm_present_milp(ct)
    return ct

def bind_milp(solver, milpvars, concretevars):
    assert len(milpvars) == len(concretevars)
    for x,y in zip(milpvars, concretevars):
        solver.add_constraint(x == y)

In [4]:
SBOX_HREP

(An inequality (13, -6, -1, 2, 12, 4, -7, -10) x + 4 >= 0,
 An inequality (-13, 1, 6, -2, -7, -4, 12, -10) x + 16 >= 0,
 An inequality (1, -2, -2, -2, -1, -3, 1, 6) x + 4 >= 0,
 An inequality (-1, 2, 2, 2, 1, 3, -1, 6) x - 4 >= 0,
 An inequality (8, -2, 4, 1, -5, -7, -9, -3) x + 12 >= 0,
 An inequality (3, 3, -6, 4, -2, 1, 4, -1) x + 1 >= 0,
 An inequality (-11, 9, -4, -3, 10, -6, 2, 1) x + 8 >= 0,
 An inequality (-2, -2, 0, -1, -3, 1, -1, -1) x + 6 >= 0,
 An inequality (5, -1, 2, -2, -1, -2, -3, 1) x + 3 >= 0)

In [11]:
class VarGen:
    
    """
    Wrapper class over `solver.new_variable`
    to provide the `gen` method
    """
    
    def __init__(self, solver:MixedIntegerLinearProgram):
        self.vargen = solver.new_variable(binary=True)
        
    def __getitem__(self, idx):
        """Get an existing variable at index `idx`"""
        assert idx < len(self.vargen.keys())
        return self.vargen[idx]
    
    def gen(self):
        """Generates a new variable"""
        return self.vargen[len(self.vargen.keys())]
    
class FakeSolver:
    
    def __init__(self):
        self.constraints = []
    
    def add_constraint(self, e):
        self.constraints.append(e)

In [11]:
nrounds = 30

flag = b"sT1lL_Us!nG_pR3&ENt?"
assert len(flag) == 20
pt1_concrete = bytes_to_bits(flag[:8])
k1_concrete = bytes_to_bits(flag[8:10]*5)
pt2_concrete = bytes_to_bits(flag[10:18])
k2_concrete = bytes_to_bits(flag[18:]*5)
ct_concrete = \
    present(pt1_concrete, k1_concrete, nrounds) \
    + present(pt2_concrete, k2_concrete, nrounds)

bits_to_bytes(ct_concrete)

b'5.\xafG\xaa\x10\xd2k\xf0\xbb\xbc\xe8\xb7e\x0e\x0c'

In [34]:
solver = MixedIntegerLinearProgram(maximization=True, solver="GLPK")
vargen = VarGen(solver)

#solver = FakeSolver()

ptmilp1 = [vargen.gen() for _ in range(64)]
kvars1 = [vargen.gen() for _ in range(16)]
kmilp1 = kvars1*5

ptmilp2 = [vargen.gen() for _ in range(64)]
kvars2 = [vargen.gen() for _ in range(16)]
kmilp2 = kvars2*5

ctmilp1 = present_milp(solver, vargen, ptmilp1, kmilp1, nrounds)
ctmilp2 = present_milp(solver, vargen, ptmilp2, kmilp2, nrounds)

bind_milp(solver, ctmilp1, ct_concrete[:64])
bind_milp(solver, ctmilp2, ct_concrete[64:])

In [35]:
# Test overal model

#bind_milp(solver, ptmilp1+kvars1+ptmilp2+kvars2, bytes_to_bits(flag[:-1] + b"?"))
#solver.solve()
#bits_to_bytes([*map(int, solver.get_values(ptmilp1+kvars1+ptmilp2+kvars2))])

In [36]:
const = solver.constraints()
bindings = [c for c in const if c[0] == c[2]]
const = [c for c in const if c[0] != c[2]]

In [37]:
# No more concrete variables

for bval, (bidx, _), _ in bindings:
    bidx = bidx[0]
    
    for i,c in enumerate(const):
        
        if bidx not in c[1][0]: continue

        nc = (c[1][0].copy(), c[1][1].copy())
        cbidx = nc[0].index(bidx)
        cbval = nc[1][cbidx]
        nc[0].pop(cbidx)
        nc[1].pop(cbidx)
        v = cbval * bval
        const[i] = (
            c[0] - v if c[0] is not None else c[0], 
            nc, 
            c[2] - v if c[2] is not None else c[2]
        )

In [38]:
# Rename input

allin = ptmilp1 + kvars1 + ptmilp2 + kvars2
allidx = [int(str(i)[2:]) for i in allin]
inmap = {int(str(i)[2:]): f"Y[{j}]" for j,i in enumerate(allin)}

for c in const:
    o = c[1][0]
    for i,n in enumerate(o):
        if n in inmap: o[i] = inmap[n]

In [39]:
# renumber everything else

# import requests
# import re
# wordlist = requests.get("https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-no-swears.txt").content.decode().strip().split("\n")
# assert all(re.match(r"[a-z]+", w) for w in wordlist)
# wordlist = [w for w in wordlist if len(w) > 3]

import random
random.seed(1)

restidx = set()
for c in const:
    restidx |= set(i for i in c[1][0] if isinstance(i, int))
restidx = sorted([*restidx])
newidx = [*range(len(restidx))]
random.shuffle(newidx)

restmap = {oidx:f"X[{nidx}]" for nidx,oidx in zip(newidx, restidx)}
for c in const:
    o = c[1][0]
    for i,n in enumerate(o):
        if n in restmap: o[i] = restmap[n]

In [40]:
def const2str(con):
    r = ""
    # x: str name
    # y: coeff
    for i,(x,y) in enumerate(zip(*con[1])):
        y = int(y)
        if i == 0:
            if y < 0: r += "-"
            r += f"{str(abs(y)) + '*' if abs(y) != 1 else ''}{x}"
            continue
        r += f" {'+-'[y < 0]} {str(abs(y)) + '*' if abs(y) != 1 else ''}{x}"
    if con[0] is not None: r = f"{int(con[0])} <= " + r
    if con[2] is not None: r = r + f" <= {int(con[2])}"
    return r

# Shuffle things

# get the input stuff
#sin = set(inmap.values())
#inidx = [j for j,c in enumerate(const) if any(i in sin for i in c[1][0])]
#resti = list(set(range(len(const))) - set(inidx))
#
#inconst = [const[i] for i in inidx]
#inconst = sorted(inconst, key=lambda c: len(const2str(c)))
#restconst = [const[i] for i in resti]
#restconst = sorted(restconst, key=lambda c: len(const2str(c)))
#
#const = inconst + restconst

const = sorted(const, key=lambda c: len(const2str(c)))

In [41]:
repmap = {
    "~~LEN~~": str(len(newidx)),
    "~~INITMODEL~~": str('\n'.join('model += ' + i for i in map(const2str, const)))
}

challenge = open("template.txt").read()
for k,v in repmap.items():
    challenge = challenge.replace(k, v)
open("challenge.py", "w").write(challenge)

1569133

# Testing stuff

In [117]:
repmap = {
    "~~LEN~~": str(len(newidx)),
    "~~INITMODEL~~": str('\n'.join('model += ' + i for i in map(const2str, const)))
}

challenge = open("template2.txt").read()
for k,v in repmap.items():
    challenge = challenge.replace(k, v)
open("test.py", "w").write(challenge)

1619915

In [32]:
from itertools import product

ALLOWED_CHARS = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!?#$%&-_"

flag1 = []
for k in product(ALLOWED_CHARS, repeat=2):
    k = bytes(k)*5
    pt = bits_to_bytes(depresent(ct_concrete[:64], bytes_to_bits(k), 30))
    if all(c in ALLOWED_CHARS for c in pt):
        flag1.append(pt + k[:2])
        
flag2 = []
for k in product(ALLOWED_CHARS, repeat=2):
    k = bytes(k)*5
    pt = bits_to_bytes(depresent(ct_concrete[64:], bytes_to_bits(k), 30))
    if all(c in ALLOWED_CHARS for c in pt):
        flag2.append(pt + k[:2])
        
for x in flag1:
    for y in flag2:
        flag = b"SEE{" + x+y + b"}"
        print(flag.decode())

SEE{5i87,P3&NvG_pR3&ENt?}
SEE{sT1lL_Us!nG_pR3&ENt?}


In [101]:
fs.constraints[:10]

[0 <= x_0 + x_64 - x_160,
 0 <= x_0 - x_64 + x_160,
 0 <= -1*x_0 + x_64 + x_160,
 0 <= 2 - x_0 - x_64 - x_160,
 0 <= x_1 + x_65 - x_161,
 0 <= x_1 - x_65 + x_161,
 0 <= -1*x_1 + x_65 + x_161,
 0 <= 2 - x_1 - x_65 - x_161,
 0 <= x_2 + x_66 - x_162,
 0 <= x_2 - x_66 + x_162]

In [96]:
const[0]

(None, ([160, 64, 0], [1.0, -1.0, -1.0]), 0.0)

In [119]:
constr = "\n".join(map(str, fs.constraints))
open("constraints.txt", "w").write(constr)

1299068

In [64]:
flag

b'2o50_&_uSinG_P3SeN4?'

In [65]:
# Test overal model

bind_milp(solver, ptmilp1+kvars1+ptmilp2+kvars2, bytes_to_bits(flag))
solver.solve()
bits_to_bytes([*map(int, solver.get_values(ptmilp1+kvars1+ptmilp2+kvars2))])

MIPSolverException: GLPK: Problem has no feasible solution

In [99]:
# test sbox

sbox_hrep = gethrep_sbox(SBOX)

solver = MixedIntegerLinearProgram(maximization=True, solver="GLPK")
vargen = VarGen(solver)

for x in range(0x10):
    inp = [*map(int, format(x, "04b"))]
    out = [vargen.gen() for _ in range(4)]
    ineqs = hrep_to_ineq(sbox_hrep, [*inp, *out])
    for i in ineqs:
        solver.add_constraint(i)

    solver.solve()
    y = int("".join([str(int(i)) for i in solver.get_values(out)]), 2)
    assert SBOX[x] == y

In [100]:
ineqs

[0 <= 12 + 12*x_60 + 4*x_61 - 7*x_62 - 10*x_63,
 0 <= 8 - 7*x_60 - 4*x_61 + 12*x_62 - 10*x_63,
 0 <= -1 - x_60 - 3*x_61 + x_62 + 6*x_63,
 0 <= 1 + x_60 + 3*x_61 - x_62 + 6*x_63,
 0 <= 23 - 5*x_60 - 7*x_61 - 9*x_62 - 3*x_63,
 0 <= 5 - 2*x_60 + x_61 + 4*x_62 - x_63,
 0 <= -1 + 10*x_60 - 6*x_61 + 2*x_62 + x_63,
 0 <= 1 - 3*x_60 + x_61 - x_62 - x_63,
 0 <= 7 - x_60 - 2*x_61 - 3*x_62 + x_63]