Skip to content

Commit

Permalink
Parallel Download, Move ArtifactCache to Interface to support future …
Browse files Browse the repository at this point in the history
…different cache types, fix README path typo, Support delete and batch delete

Co-authored-by: DavidGOrtega <g.ortega.david@gmail.com>
  • Loading branch information
DiegoCao and DavidGOrtega committed Feb 18, 2024
1 parent f21c5a4 commit 70911c5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 11 deletions.
2 changes: 1 addition & 1 deletion web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ Right now we use the SPIRV to generate shaders that can be accepted by Chrome an
- Firefox should be close pending the support of Fence.
- Download vulkan SDK (1.1 or higher) that supports SPIRV 1.3
- Start the WebSocket RPC
- run `python tests/node/webgpu_rpc_test.py`
- run `python tests/python/webgpu_rpc_test.py`
4 changes: 2 additions & 2 deletions web/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions web/src/artifact_cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
Common Interface for the artifact cache
*/
export interface ArtifactCacheTemplate {
/**
* fetch key url from cache
*/
fetchWithCache(url: string);

/**
* check if cache has all keys in Cache
*/
hasAllKeys(keys: string[]);

/**
* Delete url in cache if url exists
*/
deleteInCache(url: string);
}
2 changes: 1 addition & 1 deletion web/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export {
PackedFunc, Module, NDArray,
TVMArray, TVMObject, VirtualMachine,
InitProgressCallback, InitProgressReport,
ArtifactCache, Instance, instantiate, hasNDArrayInCache
ArtifactCache, Instance, instantiate, hasNDArrayInCache, deleteNDArrayCache
} from "./runtime";
export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
Expand Down
51 changes: 44 additions & 7 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { Memory, CachedCallStack } from "./memory";
import { assert, StringToUint8Array } from "./support";
import { Environment } from "./environment";
import { FunctionInfo, WebGPUContext } from "./webgpu";
import { ArtifactCacheTemplate } from "./artifact_cache";

import * as compact from "./compact";
import * as ctypes from "./ctypes";
Expand Down Expand Up @@ -985,7 +986,7 @@ export type InitProgressCallback = (report: InitProgressReport) => void;
/**
* Cache to store model related data.
*/
export class ArtifactCache {
export class ArtifactCache implements ArtifactCacheTemplate {
private scope: string;
private cache?: Cache;

Expand Down Expand Up @@ -1018,6 +1019,14 @@ export class ArtifactCache {
.then(cacheKeys => keys.every(key => cacheKeys.indexOf(key) !== -1))
.catch(err => false);
}

async deleteInCache(url: string) {
if (this.cache === undefined) {
this.cache = await caches.open(this.scope);
}
const result = await this.cache.delete(url);
return result;
}
}

/**
Expand Down Expand Up @@ -1451,7 +1460,7 @@ export class Instance implements Disposable {
}

/**
* Fetch NDArray cache from url.
* Given cacheUrl, search up items to fetch based on cacheUrl/ndarray-cache.json
*
* @param ndarrayCacheUrl The cache url.
* @param device The device to be fetched to.
Expand All @@ -1477,6 +1486,7 @@ export class Instance implements Disposable {
this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record<string, any>) };
}


/**
* Fetch list of NDArray into the NDArrayCache.
*
Expand All @@ -1489,7 +1499,7 @@ export class Instance implements Disposable {
ndarrayCacheUrl: string,
list: Array<NDArrayShardEntry>,
device: DLDevice,
artifactCache: ArtifactCache
artifactCache: ArtifactCacheTemplate
) {
const perf = compact.getPerformance();
const tstart = perf.now();
Expand Down Expand Up @@ -1536,18 +1546,19 @@ export class Instance implements Disposable {
});
}

for (let i = 0; i < list.length; ++i) {
const processShard = async (i: number) => {
reportCallback(i);
fetchedBytes += list[i].nbytes;
const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href;
const shard = list[i];
fetchedBytes += shard.nbytes;
const dataUrl = new URL(shard.dataPath, ndarrayCacheUrl).href;
let buffer;
try {
buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer();
} catch (err) {
this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
throw err;
}
const shardRecords = list[i].records;
const shardRecords = shard.records;
for (let j = 0; j < shardRecords.length; ++j) {
const rec = shardRecords[j];
const cpu_arr = this.withNewScope(() => {
Expand Down Expand Up @@ -1578,6 +1589,7 @@ export class Instance implements Disposable {
}
timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
}
await Promise.all(list.map((_, index) => processShard(index)));
reportCallback(list.length);
}

Expand Down Expand Up @@ -2432,3 +2444,28 @@ export async function hasNDArrayInCache(
list = list["records"] as Array<NDArrayShardEntry>;
return await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href));
}

/**
* Given cacheUrl, search up items to delete based on cacheUrl/ndarray-cache.json
*
* @param cacheUrl
* @param cacheScope
*/
export async function deleteNDArrayCache(
cacheUrl: string,
cacheScope = "tvmjs"
) {
const artifactCache = new ArtifactCache(cacheScope);
const jsonUrl = new URL("ndarray-cache.json", cacheUrl).href;
const result = await artifactCache.fetchWithCache(jsonUrl);
let list;
if (result instanceof Response){
list = await result.json();
}
const arrayentry = list["records"] as Array<NDArrayShardEntry>;
const processShard = async (i: number) => {
const dataUrl = new URL(arrayentry[i].dataPath, cacheUrl).href;
await artifactCache.deleteInCache(dataUrl);
}
await Promise.all(arrayentry.map((_, index) => processShard(index)));
}

0 comments on commit 70911c5

Please sign in to comment.