Skip to content

Commit 864ace6

Browse files
committed
fix(wireguard): allow forwarded host routing traffic
1 parent 891a4b4 commit 864ace6

4 files changed

Lines changed: 247 additions & 2 deletions

File tree

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ API_KEY = xxxxxxxx-yyyy-zzzz-mmmm-aaaaaaaaaaa
2424
# STATS_CLEANUP_INTERVAL_SECONDS = 300
2525

2626
### WireGuard host NAT
27-
### Built-in routing enables runtime IPv4 forwarding and only manages nft table ip pg_node_wg_nat.
27+
### Built-in routing enables runtime IPv4 forwarding and manages scoped nft NAT/forwarding rules.
2828
# PG_NODE_WG_HOST_ROUTING = 1
2929
# PG_NODE_WG_NAT_OUTPUT_INTERFACE = eth0
3030
# PG_NODE_WG_NAT_EGRESS_ONLY = 0

backend/wireguard/host_routing_linux.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package wireguard
44

55
import (
6+
"encoding/json"
67
"fmt"
78
"log"
89
"os"
@@ -18,6 +19,10 @@ const (
1819
nftTableFamily = "ip"
1920
nftTableName = "pg_node_wg_nat"
2021
nftPostroutingChain = "postrouting"
22+
nftFilterTableFamily = "inet"
23+
nftFilterTableName = "pg_node_wg_filter"
24+
nftForwardChain = "forward"
25+
nftForwardRulePrefix = "pg_node_wg_forward "
2126
)
2227

2328
// applyLinuxHostRouting installs an nftables masquerade rule for traffic from the
@@ -70,6 +75,10 @@ func applyLinuxHostRouting(wgInterfaceName string) {
7075
if err := ensureNFTMasquerade(wgIf, outIf, egressOnly); err != nil {
7176
log.Printf("wireguard host routing: nftables masquerade failed: %v", err)
7277
}
78+
79+
if err := ensureNFTForwarding(wgIf, outIf); err != nil {
80+
log.Printf("wireguard host routing: nftables forward rules failed: %v", err)
81+
}
7382
}
7483

7584
func envTruthy(s string) bool {
@@ -107,6 +116,172 @@ func nftMasqueradeConfig(rule string) string {
107116
`, nftTableFamily, nftTableName, nftPostroutingChain, rule)
108117
}
109118

119+
func ensureNFTForwarding(wgIface, outputIface string) error {
120+
if err := runNFT("delete", "table", nftFilterTableFamily, nftFilterTableName); err != nil && !nftTableMissing(err) {
121+
return err
122+
}
123+
if err := runNFTScript(nftForwardConfig(wgIface, outputIface)); err != nil {
124+
return err
125+
}
126+
127+
chains, err := nftForwardBaseChains()
128+
if err != nil {
129+
return err
130+
}
131+
for _, chain := range chains {
132+
if chain.family == nftFilterTableFamily && chain.table == nftFilterTableName {
133+
continue
134+
}
135+
if err := removeNFTForwardRules(chain); err != nil {
136+
return err
137+
}
138+
if err := insertNFTForwardRule(chain, wgIface, outputIface, true); err != nil {
139+
return err
140+
}
141+
if err := insertNFTForwardRule(chain, wgIface, outputIface, false); err != nil {
142+
return err
143+
}
144+
}
145+
return nil
146+
}
147+
148+
func nftForwardConfig(wgIface, outputIface string) string {
149+
return fmt.Sprintf(`table %s %s {
150+
chain %s {
151+
type filter hook forward priority 0; policy accept;
152+
iifname %q oifname %q accept comment %q
153+
iifname %q oifname %q ct state established,related accept comment %q
154+
}
155+
}
156+
`,
157+
nftFilterTableFamily,
158+
nftFilterTableName,
159+
nftForwardChain,
160+
wgIface,
161+
outputIface,
162+
nftForwardRuleComment(wgIface, outputIface, true),
163+
outputIface,
164+
wgIface,
165+
nftForwardRuleComment(wgIface, outputIface, false),
166+
)
167+
}
168+
169+
type nftBaseChain struct {
170+
family string
171+
table string
172+
name string
173+
}
174+
175+
type nftListRuleset struct {
176+
NFTables []map[string]json.RawMessage `json:"nftables"`
177+
}
178+
179+
type nftListChain struct {
180+
Family string `json:"family"`
181+
Table string `json:"table"`
182+
Name string `json:"name"`
183+
Hook string `json:"hook"`
184+
}
185+
186+
func nftForwardBaseChains() ([]nftBaseChain, error) {
187+
cmd := exec.Command("nft", "-j", "list", "ruleset")
188+
out, err := cmd.CombinedOutput()
189+
if err != nil {
190+
return nil, fmt.Errorf("nft -j list ruleset: %w: %s", err, strings.TrimSpace(string(out)))
191+
}
192+
return parseNFTForwardBaseChains(out)
193+
}
194+
195+
func parseNFTForwardBaseChains(data []byte) ([]nftBaseChain, error) {
196+
var ruleset nftListRuleset
197+
if err := json.Unmarshal(data, &ruleset); err != nil {
198+
return nil, fmt.Errorf("parse nft ruleset: %w", err)
199+
}
200+
201+
chains := make([]nftBaseChain, 0)
202+
for _, item := range ruleset.NFTables {
203+
raw, ok := item["chain"]
204+
if !ok {
205+
continue
206+
}
207+
var chain nftListChain
208+
if err := json.Unmarshal(raw, &chain); err != nil {
209+
return nil, fmt.Errorf("parse nft chain: %w", err)
210+
}
211+
if chain.Hook != nftForwardChain || !nftForwardFamilySupported(chain.Family) {
212+
continue
213+
}
214+
chains = append(chains, nftBaseChain{
215+
family: chain.Family,
216+
table: chain.Table,
217+
name: chain.Name,
218+
})
219+
}
220+
return chains, nil
221+
}
222+
223+
func nftForwardFamilySupported(family string) bool {
224+
return family == "ip" || family == "inet"
225+
}
226+
227+
func removeNFTForwardRules(chain nftBaseChain) error {
228+
cmd := exec.Command("nft", "-a", "list", "chain", chain.family, chain.table, chain.name)
229+
out, err := cmd.CombinedOutput()
230+
if err != nil {
231+
return fmt.Errorf("nft -a list chain %s %s %s: %w: %s", chain.family, chain.table, chain.name, err, strings.TrimSpace(string(out)))
232+
}
233+
234+
for _, handle := range nftRuleHandlesWithComment(out, nftForwardRulePrefix) {
235+
if err := runNFT("delete", "rule", chain.family, chain.table, chain.name, "handle", handle); err != nil {
236+
return err
237+
}
238+
}
239+
return nil
240+
}
241+
242+
func nftRuleHandlesWithComment(data []byte, commentPrefix string) []string {
243+
handles := make([]string, 0)
244+
for _, line := range strings.Split(string(data), "\n") {
245+
if !strings.Contains(line, commentPrefix) {
246+
continue
247+
}
248+
249+
before, handle, ok := strings.Cut(line, "# handle ")
250+
if !ok || strings.TrimSpace(before) == "" {
251+
continue
252+
}
253+
fields := strings.Fields(handle)
254+
if len(fields) == 0 {
255+
continue
256+
}
257+
handles = append(handles, fields[0])
258+
}
259+
return handles
260+
}
261+
262+
func insertNFTForwardRule(chain nftBaseChain, wgIface, outputIface string, outbound bool) error {
263+
comment := nftForwardRuleComment(wgIface, outputIface, outbound)
264+
args := []string{"insert", "rule", chain.family, chain.table, chain.name}
265+
if outbound {
266+
args = append(args, "iifname", nftString(wgIface), "oifname", nftString(outputIface), "accept", "comment", nftString(comment))
267+
} else {
268+
args = append(args, "iifname", nftString(outputIface), "oifname", nftString(wgIface), "ct", "state", "established,related", "accept", "comment", nftString(comment))
269+
}
270+
return runNFT(args...)
271+
}
272+
273+
func nftForwardRuleComment(wgIface, outputIface string, outbound bool) string {
274+
direction := "return"
275+
if outbound {
276+
direction = "outbound"
277+
}
278+
return fmt.Sprintf("%s%s %s %s", nftForwardRulePrefix, wgIface, outputIface, direction)
279+
}
280+
281+
func nftString(s string) string {
282+
return fmt.Sprintf("%q", s)
283+
}
284+
110285
func ensureIPv4Forwarding() error {
111286
out, err := os.ReadFile(ipv4ForwardPath)
112287
if err != nil {

backend/wireguard/host_routing_linux_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,76 @@ func TestNFTMasqueradeConfigIsScoped(t *testing.T) {
5050
}
5151
}
5252

53+
func TestNFTForwardConfigIsScoped(t *testing.T) {
54+
cfg := nftForwardConfig("wg0", "eth0")
55+
56+
for _, want := range []string{
57+
"table inet pg_node_wg_filter",
58+
"chain forward",
59+
"type filter hook forward priority 0; policy accept;",
60+
`iifname "wg0" oifname "eth0" accept comment "pg_node_wg_forward wg0 eth0 outbound"`,
61+
`iifname "eth0" oifname "wg0" ct state established,related accept comment "pg_node_wg_forward wg0 eth0 return"`,
62+
} {
63+
if !strings.Contains(cfg, want) {
64+
t.Fatalf("config missing %q:\n%s", want, cfg)
65+
}
66+
}
67+
68+
if strings.Contains(cfg, "flush ruleset") {
69+
t.Fatalf("managed config must not flush the global nft ruleset:\n%s", cfg)
70+
}
71+
}
72+
73+
func TestParseNFTForwardBaseChains(t *testing.T) {
74+
const ruleset = `{
75+
"nftables": [
76+
{"metainfo": {"version": "1.0.9"}},
77+
{"table": {"family": "ip", "name": "filter"}},
78+
{"chain": {"family": "ip", "table": "filter", "name": "FORWARD", "type": "filter", "hook": "forward", "prio": 0, "policy": "drop"}},
79+
{"chain": {"family": "inet", "table": "firewalld", "name": "filter_FORWARD", "type": "filter", "hook": "forward", "prio": 10, "policy": "accept"}},
80+
{"chain": {"family": "ip6", "table": "filter", "name": "FORWARD", "type": "filter", "hook": "forward", "prio": 0, "policy": "drop"}},
81+
{"chain": {"family": "ip", "table": "filter", "name": "INPUT", "type": "filter", "hook": "input", "prio": 0, "policy": "drop"}}
82+
]
83+
}`
84+
85+
chains, err := parseNFTForwardBaseChains([]byte(ruleset))
86+
if err != nil {
87+
t.Fatalf("parseNFTForwardBaseChains returned error: %v", err)
88+
}
89+
90+
if len(chains) != 2 {
91+
t.Fatalf("expected 2 supported forward chains, got %#v", chains)
92+
}
93+
94+
if chains[0] != (nftBaseChain{family: "ip", table: "filter", name: "FORWARD"}) {
95+
t.Fatalf("unexpected first chain: %#v", chains[0])
96+
}
97+
if chains[1] != (nftBaseChain{family: "inet", table: "firewalld", name: "filter_FORWARD"}) {
98+
t.Fatalf("unexpected second chain: %#v", chains[1])
99+
}
100+
}
101+
102+
func TestNFTString(t *testing.T) {
103+
if got := nftString("pg_node_wg_forward wg0 eth0 outbound"); got != `"pg_node_wg_forward wg0 eth0 outbound"` {
104+
t.Fatalf("unexpected quoted nft string: %s", got)
105+
}
106+
}
107+
108+
func TestNFTRuleHandlesWithComment(t *testing.T) {
109+
const chain = `table ip filter {
110+
chain FORWARD {
111+
iifname "wg0" oifname "eth0" accept comment "pg_node_wg_forward wg0 eth0 outbound" # handle 12
112+
iifname "eth0" oifname "wg0" ct state established,related accept comment "pg_node_wg_forward wg0 eth0 return" # handle 14
113+
counter packets 0 bytes 0 # handle 20
114+
}
115+
}`
116+
117+
handles := nftRuleHandlesWithComment([]byte(chain), nftForwardRulePrefix)
118+
if strings.Join(handles, ",") != "12,14" {
119+
t.Fatalf("unexpected handles: %#v", handles)
120+
}
121+
}
122+
53123
type staticError string
54124

55125
func (e staticError) Error() string { return string(e) }

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ services:
1313
environment:
1414
SERVICE_PORT: 62050
1515
SERVICE_PROTOCOL: "grpc"
16-
# Linux: enable runtime IPv4 forwarding and install a scoped nftables masquerade rule in table ip pg_node_wg_nat.
16+
# Linux: enable runtime IPv4 forwarding and install scoped nftables NAT/forwarding rules.
1717
# NAT egress is auto-detected (ip route / /proc/net/route); set PG_NODE_WG_NAT_OUTPUT_INTERFACE to override (e.g. eth0, ens192).
1818
PG_NODE_WG_HOST_ROUTING: "1"
1919

0 commit comments

Comments
 (0)