Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Web] Seperate parallel shard download and iterative shard loading #16650

Merged
merged 8 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions web/src/artifact_cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ export interface ArtifactCacheTemplate {
*/
fetchWithCache(url: string);

/**
* add ey url to cache
*/
addToCache(url: string);

/**
* check if cache has all keys in Cache
*/
Expand Down
133 changes: 96 additions & 37 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class RuntimeContext implements Disposable {
makeShapeTuple: PackedFunc;
ndarrayCreateView: PackedFunc;
sampleTopPFromLogits: PackedFunc;
sampleTopPFromProb: PackedFunc;
applyRepetitionPenalty: PackedFunc;
applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
Expand All @@ -188,6 +189,7 @@ class RuntimeContext implements Disposable {
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob");
this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
Expand Down Expand Up @@ -1020,6 +1022,17 @@ export class ArtifactCache implements ArtifactCacheTemplate {
return result;
}

async addToCache(url: string) {
const request = new Request(url);
if (this.cache === undefined) {
this.cache = await caches.open(this.scope);
}
const result = await this.cache.match(request);
if (result === undefined) {
await this.cache.add(request);
}
}

async hasAllKeys(keys: string[]) {
if (this.cache === undefined) {
this.cache = await caches.open(this.scope);
Expand Down Expand Up @@ -1534,20 +1547,24 @@ export class Instance implements Disposable {

const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href))

const reportCallback = (iter: number) => {
const reportCallback = (iter: number, loading = false) => {
// report
for (let j = 0; j < this.initProgressCallback.length; ++j) {
let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
text += " It can take a while when we first visit this page to populate the cache."
text += " Later refreshes will become faster.";
if (cacheOnly) {
let text: string;
if (loading) {
text = "Finished fetching params, loading onto WebGPU.";
} else if (cacheOnly) {
text = "Loading model from cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
} else {
text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
text += " It can take a while when we first visit this page to populate the cache."
text += " Later refreshes will become faster.";
}
this.initProgressCallback[j]({
progress: fetchedBytes / totalBytes,
Expand All @@ -1567,7 +1584,35 @@ export class Instance implements Disposable {
});
}

const processShard = async (i: number) => {
// First download all shards to cache parallely if not yet in cache
const downloadCache = async (start: number, end: number) => {
// Download params [start, end) from `list`
for (let i = start; i < end; i++) {
const shard = list[i];
const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
try {
await artifactCache.addToCache(dataUrl);
} catch (err) {
this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
throw err;
}
timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
fetchedBytes += shard.nbytes;
reportCallback(fetchedShards++);
}
}
// We launch 4 parallel for loops to limit the max concurrency to 4 download
const loopSize = Math.floor(list.length / 4);
await Promise.all([
downloadCache(0, loopSize),
downloadCache(loopSize, 2 * loopSize),
downloadCache(2 * loopSize, 3 * loopSize),
downloadCache(3 * loopSize, list.length)
]);
reportCallback(list.length, /*loading=*/true);

// Then iteratively, load the shard from cache
for (let i = 0; i < list.length; ++i) {
const shard = list[i];
const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
let buffer;
Expand All @@ -1579,39 +1624,42 @@ export class Instance implements Disposable {
}
const shardRecords = shard.records;
for (let j = 0; j < shardRecords.length; ++j) {
const rec = shardRecords[j];
const cpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.empty(rec.shape, rec.dtype, this.cpu())
)
});
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
// first sync copy to cpu.
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
// then async stream into GPU if needed
if (device.deviceType === DeviceStrToEnum.cpu) {
this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
cpu_arr.dispose();
} else {
// allocate a gpu arr and async copy to it.
const gpu_arr = this.withNewScope(() => {
try {
const rec = shardRecords[j];
const cpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.empty(rec.shape, rec.dtype, device)
this.empty(rec.shape, rec.dtype, this.cpu())
)
});
gpu_arr.copyFrom(cpu_arr);
await device.sync();
this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
cpu_arr.dispose();
gpu_arr.dispose();
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
// first sync copy to cpu.
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
// then async stream into GPU if needed
if (device.deviceType === DeviceStrToEnum.cpu) {
this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
cpu_arr.dispose();
} else {
// allocate a gpu arr and async copy to it.
const gpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.empty(rec.shape, rec.dtype, device)
)
});
gpu_arr.copyFrom(cpu_arr);
await device.sync();
this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
cpu_arr.dispose();
gpu_arr.dispose();
}
} catch (err) {
this.env.logger(
"Failed to load shard " + i + "'s record: " + JSON.stringify(shardRecords[j]) + "\n" +
"Error: " + err
);
throw err;
}
}
timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
fetchedBytes += shard.nbytes;
reportCallback(fetchedShards++);
}
await Promise.all(list.map((_, index) => processShard(index)));
reportCallback(list.length);
}

/**
Expand Down Expand Up @@ -1780,6 +1828,17 @@ export class Instance implements Disposable {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
}

/**
* Sample index via top-p sampling.
*
* @param prob The distribution, i.e. logits after `applySoftmaxWithTemperature()` is performed.
* @param top_p The top_p
* @returns The sampled index.
*/
sampleTopPFromProb(prob: NDArray, top_p: number): number {
return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
}

/**
* Apply repetition penalty to the logits.
* @param logits The input logits before penalty.
Expand Down Expand Up @@ -2549,7 +2608,7 @@ export async function deleteNDArrayCache(
const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href;
const result = await artifactCache.fetchWithCache(jsonUrl);
let list;
if (result instanceof Response){
if (result instanceof Response) {
list = await result.json();
}
const arrayentry = list["records"] as Array<NDArrayShardEntry>;
Expand Down
Loading