Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 230 additions & 86 deletions ucm/sparse/gsa/gsa.py

Large diffs are not rendered by default.

88 changes: 38 additions & 50 deletions ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp
Original file line number Diff line number Diff line change
@@ -1,58 +1,48 @@
#include "select_topk_block.h"
#include <algorithm>
#include <utility>
#include <cmath>
#include <limits>
#include <stdexcept>
#include <cmath>
#include "select_topk_block.h"
#include <utility>

namespace SelectTopkBlock {
#define OMP_THREAD_NUM 16u

bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre,
uint32_t numBlock, uint32_t kHead, uint32_t qHead,
uint32_t numKrepre, uint32_t headSize)
bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre, uint32_t numBlock,
uint32_t kHead, uint32_t qHead, uint32_t numKrepre,
uint32_t headSize)
{
return (q != nullptr) && (kRepre != nullptr) &&
(numBlock > 0) && (kHead > 0) && (qHead > 0) &&
return (q != nullptr) && (kRepre != nullptr) && (numBlock > 0) && (kHead > 0) && (qHead > 0) &&
(numKrepre > 0) && (headSize > 0);
}

void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k, int32_t* topkIndices)
void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k,
int32_t* topkIndices)
{
if (startWindow_ + endWindow_ >= numScores || k >= numScores || k == 0) {
for (uint32_t i = 0; i < numScores; ++i) {
topkIndices[i] = i;
}
for (uint32_t i = 0; i < numScores; ++i) { topkIndices[i] = i; }
return;
}
uint32_t idx = 0;
for (uint32_t i = 0; i < startWindow_; ++i) {
topkIndices[idx++] = i;
}
for (uint32_t i = 0; i < startWindow_; ++i) { topkIndices[idx++] = i; }
for (uint32_t i = 0; i < endWindow_; ++i) { topkIndices[idx++] = numScores - endWindow_ + i; }
int32_t midCount = k - startWindow_ - endWindow_;
if (midCount > 0) {
std::vector<uint32_t> middleIndices;
middleIndices.reserve(numScores - startWindow_ - endWindow_);
for (uint32_t i = startWindow_; i < numScores - endWindow_; ++i) {
middleIndices.push_back(i);
}
std::stable_sort(middleIndices.begin(), middleIndices.end(),
[scores](uint32_t lhs, uint32_t rhs) {
return scores[lhs] > scores[rhs];
});
for (int32_t i = 0; i < midCount; ++i) {
topkIndices[idx++] = middleIndices[i];
}
}
for (uint32_t i = 0; i < endWindow_; ++i) {
topkIndices[idx++] = numScores - endWindow_ + i;
std::stable_sort(
middleIndices.begin(), middleIndices.end(),
[scores](uint32_t lhs, uint32_t rhs) { return scores[lhs] > scores[rhs]; });
for (int32_t i = 0; i < midCount; ++i) { topkIndices[idx++] = middleIndices[i]; }
}
std::sort(topkIndices, topkIndices + k);
}

float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase,
uint32_t kHead, uint32_t numKrepre,
uint32_t headSize, const VecProductClass& vecProduct)
float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, uint32_t kHead,
uint32_t numKrepre, uint32_t headSize,
const VecProductClass& vecProduct)
{
const size_t headOffset = headSize;
const size_t normOffset = headSize;
Expand Down Expand Up @@ -81,8 +71,10 @@ const VecProductClass& TopkBlockSelector::ThreadLocalVecProduct::GetInstance()
return instance;
}

std::vector<float> TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean, const float* __restrict kRepre,
uint32_t numBlock, uint32_t kHead, uint32_t numKrepre, uint32_t headSize)
std::vector<float> TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean,
const float* __restrict kRepre,
uint32_t numBlock, uint32_t kHead,
uint32_t numKrepre, uint32_t headSize)
{
std::vector<float> blockScores(numBlock, 0.0f);
const size_t blockOffset = static_cast<size_t>(kHead * numKrepre * headSize);
Expand All @@ -93,16 +85,16 @@ std::vector<float> TopkBlockSelector::ComputeKQDotScores(const float* __restrict
if (idxBlock + 1 < numBlock) {
__builtin_prefetch(kRepre + (idxBlock + 1) * blockOffset, 0, 1);
}
blockScores[idxBlock] = ComputeBlockScore(const_cast<float*>(qMean), blockBase, kHead, numKrepre, headSize, vecProduct);
blockScores[idxBlock] = ComputeBlockScore(const_cast<float*>(qMean), blockBase, kHead,
numKrepre, headSize, vecProduct);
}
return blockScores;
}

