Skip to content

Commit

Permalink
Implement sub-byte addressing support to the memory model
Browse files Browse the repository at this point in the history
When load/store have different bit-width, the result is poison

In ASM mode, we zero out the leftover bits to enforce an ABI where stores
must zero-extend values.

Closes #1019
  • Loading branch information
nunoplopes committed Jun 13, 2024
1 parent 60bf5b1 commit 405c202
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 40 deletions.
1 change: 1 addition & 0 deletions ir/globals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ unsigned bits_program_pointer = 64;
unsigned bits_size_t = 64;
unsigned bits_ptr_address = 64;
unsigned bits_byte = 8;
unsigned num_sub_byte_bits = 6;
unsigned strlen_unroll_cnt = 8;
unsigned memcmp_unroll_cnt = 8;
bool little_endian = true;
Expand Down
4 changes: 4 additions & 0 deletions ir/globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ extern unsigned bits_ptr_address;
/// Number of bits for a byte.
extern unsigned bits_byte;

/// Required bits to store the size of sub-byte accesses
/// (e.g., store i5 -> we record 4, so 3 bits)
extern unsigned num_sub_byte_bits;

extern unsigned strlen_unroll_cnt;
extern unsigned memcmp_unroll_cnt;

Expand Down
4 changes: 3 additions & 1 deletion ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3503,11 +3503,13 @@ MemInstr::ByteAccessInfo::get(const Type &t, bool store, unsigned align) {
info.doesPtrStore = ptr_access && store;
info.doesPtrLoad = ptr_access && !store;
info.byteSize = gcd(align, getCommonAccessSize(t));
if (auto intTy = t.getAsIntType())
info.subByteAccess = intTy->maxSubBitAccess();
return info;
}

MemInstr::ByteAccessInfo MemInstr::ByteAccessInfo::full(unsigned byteSize) {
return { true, true, true, true, byteSize };
return { true, true, true, true, byteSize, 0 };
}


