Skip to content

Commit

Permalink
feat: Support for Pose models from Teachable Machine
Browse files Browse the repository at this point in the history
Teachable Machine requires TFJS v1, where the rest of ML4K uses
TFJS v2, so I'm moving the pose model support to a separate
iframe so I can run it on-demand.

Signed-off-by: Dale Lane <dale.lane@uk.ibm.com>
  • Loading branch information
dalelane committed Dec 29, 2020
1 parent 3082ee4 commit 02c38b9
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 5 deletions.
3 changes: 2 additions & 1 deletion Gruntfile.js
Expand Up @@ -143,7 +143,8 @@ module.exports = function(grunt) {
expand : true,
cwd : 'public/scratch-components',
src : ['help-scratch3*',
'help-scratch.css'],
'help-scratch.css',
'teachablemachinepose.html'],
dest : 'web/scratch3'
},
indexhtml : {
Expand Down
3 changes: 2 additions & 1 deletion gulpfile.js
Expand Up @@ -165,7 +165,8 @@ gulp.task('scratch3install', gulp.series('crossdomain', function() {
return gulp.src([
'public/scratch3/**',
'public/scratch-components/help-scratch3*',
'public/scratch-components/help-scratch.css'
'public/scratch-components/help-scratch.css',
'public/scratch-components/teachablemachinepose.html'
]).pipe(gulp.dest('web/scratch3'));
}));

Expand Down
2 changes: 1 addition & 1 deletion public/components/pretrained/tfjs.tmpl.html
Expand Up @@ -22,7 +22,7 @@ <h2>Use a pre-trained TensorFlow model in Scratch</h2>
<div style="padding: 1em;">
<div>What type of model is it?</div>
<md-select name="modeltypeid" ng-model="modeltypeid" ng-change="generateScratchKey(modeljson, modeltypeid)" style="padding-left: 20px;">
<md-option value="10">Teachable Machine (image)</md-option>
<md-option value="10">Teachable Machine (image or pose)</md-option>
<md-option value="11">TensorFlow GraphDef (image)</md-option>
</md-select>
</div>
Expand Down
41 changes: 41 additions & 0 deletions public/scratch-components/teachablemachinepose.html
@@ -0,0 +1,41 @@
<html>
<head>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.3.1/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@teachablemachine/pose@0.8/dist/teachablemachine-pose.min.js"></script>
<script type="text/javascript">
const posemodels = {};

function initModel(id, modelurl, metadataurl) {
return tmPose.load(modelurl, metadataurl)
.then(function (model) {
posemodels[id] = model;
});
}

function predict(id, imageelem) {
return posemodels[id].estimatePose(imageelem)
.then(function (resp) {
return posemodels[id].predict(resp.posenetOutput);
})
.then(function (output) {
return output.map(function (tmitem) {
return {
class_name : tmitem.className,
confidence : 100 * tmitem.probability
};
});
});
}

function createImage(b64imagedata, callback) {
const imageElement = document.createElement('img');
imageElement.width = 257;
imageElement.height = 257;
imageElement.onload = function () {
callback(imageElement);
};
imageElement.src = 'data:image/jpeg;base64,' + b64imagedata;
}
</script>
</head>
</html>
4 changes: 4 additions & 0 deletions src/lib/scratchx/extensions.ts
Expand Up @@ -215,6 +215,10 @@ export async function getScratchTfjsExtension(scratchkey: string): Promise<strin
const modelinfo = await scratchtfjs.getModelInfoFromScratchKey(scratchkey);
const metadata = await scratchtfjs.getMetadata(modelinfo);

if (metadata.packageName === '@teachablemachine/pose') {
modelinfo.modeltype = 'teachablemachinepose';
}

const template: string = await fileutils.read('./resources/scratch3-tfjs-classify.js');

Mustache.parse(template);
Expand Down
5 changes: 5 additions & 0 deletions src/lib/scratchx/scratchtfjs.ts
Expand Up @@ -61,6 +61,8 @@ function getModelTypeAsId(type: ScratchTypes.ScratchTfjsModelType): ScratchTypes
return 10;
case 'graphdefimage':
return 11;
case 'teachablemachinepose':
return 12;
default:
return 99;
}
Expand All @@ -72,6 +74,9 @@ function getModelTypeFromId(id: ScratchTypes.ScratchTfjsModelTypeId): ScratchTyp
else if (id === 11) {
return 'graphdefimage';
}
else if (id === 12) {
return 'teachablemachinepose';
}
else {
return 'unknown';
}
Expand Down
5 changes: 3 additions & 2 deletions src/lib/scratchx/scratchx-types.ts
Expand Up @@ -9,9 +9,10 @@ export interface Key {
}

export type ScratchTfjsModelType = 'teachablemachineimage' |
'teachablemachinepose' |
'graphdefimage' |
'unknown';
export type ScratchTfjsModelTypeId = 10 | 11 | 99;
export type ScratchTfjsModelTypeId = 10 | 11 | 12 | 99;


export interface ScratchTfjsExtensionEncoded {
Expand All @@ -21,7 +22,7 @@ export interface ScratchTfjsExtensionEncoded {

export interface ScratchTfjsExtension {
readonly modelurl: string;
readonly modeltype: ScratchTfjsModelType;
modeltype: ScratchTfjsModelType;
}

export interface ScratchTfjsExtensionWithId extends ScratchTfjsExtension {
Expand Down

0 comments on commit 02c38b9

Please sign in to comment.