Skip to content

Commit

Permalink
add support for GPT2, BERT
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewZhaoLuo committed Jun 3, 2021
1 parent 3d871ba commit 8bb3ad6
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 16 deletions.
37 changes: 24 additions & 13 deletions src/relay/transforms/fp32_to_fp16.cc
Expand Up @@ -179,21 +179,26 @@ class RewriteBasedOnColors : public ExprMutator {
}
}

Array<Expr> get_new_args(const CallNode* call, DataType arg_cast_datatype) {
Array<Expr> ret;
std::pair<Array<Expr>, Array<Type>> get_new_args(const CallNode* call,
DataType arg_cast_datatype) {
Array<Expr> args;
Array<Type> types;
for (Expr arg : call->args) {
arg = VisitExpr(arg);
Type arg_type = GetTypedExpr(arg)->checked_type();
ret.push_back(arg_cast_helper(arg, arg_type, arg_cast_datatype));
Expr new_expr = arg_cast_helper(arg, arg_type, arg_cast_datatype);
args.push_back(new_expr);
types.push_back(GetTypedExpr(new_expr)->checked_type());
}

return ret;
return {args, types};
}

Attrs get_new_attrs(const CallNode* call, DataType accumulation_dtype) {
Attrs new_attrs = Attrs(call->attrs);
if (new_attrs.get() != nullptr) {
// TODO: Figure out a better way to do this
// out_dtype attributes (accumulation dtype)
if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
modify_output_dtype(attrs, accumulation_dtype);
} else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
Expand All @@ -220,6 +225,7 @@ class RewriteBasedOnColors : public ExprMutator {
modify_output_dtype(attrs, accumulation_dtype);
}

// dtype attributes (creating new tensors of type dtype)
if (auto attrs = new_attrs.as<InitOpAttrs>()) {
modify_dtype(attrs, accumulation_dtype);
}
Expand Down Expand Up @@ -286,9 +292,19 @@ class RewriteBasedOnColors : public ExprMutator {
// TODO: extend to bfloat types
DataType arg_cast_dtype = color == GREEN ? DataType::Float(16) : DataType::Float(32);

Array<Expr> new_args = get_new_args(call, arg_cast_dtype);
auto new_args_and_types = get_new_args(call, arg_cast_dtype);
Array<Expr> new_args = new_args_and_types.first;
Array<Type> new_types;

if (call->op.as<FunctionNode>()) {
// Function Nodes don't store type info in the Call, it should be a []
new_types = call->type_args;
} else {
new_types = new_args_and_types.second;
}

Attrs new_attrs = get_new_attrs(call, output_dtypes.accumulation_dtype);
Expr output = Call(new_op, new_args, new_attrs, call->type_args, call->span);
Expr output = Call(new_op, new_args, new_attrs, new_types, call->span);

color_map[output.as<CallNode>()] = color_map[call];
if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
Expand All @@ -303,7 +319,8 @@ class RewriteBasedOnColors : public ExprMutator {
Expr VisitExpr_(const FunctionNode* func) final {
// Erase the ret_type annotation and let the pass recalculate
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
return ExprMutator::VisitExpr_(func);
Expr result = ExprMutator::VisitExpr_(func);
return result;
}
};

Expand Down Expand Up @@ -352,12 +369,6 @@ Expr RewriteFp16Graph(const Expr& expr, bool debug) {

// Insert an extraneous cast to FP32 to match old module output
Expr result = rewriter.Mutate(expr);

// Old type annotations may no longer be accurate so rewrite
if (const FunctionNode* func = result.as<FunctionNode>()) {
const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
}

return result;
}

Expand Down
48 changes: 48 additions & 0 deletions src/relay/transforms/fp32_to_fp16.h
Expand Up @@ -32,22 +32,58 @@ OpStringSet DEFAULT_GREEN_LIST({
"nn.conv2d_transpose",
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
});
// TODO make a list of ops which don't care about the types of tensors coming in for stuff like
// "where" and "strided_slice"
OpStringSet DEFAULT_GRAY_LIST({
// These ops add new data or change shape
"nn.pad",
"nn.batch_flatten",
"concatenate",
"zeros",
"split",
"squeeze",
"transpose",
"expand_dims",
"reshape",
"dyn.reshape",
"broadcast_to_like",
"strided_slice",
"dyn.strided_slice",
"take",
"argwhere",
"where",
"tile",
"dyn.tile",
"scatter",
"full",
"dyn.full",
// Comparison
"less",
"greater",
"less_equal",
"greater_equal",
// By definition copy and cast will become green or red based on inputs
"copy",
"cast",
"cast_like",
// Simple arithmetic
"add",
"subtract",
"multiply",
"divide",
"nn.bias_add",
"nn.batch_norm",
"sum",
"mean",
"sqrt",
"shape_of",
// Simple activations
"max",
"min",
"maximum",
"minimum",
"nn.relu",
"nn.leaky_relu",
"nn.prelu",
Expand Down Expand Up @@ -78,6 +114,8 @@ OpStringSet DEFAULT_GRAY_LIST({
OpStringSet DEFAULT_RED_LIST({
// In general if |f(x)| >> |x| for some expected inputs to the op then put it here.
// Activations with exponents or dividing by small numbers
"exp",
"power",
"nn.cross_entropy",
"nn.cross_entropy_with_logits",
"nn.softmax",
Expand Down Expand Up @@ -132,6 +170,16 @@ class DefaultFP16Colorer {
class DefaultFP16OpDefinition {
public:
FP16OpDType operator()(const CallNode* call) {
// TODO: remove when batch_matmul handles accumulation dtypes properly
// Batched matmul has inconsistent support for mixed precision operations
// Many schedules ignore the out_dtype attribute which leads to errors when
// input types do not match the out_dtype. Therefore, accumulate to fp16.
if (auto op_node = call->op.as<OpNode>()) {
if (op_node->name == "nn.batch_matmul") {
return {DataType::Float(16), DataType::Float(16)};
}
}

if (call->attrs != NullValue<Attrs>()) {
Array<AttrFieldInfo> fields = call->attrs->ListFieldInfo();
for (AttrFieldInfo field_info : fields) {
Expand Down
170 changes: 167 additions & 3 deletions tests/python/relay/test_fp32_to_fp16_transform.py
@@ -1,7 +1,11 @@
import tempfile
from collections import defaultdict
from typing import *

import numpy as np
import onnx
import torch.onnx
import torchvision
import tvm
from tvm import relay
from tvm.relay.op.tensor import exp
Expand All @@ -13,18 +17,23 @@
def run_module(mod, mod_params):
dev = tvm.device("llvm", 0)
intrp = relay.create_executor("debug", mod, device=dev, target="llvm")
return intrp.evaluate()(**mod_params).asnumpy()
result = intrp.evaluate()(**mod_params)
if isinstance(result, tvm.runtime.container.ADT):
result = [r.asnumpy() for r in result]
return result
else:
return [result.asnumpy()]


def verify_fp32_fp16_output_close(mod, mod_params, rtol=1e-3, atol=0):
mod = InferType()(mod)
mod = AnnotateSpans()(mod)
result_fp32 = run_module(mod, mod_params)
fp16_mod = RewriteFP16()(mod)
result_fp16 = run_module(fp16_mod, mod_params)

# Ensure the results are close
np.testing.assert_allclose(result_fp32, result_fp16, rtol=rtol, atol=atol)
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)

return fp16_mod

Expand Down Expand Up @@ -240,3 +249,158 @@ def test_let_statement_simple():
expected_mod = InferType()(expected_mod)

assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_where_simple():
# Where can be a little tricky due the mixing of dtypes
data = relay.var("data", shape=[1, 20])
weight = relay.var("weight", shape=[20, 20])
a = relay.nn.dense(data, weight, units=20)
b = relay.where(data, a, a)
mod = tvm.IRModule.from_expr(b)
mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
"weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
}

output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)

# Create expected module
data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
b = relay.where(data, a, a)
expected_mod = tvm.IRModule.from_expr(b)
expected_mod = InferType()(expected_mod)

assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_batch_matmul_simple():
# Batch matmul is a special case where we try to accumulate to fp16
# Due to the fact the topi does not work at the moment.
data = relay.var("data", shape=[1, 1, 20])
weight = relay.var("weight", shape=[1, 20, 20])
a = relay.nn.batch_matmul(data, weight)
mod = tvm.IRModule.from_expr(a)
mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"),
"weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"),
}
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
# Create expected module
data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16")
weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16")
a = relay.nn.batch_matmul(data, weight, out_dtype="float16")
expected_mod = tvm.IRModule.from_expr(a)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


# Straight image classification models
def test_onnx_resnet18():
model_path = "/Users/andrewzhaoluo/Downloads/resnet18-v1-7.onnx"
# now you have super_resolution.onnx on disk
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model)
mod_params["data"] = np.random.uniform(0, 1, size=[1, 3, 224, 224]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_efficientnet():
model_path = "/Users/andrewzhaoluo/Downloads/efficientnet-lite4-11.onnx"
# now you have super_resolution.onnx on disk
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model)
mod_params["images:0"] = np.random.uniform(0, 1, size=[1, 224, 224, 3]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_densenet():
model_path = "/Users/andrewzhaoluo/Downloads/densenet-3.onnx"
# now you have super_resolution.onnx on disk
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model)
mod_params["data_0"] = np.random.uniform(0, 1, size=[1, 3, 224, 224]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_inceptionv3():
model_path = "/Users/andrewzhaoluo/Downloads/inceptionv3.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"input.1": [1, 3, 299, 299]})
mod_params["input.1"] = np.random.uniform(0, 1, size=[1, 3, 299, 299]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


# Object detection models
def test_onnx_tinyyolo2():
model_path = "/Users/andrewzhaoluo/Downloads/tinyyolov2-7.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"image": [1, 3, 416, 416]})
mod_params["image"] = np.random.uniform(0, 1, size=[1, 3, 416, 416]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_yolo2():
model_path = "/Users/andrewzhaoluo/Downloads/yolov2-coco-9.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"input.1": [1, 3, 416, 416]})
mod_params["input.1"] = np.random.uniform(0, 1, size=[1, 3, 416, 416]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


# Face recognition / embedding
def test_onnx_arcfaceresnet():
model_path = "/Users/andrewzhaoluo/Downloads/arcfaceresnet100-8.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model)
mod_params["data"] = np.random.uniform(0, 1, size=[1, 3, 112, 112]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_rfb():
model_path = "/Users/andrewzhaoluo/Downloads/version-RFB-320.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model)
mod_params["input"] = np.random.uniform(0, 1, size=[1, 3, 240, 320]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


# Super resolution
def test_onnx_superresolution():
model_path = "/Users/andrewzhaoluo/Downloads/super-resolution-10.onnx"
onnx_model = onnx.load(model_path)
mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"input": [1, 1, 224, 224]})
mod_params["input"] = np.random.uniform(0, 1, size=[1, 1, 224, 224]).astype("float32")
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


# NLP models (ruh roh!)
def test_onnx_gpt2():
model_path = "/Users/andrewzhaoluo/Downloads/gpt2-10.onnx"
onnx_model = onnx.load(model_path)

mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"input1": [1, 1, 1]})
mod_params["input1"] = np.random.randint(0, 100, size=[1, 1, 1])
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)


def test_onnx_distillbert():
model_path = "/Users/andrewzhaoluo/Downloads/distilbert.onnx"
onnx_model = onnx.load(model_path)

mod, mod_params = relay.frontend.from_onnx(onnx_model, shape={"input.1": [10, 100]})
mod_params["input.1"] = np.random.randint(0, 100, size=[10, 100])
output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.05, rtol=0.01)
assert not tvm.ir.structural_equal(mod, output_mod)

0 comments on commit 8bb3ad6

Please sign in to comment.