Permalink
Browse files

Working Loadable model dump from Trainer

  • Loading branch information...
adityakusupati committed Nov 28, 2017
1 parent d7ceeb4 commit f5407b1d49ad4049887fc3dbdd58d9633a6f000a
Showing with 37 additions and 18 deletions.
  1. +3 −3 drivers/Bonsai/trainer/BonsaiTrainDriver.cpp
  2. +5 −0 src/Bonsai/Bonsai.h
  3. +29 −15 src/Bonsai/BonsaiTrainer.cpp
@@ -30,9 +30,9 @@ int main(int argc, char **argv)
auto meanVarBytes = trainer.getMeanVarSize();
auto meanVar = new char[meanVarBytes];
trainer.exportModel(modelBytes, model, currResultsPath); // use exportSparseModel(...) if you need sparse model
trainer.exportMeanVar(meanVarBytes, meanVar, currResultsPath);
trainer.exportModel(modelBytes, model); // use exportSparseModel(...) if you need sparse model
trainer.exportMeanVar(meanVarBytes, meanVar);
trainer.getLoadableModelMeanVar(model, modelBytes, meanVar, meanVarBytes, currResultsPath);
trainer.dumpModelMeanVar(currResultsPath);
delete[] model, meanVar;
View
@@ -535,6 +535,11 @@ namespace EdgeML
///
void dumpModelMeanVar(const std::string& currResultsPath);
///
/// Function to Dump Loadable Mean, Variance and Model
///
void getLoadableModelMeanVar(char *const modelBuffer, const size_t& modelBytes, char *const meanVarBuffer, const size_t& meanVarBytes, const std::string& currResultsPath);
size_t totalNonZeros();
};
@@ -28,11 +28,11 @@ BonsaiTrainer::BonsaiTrainer(
model(numBytes, fromModel, isDense), // Initialize model
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.nvalidation,
model.hyperParams.ntest,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
model.hyperParams.ntrain,
model.hyperParams.nvalidation,
model.hyperParams.ntest,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
{
assert(dataIngestType == FileIngest);
@@ -62,7 +62,7 @@ BonsaiTrainer::BonsaiTrainer(
model(argc, argv, dataDir), // Initialize model
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.ntrain,
model.hyperParams.nvalidation,
model.hyperParams.ntest,
model.hyperParams.numClasses,
@@ -97,11 +97,11 @@ BonsaiTrainer::BonsaiTrainer(
: model(fromHyperParams),
data(dataIngestType,
DataFormatParams{
model.hyperParams.ntrain,
model.hyperParams.nvalidation,
model.hyperParams.ntest,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
model.hyperParams.ntrain,
model.hyperParams.nvalidation,
model.hyperParams.ntest,
model.hyperParams.numClasses,
model.hyperParams.dataDimension })
{
assert(dataIngestType == InterfaceIngest);
assert(model.hyperParams.normalizationType == none);
@@ -218,7 +218,7 @@ FP_TYPE BonsaiTrainer::computeObjective(
if (ZX.cols() == data.Xtrain.cols())
LOG_INFO(infoStr);
/* else
LOG_TRACE(infoStr);*/
LOG_TRACE(infoStr);*/
return normAdd + (FP_TYPE)marginLoss / ZX.cols();
}
@@ -280,12 +280,26 @@ void BonsaiTrainer::exportModel(
{
std::string loadableModelPath = currResultsPath + "/loadableModel";
model.exportModel(modelSize, buffer);
std::ofstream modelExporter(loadableModelPath);
std::ofstream modelExporter(loadableModelPath, std::ios::out | std::ios::binary);
modelExporter.write(buffer, modelSize);
modelExporter.close();
}
void BonsaiTrainer::getLoadableModelMeanVar(
char *const modelBuffer,
const size_t& modelBytes,
char *const meanVarBuffer,
const size_t& meanVarBytes,
const std::string& currResultsPath)
{
std::ofstream modelExporter(currResultsPath + "/loadableModel", std::ios::out | std::ios::binary);
modelExporter.write(modelBuffer, modelBytes);
modelExporter.close();
std::ofstream meanVarExporter(currResultsPath + "/loadableMeanVar", std::ios::out | std::ios::binary);
meanVarExporter.write(meanVarBuffer, meanVarBytes);
meanVarExporter.close();
}
void BonsaiTrainer::exportSparseModel(
const size_t& modelSize,
char *const buffer,
@@ -487,7 +501,7 @@ void BonsaiTrainer::TreeCache::fillNodeProbability(
tanhThetaXCache = MatrixXuf::Zero(model.hyperParams.internalNodes, Xdata.cols());
nodeProbability = MatrixXuf::Ones(model.hyperParams.totalNodes, Xdata.cols());
if(model.hyperParams.internalNodes > 0)
if (model.hyperParams.internalNodes > 0)
mm(tanhThetaXCache, Thetamat, CblasNoTrans, Xdata, CblasNoTrans, (FP_TYPE)1.0, (FP_TYPE)0.0L);
// Scale VXClassIDScratch by scalar sigma_i
scal(tanhThetaXCache.rows()*tanhThetaXCache.cols(), model.hyperParams.sigma_i, tanhThetaXCache.data(), 1);

0 comments on commit f5407b1

Please sign in to comment.