void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead, uint32_t headSize)
void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead,
uint32_t headSize)
{
if (kHead == qHead) {
return;
}
if (kHead == qHead) { return; }
const VecProductClass& vecProduct = ThreadLocalVecProduct::GetInstance();
const uint32_t groupSize = qHead / kHead;
for (uint32_t kIdx = 0; kIdx < kHead; ++kIdx) {
Expand All @@ -113,28 +105,25 @@ void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, ui
}
}

void TopkBlockSelector::SelectTopK(float* q, const float* kRepre,
uint32_t numBlock, uint32_t kHead, uint32_t qHead,
uint32_t numKrepre, uint32_t headSize,
void TopkBlockSelector::SelectTopK(float* q, const float* kRepre, uint32_t numBlock, uint32_t kHead,
uint32_t qHead, uint32_t numKrepre, uint32_t headSize,
uint32_t topkLength, int32_t* topkResult)
{
if (!ValidateParameters(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize) ||
topkResult == nullptr || topkLength == 0) {
return;
return;
}
ComputeQHeadMean(q, kHead, qHead, headSize);
const std::vector<float> scores = ComputeKQDotScores(q, kRepre, numBlock,
kHead, numKrepre, headSize);
const std::vector<float> scores =
ComputeKQDotScores(q, kRepre, numBlock, kHead, numKrepre, headSize);
TopKImpl(scores.data(), numBlock, topkLength, topkResult);
}

void TopkBlockSelector::SelectTopKBS(const std::vector<float*>& qCacheVec,
const std::vector<const float*>& kfCacheVec,
const std::vector<int32_t*>& topkCacheVec,
uint32_t numBatch,
const std::vector<uint32_t>& numBlockVec,
uint32_t kHead, uint32_t qHead,
uint32_t numKrepre, uint32_t headSize,
const std::vector<int32_t*>& topkCacheVec, uint32_t numBatch,
const std::vector<uint32_t>& numBlockVec, uint32_t kHead,
uint32_t qHead, uint32_t numKrepre, uint32_t headSize,
const std::vector<uint32_t>& topkLengthVec)
{
for (uint32_t bs = 0; bs < numBatch; ++bs) {
Expand All @@ -143,9 +132,8 @@ void TopkBlockSelector::SelectTopKBS(const std::vector<float*>& qCacheVec,
float* q = qCacheVec[bs];
const float* kRepre = kfCacheVec[bs];
int32_t* topkResult = topkCacheVec[bs];
SelectTopK(q, kRepre, numBlock, kHead, qHead,
numKrepre, headSize, topkLength, topkResult);
SelectTopK(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize, topkLength, topkResult);
}
}

}
} // namespace SelectTopkBlock
88 changes: 36 additions & 52 deletions ucm/sparse/gsa/prefetch/include/kvcache_log.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
#ifndef ATB_KV_LOG_H
#define ATB_KV_LOG_H
#include <iostream>
#include <fstream>
#include <string>
#include <ctime>
#include <mutex>
#include <sstream>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <mutex>
#include <omp.h>
enum class LogLevel {
DEBUG,
INFO,
WARNING,
ERROR
};
#include <sstream>
#include <stdarg.h>
#include <string>
enum class LogLevel { DEBUG, INFO, WARNING, ERROR };

