Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add memory logging support to fl_lm_train and fl_img_imagenet_resnet34 #444

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 2 additions & 13 deletions flashlight/app/asr/Train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,9 @@ int main(int argc, char** argv) {
FL_LOG_MASTER(INFO) << "Experiment runidx: " << runIdx;

// log memory manager operations.
std::ofstream memLog;
if (FLAGS_fl_log_mem_ops_interval > 0 && isMaster) {
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
memLog.open(getRunFile("mem", runIdx, runPath));
if (!memLog) {
LOG(FATAL) << "failed to open memory log file="
<< getRunFile("mem", runIdx, runPath) << " for writing";
}
curMemMgr->setLogStream(&memLog);
curMemMgr->setLoggingEnabled(true);
curMemMgr->setLogFlushInterval(FLAGS_fl_log_mem_ops_interval);
}
fl::MemoryManagerInstaller::logIfInstalled(
getRunFile("mem", runIdx, runPath), FLAGS_fl_log_mem_ops_interval);
}

// flashlight optim mode
Expand Down
13 changes: 13 additions & 0 deletions flashlight/app/imgclass/examples/ImageNetResnet34.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ DEFINE_string(
DEFINE_uint64(data_batch_size, 256, "Total batch size across all gpus");
DEFINE_string(exp_checkpoint_path, "/tmp/model", "Checkpointing prefix path");
DEFINE_int64(exp_checkpoint_epoch, -1, "Checkpoint epoch to load from");
DEFINE_int64(
exp_log_mem_ops_interval,
0,
"Flushes memory manager logs after a specified number of log entries. "
"1000000 is a reasonable value which will reduces overhead. "
"Logs when > 0");

using namespace fl;
using fl::ext::image::compose;
Expand Down Expand Up @@ -115,6 +121,13 @@ int main(int argc, char** argv) {
af::setDevice(worldRank);
af::setSeed(worldSize);

// log memory manager operations.
if (FLAGS_exp_log_mem_ops_interval > 0 && isMaster) {
fl::MemoryManagerInstaller::logIfInstalled(
lib::pathsConcat(FLAGS_exp_checkpoint_path, "mem.log"),
FLAGS_exp_log_mem_ops_interval);
}

auto reducer =
std::make_shared<fl::CoalescingReducer>(1.0 / worldSize, true, true);

Expand Down
17 changes: 17 additions & 0 deletions flashlight/app/lm/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "flashlight/app/lm/Trainer.h"

#include "flashlight/fl/memory/MemoryManagerInstaller.h"

using namespace fl::ext;
using namespace fl::lib;

Expand Down Expand Up @@ -64,6 +66,12 @@ DEFINE_string(
exp_init_model_path,
"",
"Initialization model full path, used as init model to start training.");
DEFINE_int64(
exp_log_mem_ops_interval,
0,
"Flushes memory manager logs after a specified "
"number of log entries. 1000000 is a reasonable "
"value which will reduces overhead. Logs when > 0");

/* DATA OPTIONS */
DEFINE_string(
Expand Down Expand Up @@ -231,6 +239,15 @@ void Trainer::runTraining() {
logWriter_ = createOutputStream(
pathsConcat(FLAGS_exp_rundir, FLAGS_exp_model_name + ".log"),
std::ios_base::app);

// log memory manager operations.
if (FLAGS_exp_log_mem_ops_interval > 0) {
logWriter_ = createOutputStream(
pathsConcat(FLAGS_exp_rundir, FLAGS_exp_model_name + "_mem.log"),
std::ios_base::trunc);
fl::MemoryManagerInstaller::logIfInstalled(
&logWriter_, FLAGS_exp_log_mem_ops_interval);
}
}

FL_LOG_MASTER(INFO) << "training started (epoch=" << epoch_
Expand Down
2 changes: 2 additions & 0 deletions flashlight/app/lm/Trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ DECLARE_string(distributed_rndv_filepath);
DECLARE_string(exp_rundir);
DECLARE_string(exp_model_name);
DECLARE_string(exp_init_model_path);
DECLARE_int64(exp_log_mem_ops_interval);

/* DATA OPTIONS */
DECLARE_string(data_dir);
Expand Down Expand Up @@ -132,6 +133,7 @@ class Trainer {
fl::AverageValueMeter tokenCountMeter_;

std::ofstream logWriter_;
std::ofstream memLogWriter_;

/* Initializers */
void initTrain();
Expand Down
26 changes: 26 additions & 0 deletions flashlight/fl/memory/MemoryManagerInstaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@
#include <cstdlib>
#include <stdexcept>

#include "flashlight/fl/common/Logging.h"
#include "flashlight/fl/common/Utils.h"
#include "flashlight/fl/memory/managers/CachingMemoryManager.h"
#include "flashlight/lib/common/System.h"

using ::fl::lib::createOutputStream;

namespace fl {

// Statics from MemoryManagerInstaller
std::shared_ptr<MemoryManagerAdapter>
MemoryManagerInstaller::currentlyInstalledMemoryManager_;
std::unique_ptr<std::ofstream> MemoryManagerInstaller::log_;

MemoryManagerAdapter* MemoryManagerInstaller::getImpl(
af_memory_manager manager) {
Expand Down Expand Up @@ -234,4 +239,25 @@ void MemoryManagerInstaller::unsetMemoryManager() {
}
}

void MemoryManagerInstaller::logIfInstalled(
const std::string& logFilename,
size_t interval) {
if (currentlyInstalledMemoryManager_) {
log_ = std::make_unique<std::ofstream>(
createOutputStream(logFilename, std::ios_base::trunc));
FL_LOG(fl::INFO) << "Saving memory log to file=" << logFilename;
logIfInstalled(log_.get(), interval);
}
}

void MemoryManagerInstaller::logIfInstalled(
std::ofstream* log,
size_t interval) {
if (currentlyInstalledMemoryManager_) {
currentlyInstalledMemoryManager_->setLogStream(log);
currentlyInstalledMemoryManager_->setLoggingEnabled(true);
currentlyInstalledMemoryManager_->setLogFlushInterval(interval);
}
}

} // namespace fl
17 changes: 17 additions & 0 deletions flashlight/fl/memory/MemoryManagerInstaller.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <af/memory.h>

#include <fstream>
#include <memory>
#include <mutex>

Expand Down Expand Up @@ -95,9 +96,25 @@ class MemoryManagerInstaller {
*/
static void unsetMemoryManager();

/**
* the currentlyInstalledMemoryManager is set to flush the log every
* 'interval' memory manager api calls. Each operation is written as a text
* line into memLogWriter.
*/
static void logIfInstalled(const std::string& logFilename, size_t interval);

/**
* the currentlyInstalledMemoryManager is set to flush the log every
* 'interval' memory manager api calls. Each operation is written as a text
* line into memLogWriter
*/
static void logIfInstalled(std::ofstream* log, size_t interval);

private:
// The given memory manager implementation
std::shared_ptr<MemoryManagerAdapter> impl_;
// Used to keep file opened by logIfInstalled(string, size_t)
static std::unique_ptr<std::ofstream> log_;
// Points to the impl_ of the most recently installed manager.
static std::shared_ptr<MemoryManagerAdapter> currentlyInstalledMemoryManager_;
};
Expand Down