Expand Down
2 changes: 2 additions & 0 deletions ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,8 @@ class MemInstr : public Instr {
// Otherwise, bytes of a memory can be widened to this size.
unsigned byteSize = 0;

unsigned subByteAccess = 0;

bool doesMemAccess() const { return byteSize; }

static ByteAccessInfo intOnly(unsigned byteSize);
Expand Down
90 changes: 70 additions & 20 deletions ir/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ static unsigned padding_ptr_byte() {

static unsigned padding_nonptr_byte() {
return
Byte::bitsByte() - byte_has_ptr_bit() - bits_byte - bits_poison_per_byte;
Byte::bitsByte() - byte_has_ptr_bit() - bits_byte - bits_poison_per_byte
- num_sub_byte_bits;
}

static expr concat_if(const expr &ifvalid, expr &&e) {
Expand Down Expand Up @@ -217,7 +218,8 @@ Byte::Byte(const Memory &m, const StateValue &ptr, unsigned i) : m(m) {
assert(!ptr.isValid() || p.bits() == bitsByte());
}

Byte::Byte(const Memory &m, const StateValue &v) : m(m) {
Byte::Byte(const Memory &m, const StateValue &v, unsigned bits_read, bool)
: m(m) {
assert(!v.isValid() || v.value.bits() == bits_byte);
assert(!v.isValid() || v.non_poison.isBool() ||
v.non_poison.bits() == bits_poison_per_byte);
Expand All @@ -234,12 +236,18 @@ Byte::Byte(const Memory &m, const StateValue &v) : m(m) {
expr::mkInt(-1, bits_poison_per_byte),
expr::mkUInt(0, bits_poison_per_byte))
: v.non_poison;
p = concat_if(p, np.concat(v.value).concat_zeros(padding_nonptr_byte()));
p = concat_if(p, np.concat(v.value));
if (num_sub_byte_bits) {
if ((bits_read % 8) == 0)
bits_read = 0;
p = p.concat(expr::mkUInt(bits_read, num_sub_byte_bits));
}
p = p.concat_zeros(padding_nonptr_byte());
assert(!p.isValid() || p.bits() == bitsByte());
}

Byte Byte::mkPoisonByte(const Memory &m) {
return { m, StateValue(expr::mkUInt(0, bits_byte), false) };
return { m, StateValue(expr::mkUInt(0, bits_byte), false), 0, true };
}

expr Byte::isPtr() const {
Expand Down Expand Up @@ -276,7 +284,7 @@ expr Byte::ptrByteoffset() const {
expr Byte::nonptrNonpoison() const {
if (!does_int_mem_access)
return expr::mkUInt(0, 1);
unsigned start = padding_nonptr_byte() + bits_byte;
unsigned start = padding_nonptr_byte() + bits_byte + num_sub_byte_bits;
return p.extract(start + bits_poison_per_byte - 1, start);
}

Expand All @@ -288,10 +296,15 @@ expr Byte::boolNonptrNonpoison() const {
expr Byte::nonptrValue() const {
if (!does_int_mem_access)
return expr::mkUInt(0, bits_byte);
unsigned start = padding_nonptr_byte();
unsigned start = padding_nonptr_byte() + num_sub_byte_bits;
return p.extract(start + bits_byte - 1, start);
}

expr Byte::numStoredBits() const {
unsigned start = padding_nonptr_byte();
return p.extract(start + num_sub_byte_bits - 1, start);
}

expr Byte::isPoison() const {
if (!does_int_mem_access)
return does_ptr_mem_access ? !ptrNonpoison() : true;
Expand Down Expand Up @@ -403,7 +416,8 @@ expr Byte::refined(const Byte &other) const {
unsigned Byte::bitsByte() {
unsigned ptr_bits = does_ptr_mem_access *
(1 + Pointer::totalBits() + bits_ptr_byte_offset());
unsigned int_bits = does_int_mem_access * (bits_byte + bits_poison_per_byte);
unsigned int_bits = does_int_mem_access * (bits_byte + bits_poison_per_byte)
+ num_sub_byte_bits;
// allow at least 1 bit if there's no memory access
return max(1u, byte_has_ptr_bit() + max(ptr_bits, int_bits));
}
Expand All @@ -419,10 +433,11 @@ ostream& operator<<(ostream &os, const Byte &byte) {
} else {
auto np = byte.nonptrNonpoison();
auto val = byte.nonptrValue();
if (np.isZero())
return os << "poison";

if (np.isAllOnes()) {
val.printHexadecimal(os);
} else if (np.isZero()) {
os << "poison";
} else {
os << "#b";
for (unsigned i = 0; i < bits_poison_per_byte; ++i) {
Expand All @@ -432,6 +447,12 @@ ostream& operator<<(ostream &os, const Byte &byte) {
os << (is_poison ? 'p' : (v ? '1' : '0'));
}
}

uint64_t num_bits;
if (num_sub_byte_bits &&
byte.numStoredBits().isUInt(num_bits) && num_bits != 0) {
os << " / written with " << num_bits << " bits";
}
}
return os;
}
Expand Down Expand Up @@ -484,14 +505,14 @@ static void pad(StateValue &v, unsigned amount, State &s) {
}

static vector<Byte> valueToBytes(const StateValue &val, const Type &fromType,
const Memory &mem, State *s) {
const Memory &mem, State &s) {
vector<Byte> bytes;
if (fromType.isPtrType()) {
Pointer p(mem, val.value);
unsigned bytesize = bits_program_pointer / bits_byte;

// constant global can't store pointers that alias with local blocks
if (s->isInitializationPhase() && !p.isLocal().isFalse()) {
if (s.isInitializationPhase() && !p.isLocal().isFalse()) {
expr bid = expr::mkUInt(0, 1).concat(p.getShortBid());
p = Pointer(mem, bid, p.getOffset(), p.getAttrs());
}
Expand All @@ -500,19 +521,24 @@ static vector<Byte> valueToBytes(const StateValue &val, const Type &fromType,
bytes.emplace_back(mem, StateValue(expr(p()), expr(val.non_poison)), i);
} else {
assert(!fromType.isAggregateType() || isNonPtrVector(fromType));
StateValue bvval = fromType.toInt(*s, val);
StateValue bvval = fromType.toInt(s, val);
unsigned bitsize = bvval.bits();
unsigned bytesize = divide_up(bitsize, bits_byte);

pad(bvval, bytesize * bits_byte - bitsize, *s);
// There are no sub-byte accesses in assembly
if (mem.isAsmMode() && (bitsize % 8) != 0) {
s.addUB(expr(false));
}

pad(bvval, bytesize * bits_byte - bitsize, s);
unsigned np_mul = bits_poison_per_byte;

for (unsigned i = 0; i < bytesize; ++i) {
StateValue data {
bvval.value.extract((i + 1) * bits_byte - 1, i * bits_byte),
bvval.non_poison.extract((i + 1) * np_mul - 1, i * np_mul)
};
bytes.emplace_back(mem, data);
bytes.emplace_back(mem, data, bitsize, true);
}
}
return bytes;
Expand Down Expand Up @@ -581,18 +607,42 @@ static StateValue bytesToValue(const Memory &m, const vector<Byte> &bytes,
bool is_asm = m.isAsmMode();

StateValue val;
expr stored_bits;
bool first = true;
IntType ibyteTy("", bits_byte);

for (auto &b: bytes) {
expr isptr = ub_pre(!b.isPtr());
expr expr_np = ub_pre(!b.isPtr());

if (num_sub_byte_bits) {
unsigned bits = (bitsize % 8) == 0 ? 0 : bitsize;
auto this_stored_bits = b.numStoredBits();
expr_np &= this_stored_bits == bits;
if (first)
stored_bits = std::move(this_stored_bits);
}

StateValue v(is_asm ? b.forceCastToInt() : b.nonptrValue(),
is_asm ? b.nonPoison()
: ibyteTy.combine_poison(isptr, b.nonptrNonpoison()));
: ibyteTy.combine_poison(
expr_np, b.nonptrNonpoison()));
val = first ? std::move(v) : v.concat(val);
first = false;
}
return toType.fromInt(val.trunc(bitsize, toType.np_bits(true)));

val = toType.fromInt(val.trunc(bitsize, toType.np_bits(true)));

// Assume that in ASM mode, the ABI mandates that stores of non-byte-aligned
// stores are zero extended to the next byte boundary
// So here we need to zero out any bit that may not be zero since the
// initial memory is not for ASM.
// It also means we need to ensure the ABI is respected on stores.
if (is_asm && num_sub_byte_bits) {
auto shift
= expr::mkUInt(bitsize, val.value) - stored_bits.zextOrTrunc(val.bits());
val.value = val.value & (expr::mkInt(-1, shift).lshr(shift));
}
return val;
}
}

Expand Down Expand Up @@ -1704,7 +1754,7 @@ void Memory::setState(const Memory::CallState &st,
// TODO: havoc local blocks
// for now, zero out if in non UB-exploitation mode to avoid false positives
if (config::disallow_ub_exploitation) {
expr raw_byte = Byte(*this, {expr::mkUInt(0, bits_byte), true})();
expr raw_byte = Byte(*this, {expr::mkUInt(0, bits_byte), true}, 0, true)();
expr array = expr::mkConstArray(expr::mkUInt(0, bits_for_offset),
raw_byte);

Expand Down Expand Up @@ -1925,7 +1975,7 @@ void Memory::store(const StateValue &v, const Type &type, unsigned offset0,
assert(byteofs == getStoreByteSize(type));

} else {
vector<Byte> bytes = valueToBytes(v, type, *this, state);
vector<Byte> bytes = valueToBytes(v, type, *this, *state);
assert(!v.isValid() || bytes.size() * bytesz == getStoreByteSize(type));

for (unsigned i = 0, e = bytes.size(); i < e; ++i) {
Expand Down Expand Up @@ -2043,7 +2093,7 @@ void Memory::memset(const expr &p, const StateValue &val, const expr &bytesize,
}
assert(!val.isValid() || wval.bits() == bits_byte);

auto bytes = valueToBytes(wval, IntType("", bits_byte), *this, state);
auto bytes = valueToBytes(wval, IntType("", bits_byte), *this, *state);
assert(bytes.size() == 1);
expr raw_byte = std::move(bytes[0])();

Expand Down
19 changes: 10 additions & 9 deletions ir/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ class State;
// A data structure that represents a byte.
// A byte is either a pointer byte or a non-pointer byte.
// Pointer byte's representation:
// +-+-----------+-----------------------------+---------------+---------+
// |1|non-poison?| Pointer (see class below) | byte offset | padding |
// | |(1 bit) | | (0 or 3 bits) | |
// +-+-----------+-----------------------------+---------------+---------+
// +-+-------------+-----------+---------------+----------------------+
// |1| non-poison? | Pointer | byte offset | padding |
// | | (1 bit) | | (0 or 3 bits) | |
// +-+-------------+-----------+---------------+----------------------+
// Non-pointer byte's representation:
// +-+--------------------+--------------------+-------------------------+
// |0| non-poison bit(s) | data | padding |
// | | (bits_byte) | (bits_byte) | |
// +-+--------------------+--------------------+-------------------------+
// +-+-------------------+-------------+--------------------+---------+
// |0| non-poison bit(s) | data | stored bits | padding |
// | | (bits_byte) | (bits_byte) | (num_subbyte_bits) | |
// +-+-------------------+-------------+--------------------+---------+

class Byte {
const Memory &m;
Expand All @@ -52,7 +52,7 @@ class Byte {
// non_poison should be an one-bit vector or boolean.
Byte(const Memory &m, const StateValue &ptr, unsigned i);

Byte(const Memory &m, const StateValue &v);
Byte(const Memory &m, const StateValue &v, unsigned bits_read, bool);

static Byte mkPoisonByte(const Memory &m);

Expand All @@ -64,6 +64,7 @@ class Byte {
smt::expr nonptrNonpoison() const;
smt::expr boolNonptrNonpoison() const;
smt::expr nonptrValue() const;
smt::expr numStoredBits() const;
smt::expr isPoison() const;
smt::expr nonPoison() const;
smt::expr isZero() const; // zero or null
Expand Down
20 changes: 20 additions & 0 deletions ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ expr Type::enforcePtrOrVectorType() const {
[&](auto &ty) { return ty.enforcePtrType(); });
}

const IntType* Type::getAsIntType() const {
return nullptr;
}

const FloatType* Type::getAsFloatType() const {
return nullptr;
}
Expand Down Expand Up @@ -321,6 +325,14 @@ void VoidType::print(ostream &os) const {
}


unsigned IntType::maxSubBitAccess() const {
if (!defined)
return 63;
if (bitwidth % 8)
return bitwidth;
return 0;
}

unsigned IntType::bits() const {
return bitwidth;
}
Expand Down Expand Up @@ -359,6 +371,10 @@ expr IntType::enforceIntType(unsigned bits) const {
return bits ? sizeVar() == bits : true;
}

const IntType* IntType::getAsIntType() const {
return this;
}

pair<expr, expr>
IntType::refines(State &src_s, State &tgt_s, const StateValue &src,
const StateValue &tgt) const {
Expand Down Expand Up @@ -1357,6 +1373,10 @@ expr SymbolicType::enforceVectorType(
return v ? (isVector() && v->enforceVectorType(enforceElem)) : false;
}

const IntType* SymbolicType::getAsIntType() const {
return &*i;
}

const FloatType* SymbolicType::getAsFloatType() const {
return &*f;
}
Expand Down
5 changes: 5 additions & 0 deletions ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace IR {

class AggregateType;
class FloatType;
class IntType;
class StructType;
class SymbolicType;
class VectorType;
Expand Down Expand Up @@ -88,6 +89,7 @@ class Type {
smt::expr enforceFloatOrVectorType() const;
smt::expr enforcePtrOrVectorType() const;

virtual const IntType* getAsIntType() const;
virtual const FloatType* getAsFloatType() const;
virtual const AggregateType* getAsAggregateType() const;
virtual const StructType* getAsStructType() const;
Expand Down Expand Up @@ -157,6 +159,7 @@ class IntType final : public Type {
IntType(std::string &&name, unsigned bitwidth)
: Type(std::move(name)), bitwidth(bitwidth), defined(true) {}

unsigned maxSubBitAccess() const;
unsigned bits() const override;
IR::StateValue getDummyValue(bool non_poison) const override;
smt::expr getTypeConstraints() const override;
Expand All @@ -165,6 +168,7 @@ class IntType final : public Type {
void fixup(const smt::Model &m) override;
bool isIntType() const override;
smt::expr enforceIntType(unsigned bits = 0) const override;
const IntType* getAsIntType() const override;
std::pair<smt::expr, smt::expr>
refines(State &src_s, State &tgt_s, const StateValue &src,
const StateValue &tgt) const override;
Expand Down Expand Up @@ -399,6 +403,7 @@ class SymbolicType final : public Type {
smt::expr enforceFloatType() const override;
smt::expr enforceVectorType(
const std::function<smt::expr(const Type&)> &enforceElem) const override;
const IntType* getAsIntType() const override;
const FloatType* getAsFloatType() const override;
const AggregateType* getAsAggregateType() const override;
const StructType* getAsStructType() const override;
Expand Down
Loading

0 comments on commit 405c202

Please sign in to comment.