Skip to content

Commit

Permalink
Add test cases for the compile and run
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Feb 21, 2024
1 parent 183cf98 commit 1db7d03
Show file tree
Hide file tree
Showing 2 changed files with 460 additions and 209 deletions.
259 changes: 175 additions & 84 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,26 +2057,31 @@ def cumsum(
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)


def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):
def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"):
"""Returns a tensor where each row contains the index sampled from the multinomial
probability distribution located in the corresponding row of tensor prob.
Notes
-----
For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
For accurate results, ensure probabilities are between 0 and 1 and sum to 1.
Parameters
----------
prob : Tensor
The probability tensor.
A 2-D tensor of shape (batch, vocab_size) representing probability distributions.
Each row is a distribution across vocabulary for a batch, where:
Values range from [0, 1], indicating the probability of each vocabulary item.
The sum of values in each row is 1, forming a valid distribution.
uniform_sample : Tensor
The uniform sample.
The uniformly sampled 2-D tensor with the shape (batch, 1).
Values range from 0 to 1, indicating probabilities sampled uniformly.
Returns
-------
result : Tensor
The computed result.
The computed tensor with shape (batch, 1).
Examples
--------
Expand All @@ -2085,85 +2090,147 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):
prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
usample = [[0.4], [0.9]]
multinomial_from_uniform(a)
multinomial_from_uniform(prob, usample)
-> [[1], [2]]
"""

prob_dtype = prob.dtype
sample_dtype = uniform_sample.dtype
batch = prob.shape[0]
cumsum_prob = cumsum(prob, axis=1, exclusive=False)

@T.prim_func(private=True)
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size), "float32")
usample = T.match_buffer(B, (batch, 1), "float32")
output_index = T.match_buffer(C, (batch, 1), "int64")

for ax0 in T.parallel(batch):
for ax1 in T.parallel(vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
if v_ax1 == 0:
output_index[v_ax0, 0] = 0
else:
if not (usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1 - 1]):
output_index[v_ax0, 0] = v_ax1
prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
usample = T.match_buffer(B, (batch, 1), sample_dtype)
output_index = T.match_buffer(C, (batch, 1), dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_sample_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
if v_ax1 == 0:
output_index[v_ax0, 0] = 0
elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
output_index[v_ax0, 0] = v_ax1

return tensor_ir_op(
_get_sample_index,
"get_sample_index",
args=[cumsum_prob, uniform_sample],
out=Tensor.placeholder([batch, 1], "int64"),
out=Tensor.placeholder([batch, 1], dtype),
)


def sample_top_p_top_k_from_sorted_prob(sorted_prob, sorted_index, top_p, top_k, uniform_sample):
def sample_top_p_top_k_from_sorted_prob(
sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor
):
"""Samples indices from a sorted probability tensor based on top_p and top_k criteria.
Notes
-----
For accurate results, ensure probabilities are between 0 and 1 and sum to 1.
Parameters
----------
sorted_prob : Tensor
A 2-D tensor, with shape (batch, vocab_size), contains probabilities
sorted in descending order.
sorted_index: Tensor
The indices tensor with shape (batch, vocab_size), corresponding to the
sorted_prob. Potentially from applying argsort on the original probability
tensor in descending order.
top_p : Tensor
The cumulative probability threshold with shape (batch, 1) for nucleus sampling.
top_k :Tensor
A tensor with shape (batch, 1), representing the number of top probabilities
to consider for top-k sampling.
uniform_sample : Tensor
Uniformly sampled values with shape (batch, 1) are used to select the output indices.
Returns
-------
result : Tensor
The selected indices with shape (batch, 1).
Examples
--------
.. code-block:: python
prob = [[0.1 , 0.4, 0.5],
[0.3, 0.3, 0.4]]
sorted_prob = [[0.5, 0.4, 0.1],
[0.4, 0.3, 0.3]]
sorted_index = [[2, 1, 0],
[2, 0, 1]]
top_p = [[0.6],[0.9]]
top_k = [[3],[2]]
uniform_sample = [[0.5], [0.6]]
sample_top_p_top_k_from_sorted_prob(
sorted_prob, sorted_index,top_p, top_k, uniform_sample)
-> [2, 0]
"""
prob_dtype = sorted_prob.dtype
index_dtype = sorted_index.dtype
batch = sorted_prob.shape[0]

@T.prim_func(private=True)
def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle):
batch, vocab_size = T.int64(), T.int64()
cumsum_prob = T.match_buffer(A, (batch, vocab_size), "float32")
cumsum_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
cumsum_mask = T.match_buffer(B, (batch, vocab_size), "bool")
renorm_prob = T.match_buffer(C, (batch, 1), "float32")
for ax0 in range(batch):
for ax1 in range(vocab_size):
with T.block("T_get_renorm_prob"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if cumsum_mask[v_ax0, v_ax1] == 1 and cumsum_mask[v_ax0, v_ax1 + 1] == 0:
renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1 + 1]
if cumsum_mask[v_ax0, v_ax1] == 1 and v_ax1 + 1 == vocab_size:
renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1]
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_prob"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if cumsum_mask[v_ax0, 0] == 0:
renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, 0]
elif cumsum_mask[v_ax0, v_ax1] == 1 and cumsum_mask[v_ax0, v_ax1 + 1] == 0:
renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1 + 1]
elif cumsum_mask[v_ax0, v_ax1] == 1 and v_ax1 + 1 == vocab_size:
renorm_prob[v_ax0, 0] = cumsum_prob[v_ax0, v_ax1]

