Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -975,12 +975,17 @@ export type InitProgressCallback = (report: InitProgressReport) => void;
* Cache to store model related data.
*/
export class ArtifactCache {
private scope: string;
private cache?: Cache;

constructor(scope: string) {
this.scope = scope;
}

async fetchWithCache(url: string) {
const request = new Request(url);
if (this.cache === undefined) {
this.cache = await caches.open("tvmjs");
this.cache = await caches.open(this.scope);
}
let result = await this.cache.match(request);
if (result === undefined) {
Expand Down Expand Up @@ -1020,7 +1025,6 @@ export class Instance implements Disposable {
private objFactory: Map<number, FObjectConstructor>;
private ctx: RuntimeContext;
private initProgressCallback: Array<InitProgressCallback> = [];
private artifactCache = new ArtifactCache();

/**
* Internal function(registered by the runtime)
Expand Down Expand Up @@ -1416,19 +1420,25 @@ export class Instance implements Disposable {
*
* @param ndarrayCacheUrl The cache url.
* @param device The device to be fetched to.
* @param cacheScope The scope identifier of the cache
* @returns The meta data
*/
async fetchNDArrayCache(ndarrayCacheUrl: string, device: DLDevice) : Promise<any> {
async fetchNDArrayCache(
ndarrayCacheUrl: string,
device: DLDevice,
cacheScope: string = "tvmjs"
): Promise<any> {
const artifactCache = new ArtifactCache(cacheScope);
const jsonUrl = new URL("ndarray-cache.json", ndarrayCacheUrl).href;
const result = await this.artifactCache.fetchWithCache(jsonUrl);
const result = await artifactCache.fetchWithCache(jsonUrl);

let list;
if (result instanceof Response) {
list = await result.json();
}
await this.fetchNDArrayCacheInternal(
ndarrayCacheUrl,
list["records"] as Array<NDArrayShardEntry>, device);
list["records"] as Array<NDArrayShardEntry>, device, artifactCache);
this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record<string, any>) };
}

Expand All @@ -1438,8 +1448,14 @@ export class Instance implements Disposable {
* @param ndarrayCacheUrl The cache url.
* @param list The list of array data.
* @param device The device to store the data to.
* @param artifactCache The artifact cache
*/
private async fetchNDArrayCacheInternal(ndarrayCacheUrl: string, list: Array<NDArrayShardEntry>, device: DLDevice) {
private async fetchNDArrayCacheInternal(
ndarrayCacheUrl: string,
list: Array<NDArrayShardEntry>,
device: DLDevice,
artifactCache: ArtifactCache
) {
const perf = compact.getPerformance();
let tstart = perf.now();

Expand Down Expand Up @@ -1481,7 +1497,7 @@ export class Instance implements Disposable {
const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href;
let buffer;
try {
buffer = await (await this.artifactCache.fetchWithCache(dataUrl)).arrayBuffer();
buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer();
} catch (err) {
this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
throw err;
Expand Down