Skip to content

Commit

Permalink
[mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Buil…
Browse files Browse the repository at this point in the history
…der (llvm#68969)

Previously, these would set their ArrayRef members to reference their
storage SmallVectors after a copy-on-write (COW) operation. This leads
to a use-after-free if the builder is copied and the original destroyed
(as the new builder would still reference the old SmallVector).

This could easily accidentally occur in code like (annotated):
```c++
// 1. `VectorType::Builder(type)` constructs a new temporary builder
// 2. `.dropDim(0)` updates the temporary builder by reference, and returns a `VectorType::Builder&`
//    - Modifying the shape is a COW operation, so `storage` is used, and `shape` updated the reference it
// 3. Assigning the reference to `auto` copies the builder (via the default C++ copy ctor)
//    -  There's no special handling for `shape` and `storage`, so the new shape points to the old builder's `storage`
auto newType = VectorType::Builder(type).dropDim(0);
// 4. When this line is reached the original temporary builder is destroyed
//    - Actually constructing the vector type is now a use-after-free
VectorType newVectorType = VectorType(newType);
```

This is fixed with these changes by using `CopyOnWriteArrayRef<T>`,
which implements the same functionality, but ensures no
dangling references are possible if it's copied. 

---

The VectorType::Builder also set the ArrayRef<bool> scalableDims member
to a temporary SmallVector when the provided scalableDims are empty.
This again leads to a use-after-free, and is unnecessary as
VectorType::get already handles being passed an empty scalableDims
array.

These bugs were in-part caught by UBSAN, see:
https://lab.llvm.org/buildbot/#/builders/5/builds/37355
  • Loading branch information
MacDue committed Oct 18, 2023
1 parent 28e4f97 commit b0b8e83
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 44 deletions.
59 changes: 15 additions & 44 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"

namespace llvm {
class BitVector;
Expand Down Expand Up @@ -274,20 +275,14 @@ class RankedTensorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.erase(storage.begin() + pos);
shape = {storage.data(), storage.size()};
shape.erase(pos);
return *this;
}

/// Insert a val into shape @pos.
Builder &insertDim(int64_t val, unsigned pos) {
assert(pos <= shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.insert(storage.begin() + pos, val);
shape = {storage.data(), storage.size()};
shape.insert(pos, val);
return *this;
}

Expand All @@ -296,9 +291,7 @@ class RankedTensorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
CopyOnWriteArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
Expand All @@ -313,27 +306,18 @@ class VectorType::Builder {
public:
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()),
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}

/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
: shape(shape), elementType(elementType) {
if (scalableDims.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
this->scalableDims = scalableDims;
}
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}

Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
if (newIsScalableDim.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
scalableDims = newIsScalableDim;

shape = newShape;
scalableDims = newIsScalableDim;
return *this;
}

Expand All @@ -345,25 +329,16 @@ class VectorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
if (storageScalableDims.empty())
storageScalableDims.append(scalableDims.begin(), scalableDims.end());
storage.erase(storage.begin() + pos);
storageScalableDims.erase(storageScalableDims.begin() + pos);
shape = {storage.data(), storage.size()};
scalableDims =
ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
shape.erase(pos);
if (!scalableDims.empty())
scalableDims.erase(pos);
return *this;
}

/// Set a dim in shape @pos to val.
Builder &setDim(unsigned pos, int64_t val) {
if (storage.empty())
storage.append(shape.begin(), shape.end());
assert(pos < storage.size() && "overflow");
storage[pos] = val;
shape = {storage.data(), storage.size()};
assert(pos < shape.size() && "overflow");
shape.set(pos, val);
return *this;
}

Expand All @@ -372,13 +347,9 @@ class VectorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
ArrayRef<bool> scalableDims;
// Owning scalableDims data for copy-on-write operations.
SmallVector<bool> storageScalableDims;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
Expand Down
82 changes: 82 additions & 0 deletions mlir/include/mlir/Support/ADTExtras.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- ADTExtras.h - Extra ADTs for use in MLIR -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_SUPPORT_ADTEXTRAS_H
#define MLIR_SUPPORT_ADTEXTRAS_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {

//===----------------------------------------------------------------------===//
// CopyOnWriteArrayRef<T>
//===----------------------------------------------------------------------===//

// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
// modification. This is for use in the mlir::<Type>::Builders.
template <typename T>
class CopyOnWriteArrayRef {
public:
CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};

CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
nonOwning = array;
owningStorage = {};
return *this;
}

void insert(size_t index, T value) {
SmallVector<T> &vector = ensureCopy();
vector.insert(vector.begin() + index, value);
}

void erase(size_t index) {
// Note: A copy can be avoided when just dropping the front/back dims.
if (isNonOwning() && index == 0) {
nonOwning = nonOwning.drop_front();
} else if (isNonOwning() && index == size() - 1) {
nonOwning = nonOwning.drop_back();
} else {
SmallVector<T> &vector = ensureCopy();
vector.erase(vector.begin() + index);
}
}

void set(size_t index, T value) { ensureCopy()[index] = value; }

size_t size() const { return ArrayRef<T>(*this).size(); }

bool empty() const { return ArrayRef<T>(*this).empty(); }

operator ArrayRef<T>() const {
return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
}

private:
bool isNonOwning() const { return !nonOwning.empty(); }

SmallVector<T> &ensureCopy() {
// Empty non-owning storage signals the array has been copied to the owning
// storage (or both are empty). Note: `nonOwning` should never reference
// `owningStorage`. This can lead to dangling references if the
// CopyOnWriteArrayRef<T> is copied.
if (isNonOwning()) {
owningStorage = SmallVector<T>(nonOwning);
nonOwning = {};
}
return owningStorage;
}

ArrayRef<T> nonOwning;
SmallVector<T> owningStorage;
};

} // namespace mlir

#endif
95 changes: 95 additions & 0 deletions mlir/unittests/IR/ShapedTypeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,99 @@ TEST(ShapedTypeTest, CloneVector) {
VectorType::get(vectorNewShape, vectorNewType));
}

TEST(ShapedTypeTest, VectorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 9, 1};
SmallVector<bool> scalableDims{true, false, true, false, false};
VectorType vectorType = VectorType::get(shape, f32, scalableDims);

{
// Drop some dims.
VectorType dropFrontTwoDims =
VectorType::Builder(vectorType).dropDim(0).dropDim(0);
ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
dropFrontTwoDims.getScalableDims());
}

{
// Set some dims.
VectorType setTwoDims =
VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
VectorType newVectorType = VectorType(builder);
ASSERT_EQ(newVectorType.getDimSize(0), 16);
}

{
// Make builder from scratch (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
SmallVector<int64_t> shape{1, 2, 3, 4};
VectorType::Builder builder(shape, f32);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
}

{
// Set vector shape (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
VectorType::Builder builder(vectorType);
SmallVector<int64_t> newShape{2, 2};
builder.setShape(newShape);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
}
}

TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 16, 32};
RankedTensorType tensorType = RankedTensorType::get(shape, f32);

{
// Drop some dims.
RankedTensorType dropFrontTwoDims =
RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
}

{
// Insert some dims.
RankedTensorType insertTwoDims =
RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
ASSERT_EQ(insertTwoDims.getShape(),
ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
RankedTensorType::Builder builder =
RankedTensorType::Builder(tensorType).dropDim(0);
RankedTensorType newTensorType = RankedTensorType(builder);
ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
}
}

} // namespace

0 comments on commit b0b8e83

Please sign in to comment.