@T.prim_func(private=True)
def _get_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle):
batch, vocab_size = T.int64(), T.int64()
prob = T.match_buffer(A, (batch, vocab_size), "float32")
usample = T.match_buffer(B, (batch, 1), "float32")
indices = T.match_buffer(C, (batch, vocab_size), "int64")
output_index = T.match_buffer(D, (batch, 1), "int64")

for ax0 in T.parallel(batch):
for ax1 in T.grid(vocab_size):
with T.block("T_get_index"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
else:
if not (usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1 - 1]):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
usample = T.match_buffer(C, (batch, 1), prob_dtype)
indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
output_index = T.match_buffer(E, (batch, 1), index_dtype)

for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if (
usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[v_ax0, 0]
elif (
usample[v_ax0, T.int64(0)]
>= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0]
):
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]

batch = sorted_prob.shape[0]
cumsum_sorted = cumsum(sorted_prob, axis=1)

cumsum_mask = tensor_expr_op(
lambda cumsum_sorted, top_p, top_k: te.compute(
cumsum_sorted.shape,
lambda i, j: _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0]),
name="get_cumsum_mask_top_p_top_k",
),
"get_mask_top_p_top_k",
"get_cumsum_mask_top_p_top_k",
args=[cumsum_sorted, top_p, top_k],
)

Expand All @@ -2173,47 +2240,76 @@ def _get_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
args=[cumsum_sorted, cumsum_mask],
out=Tensor.placeholder(
[batch, 1],
"float32",
prob_dtype,
),
)
cumsum_sorted_renorm = cumsum_sorted / renorm_prob
# TODO(yongwww): fuse the division into get_index

out_index_in_sorted = tensor_ir_op(
_get_index,
"get_index",
args=[cumsum_sorted_renorm, uniform_sample, sorted_index],
out=Tensor.placeholder([batch, 1], "int64"),
_get_index_from_sorted,
"get_index_from_sorted",
args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
out=Tensor.placeholder([batch, 1], index_dtype),
)
return out_index_in_sorted


def renormalize_top_p_top_k_prob(prob, sorted_prob, sorted_index, top_p, top_k):
from tvm.script import tir as T
"""Renormalizes probabilities after filtering with top_p and top_k, ensuring
they sum up to 1.
Notes
-----
For accurate results, ensure probabilities are between 0 and 1 and sum to 1.
Parameters
----------
prob : Tensor
A 2-D tensor of shape (batch, vocab_size) representing probability distributions.
sorted_prob : Tensor
Probabilities sorted in descending order.
sorted_index : Tensor
Indices corresponding to the sorted probabilities.
top_p : Tensor
The cumulative probability threshold with shape (batch, 1) for nucleus sampling.
top_k :Tensor
A tensor with shape (batch, 1), representing the number of top probabilities
to consider for top-k sampling.
Returns
-------
result : Tensor
The filtered and nomalized tensor with the sampe shape as input prob.
"""
dtype = prob.dtype
batch = sorted_prob.shape[0]

@T.prim_func(private=True)
def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle):
batch, vocab_size = T.int64(), T.int64()
# cumsum_prob = T.match_buffer(A, (batch, vocab_size), "float32")
sorted_prob = T.match_buffer(A, (batch, vocab_size), "float32")
sorted_prob = T.match_buffer(A, (batch, vocab_size), dtype)
cumsum_mask = T.match_buffer(B, (batch, vocab_size), "bool")
cutoff = T.match_buffer(C, (batch, 1), "float32")
for ax0 in range(batch):
for ax1 in range(vocab_size):
with T.block("T_get_renorm_prob"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if cumsum_mask[v_ax0, v_ax1] == 1 and cumsum_mask[v_ax0, v_ax1 + 1] == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1]
if cumsum_mask[v_ax0, v_ax1] == 1 and v_ax1 + 1 == vocab_size:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
cutoff = T.match_buffer(C, (batch, 1), dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_prob"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if cumsum_mask[v_ax0, 0] == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
elif cumsum_mask[v_ax0, v_ax1] == 1 and cumsum_mask[v_ax0, v_ax1 + 1] == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1]
elif cumsum_mask[v_ax0, v_ax1] == 1 and v_ax1 + 1 == vocab_size:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]

batch = sorted_prob.shape[0]
cumsum_sorted = cumsum(sorted_prob, axis=1)

cumsum_mask = tensor_expr_op(
lambda cumsum_sorted, top_p, top_k: te.compute(
cumsum_sorted.shape,
lambda i, j: _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0]),
name="get_mask_top_p_top_k",
),
"get_mask_top_p_top_k",
args=[cumsum_sorted, top_p, top_k],
Expand All @@ -2225,23 +2321,18 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle):
args=[sorted_prob, cumsum_mask],
out=Tensor.placeholder(
[batch, 1],
"float32",
dtype,
),
)

new_prob = tensor_expr_op(
filtered_prob = tensor_expr_op(
lambda prob, renorm_cutoff: te.compute(
prob.shape,
lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0], prob[i, j], T.float32(0)),
lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0], prob[i, j], 0.0),
name="filter_with_top_p_top_k",
),
"get_prob_top_p_top_k",
"filter_with_top_p_top_k",
args=[prob, renorm_cutoff],
)

prob_sum = sum(new_prob, axis=1, keepdims=True)
# If the value of sum is zero, replace it with one instead to avoid NAN
safe_prob_sum = where(
equal(new_prob, Tensor.from_scalar(0, new_prob.dtype)), ones(prob_sum.shape), prob_sum
)
new_prob = new_prob / safe_prob_sum
return new_prob
renorm_prob = filtered_prob / sum(filtered_prob, axis=1, keepdims=True)
return renorm_prob

0 comments on commit 1db7d03

Please sign in to comment.