Skip to content

Commit

Permalink
plugin: don't load prebuilt "imagenet" weights when inferencing (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
CandyQiu committed Aug 4, 2020
1 parent ac00e87 commit 729700f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
Expand Up @@ -66,11 +66,23 @@ const mobilenetDefine: ModelDefineType = async (data: ImageDataset, args: ModelD
await download(MODEL_URL, MODEL_PATH);
}

model = MobileNetV2(boa.kwargs({
include_top: false,
weights: 'imagenet',
input_shape: inputShape
}));
if (recoverPath) {
model = MobileNetV2(
boa.kwargs({
include_top: false,
weights: 'none',
input_shape: inputShape
})
);
} else {
model = MobileNetV2(
boa.kwargs({
include_top: false,
weights: 'imagenet',
input_shape: inputShape
})
);
}

let output = model.output;
output = GlobalAveragePooling2D()(output);
Expand Down
Expand Up @@ -64,11 +64,23 @@ const resnetModelDefine: ModelDefineType = async (data: ImageDataset, args: Mode
await download(MODEL_URL, MODEL_PATH);
}

model = ResNet50(boa.kwargs({
include_top: false,
weights: 'imagenet',
input_shape: inputShape
}));
if (recoverPath) {
model = ResNet50(
boa.kwargs({
include_top: false,
weights: 'none',
input_shape: inputShape
})
);
} else {
model = ResNet50(
boa.kwargs({
include_top: false,
weights: 'imagenet',
input_shape: inputShape
})
);
}

let output = model.output;
output = GlobalAveragePooling2D()(output);
Expand Down

0 comments on commit 729700f

Please sign in to comment.