class Logger
{
class Logger {
private:
std::ofstream mLogFile;
LogLevel mMinLevel;
Expand All @@ -25,31 +20,30 @@ class Logger

static std::string LevelToString(LogLevel level)
{
switch (level)
{
case LogLevel::DEBUG: return "DEBUG";
case LogLevel::INFO: return "INFO";
case LogLevel::WARNING: return "WARNING";
case LogLevel::ERROR: return "ERROR";
default: return "UNKNOWN";
switch (level) {
case LogLevel::DEBUG: return "DEBUG";
case LogLevel::INFO: return "INFO";
case LogLevel::WARNING: return "WARNING";
case LogLevel::ERROR: return "ERROR";
default: return "UNKNOWN";
}
}

static std::string GetTimesTamp()
{
auto now = std::chrono::system_clock::now();
auto nowC = std::chrono::system_clock::to_time_t(now);
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch()) % 1000;
auto ms =
std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()) % 1000;
std::stringstream oss;
oss << std::put_time(std::localtime(&nowC), "%Y-%m-%d %H:%M:%S");
oss << '.' << std::setfill('0') << std::setw(3) << ms.count();
return oss.str();
}

public:
Logger(const std::string &fileName, LogLevel level = LogLevel::INFO, bool enable = true)
:mMinLevel(level), mEnable(enable)
Logger(const std::string& fileName, LogLevel level = LogLevel::INFO, bool enable = true)
: mMinLevel(level), mEnable(enable)
{
if (enable) {
mLogFile.open(fileName, std::ios::app);
Expand All @@ -59,62 +53,52 @@ class Logger
}
}

Logger(){}
Logger() {}

~Logger()
{
if (mLogFile.is_open()) {
mLogFile.close();
}
if (mLogFile.is_open()) { mLogFile.close(); }
}

void SetLevel(LogLevel level)
{
mMinLevel = level;
}
void SetLevel(LogLevel level) { mMinLevel = level; }

void log(LogLevel level, const char* format, ...)
{
if (level < mMinLevel || !mLogFile.is_open() || !mEnable) {
return;
}
if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; }
std::lock_guard<std::mutex> lock(mMutex);
auto now = std::chrono::system_clock::now();
auto nowC = std::chrono::system_clock::to_time_t(now);
auto duration = now.time_since_epoch();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() % 1000;
auto micros = std::chrono::duration_cast<std::chrono::microseconds>(duration).count() % 1000;
auto millis =
std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() % 1000;
auto micros =
std::chrono::duration_cast<std::chrono::microseconds>(duration).count() % 1000;

std::tm localTime = *std::localtime(&nowC);
char timeBuffer[26];
std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", &localTime);
const char *levelStr = "";
switch (level)
{
case LogLevel::DEBUG: levelStr = "DEBUG"; break;
case LogLevel::INFO: levelStr = "INFO"; break;
case LogLevel::WARNING: levelStr = "WARNING"; break;
case LogLevel::ERROR: levelStr = "ERROR"; break;
default: levelStr = "UNKNOWN"; break;
const char* levelStr = "";
switch (level) {
case LogLevel::DEBUG: levelStr = "DEBUG"; break;
case LogLevel::INFO: levelStr = "INFO"; break;
case LogLevel::WARNING: levelStr = "WARNING"; break;
case LogLevel::ERROR: levelStr = "ERROR"; break;
default: levelStr = "UNKNOWN"; break;
}
char messageBuffer[4096];
va_list args;
va_start(args, format);
vsnprintf(messageBuffer, sizeof(messageBuffer), format, args);
va_end(args);

mLogFile << timeBuffer << "."
<< std::setfill('0') << std::setw(3) << millis << std::setw(3)
<< micros << " " << "[" << levelStr << "]"
<< messageBuffer;
mLogFile << timeBuffer << "." << std::setfill('0') << std::setw(3) << millis << std::setw(3)
<< micros << " " << "[" << levelStr << "]" << messageBuffer;
mLogFile.flush();
}

void LogWOPrefix(LogLevel level, const char* format, ...)
{
if (level < mMinLevel || !mLogFile.is_open() || !mEnable) {
return;
}
if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; }
std::lock_guard<std::mutex> lock(mMutex);
char messageBuffer[2048];
va_list args;
Expand Down
Loading
Loading