Skip to content

Commit

Permalink
Add clone functionality for module and containers (#1111)
Browse files Browse the repository at this point in the history
Summary:
**Original Issue**: #1110

closes #1110

Adds cloning functionality to modules and containers by adding a pure virtual `clone()` method to `Module`.

- `clone()` performs a deep copy of parameters and modules, taking advantage of the underlying tensors copy on write semantics if implemented by the backend in order to minimise memory usage.
- Every module and container now must implement `clone()`. This has been done for the core library and some simpler modules in pkg, the remaining will throw a runtime error indicating clone is unimplmented.
- Core modules have been updated where necessary to add appropriate copy, assignment and move constructors to perform a deep copy of `Variable`.
- Core containers have been updated to use a new macro `FL_BASIC_CONTAINER_CLONING` where possible, which implements `clone()` as well as appropriate copy, assignment and move constructors.
- Users must be aware of and manage the lifetimes and cloning behaviour of their modules/containers. This means if users have any custom or shared lifetime requirements they should not use the `FL_BASIC_CONTAINER_CLONING` macro, but implement their own `clone()` override method as well as copy, assignment and move constructors to achieve their desired behaviour.

### Test Plan (required)
Added some tests and successfully ran locally. Also use CI.

Pull Request resolved: #1111

Reviewed By: bwasti

Differential Revision: D46599854

Pulled By: jacobkahn

fbshipit-source-id: 8dac9b49847e7519f6ad600d66b3eff5886fa6ae
  • Loading branch information
benborder authored and facebook-github-bot committed Oct 22, 2023
1 parent 4fdc336 commit f354e7f
Show file tree
Hide file tree
Showing 107 changed files with 1,584 additions and 151 deletions.
35 changes: 35 additions & 0 deletions flashlight/app/benchmark/models/AsrTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@ AsrTransformer::AsrTransformer(int64_t nFeature, int64_t nLabel) {
add(linear_);
}

AsrTransformer::AsrTransformer(const AsrTransformer& other) {
convFrontend_ = std::make_shared<fl::Sequential>(*other.convFrontend_);
// nFeature x Time x Batch x 1
add(convFrontend_);
sinpos_ = std::make_shared<fl::SinusoidalPositionEmbedding>(*other.sinpos_);
add(sinpos_);
for (const auto& transformer : transformers_) {
auto layer = std::make_shared<fl::Transformer>(*transformer);
transformers_.push_back(layer);
add(std::move(layer));
}
linear_ = std::make_shared<fl::Linear>(*other.linear_);
add(linear_);
}

AsrTransformer& AsrTransformer::operator=(const AsrTransformer& other) {
convFrontend_ = std::make_shared<fl::Sequential>(*other.convFrontend_);
// nFeature x Time x Batch x 1
add(convFrontend_);
sinpos_ = std::make_shared<fl::SinusoidalPositionEmbedding>(*other.sinpos_);
add(sinpos_);
for (const auto& transformer : transformers_) {
auto layer = std::make_shared<fl::Transformer>(*transformer);
transformers_.push_back(layer);
add(std::move(layer));
}
linear_ = std::make_shared<fl::Linear>(*other.linear_);
add(linear_);
return *this;
}

std::unique_ptr<Module> AsrTransformer::clone() const {
return std::make_unique<AsrTransformer>(*this);
}

std::vector<fl::Variable> AsrTransformer::forward(
const std::vector<fl::Variable>& input) {
auto out = input[0];
Expand Down
3 changes: 3 additions & 0 deletions flashlight/app/benchmark/models/AsrTransformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ namespace benchmark {
class AsrTransformer : public fl::Container {
public:
AsrTransformer(int64_t nFeature, int64_t nLabel);
AsrTransformer(const AsrTransformer& other);
AsrTransformer& operator=(const AsrTransformer& other);
std::unique_ptr<Module> clone() const override;

std::vector<fl::Variable> forward(
const std::vector<fl::Variable>& input) override;
Expand Down
27 changes: 27 additions & 0 deletions flashlight/app/benchmark/models/LmTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@ LmTransformer::LmTransformer(int64_t nLabel, bool fp16) : fp16_(fp16) {
}
}

LmTransformer::LmTransformer(const LmTransformer& other) {
frontend_ = std::make_shared<fl::Sequential>(*other.frontend_);
// nFeature x Time x Batch x 1
add(frontend_);
for (const auto& transformer : transformers_) {
auto layer = std::make_shared<fl::Transformer>(*transformer);
transformers_.push_back(layer);
add(std::move(layer));
}
}

LmTransformer& LmTransformer::operator=(const LmTransformer& other) {
frontend_ = std::make_shared<fl::Sequential>(*other.frontend_);
// nFeature x Time x Batch x 1
add(frontend_);
for (const auto& transformer : transformers_) {
auto layer = std::make_shared<fl::Transformer>(*transformer);
transformers_.push_back(layer);
add(std::move(layer));
}
return *this;
}

std::unique_ptr<Module> LmTransformer::clone() const {
return std::make_unique<LmTransformer>(*this);
}

std::vector<fl::Variable> LmTransformer::forward(
const std::vector<fl::Variable>& input) {
auto out = input[0];
Expand Down
3 changes: 3 additions & 0 deletions flashlight/app/benchmark/models/LmTransformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace benchmark {
class LmTransformer : public fl::Container {
public:
explicit LmTransformer(int64_t nLabel, bool fp16 = false);
LmTransformer(const LmTransformer& other);
LmTransformer& operator=(const LmTransformer& other);
std::unique_ptr<Module> clone() const override;

std::vector<fl::Variable> forward(
const std::vector<fl::Variable>& input) override;
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/autograd/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ Tensor& Variable::tensor() const {
return sharedData_->data;
}

Variable Variable::copy() const {
return Variable(sharedData_->data, sharedGrad_->calcGrad);
}

Variable Variable::astype(fl::dtype newType) const {
auto output = tensor().astype(newType);
auto gradFunc = [](std::vector<Variable>& inputs,
Expand Down
6 changes: 6 additions & 0 deletions flashlight/fl/autograd/Variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ class FL_API Variable {
*/
Tensor& tensor() const;

/**
* Creates a copy of this variable, but detached from the computation graph.
* @return returns the cloned and detached variable.
*/
Variable copy() const;

/**
* Creates a new variable based on the current variable whose type will be
* adjusted based on the input type.
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ Variable AdaptiveEmbedding::forward(const Variable& input) {
return moddims(result(fl::span, tmpIndices), outShape);
}

std::unique_ptr<Module> AdaptiveEmbedding::clone() const {
return std::make_unique<AdaptiveEmbedding>(*this);
}

std::string AdaptiveEmbedding::prettyString() const {
std::ostringstream ss;
ss << "AdaptiveEmbedding (dim: " << embeddingDim_ << "), (cutoff: ";
Expand Down
2 changes: 2 additions & 0 deletions flashlight/fl/contrib/modules/AdaptiveEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class FL_API AdaptiveEmbedding : public UnaryModule {

Variable forward(const Variable& input) override;

std::unique_ptr<Module> clone() const override;

std::string prettyString() const override;
};

Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/contrib/modules/AsymmetricConv1D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ Variable AsymmetricConv1D::forward(const Variable& input) {
return output;
}

std::unique_ptr<Module> AsymmetricConv1D::clone() const {
return std::make_unique<AsymmetricConv1D>(*this);
}

std::string AsymmetricConv1D::prettyString() const {
std::ostringstream ss;
ss << "AsymmetricConv1D";
Expand Down
2 changes: 2 additions & 0 deletions flashlight/fl/contrib/modules/AsymmetricConv1D.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class FL_API AsymmetricConv1D : public fl::Conv2D {

fl::Variable forward(const fl::Variable& input) override;

std::unique_ptr<Module> clone() const override;

std::string prettyString() const override;

private:
Expand Down
54 changes: 52 additions & 2 deletions flashlight/fl/contrib/modules/Conformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,52 @@ Conformer::Conformer(
if (posEmbContextSize_ > 0) {
params_.push_back(uniform(2 * posEmbContextSize_ - 1, headDim, -0.1, 0.1));
}
createLayers();
}

Conformer::Conformer(const Conformer& other) {
copy(other);
createLayers();
}

Conformer& Conformer::operator=(const Conformer& other) {
clear();
copy(other);
createLayers();
return *this;
}

void Conformer::copy(const Conformer& other) {
train_ = other.train_;
nHeads_ = other.nHeads_;
posEmbContextSize_ = other.posEmbContextSize_;
convKernelSize_ = other.convKernelSize_;
pDropout_ = other.pDropout_;
pLayerDropout_ = other.pLayerDropout_;
w11_ = std::make_shared<Linear>(*other.w11_);
w12_ = std::make_shared<Linear>(*other.w12_);
w21_ = std::make_shared<Linear>(*other.w21_);
w22_ = std::make_shared<Linear>(*other.w22_);
wq_ = std::make_shared<Linear>(*other.wq_);
wk_ = std::make_shared<Linear>(*other.wk_);
wv_ = std::make_shared<Linear>(*other.wv_);
wf_ = std::make_shared<Linear>(*other.wf_);
conv1_ = std::make_shared<Linear>(*other.conv1_);
conv2_ = std::make_shared<Linear>(*other.conv2_);
norm1_ = std::make_shared<LayerNorm>(*other.norm1_);
norm2_ = std::make_shared<LayerNorm>(*other.norm2_);
normMhsa_ = std::make_shared<LayerNorm>(*other.normMhsa_);
normConv1_ = std::make_shared<LayerNorm>(*other.normConv1_);
normConv2_ = std::make_shared<LayerNorm>(*other.normConv2_);
norm3_ = std::make_shared<LayerNorm>(*other.norm3_);
convDepthWise_ = std::make_shared<Conv2D>(*other.convDepthWise_);
if (posEmbContextSize_ > 0) {
const auto& p = other.param(0);
params_.emplace_back(p.copy());
}
}

void Conformer::createLayers() {
// first feed-forward module
add(w11_);
add(w12_);
Expand Down Expand Up @@ -165,8 +211,8 @@ Variable Conformer::conv(const Variable& _input) {
float pDropout = train_ ? pDropout_ : 0.0;
// input C x T x B x 1
// apply first pointwise conv
auto result =
gatedlinearunit((*conv1_)(((*normConv1_)(input)).astype(input.type())), 0);
auto result = gatedlinearunit(
(*conv1_)(((*normConv1_)(input)).astype(input.type())), 0);
result = reorder(result, {1, 3, 0, 2});
// T x 1 x C x B
// apply depthwise separable convolutions
Expand Down Expand Up @@ -219,6 +265,10 @@ std::vector<Variable> Conformer::forward(const std::vector<Variable>& input) {
return {x};
}

std::unique_ptr<Module> Conformer::clone() const {
return std::make_unique<Conformer>(*this);
}

std::string Conformer::prettyString() const {
std::ostringstream ss;
ss << "Conformer "
Expand Down
8 changes: 8 additions & 0 deletions flashlight/fl/contrib/modules/Conformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ class FL_API Conformer : public Container {
int32_t convKernelSize,
float pDropout,
float pLayerDropout = 0.);
Conformer(const Conformer& other);
Conformer(Conformer&& other) = default;

Conformer& operator=(const Conformer& other);
Conformer& operator=(Conformer&& other) = default;

std::vector<Variable> forward(const std::vector<Variable>& input) override;
std::unique_ptr<Module> clone() const override;
std::string prettyString() const override;

private:
Expand All @@ -61,6 +67,8 @@ class FL_API Conformer : public Container {
norm3_;
std::shared_ptr<Conv2D> convDepthWise_;

void copy(const Conformer& other);
void createLayers();
static Variable conformerInitLinear(int32_t inDim, int32_t outDim);
Variable mhsa(const Variable& input, const Variable& inputPadMask);
Variable conv(const Variable& input);
Expand Down
17 changes: 17 additions & 0 deletions flashlight/fl/contrib/modules/PositionEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ PositionEmbedding::PositionEmbedding(
params_ = {embeddings};
}

PositionEmbedding::PositionEmbedding(const PositionEmbedding& other)
: Module(other.copyParams()), dropout_(other.dropout_) {
train_ = other.train_;
}

PositionEmbedding& PositionEmbedding::operator=(
const PositionEmbedding& other) {
params_ = other.copyParams();
train_ = other.train_;
dropout_ = other.dropout_;
return *this;
}

std::vector<Variable> PositionEmbedding::forward(
const std::vector<Variable>& input) {
if (input[0].ndim() != 3) {
Expand All @@ -48,6 +61,10 @@ std::vector<Variable> PositionEmbedding::operator()(
return forward(input);
}

std::unique_ptr<Module> PositionEmbedding::clone() const {
return std::make_unique<PositionEmbedding>(*this);
}

std::string PositionEmbedding::prettyString() const {
return "Position Embedding Layer";
}
Expand Down
14 changes: 12 additions & 2 deletions flashlight/fl/contrib/modules/PositionEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,18 @@ namespace fl {
* B is the batch size.
*
*/
class FL_API PositionEmbedding : public Container {
class FL_API PositionEmbedding : public Module {
public:
PositionEmbedding(int32_t layerDim, int32_t maxLen, double dropout = 0);

PositionEmbedding(const PositionEmbedding& other);

PositionEmbedding& operator=(const PositionEmbedding& other);

PositionEmbedding(PositionEmbedding&& other) = default;

PositionEmbedding& operator=(PositionEmbedding&& other) = default;

/**
* PositionEmbedding::forward(input) expects input[0] to be of
* dimensions C x T x B with C = layerDim and T <= maxLen.
Expand All @@ -40,10 +48,12 @@ class FL_API PositionEmbedding : public Container {

std::vector<Variable> operator()(const std::vector<Variable>& input);

std::unique_ptr<Module> clone() const override;

std::string prettyString() const override;

private:
FL_SAVE_LOAD_WITH_BASE(Container, dropout_)
FL_SAVE_LOAD_WITH_BASE(Module, dropout_)

double dropout_;

Expand Down
6 changes: 5 additions & 1 deletion flashlight/fl/contrib/modules/RawWavSpecAugment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
*/

#include <cmath>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <numeric>

#include "flashlight/fl/common/Logging.h"
#include "flashlight/fl/contrib/modules/RawWavSpecAugment.h"
Expand Down Expand Up @@ -185,6 +185,10 @@ int RawWavSpecAugment::generateRandomInt(int low, int high) {
return uniformDist(eng_);
}

std::unique_ptr<Module> RawWavSpecAugment::clone() const {
return std::make_unique<RawWavSpecAugment>(*this);
}

std::string RawWavSpecAugment::prettyString() const {
std::ostringstream ss;
ss << "RawWavSpecAugment ( ";
Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/contrib/modules/RawWavSpecAugment.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class FL_API RawWavSpecAugment : public UnaryModule {
MaskingStrategy mStrategy = MaskingStrategy::ZERO);

Variable forward(const Variable& input) override;
std::unique_ptr<Module> clone() const override;
std::string prettyString() const override;

private:
Expand Down
2 changes: 2 additions & 0 deletions flashlight/fl/contrib/modules/Residual.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class FL_API Residual : public Container {
Variable forward(const Variable& input);

std::string prettyString() const override;

FL_BASIC_CONTAINER_CLONING(Residual)
};

} // namespace fl
Expand Down

0 comments on commit f354e7f

Please sign in to comment.