Skip to content

Commit

Permalink
Adapt multihead_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
akropp authored and akropp committed Oct 18, 2023
1 parent 202e7f5 commit 20f646e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 89 deletions.
77 changes: 37 additions & 40 deletions src/mlpack/methods/ann/layer/not_adapted/multihead_attention.hpp
Expand Up @@ -48,18 +48,15 @@ namespace mlpack {
* of shape `(embedDim * tgtSeqLen, batchSize)`. The embeddings are stored
* consequently.
*
* @tparam InputType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputType Type of the output data (arma::colvec, arma::mat,
* @tparam MatType Type of the input/output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam RegularizerType Type of the regularizer to be used.
*/
template <
typename InputType = arma::mat,
typename OutputType = arma::mat,
typename MatType = arma::mat,
typename RegularizerType = NoRegularizer
>
class MultiheadAttentionType : public Layer<InputType, OutputType>
class MultiheadAttentionType : public Layer<MatType>
{
public:
/**
Expand All @@ -82,20 +79,20 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
const size_t srcSeqLen,
const size_t embedDim,
const size_t numHeads,
const InputType& attnmask = InputType(),
const InputType& keyPaddingMask = InputType());
const MatType& attnmask = MatType(),
const MatType& keyPaddingMask = MatType());

//! Clone the MultiheadAttentionType object. This handles polymorphism
//! correctly.
MultiheadAttentionType* Clone() const
MultiheadAttentionType* Clone() const override
{
return new MultiheadAttentionType(*this);
}

/**
* Reset the layer parameters.
*/
void SetWeights(typename OutputType::elem_type* weightsPtr);
void SetWeights(typename MatType::elem_type* weightsPtr);

/**
* Ordinary feed forward pass of a neural network, evaluating the function
Expand All @@ -104,7 +101,7 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
* @param input The query matrix.
* @param output Resulting output activation.
*/
void Forward(const InputType& input, OutputType& output);
void Forward(const MatType& input, MatType& output) override;

/**
* Ordinary feed backward pass of a neural network, calculating the function
Expand All @@ -114,9 +111,9 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
* @param gy The backpropagated error.
* @param g The calculated gradient.
*/
void Backward(const InputType& /* input */,
const OutputType& gy,
OutputType& g);
void Backward(const MatType& /* input */,
const MatType& gy,
MatType& g) override;

/**
* Calculate the gradient using the output delta and the input activation.
Expand All @@ -125,12 +122,12 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
* @param error The calculated error.
* @param gradient The calculated gradient.
*/
void Gradient(const InputType& input,
const OutputType& error,
OutputType& gradient);
void Gradient(const MatType& input,
const MatType& error,
MatType& gradient) override;

//! Get the size of the weights.
size_t WeightSize() const { return 4 * (embedDim + 1) * embedDim; }
size_t WeightSize() const override { return 4 * (embedDim + 1) * embedDim; }

/**
* Serialize the layer.
Expand Down Expand Up @@ -159,22 +156,20 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
size_t& NumHeads() { return numHeads; }

//! Get the two dimensional Attention Mask.
OutputType const& AttentionMask() const { return attnMask; }
MatType const& AttentionMask() const { return attnMask; }
//! Modify the two dimensional Attention Mask.
OutputType& AttentionMask() { return attnMask; }
MatType& AttentionMask() { return attnMask; }

//! Get Key Padding Mask.
OutputType const& KeyPaddingMask() const { return keyPaddingMask; }
MatType const& KeyPaddingMask() const { return keyPaddingMask; }
//! Modify the Key Padding Mask.
OutputType& KeyPaddingMask() { return keyPaddingMask; }

const size_t WeightSize() const { return (4 * embedDim + 4) * embedDim; }
MatType& KeyPaddingMask() { return keyPaddingMask; }

const std::vector<size_t> OutputDimensions() const
{
// This returns the output as a 2-dimensional (embedDim * tgtSeqLen)
// matrix.
std::vector<size_t> outputDimensions(inputDimensions.size(), 1);
std::vector<size_t> outputDimensions(this->inputDimensions.size(), 1);
outputDimensions[0] = embedDim;
outputDimensions[1] = tgtSeqLen;

Expand All @@ -188,7 +183,7 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>

private:
//! Element Type of the output.
typedef typename OutputType::elem_type ElemType;
typedef typename MatType::elem_type ElemType;

//! Target sequence length.
size_t tgtSeqLen;
Expand All @@ -206,37 +201,37 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
size_t headDim;

//! Two dimensional Attention Mask of shape (tgtSeqLen, srcSeqLen).
OutputType attnMask;
MatType attnMask;

//! Key Padding Mask.
OutputType keyPaddingMask;
MatType keyPaddingMask;

//! Locally-stored weight matrix associated with query.
OutputType queryWt;
MatType queryWt;

//! Locally-stored weight matrix associated with key.
OutputType keyWt;
MatType keyWt;

//! Locally-stored weight matrix associated with value.
OutputType valueWt;
MatType valueWt;

//! Locally-stored weight matrix associated with attnWt.
OutputType outWt;
MatType outWt;

//! Locally-stored bias associated with query.
OutputType qBias;
MatType qBias;

//! Locally-stored bias associated with key.
OutputType kBias;
MatType kBias;

//! Locall-stored bias associated with value.
OutputType vBias;
MatType vBias;

//! Locally-stored bias associated with attnWt.
OutputType outBias;
MatType outBias;

//! Locally-stored weights parameter.
OutputType weights;
MatType weights;

//! Locally-stored projected query matrix over linear layer.
arma::Cube<ElemType> qProj;
Expand All @@ -254,15 +249,17 @@ class MultiheadAttentionType : public Layer<InputType, OutputType>
arma::Cube<ElemType> attnOut;

//! Softmax layer to represent the probabilities of next sequence.
Softmax softmax;
SoftmaxType<MatType> softmax;

// temporary storage for softmax output
MatType softmaxOutput;

//! Locally-stored regularizer object.
RegularizerType regularizer;
}; // class MultiheadAttention

// Standard MultiheadAttention layer using no regularization.
typedef MultiheadAttentionType<arma::mat, arma::mat, NoRegularizer>
MultiheadAttention;
typedef MultiheadAttentionType<arma::mat, NoRegularizer> MultiheadAttention;

} // namespace mlpack

Expand Down
Expand Up @@ -20,29 +20,29 @@

namespace mlpack {

template <typename InputType, typename OutputType, typename RegularizerType>
MultiheadAttentionType<InputType, OutputType, RegularizerType>::
template <typename MatType, typename RegularizerType>
MultiheadAttentionType<MatType, RegularizerType>::
MultiheadAttentionType() :
tgtSeqLen(0),
srcSeqLen(0),
embedDim(0),
numHeads(0),
headDim(0),
attnMask(InputType()),
keyPaddingMask(InputType())
attnMask(MatType()),
keyPaddingMask(MatType())
{
// Nothing to do here.
}

template <typename InputType, typename OutputType, typename RegularizerType>
MultiheadAttentionType<InputType, OutputType, RegularizerType>::
template <typename MatType, typename RegularizerType>
MultiheadAttentionType<MatType, RegularizerType>::
MultiheadAttentionType(
const size_t tgtSeqLen,
const size_t srcSeqLen,
const size_t embedDim,
const size_t numHeads,
const InputType& attnMask,
const InputType& keyPaddingMask) :
const MatType& attnMask,
const MatType& keyPaddingMask) :
tgtSeqLen(tgtSeqLen),
srcSeqLen(srcSeqLen),
embedDim(embedDim),
Expand All @@ -59,36 +59,36 @@ MultiheadAttentionType(
headDim = embedDim / numHeads;
}

template <typename InputType, typename OutputType, typename RegularizerType>
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::SetWeights(
typename OutputType::elem_type* weightsPtr)
template <typename MatType, typename RegularizerType>
void MultiheadAttentionType<MatType, RegularizerType>::SetWeights(
typename MatType::elem_type* weightsPtr)
{
weights = OutputType(weightsPtr, 1, (4 * embedDim + 4) * embedDim, false,
weights = MatType(weightsPtr, 1, (4 * embedDim + 4) * embedDim, false,
true);

queryWt = OutputType(weightsPtr, embedDim, embedDim, false, true);
keyWt = OutputType(weightsPtr + embedDim * embedDim, embedDim, embedDim,
queryWt = MatType(weightsPtr, embedDim, embedDim, false, true);
keyWt = MatType(weightsPtr + embedDim * embedDim, embedDim, embedDim,
false, true);
valueWt = OutputType(weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
valueWt = MatType(weightsPtr + 2 * embedDim * embedDim, embedDim, embedDim,
false, true);
outWt = OutputType(weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
outWt = MatType(weightsPtr + 3 * embedDim * embedDim, embedDim, embedDim,
false, true);

qBias = OutputType(weightsPtr + 4 * embedDim * embedDim, embedDim, 1, false,
qBias = MatType(weightsPtr + 4 * embedDim * embedDim, embedDim, 1, false,
true);
kBias = OutputType(weightsPtr + (4 * embedDim + 1) * embedDim, embedDim, 1,
kBias = MatType(weightsPtr + (4 * embedDim + 1) * embedDim, embedDim, 1,
false, true);
vBias = OutputType(weightsPtr + (4 * embedDim + 2) * embedDim, embedDim, 1,
vBias = MatType(weightsPtr + (4 * embedDim + 2) * embedDim, embedDim, 1,
false, true);
outBias = OutputType(weightsPtr + (4 * embedDim + 3) * embedDim, 1, embedDim,
outBias = MatType(weightsPtr + (4 * embedDim + 3) * embedDim, 1, embedDim,
false, true);
}

template <typename InputType, typename OutputType, typename RegularizerType>
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
Forward(const InputType& input, OutputType& output)
template <typename MatType, typename RegularizerType>
void MultiheadAttentionType<MatType, RegularizerType>::
Forward(const MatType& input, MatType& output)
{
typedef typename arma::Cube<typename InputType::elem_type> CubeType;
typedef typename arma::Cube<typename MatType::elem_type> CubeType;

if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
{
Expand All @@ -104,12 +104,12 @@ Forward(const InputType& input, OutputType& output)
// The shape of q : (embedDim, tgtSeqLen, batchSize).
// The shape of k : (embedDim, srcSeqLen, batchSize).
// The shape of v : (embedDim, srcSeqLen, batchSize).
const CubeType q(const_cast<InputType&>(input).memptr(),
const CubeType q(const_cast<MatType&>(input).memptr(),
embedDim, tgtSeqLen, batchSize, false, false);
const CubeType k(const_cast<InputType&>(input).memptr() +
const CubeType k(const_cast<MatType&>(input).memptr() +
embedDim * tgtSeqLen * batchSize,
embedDim, srcSeqLen, batchSize, false, false);
const CubeType v(const_cast<InputType&>(input).memptr() +
const CubeType v(const_cast<MatType&>(input).memptr() +
embedDim * (tgtSeqLen + srcSeqLen) * batchSize,
embedDim, srcSeqLen, batchSize, false, false);

Expand Down Expand Up @@ -167,8 +167,8 @@ Forward(const InputType& input, OutputType& output)

for (size_t i = 0; i < numHeads * batchSize; ++i)
{
softmax.Forward(scores.slice(i), softmax.OutputParameter());
scores.slice(i) = softmax.OutputParameter();
softmax.Forward(scores.slice(i), softmaxOutput);
scores.slice(i) = softmaxOutput;
}

// Calculate the attention output i.e. matrix multiplication of softmax
Expand All @@ -188,13 +188,13 @@ Forward(const InputType& input, OutputType& output)
}
}

template <typename InputType, typename OutputType, typename RegularizerType>
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
Backward(const InputType& /* input */,
const OutputType& gy,
OutputType& g)
template <typename MatType, typename RegularizerType>
void MultiheadAttentionType<MatType, RegularizerType>::
Backward(const MatType& /* input */,
const MatType& gy,
MatType& g)
{
typedef typename arma::Cube<typename OutputType::elem_type> CubeType;
typedef typename arma::Cube<typename MatType::elem_type> CubeType;

if (gy.n_rows != tgtSeqLen * embedDim)
{
Expand All @@ -208,7 +208,7 @@ Backward(const InputType& /* input */,
// The shape of gyTemp : (tgtSeqLen, embedDim, batchSize).
// We need not split it into n heads now because this is the part when
// output were concatenated from n heads.
CubeType gyTemp(const_cast<OutputType&>(gy).memptr(), embedDim,
CubeType gyTemp(const_cast<MatType&>(gy).memptr(), embedDim,
tgtSeqLen, batchSize, true, false);

// The shape of gyTemp : (embedDim, tgtSeqLen, batchSize).
Expand Down Expand Up @@ -278,13 +278,13 @@ Backward(const InputType& /* input */,
}
}

template <typename InputType, typename OutputType, typename RegularizerType>
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
Gradient(const InputType& input,
const OutputType& error,
OutputType& gradient)
template <typename MatType, typename RegularizerType>
void MultiheadAttentionType<MatType, RegularizerType>::
Gradient(const MatType& input,
const MatType& error,
MatType& gradient)
{
typedef typename arma::Cube<typename InputType::elem_type> CubeType;
typedef typename arma::Cube<typename MatType::elem_type> CubeType;

if (input.n_rows != embedDim * (tgtSeqLen + 2 * srcSeqLen))
{
Expand All @@ -302,16 +302,16 @@ Gradient(const InputType& input,
// The shape of gradient : (4 * embedDim * embedDim + 4 * embedDim, 1).
gradient.set_size(arma::size(weights));

const CubeType q(const_cast<InputType&>(input).memptr(),
const CubeType q(const_cast<MatType&>(input).memptr(),
embedDim, tgtSeqLen, batchSize, false, false);
const CubeType k(const_cast<InputType&>(input).memptr() + q.n_elem,
const CubeType k(const_cast<MatType&>(input).memptr() + q.n_elem,
embedDim, srcSeqLen, batchSize, false, false);
const CubeType v(const_cast<InputType&>(input).memptr() + q.n_elem + k.n_elem,
const CubeType v(const_cast<MatType&>(input).memptr() + q.n_elem + k.n_elem,
embedDim, srcSeqLen, batchSize, false, false);

// Reshape the propagated error into a cube.
// The shape of errorTemp : (embedDim, tgtSeqLen, batchSize).
CubeType errorTemp(const_cast<OutputType&>(error).memptr(), embedDim,
CubeType errorTemp(const_cast<MatType&>(error).memptr(), embedDim,
tgtSeqLen, batchSize, true, false);

// Gradient wrt. outBias, i.e. dL/d(outBias).
Expand Down Expand Up @@ -425,12 +425,12 @@ Gradient(const InputType& input,
regularizer.Evaluate(weights, gradient);
}

template <typename InputType, typename OutputType, typename RegularizerType>
template <typename MatType, typename RegularizerType>
template <typename Archive>
void MultiheadAttentionType<InputType, OutputType, RegularizerType>::
void MultiheadAttentionType<MatType, RegularizerType>::
serialize(Archive& ar, const uint32_t /* version */)
{
ar(cereal::base_class<Layer<InputType, OutputType>>(this));
ar(cereal::base_class<Layer<MatType>>(this));

ar(CEREAL_NVP(tgtSeqLen));
ar(CEREAL_NVP(srcSeqLen));
Expand Down

0 comments on commit 20f646e

Please sign in to comment.