33package wireguard
44
55import (
6+ "encoding/json"
67 "fmt"
78 "log"
89 "os"
@@ -18,6 +19,10 @@ const (
1819 nftTableFamily = "ip"
1920 nftTableName = "pg_node_wg_nat"
2021 nftPostroutingChain = "postrouting"
22+ nftFilterTableFamily = "inet"
23+ nftFilterTableName = "pg_node_wg_filter"
24+ nftForwardChain = "forward"
25+ nftForwardRulePrefix = "pg_node_wg_forward "
2126)
2227
2328// applyLinuxHostRouting installs an nftables masquerade rule for traffic from the
@@ -70,6 +75,10 @@ func applyLinuxHostRouting(wgInterfaceName string) {
7075 if err := ensureNFTMasquerade (wgIf , outIf , egressOnly ); err != nil {
7176 log .Printf ("wireguard host routing: nftables masquerade failed: %v" , err )
7277 }
78+
79+ if err := ensureNFTForwarding (wgIf , outIf ); err != nil {
80+ log .Printf ("wireguard host routing: nftables forward rules failed: %v" , err )
81+ }
7382}
7483
7584func envTruthy (s string ) bool {
@@ -107,6 +116,172 @@ func nftMasqueradeConfig(rule string) string {
107116` , nftTableFamily , nftTableName , nftPostroutingChain , rule )
108117}
109118
119+ func ensureNFTForwarding (wgIface , outputIface string ) error {
120+ if err := runNFT ("delete" , "table" , nftFilterTableFamily , nftFilterTableName ); err != nil && ! nftTableMissing (err ) {
121+ return err
122+ }
123+ if err := runNFTScript (nftForwardConfig (wgIface , outputIface )); err != nil {
124+ return err
125+ }
126+
127+ chains , err := nftForwardBaseChains ()
128+ if err != nil {
129+ return err
130+ }
131+ for _ , chain := range chains {
132+ if chain .family == nftFilterTableFamily && chain .table == nftFilterTableName {
133+ continue
134+ }
135+ if err := removeNFTForwardRules (chain ); err != nil {
136+ return err
137+ }
138+ if err := insertNFTForwardRule (chain , wgIface , outputIface , true ); err != nil {
139+ return err
140+ }
141+ if err := insertNFTForwardRule (chain , wgIface , outputIface , false ); err != nil {
142+ return err
143+ }
144+ }
145+ return nil
146+ }
147+
148+ func nftForwardConfig (wgIface , outputIface string ) string {
149+ return fmt .Sprintf (`table %s %s {
150+ chain %s {
151+ type filter hook forward priority 0; policy accept;
152+ iifname %q oifname %q accept comment %q
153+ iifname %q oifname %q ct state established,related accept comment %q
154+ }
155+ }
156+ ` ,
157+ nftFilterTableFamily ,
158+ nftFilterTableName ,
159+ nftForwardChain ,
160+ wgIface ,
161+ outputIface ,
162+ nftForwardRuleComment (wgIface , outputIface , true ),
163+ outputIface ,
164+ wgIface ,
165+ nftForwardRuleComment (wgIface , outputIface , false ),
166+ )
167+ }
168+
169+ type nftBaseChain struct {
170+ family string
171+ table string
172+ name string
173+ }
174+
175+ type nftListRuleset struct {
176+ NFTables []map [string ]json.RawMessage `json:"nftables"`
177+ }
178+
179+ type nftListChain struct {
180+ Family string `json:"family"`
181+ Table string `json:"table"`
182+ Name string `json:"name"`
183+ Hook string `json:"hook"`
184+ }
185+
186+ func nftForwardBaseChains () ([]nftBaseChain , error ) {
187+ cmd := exec .Command ("nft" , "-j" , "list" , "ruleset" )
188+ out , err := cmd .CombinedOutput ()
189+ if err != nil {
190+ return nil , fmt .Errorf ("nft -j list ruleset: %w: %s" , err , strings .TrimSpace (string (out )))
191+ }
192+ return parseNFTForwardBaseChains (out )
193+ }
194+
195+ func parseNFTForwardBaseChains (data []byte ) ([]nftBaseChain , error ) {
196+ var ruleset nftListRuleset
197+ if err := json .Unmarshal (data , & ruleset ); err != nil {
198+ return nil , fmt .Errorf ("parse nft ruleset: %w" , err )
199+ }
200+
201+ chains := make ([]nftBaseChain , 0 )
202+ for _ , item := range ruleset .NFTables {
203+ raw , ok := item ["chain" ]
204+ if ! ok {
205+ continue
206+ }
207+ var chain nftListChain
208+ if err := json .Unmarshal (raw , & chain ); err != nil {
209+ return nil , fmt .Errorf ("parse nft chain: %w" , err )
210+ }
211+ if chain .Hook != nftForwardChain || ! nftForwardFamilySupported (chain .Family ) {
212+ continue
213+ }
214+ chains = append (chains , nftBaseChain {
215+ family : chain .Family ,
216+ table : chain .Table ,
217+ name : chain .Name ,
218+ })
219+ }
220+ return chains , nil
221+ }
222+
223+ func nftForwardFamilySupported (family string ) bool {
224+ return family == "ip" || family == "inet"
225+ }
226+
227+ func removeNFTForwardRules (chain nftBaseChain ) error {
228+ cmd := exec .Command ("nft" , "-a" , "list" , "chain" , chain .family , chain .table , chain .name )
229+ out , err := cmd .CombinedOutput ()
230+ if err != nil {
231+ return fmt .Errorf ("nft -a list chain %s %s %s: %w: %s" , chain .family , chain .table , chain .name , err , strings .TrimSpace (string (out )))
232+ }
233+
234+ for _ , handle := range nftRuleHandlesWithComment (out , nftForwardRulePrefix ) {
235+ if err := runNFT ("delete" , "rule" , chain .family , chain .table , chain .name , "handle" , handle ); err != nil {
236+ return err
237+ }
238+ }
239+ return nil
240+ }
241+
242+ func nftRuleHandlesWithComment (data []byte , commentPrefix string ) []string {
243+ handles := make ([]string , 0 )
244+ for _ , line := range strings .Split (string (data ), "\n " ) {
245+ if ! strings .Contains (line , commentPrefix ) {
246+ continue
247+ }
248+
249+ before , handle , ok := strings .Cut (line , "# handle " )
250+ if ! ok || strings .TrimSpace (before ) == "" {
251+ continue
252+ }
253+ fields := strings .Fields (handle )
254+ if len (fields ) == 0 {
255+ continue
256+ }
257+ handles = append (handles , fields [0 ])
258+ }
259+ return handles
260+ }
261+
262+ func insertNFTForwardRule (chain nftBaseChain , wgIface , outputIface string , outbound bool ) error {
263+ comment := nftForwardRuleComment (wgIface , outputIface , outbound )
264+ args := []string {"insert" , "rule" , chain .family , chain .table , chain .name }
265+ if outbound {
266+ args = append (args , "iifname" , nftString (wgIface ), "oifname" , nftString (outputIface ), "accept" , "comment" , nftString (comment ))
267+ } else {
268+ args = append (args , "iifname" , nftString (outputIface ), "oifname" , nftString (wgIface ), "ct" , "state" , "established,related" , "accept" , "comment" , nftString (comment ))
269+ }
270+ return runNFT (args ... )
271+ }
272+
273+ func nftForwardRuleComment (wgIface , outputIface string , outbound bool ) string {
274+ direction := "return"
275+ if outbound {
276+ direction = "outbound"
277+ }
278+ return fmt .Sprintf ("%s%s %s %s" , nftForwardRulePrefix , wgIface , outputIface , direction )
279+ }
280+
281+ func nftString (s string ) string {
282+ return fmt .Sprintf ("%q" , s )
283+ }
284+
110285func ensureIPv4Forwarding () error {
111286 out , err := os .ReadFile (ipv4ForwardPath )
112287 if err != nil {
0 commit comments