Skip to content

Commit

Permalink
Use TFile::Serialize, TFile::DeSerialize
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <sw@weilnetz.de>
  • Loading branch information
stweil committed Jul 18, 2018
1 parent c383b1a commit 6ef267c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 42 deletions.
32 changes: 15 additions & 17 deletions src/lstm/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,26 @@ bool Network::SetupNeedsBackprop(bool needs_backprop) {
// Writes to the given file. Returns false in case of error.
bool Network::Serialize(TFile* fp) const {
int8_t data = NT_NONE;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
if (!fp->Serialize(&data)) return false;
STRING type_name = kTypeNames[type_];
if (!type_name.Serialize(fp)) return false;
data = training_;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
if (!fp->Serialize(&data)) return false;
data = needs_to_backprop_;
if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
if (!fp->Serialize(&data)) return false;
if (!fp->Serialize(&network_flags_)) return false;
if (!fp->Serialize(&ni_)) return false;
if (!fp->Serialize(&no_)) return false;
if (!fp->Serialize(&num_weights_)) return false;
if (!name_.Serialize(fp)) return false;
return true;
}

// Reads from the given file. Returns false in case of error.
// Should be overridden by subclasses, but NOT called by their DeSerialize.
bool Network::DeSerialize(TFile* fp) {
int8_t data = 0;
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
int8_t data;
if (!fp->DeSerialize(&data)) return false;
if (data == NT_NONE) {
STRING type_name;
if (!type_name.DeSerialize(fp)) return false;
Expand All @@ -181,16 +181,14 @@ bool Network::DeSerialize(TFile* fp) {
}
}
type_ = static_cast<NetworkType>(data);
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
if (!fp->DeSerialize(&data)) return false;
training_ = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
if (!fp->DeSerialize(&data)) return false;
needs_to_backprop_ = data != 0;
if (fp->FReadEndian(&network_flags_, sizeof(network_flags_), 1) != 1)
return false;
if (fp->FReadEndian(&ni_, sizeof(ni_), 1) != 1) return false;
if (fp->FReadEndian(&no_, sizeof(no_), 1) != 1) return false;
if (fp->FReadEndian(&num_weights_, sizeof(num_weights_), 1) != 1)
return false;
if (!fp->DeSerialize(&network_flags_)) return false;
if (!fp->DeSerialize(&ni_)) return false;
if (!fp->DeSerialize(&no_)) return false;
if (!fp->DeSerialize(&num_weights_)) return false;
if (!name_.DeSerialize(fp)) return false;
return true;
}
Expand Down
12 changes: 6 additions & 6 deletions src/lstm/plumbing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ float* Plumbing::LayerLearningRatePtr(const char* id) const {
// Writes to the given file. Returns false in case of error.
bool Plumbing::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
int32_t size = stack_.size();
uint32_t size = stack_.size();
// Can't use PointerVector::Serialize here as we need a special DeSerialize.
if (fp->FWrite(&size, sizeof(size), 1) != 1) return false;
for (int i = 0; i < size; ++i)
if (!fp->Serialize(&size)) return false;
for (uint32_t i = 0; i < size; ++i)
if (!stack_[i]->Serialize(fp)) return false;
if ((network_flags_ & NF_LAYER_SPECIFIC_LR) &&
!learning_rates_.Serialize(fp)) {
Expand All @@ -197,9 +197,9 @@ bool Plumbing::Serialize(TFile* fp) const {
bool Plumbing::DeSerialize(TFile* fp) {
stack_.truncate(0);
no_ = 0; // We will be modifying this as we AddToStack.
int32_t size;
if (fp->FReadEndian(&size, sizeof(size), 1) != 1) return false;
for (int i = 0; i < size; ++i) {
uint32_t size;
if (!fp->DeSerialize(&size)) return false;
for (uint32_t i = 0; i < size; ++i) {
Network* network = CreateFromFile(fp);
if (network == nullptr) return false;
AddToStack(network);
Expand Down
11 changes: 5 additions & 6 deletions src/lstm/reconfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,15 @@ int Reconfig::XScaleFactor() const {

// Writes to the given file. Returns false in case of error.
bool Reconfig::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
if (fp->FWrite(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
return true;
return Network::Serialize(fp) &&
fp->Serialize(&x_scale_) &&
fp->Serialize(&y_scale_);
}

// Reads from the given file. Returns false in case of error.
bool Reconfig::DeSerialize(TFile* fp) {
if (fp->FReadEndian(&x_scale_, sizeof(x_scale_), 1) != 1) return false;
if (fp->FReadEndian(&y_scale_, sizeof(y_scale_), 1) != 1) return false;
if (!fp->DeSerialize(&x_scale_)) return false;
if (!fp->DeSerialize(&y_scale_)) return false;
no_ = ni_ * x_scale_ * y_scale_;
return true;
}
Expand Down
20 changes: 10 additions & 10 deletions src/lstm/static_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ class StaticShape {
bool DeSerialize(TFile *fp) {
int32_t tmp = LT_NONE;
bool result =
fp->FReadEndian(&batch_, sizeof(batch_), 1) == 1 &&
fp->FReadEndian(&height_, sizeof(height_), 1) == 1 &&
fp->FReadEndian(&width_, sizeof(width_), 1) == 1 &&
fp->FReadEndian(&depth_, sizeof(depth_), 1) == 1 &&
fp->FReadEndian(&tmp, sizeof(tmp), 1) == 1;
fp->DeSerialize(&batch_) &&
fp->DeSerialize(&height_) &&
fp->DeSerialize(&width_) &&
fp->DeSerialize(&depth_) &&
fp->DeSerialize(&tmp);
loss_type_ = static_cast<LossType>(tmp);
return result;
}

bool Serialize(TFile *fp) const {
int32_t tmp = loss_type_;
return
fp->FWrite(&batch_, sizeof(batch_), 1) == 1 &&
fp->FWrite(&height_, sizeof(height_), 1) == 1 &&
fp->FWrite(&width_, sizeof(width_), 1) == 1 &&
fp->FWrite(&depth_, sizeof(depth_), 1) == 1 &&
fp->FWrite(&tmp, sizeof(tmp), 1) == 1;
fp->Serialize(&batch_) &&
fp->Serialize(&height_) &&
fp->Serialize(&width_) &&
fp->Serialize(&depth_) &&
fp->Serialize(&tmp);
}

private:
Expand Down
6 changes: 3 additions & 3 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
// format, without errs, so we can detect and read old format weight matrices.
uint8_t mode =
(int_mode_ ? kInt8Flag : 0) | (use_adam_ ? kAdamFlag : 0) | kDoubleFlag;
if (fp->FWrite(&mode, sizeof(mode), 1) != 1) return false;
if (!fp->Serialize(&mode)) return false;
if (int_mode_) {
if (!wi_.Serialize(fp)) return false;
if (!scales_.Serialize(fp)) return false;
Expand All @@ -163,8 +163,8 @@ bool WeightMatrix::Serialize(bool training, TFile* fp) const {
// Reads from the given file. Returns false in case of error.

bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
uint8_t mode = 0;
if (fp->FRead(&mode, sizeof(mode), 1) != 1) return false;
uint8_t mode;
if (!fp->DeSerialize(&mode)) return false;
int_mode_ = (mode & kInt8Flag) != 0;
use_adam_ = (mode & kAdamFlag) != 0;
if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp);
Expand Down

0 comments on commit 6ef267c

Please sign in to comment.