Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions operatorspy/tests/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,18 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
if(random_val < sum_s):
return indices[i]


def random_sample_0(data):
return torch.argmax(data)
def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
print(
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
)

data = torch.rand((voc), dtype=x_dtype).to(torch_device)
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
if(topp > 0 and topk > 0):
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
else:
ans = random_sample_0(data)
if(torch_device == 'mlu'):

indices = torch.zeros([1], dtype = torch.int64).to(torch_device)
Expand Down Expand Up @@ -123,8 +127,6 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
None,
)
)


assert indices[0].type(ans.dtype) == ans or abs(data[indices[0]] - data[ans]) == 0.0, "compute error"


Expand Down Expand Up @@ -164,6 +166,9 @@ def test_bang(lib, test_cases):
(512, 0.92, 0.8, 3, 0.5),
(4096, 0.95, 0.9, 5, 1.0),
(16384, 0.85, 0.85, 10, 2.0),
(512, 0.92, 0, 3, 0.5),
(4096, 0.95, 0.9, 0, 1.0),
(16384, 0.85, 0, 0, 2.0),
]

args = get_args()
Expand Down
78 changes: 68 additions & 10 deletions src/ops/random_sample/bang/random_sample_bang.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,59 @@ __mlu_global__ void random_sampleD(T const *source, uint64_t *indices, uint64_t
__memcpy(globalTopk, srcGlobal, topk * sizeof(T), NRAM2GDRAM);
}
}
template<typename T>
__mlu_global__ void random_sample(T const *source, uint64_t *indices, uint64_t *indGdram, int voc){
const uint64_t maxNum = SRC_MAX_SIZE/sizeof(T);

uint64_t taskSize = taskDim * maxNum;
uint64_t remain = voc % taskSize;
uint64_t repeat = (voc - remain) / taskSize;

uint64_t remainT = remain % taskDim;
uint64_t stepEasy = (remain - remainT) / taskDim;
uint64_t stepHard = stepEasy + 1;
uint64_t step = (taskId < remainT ? stepHard : stepEasy);
uint64_t indStart = repeat * taskSize + (taskId < remainT ? taskId * stepHard : remainT * stepHard + (taskId - remainT) * stepEasy);

T *src = (T *)nram_buffer;
T *srcMax = src + maxNum;
uint64_t index = 0;

T newMax = -INFINITY;
for(uint64_t r = 0; r < repeat; r++){
__memcpy(src, source + r * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if(newMax < srcMax[0]){
newMax = srcMax[0];
index = r * taskSize + taskId * maxNum + *((int64_t*)&srcMax[1]);
}

}
if(step){
__bang_write_value(src, maxNum, -INFINITY);
__memcpy(src, source + indStart, step * sizeof(T), GDRAM2NRAM);
__bang_argmax(srcMax, src, maxNum);
if(newMax < srcMax[0]){
newMax = srcMax[0];
index = indStart + *((int64_t*)&srcMax[1]);
}

}

indGdram[taskId] = index;
__sync_all();
if(taskId == 0){
uint64_t globalInd = indGdram[0];
T globalM = source[globalInd];
for(uint64_t id = 0; id < taskDim; id++){
if(globalM < source[indGdram[id]]){
globalM = source[indGdram[id]];
globalInd = indGdram[id];
}
}
indices[0] = globalInd;
}
}
template<typename T>
void random_sampleUnion(cnrtQueue_t queue, void *workspace, void const *source, void *indices, float random_val, float topp, int topk, float temperature, int voc) {
auto logits_ = reinterpret_cast<const T *>(source);
Expand All @@ -412,18 +464,24 @@ void random_sampleUnion(cnrtQueue_t queue, void *workspace, void const *source,
k_type = CNRT_FUNC_TYPE_UNION1;

int taskNum = k_dim.x * k_dim.y * k_dim.z;
const int maxNum = SRC_MAX_SIZE/sizeof(T);
char *origin = reinterpret_cast<char *>(workspace);
char *indTmp = origin + taskNum * topk * sizeof(uint64_t);
uint64_t *indGdram = (uint64_t *)origin;
T *globalTopk = (T *)indTmp;
T *globalSum = globalTopk + taskNum * topk;

if(voc >= taskNum * maxNum){
random_sampleD<T><<<k_dim, k_type, queue>>>(logits_, index_, indGdram, globalTopk, globalSum, random_val, topp, topk, temperature, voc);
if(topp > 0 && topk > 0){
const int maxNum = SRC_MAX_SIZE/sizeof(T);
char *origin = reinterpret_cast<char *>(workspace);
char *indTmp = origin + taskNum * topk * sizeof(uint64_t);
uint64_t *indGdram = (uint64_t *)origin;
T *globalTopk = (T *)indTmp;
T *globalSum = globalTopk + taskNum * topk;

if(voc >= taskNum * maxNum){
random_sampleD<T><<<k_dim, k_type, queue>>>(logits_, index_, indGdram, globalTopk, globalSum, random_val, topp, topk, temperature, voc);
}
else{
random_sampleX<T><<<k_dim, k_type, queue>>>(logits_, index_, indGdram, globalTopk, globalSum, random_val, topp, topk, temperature, voc);
}
}
else{
random_sampleX<T><<<k_dim, k_type, queue>>>(logits_, index_, indGdram, globalTopk, globalSum, random_val, topp, topk, temperature, voc);
uint64_t *indGdram = reinterpret_cast<uint64_t *>(workspace);
random_sample<T><<<k_dim, k_type, queue>>>(logits_, index_, indGdram, voc);
}
cnrtQueueSync(queue);

Expand Down
47 changes: 39 additions & 8 deletions src/ops/random_sample/cpu/random_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ void random_sample_cpu_f16(RandomSampleCpuDescriptor_t desc,
}
}
}
void random_sample_cpu_f16(RandomSampleCpuDescriptor_t desc,
void *workspace,
void *result,
void const *probs) {
int voc = desc->voc;
auto index_ = reinterpret_cast<uint64_t *>(result);
auto source = reinterpret_cast<const uint16_t *>(probs);

char *origin = reinterpret_cast<char *>(workspace);
uint16_t *logits_ = (uint16_t *) origin;

std::copy(source, source + voc, logits_);

float M = f16_to_f32(logits_[0]);
int index = 0;
for (int j = 1; j < voc; j++) {
if (M < f16_to_f32(logits_[j])) {
M = f16_to_f32(logits_[j]);
index = j;
}
}

index_[0] = index;
}

infiniopStatus_t cpuRandomSample(RandomSampleCpuDescriptor_t desc,
void *workspace,
Expand All @@ -139,14 +163,21 @@ infiniopStatus_t cpuRandomSample(RandomSampleCpuDescriptor_t desc,
float temperature,
void *stream) {
if (dtype_eq(desc->dtype, F16)) {
random_sample_cpu_f16(desc,
workspace,
result,
probs,
random_val,
topp,
topk,
temperature);
if (topp > 0 && topk > 0) {
random_sample_cpu_f16(desc,
workspace,
result,
probs,
random_val,
topp,
topk,
temperature);
} else {
random_sample_cpu_f16(desc,
workspace,
result,
probs);
}
return STATUS_SUCCESS;
}

Expand Down
37 changes: 23 additions & 14 deletions src/ops/random_sample/cuda/random_sample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void random_sample_workspace(size_t &size_radix_sort, size_t &size_scan,
nullptr, voc,
stream);
}
__global__ void random_sample_kernel(uint64_t *result,
uint64_t *key_out) {
result[0] = key_out[0];
}
void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace, void *result,
void const *probs,
float random_val,
Expand Down Expand Up @@ -129,23 +133,28 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
key_in, key_out,
voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上
//排序结束,然后开始做softmax变换
if (topp > 0 && topk > 0) {
int BLOCK_DIM = 1024;
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
softmax<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>(val_out, topk,
temperature, voc);

int BLOCK_DIM = 1024;
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
softmax<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>(val_out, topk,
temperature, voc);

inclusive_sum<half>(
workspace_extra, size_scan,
val_out, voc,
(cudaStream_t) stream);//该函数会实现scan功能不断累加结果
random_sample_kernel<half><<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result,
val_out,
random_val,
topp,
topk,
key_out);

inclusive_sum<half>(
workspace_extra, size_scan,
val_out, voc,
(cudaStream_t) stream);//该函数会实现scan功能不断累加结果
random_sample_kernel<half><<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result,
val_out,
random_val,
topp,
topk,
key_out);
} else {
random_sample_kernel<<<1, 1, 0, (cudaStream_t) stream>>>((uint64_t *) result,
key_out);
}
cudaFree(workspace_extra);
}

Expand Down
Loading