From 8fce893ba7143251b91f5c633008eb17bee9b0be Mon Sep 17 00:00:00 2001 From: "rust.dev" <102041955+RustNinja@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:53:01 +0100 Subject: [PATCH] Introduce custom logic into OnRecvPacket in PFM module to charge fee (#520) Introduce custom logic into OnRecvPacket in PFM module to charge fee this code in custom pfm inside IbcMiddleware `OnRecvPacket` method introduced by me. https://github.com/ComposableFi/composable-cosmos/blob/rustninja/pmf-middleware/custom/custompfm/keeper/keeper.go#L185-L216 all other code in method `OnRecvPacket` is taken from original version of `OnRecvPacket`. --------- Co-authored-by: Hoa Nguyen Co-authored-by: dzmitry-lahoda Co-authored-by: kienn6034 Co-authored-by: kkast Co-authored-by: Kanstantsin Kastsevich Co-authored-by: rjonczy Co-authored-by: tungle --- app/app.go | 1 + app/keepers/keepers.go | 6 +- custom/custompfm/keeper/keeper.go | 314 +++++++++++++++++++++++ x/ibctransfermiddleware/keeper/keeper.go | 114 ++++++++ 4 files changed, 433 insertions(+), 2 deletions(-) create mode 100644 custom/custompfm/keeper/keeper.go diff --git a/app/app.go b/app/app.go index 5739e007..f9e084cb 100644 --- a/app/app.go +++ b/app/app.go @@ -330,6 +330,7 @@ func NewComposableApp( appOpts, ) + // custompfm.NewIBCMiddleware() // transferModule := transfer.NewAppModule(app.TransferKeeper) transferModule := customibctransfer.NewAppModule(appCodec, app.TransferKeeper, app.BankKeeper) pfmModule := pfm.NewAppModule(app.PfmKeeper, app.GetSubspace(pfmtypes.ModuleName)) diff --git a/app/keepers/keepers.go b/app/keepers/keepers.go index 7d562856..7056d342 100644 --- a/app/keepers/keepers.go +++ b/app/keepers/keepers.go @@ -79,7 +79,6 @@ import ( custombankkeeper "github.com/notional-labs/composable/v6/custom/bank/keeper" - pfm "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward" pfmkeeper "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward/keeper" pfmtypes "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward/types" @@ -116,6 +115,7 @@ import ( stakingmiddleware "github.com/notional-labs/composable/v6/x/stakingmiddleware/keeper" stakingmiddlewaretypes "github.com/notional-labs/composable/v6/x/stakingmiddleware/types" + custompfm "github.com/notional-labs/composable/v6/custom/custompfm/keeper" ibctransfermiddleware "github.com/notional-labs/composable/v6/x/ibctransfermiddleware/keeper" ibctransfermiddlewaretypes "github.com/notional-labs/composable/v6/x/ibctransfermiddleware/types" ) @@ -396,12 +396,14 @@ func (appKeepers *AppKeepers) InitNormalKeepers( appKeepers.TransferMiddlewareKeeper, ) - ibcMiddlewareStack := pfm.NewIBCMiddleware( + ibcMiddlewareStack := custompfm.NewIBCMiddleware( transfermiddlewareStack, appKeepers.PfmKeeper, 0, pfmkeeper.DefaultForwardTransferPacketTimeoutTimestamp, pfmkeeper.DefaultRefundTransferPacketTimeoutTimestamp, + &appKeepers.IbcTransferMiddlewareKeeper, + &appKeepers.BankKeeper, ) ratelimitMiddlewareStack := ratelimitmodule.NewIBCMiddleware(appKeepers.RatelimitKeeper, ibcMiddlewareStack) hooksTransferMiddleware := ibc_hooks.NewIBCMiddleware(ratelimitMiddlewareStack, &appKeepers.HooksICS4Wrapper) diff --git a/custom/custompfm/keeper/keeper.go b/custom/custompfm/keeper/keeper.go new file mode 100644 index 00000000..30791dc7 --- /dev/null +++ b/custom/custompfm/keeper/keeper.go @@ -0,0 +1,314 @@ +package keeper + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + sdkmath "cosmossdk.io/math" + + "github.com/hashicorp/go-metrics" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/address" + router "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward" + "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward/keeper" + "github.com/cosmos/ibc-apps/middleware/packet-forward-middleware/v8/packetforward/types" + transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" + ibcexported "github.com/cosmos/ibc-go/v8/modules/core/exported" + custombankkeeper "github.com/notional-labs/composable/v6/custom/bank/keeper" + ibctransfermiddlewarekeeper "github.com/notional-labs/composable/v6/x/ibctransfermiddleware/keeper" +) + +var _ porttypes.Middleware = &IBCMiddleware{} + +// IBCMiddleware implements the ICS26 callbacks for the forward middleware given the +// forward keeper and the underlying application. +type IBCMiddleware struct { + router.IBCMiddleware + + app1 porttypes.IBCModule + keeper1 *keeper.Keeper + + retriesOnTimeout1 uint8 + forwardTimeout1 time.Duration + refundTimeout1 time.Duration + ibcfeekeeper *ibctransfermiddlewarekeeper.Keeper + bank *custombankkeeper.Keeper +} + +func NewIBCMiddleware( + app porttypes.IBCModule, + k *keeper.Keeper, + retriesOnTimeout uint8, + forwardTimeout time.Duration, + refundTimeout time.Duration, + ibcfeekeeper *ibctransfermiddlewarekeeper.Keeper, + bankkeeper *custombankkeeper.Keeper, +) IBCMiddleware { + return IBCMiddleware{ + IBCMiddleware: router.NewIBCMiddleware(app, k, retriesOnTimeout, forwardTimeout, refundTimeout), + ibcfeekeeper: ibcfeekeeper, + + app1: app, + keeper1: k, + retriesOnTimeout1: retriesOnTimeout, + forwardTimeout1: forwardTimeout, + refundTimeout1: refundTimeout, + bank: bankkeeper, + } +} + +func (im IBCMiddleware) OnRecvPacket( + ctx sdk.Context, + packet channeltypes.Packet, + relayer sdk.AccAddress, +) ibcexported.Acknowledgement { + logger := im.keeper1.Logger(ctx) + + var data transfertypes.FungibleTokenPacketData + if err := transfertypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &data); err != nil { + logger.Debug(fmt.Sprintf("packetForwardMiddleware OnRecvPacket payload is not a FungibleTokenPacketData: %s", err.Error())) + return im.IBCMiddleware.OnRecvPacket(ctx, packet, relayer) + } + + logger.Debug("packetForwardMiddleware OnRecvPacket", + "sequence", packet.Sequence, + "src-channel", packet.SourceChannel, "src-port", packet.SourcePort, + "dst-channel", packet.DestinationChannel, "dst-port", packet.DestinationPort, + "amount", data.Amount, "denom", data.Denom, "memo", data.Memo, + ) + + d := make(map[string]interface{}) + err := json.Unmarshal([]byte(data.Memo), &d) + if err != nil || d["forward"] == nil { + // not a packet that should be forwarded + logger.Debug("packetForwardMiddleware OnRecvPacket forward metadata does not exist") + return im.app1.OnRecvPacket(ctx, packet, relayer) + } + m := &types.PacketMetadata{} + err = json.Unmarshal([]byte(data.Memo), m) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error parsing forward metadata", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error parsing forward metadata: %w", err)) + } + + metadata := m.Forward + + goCtx := ctx.Context() + processed := getBoolFromAny(goCtx.Value(types.ProcessedKey{})) + nonrefundable := getBoolFromAny(goCtx.Value(types.NonrefundableKey{})) + disableDenomComposition := getBoolFromAny(goCtx.Value(types.DisableDenomCompositionKey{})) + + if err := metadata.Validate(); err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket forward metadata is invalid", "error", err) + return newErrorAcknowledgement(err) + } + + // override the receiver so that senders cannot move funds through arbitrary addresses. + overrideReceiver, err := getReceiver(packet.DestinationChannel, data.Sender) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket failed to construct override receiver", "error", err) + return newErrorAcknowledgement(fmt.Errorf("failed to construct override receiver: %w", err)) + } + + // if this packet has been handled by another middleware in the stack there may be no need to call into the + // underlying app, otherwise the transfer module's OnRecvPacket callback could be invoked more than once + // which would mint/burn vouchers more than once + if !processed { + if err := im.receiveFunds(ctx, packet, data, overrideReceiver, relayer); err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error receiving packet", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error receiving packet: %w", err)) + } + } + + // if this packet's token denom is already the base denom for some native token on this chain, + // we do not need to do any further composition of the denom before forwarding the packet + denomOnThisChain := data.Denom + // Check if the packet was sent from Picasso + paraChainIBCTokenInfo, found := im.keeper1.GetParachainTokenInfoByAssetID(ctx, data.Denom) + if found && (paraChainIBCTokenInfo.ChannelID == packet.DestinationChannel) { + disableDenomComposition = true + denomOnThisChain = paraChainIBCTokenInfo.NativeDenom + } + + if !disableDenomComposition { + denomOnThisChain = getDenomForThisChain( + packet.DestinationPort, packet.DestinationChannel, + packet.SourcePort, packet.SourceChannel, + data.Denom, + ) + } + + amountInt, ok := sdkmath.NewIntFromString(data.Amount) + if !ok { + logger.Error("packetForwardMiddleware OnRecvPacket error parsing amount for forward", "amount", data.Amount) + return newErrorAcknowledgement(fmt.Errorf("error parsing amount for forward: %s", data.Amount)) + } + + token := sdk.NewCoin(denomOnThisChain, amountInt) + + timeout := time.Duration(metadata.Timeout) + + if timeout.Nanoseconds() <= 0 { + timeout = im.forwardTimeout1 + } + + var retries uint8 + if metadata.Retries != nil { + retries = *metadata.Retries + } else { + retries = im.retriesOnTimeout1 + } + + memo := "" + + // set memo for next transfer with next from this transfer. + if metadata.Next != nil { + memoBz, err := json.Marshal(metadata.Next) + if err != nil { + im.keeper1.Logger(ctx).Error("packetForwardMiddleware error marshaling next as JSON", + "error", err, + ) + logger.Error("packetForwardMiddleware OnRecvPacket error marshaling next as JSON", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error marshaling next as JSON: %w", err)) + } + memo = string(memoBz) + } + + tr := transfertypes.NewMsgTransfer( + metadata.Port, + metadata.Channel, + token, + overrideReceiver, + metadata.Receiver, + clienttypes.Height{ + RevisionNumber: 0, + RevisionHeight: 0, + }, + uint64(ctx.BlockTime().UnixNano())+uint64(timeout.Nanoseconds()), + memo, + ) + + result, err := im.ibcfeekeeper.GetBridgeFeeBasedOnConfigForChannelAndDenom(ctx, tr) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error charging fee", "error", err) + return newErrorAcknowledgement(fmt.Errorf("error charging fee: %w", err)) + } + if result != nil { + if token.Amount.GTE(result.Fee.Amount) { + send_err := im.bank.SendCoins(ctx, result.Sender, result.Receiver, sdk.NewCoins(result.Fee)) + if send_err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error sending fee", "error", send_err) + return newErrorAcknowledgement(fmt.Errorf("error charging fee: %w", send_err)) + } + } else { + logger.Error("packetForwardMiddleware OnRecvPacket error charging fee", "error", err) + return newErrorAcknowledgement(fmt.Errorf("incorrect fee %w for channel id %s and denom %s", err, tr.SourceChannel, tr.Token.Denom)) + } + if result.Fee.Amount.LT(token.Amount) { + token = token.SubAmount(result.Fee.Amount) + } else { + ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + return ack + } + } + + err = im.keeper1.ForwardTransferPacket(ctx, nil, packet, data.Sender, overrideReceiver, metadata, token, retries, timeout, []metrics.Label{}, nonrefundable) + if err != nil { + logger.Error("packetForwardMiddleware OnRecvPacket error forwarding packet", "error", err) + return newErrorAcknowledgement(err) + } + + // returning nil ack will prevent WriteAcknowledgement from occurring for forwarded packet. + // This is intentional so that the acknowledgement will be written later based on the ack/timeout of the forwarded packet. + return nil +} + +func newErrorAcknowledgement(err error) channeltypes.Acknowledgement { + return channeltypes.Acknowledgement{ + Response: &channeltypes.Acknowledgement_Error{ + Error: fmt.Sprintf("packet-forward-middleware error: %s", err.Error()), + }, + } +} + +func getBoolFromAny(value any) bool { + if value == nil { + return false + } + boolVal, ok := value.(bool) + if !ok { + return false + } + return boolVal +} + +func getReceiver(channel, originalSender string) (string, error) { + senderStr := fmt.Sprintf("%s/%s", channel, originalSender) + senderHash32 := address.Hash(types.ModuleName, []byte(senderStr)) + sender := sdk.AccAddress(senderHash32[:20]) + bech32Prefix := sdk.GetConfig().GetBech32AccountAddrPrefix() + return sdk.Bech32ifyAddressBytes(bech32Prefix, sender) +} + +func (im IBCMiddleware) receiveFunds( + ctx sdk.Context, + packet channeltypes.Packet, + data transfertypes.FungibleTokenPacketData, + overrideReceiver string, + relayer sdk.AccAddress, +) error { + overrideData := transfertypes.FungibleTokenPacketData{ + Denom: data.Denom, + Amount: data.Amount, + Sender: data.Sender, + Receiver: overrideReceiver, // override receiver + // Memo explicitly zeroed + } + overrideDataBz := transfertypes.ModuleCdc.MustMarshalJSON(&overrideData) + overridePacket := channeltypes.Packet{ + Sequence: packet.Sequence, + SourcePort: packet.SourcePort, + SourceChannel: packet.SourceChannel, + DestinationPort: packet.DestinationPort, + DestinationChannel: packet.DestinationChannel, + Data: overrideDataBz, // override data + TimeoutHeight: packet.TimeoutHeight, + TimeoutTimestamp: packet.TimeoutTimestamp, + } + + ack := im.app1.OnRecvPacket(ctx, overridePacket, relayer) + + if ack == nil { + return fmt.Errorf("ack is nil") + } + + if !ack.Success() { + return fmt.Errorf("ack error: %s", string(ack.Acknowledgement())) + } + + return nil +} + +func getDenomForThisChain(port, channel, counterpartyPort, counterpartyChannel, denom string) string { + counterpartyPrefix := transfertypes.GetDenomPrefix(counterpartyPort, counterpartyChannel) + if strings.HasPrefix(denom, counterpartyPrefix) { + // unwind denom + unwoundDenom := denom[len(counterpartyPrefix):] + denomTrace := transfertypes.ParseDenomTrace(unwoundDenom) + if denomTrace.Path == "" { + // denom is now unwound back to native denom + return unwoundDenom + } + // denom is still IBC denom + return denomTrace.IBCDenom() + } + // append port and channel from this chain to denom + prefixedDenom := transfertypes.GetDenomPrefix(port, channel) + denom + return transfertypes.ParseDenomTrace(prefixedDenom).IBCDenom() +} diff --git a/x/ibctransfermiddleware/keeper/keeper.go b/x/ibctransfermiddleware/keeper/keeper.go index 6922415d..1429d5ff 100644 --- a/x/ibctransfermiddleware/keeper/keeper.go +++ b/x/ibctransfermiddleware/keeper/keeper.go @@ -1,12 +1,18 @@ package keeper import ( + "encoding/json" + "fmt" + "time" + "cosmossdk.io/log" + "github.com/notional-labs/composable/v6/x/ibctransfermiddleware/types" storetypes "cosmossdk.io/store/types" "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + ibctypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" ) // Keeper of the staking middleware store @@ -104,3 +110,111 @@ func (k Keeper) GetChannelFeeAddress(ctx sdk.Context, targetChannelID string) st } return channelFee.FeeAddress } + +type BridgeFee struct { + Fee sdk.Coin + Sender sdk.AccAddress + Receiver sdk.AccAddress +} + +func (k Keeper) GetBridgeFeeBasedOnConfigForChannelAndDenom(ctx sdk.Context, msg *ibctypes.MsgTransfer) (*BridgeFee, error) { + params := k.GetParams(ctx) + // charge_coin := sdk.NewCoin(msg.Token.Denom, sdk.ZeroInt()) + if params.ChannelFees != nil && len(params.ChannelFees) > 0 { + channelFee := findChannelParams(params.ChannelFees, msg.SourceChannel) + if channelFee != nil { + if channelFee.MinTimeoutTimestamp > 0 { + + blockTime := ctx.BlockTime() + + timeoutTimeInFuture := time.Unix(0, int64(msg.TimeoutTimestamp)) + if timeoutTimeInFuture.Before(blockTime) { + return nil, fmt.Errorf("incorrect timeout timestamp found during ibc transfer. timeout timestamp is in the past") + } + + difference := timeoutTimeInFuture.Sub(blockTime).Nanoseconds() + if difference < channelFee.MinTimeoutTimestamp { + return nil, fmt.Errorf("incorrect timeout timestamp found during ibc transfer. too soon") + } + } + coin := findCoinByDenom(channelFee.AllowedTokens, msg.Token.Denom) + if coin == nil { + return nil, fmt.Errorf("token not allowed to be transferred in this channel") + } + + minFee := coin.MinFee.Amount + priority := GetPriority(msg.Memo) + if priority != nil { + p := findPriority(coin.TxPriorityFee, *priority) + if p != nil && coin.MinFee.Denom == p.PriorityFee.Denom { + minFee = minFee.Add(p.PriorityFee.Amount) + } + } + + charge := minFee + if charge.GT(msg.Token.Amount) { + charge = msg.Token.Amount + } + + newAmount := msg.Token.Amount.Sub(charge) + + if newAmount.IsPositive() { + percentageCharge := newAmount.QuoRaw(coin.Percentage) + newAmount = newAmount.Sub(percentageCharge) + charge = charge.Add(percentageCharge) + } + + msgSender, err := sdk.AccAddressFromBech32(msg.Sender) + if err != nil { + return nil, err + } + + feeAddress, err := sdk.AccAddressFromBech32(channelFee.FeeAddress) + if err != nil { + return nil, err + } + + charge_coin := sdk.NewCoin(msg.Token.Denom, charge) + // send_err := k.bank.SendCoins(ctx, msgSender, feeAddress, sdk.NewCoins(charge_coin)) + // if send_err != nil { + // return nil, send_err + // } + msg.Token.Amount = newAmount + return &BridgeFee{Fee: charge_coin, Sender: msgSender, Receiver: feeAddress}, nil + + // if newAmount.LTE(sdk.ZeroInt()) { + // zeroTransfer := sdk.NewCoin(msg.Token.Denom, sdk.ZeroInt()) + // return &zeroTransfer, nil + // } + } + } + // ret, err := k.Keeper.Transfer(goCtx, msg) + // if err == nil && ret != nil && !charge_coin.IsZero() { + // if !charge_coin.IsZero() { + // k.SetSequenceFee(ctx, ret.Sequence, charge_coin) + // } + return nil, nil +} + +func GetPriority(jsonString string) *string { + var data map[string]interface{} + if err := json.Unmarshal([]byte(jsonString), &data); err != nil { + return nil + } + + priority, ok := data["priority"].(string) + if !ok { + return nil + } + + return &priority +} + +func findPriority(priorities []*types.TxPriorityFee, priority string) *types.TxPriorityFee { + for _, p := range priorities { + if p.Priority == priority { + return p + } + } + return nil +}