Skip to content

Commit

Permalink
Adding amlClient#predict
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwanthkumar-avalara committed Jul 6, 2020
1 parent e9c2af5 commit f42d0db
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
40 changes: 40 additions & 0 deletions frontend/gcloud-apis/aml.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ export type Model = {
export type ListModelResponse = {
model: Array<Model>,
}

export type PredictResponse = {
payload: Array<{
annotationSpecId: string,
classification?: {
score: number,
},
displayName: string,
}>
}
export class AutoMLClient extends BaseClient {

constructor(settings: UseSettingsHook, endpoint?: string) {
Expand Down Expand Up @@ -151,4 +161,34 @@ export class AutoMLClient extends BaseClient {
}
}

async predict(projectId: string, modelId: string, imageAsBlob: Blob, scoreThreshold: number = 0.5): Promise<PredictResponse> {
const converToBase64 = (input) => new Promise<[string, string]>(function (resolve, reject) {
const reader = new FileReader();
reader.onload = () => {
const base64 = 'base64';
const dataUri = reader.result;
const pos = (dataUri as string).indexOf(base64);
const base64Payload = dataUri.slice(pos + base64.length + 1)
resolve([base64Payload as string, dataUri as string]);
};
reader.onerror = reject;
reader.readAsDataURL(input);
});

const [imageAsBase64, dataUri] = await converToBase64(imageAsBlob);
const payload = {
payload: {
image: {
imageBytes: imageAsBase64
},
},
params: {
score_threshold: scoreThreshold.toString()
}
};

const response = await this._makeRequestPost(`/v1/projects/${projectId}/locations/us-central1/models/${modelId}:predict`, payload);
URL.revokeObjectURL(dataUri);
return response;
}
}
3 changes: 2 additions & 1 deletion frontend/gcloud-apis/base.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export abstract class BaseClient {

protected async _makeRequestPost(resource, body) {
const accessToken = await this.accessToken();
const requestBody = JSON.stringify(body);

const response = await fetch(`${this.endpoint}${resource}`, {
method: 'POST',
Expand All @@ -51,7 +52,7 @@ export abstract class BaseClient {
},
redirect: 'follow',
referrerPolicy: 'no-referrer',
body: JSON.stringify(body),
body: requestBody,
});

return this.handleResponse(response);
Expand Down

0 comments on commit f42d0db

Please sign in to comment.