diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index ffb5011324f5..da9aaddfb0d6 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -25,6 +25,11 @@ export interface ArtifactCacheTemplate { */ fetchWithCache(url: string); + /** + * add ey url to cache + */ + addToCache(url: string); + /** * check if cache has all keys in Cache */ diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 8df48c43a5f9..ea022d1b3e9d 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -165,6 +165,7 @@ class RuntimeContext implements Disposable { makeShapeTuple: PackedFunc; ndarrayCreateView: PackedFunc; sampleTopPFromLogits: PackedFunc; + sampleTopPFromProb: PackedFunc; applyRepetitionPenalty: PackedFunc; applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; @@ -188,6 +189,7 @@ class RuntimeContext implements Disposable { this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); + this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); @@ -1020,6 +1022,17 @@ export class ArtifactCache implements ArtifactCacheTemplate { return result; } + async addToCache(url: string) { + const request = new Request(url); + if (this.cache === undefined) { + this.cache = await caches.open(this.scope); + } + const result = await this.cache.match(request); + if (result === undefined) { + await this.cache.add(request); + } + } + async hasAllKeys(keys: string[]) { if (this.cache === undefined) { this.cache = await caches.open(this.scope); @@ -1534,20 +1547,24 @@ export class Instance implements Disposable { const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href)) - const reportCallback = (iter: number) => { + const reportCallback = (iter: number, loading = false) => { // report for (let j = 0; j < this.initProgressCallback.length; ++j) { - let text = "Fetching param cache[" + iter + "/" + list.length + "]: "; - text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. " - text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " - text += timeElapsed + " secs elapsed."; - text += " It can take a while when we first visit this page to populate the cache." - text += " Later refreshes will become faster."; - if (cacheOnly) { + let text: string; + if (loading) { + text = "Finished fetching params, loading onto WebGPU."; + } else if (cacheOnly) { text = "Loading model from cache[" + iter + "/" + list.length + "]: "; text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. " text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " text += timeElapsed + " secs elapsed."; + } else { + text = "Fetching param cache[" + iter + "/" + list.length + "]: "; + text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. " + text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " + text += timeElapsed + " secs elapsed."; + text += " It can take a while when we first visit this page to populate the cache." + text += " Later refreshes will become faster."; } this.initProgressCallback[j]({ progress: fetchedBytes / totalBytes, @@ -1567,7 +1584,35 @@ export class Instance implements Disposable { }); } - const processShard = async (i: number) => { + // First download all shards to cache parallely if not yet in cache + const downloadCache = async (start: number, end: number) => { + // Download params [start, end) from `list` + for (let i = start; i < end; i++) { + const shard = list[i]; + const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; + try { + await artifactCache.addToCache(dataUrl); + } catch (err) { + this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); + throw err; + } + timeElapsed = Math.ceil((perf.now() - tstart) / 1000); + fetchedBytes += shard.nbytes; + reportCallback(fetchedShards++); + } + } + // We launch 4 parallel for loops to limit the max concurrency to 4 download + const loopSize = Math.floor(list.length / 4); + await Promise.all([ + downloadCache(0, loopSize), + downloadCache(loopSize, 2 * loopSize), + downloadCache(2 * loopSize, 3 * loopSize), + downloadCache(3 * loopSize, list.length) + ]); + reportCallback(list.length, /*loading=*/true); + + // Then iteratively, load the shard from cache + for (let i = 0; i < list.length; ++i) { const shard = list[i]; const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href; let buffer; @@ -1579,39 +1624,42 @@ export class Instance implements Disposable { } const shardRecords = shard.records; for (let j = 0; j < shardRecords.length; ++j) { - const rec = shardRecords[j]; - const cpu_arr = this.withNewScope(() => { - return this.detachFromCurrentScope( - this.empty(rec.shape, rec.dtype, this.cpu()) - ) - }); - const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); - // first sync copy to cpu. - this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); - // then async stream into GPU if needed - if (device.deviceType === DeviceStrToEnum.cpu) { - this.ndarrayCacheUpdate(rec.name, cpu_arr, false); - cpu_arr.dispose(); - } else { - // allocate a gpu arr and async copy to it. - const gpu_arr = this.withNewScope(() => { + try { + const rec = shardRecords[j]; + const cpu_arr = this.withNewScope(() => { return this.detachFromCurrentScope( - this.empty(rec.shape, rec.dtype, device) + this.empty(rec.shape, rec.dtype, this.cpu()) ) }); - gpu_arr.copyFrom(cpu_arr); - await device.sync(); - this.ndarrayCacheUpdate(rec.name, gpu_arr, false); - cpu_arr.dispose(); - gpu_arr.dispose(); + const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + // first sync copy to cpu. + this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); + // then async stream into GPU if needed + if (device.deviceType === DeviceStrToEnum.cpu) { + this.ndarrayCacheUpdate(rec.name, cpu_arr, false); + cpu_arr.dispose(); + } else { + // allocate a gpu arr and async copy to it. + const gpu_arr = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.empty(rec.shape, rec.dtype, device) + ) + }); + gpu_arr.copyFrom(cpu_arr); + await device.sync(); + this.ndarrayCacheUpdate(rec.name, gpu_arr, false); + cpu_arr.dispose(); + gpu_arr.dispose(); + } + } catch (err) { + this.env.logger( + "Failed to load shard " + i + "'s record: " + JSON.stringify(shardRecords[j]) + "\n" + + "Error: " + err + ); + throw err; } } - timeElapsed = Math.ceil((perf.now() - tstart) / 1000); - fetchedBytes += shard.nbytes; - reportCallback(fetchedShards++); } - await Promise.all(list.map((_, index) => processShard(index))); - reportCallback(list.length); } /** @@ -1780,6 +1828,17 @@ export class Instance implements Disposable { return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random()); } + /** + * Sample index via top-p sampling. + * + * @param prob The distribution, i.e. logits after `applySoftmaxWithTemperature()` is performed. + * @param top_p The top_p + * @returns The sampled index. + */ + sampleTopPFromProb(prob: NDArray, top_p: number): number { + return this.ctx.sampleTopPFromProb(prob, top_p, Math.random()); + } + /** * Apply repetition penalty to the logits. * @param logits The input logits before penalty. @@ -2549,7 +2608,7 @@ export async function deleteNDArrayCache( const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href; const result = await artifactCache.fetchWithCache(jsonUrl); let list; - if (result instanceof Response){ + if (result instanceof Response) { list = await result.json(); } const arrayentry = list["records"] as Array;