Skip to content

Commit

Permalink
util/linuxfw: fix broken tests
Browse files Browse the repository at this point in the history
These tests were broken at HEAD. CI currently does not run these
as root, will figure out how to do that in a followup.

Updates tailscale#5621
Updates tailscale#8555
Updates tailscale#8762

Signed-off-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Alex Paguis <alex@windscribe.com>
  • Loading branch information
maisem authored and alexelisenko committed Feb 15, 2024
1 parent e9cb4de commit e53be95
Showing 1 changed file with 68 additions and 178 deletions.
246 changes: 68 additions & 178 deletions util/linuxfw/nftables_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) {

func newSysConn(t *testing.T) *nftables.Conn {
t.Helper()
if os.Geteuid() != 0 {
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
return nil
}

runtime.LockOSThread()

Expand Down Expand Up @@ -512,12 +516,21 @@ func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
}
}

func TestAddAndDelNetfilterChains(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
return
func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListChainsOfTableFamily(fam)
if err != nil {
t.Fatalf("conn.ListChainsOfTableFamily(%v) failed: %v", fam, err)
}
if len(got) != wantCount {
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}

func TestAddAndDelNetfilterChains(t *testing.T) {
conn := newSysConn(t)
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
checkChains(t, conn, nftables.TableFamilyIPv6, 0)

runner := newFakeNftablesRunner(t, conn)
runner.AddChains()
Expand All @@ -531,33 +544,22 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
t.Fatalf("len(tables) = %d, want 4", len(tables))
}

chainsV4, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
t.Fatalf("list chains failed: %v", err)
}

if len(chainsV4) != 6 {
t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4))
}

chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6)
if err != nil {
t.Fatalf("list chains failed: %v", err)
}

if len(chainsV6) != 6 {
t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6))
}
checkChains(t, conn, nftables.TableFamilyIPv4, 6)
checkChains(t, conn, nftables.TableFamilyIPv6, 6)

runner.DelChains()

// The default chains should still be present.
checkChains(t, conn, nftables.TableFamilyIPv4, 3)
checkChains(t, conn, nftables.TableFamilyIPv6, 3)

tables, err = conn.ListTables()
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
}

if len(tables) != 0 {
t.Fatalf("len(tables) = %d, want 0", len(tables))
if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
}
}

Expand Down Expand Up @@ -646,12 +648,19 @@ func findCommonBaseRules(
return get, nil
}

func TestNFTAddAndDelNetfilterBase(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
return
// checkChainRules verifies that the chain has the expected number of rules.
func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, wantCount int) {
t.Helper()
got, err := conn.GetRules(chain.Table, chain)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(got) != wantCount {
t.Fatalf("got = %d, want %d", len(got), wantCount)
}
}

func TestNFTAddAndDelNetfilterBase(t *testing.T) {
conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
Expand All @@ -664,30 +673,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
if err != nil {
t.Fatalf("getTsChains() failed: %v", err)
}

inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(inputV4Rules) != 2 {
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
}

forwardV4Rules, err := conn.GetRules(runner.nft4.Filter, forwardV4)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(forwardV4Rules) != 4 {
t.Fatalf("len(forwardV4Rules) = %d, want 4", len(forwardV4Rules))
}

postroutingV4Rules, err := conn.GetRules(runner.nft4.Nat, postroutingV4)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(postroutingV4Rules) != 0 {
t.Fatalf("len(postroutingV4Rules) = %d, want 0", len(postroutingV4Rules))
}
checkChainRules(t, conn, inputV4, 3)
checkChainRules(t, conn, forwardV4, 4)
checkChainRules(t, conn, postroutingV4, 0)

_, err = findV4BaseRules(conn, inputV4, forwardV4, "testTunn")
if err != nil {
Expand All @@ -703,30 +691,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
if err != nil {
t.Fatalf("getTsChains() failed: %v", err)
}

inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV6)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(inputV6Rules) != 0 {
t.Fatalf("len(inputV6Rules) = %d, want 0", len(inputV4Rules))
}

forwardV6Rules, err := conn.GetRules(runner.nft6.Filter, forwardV6)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(forwardV6Rules) != 3 {
t.Fatalf("len(forwardV6Rules) = %d, want 3", len(forwardV4Rules))
}

postroutingV6Rules, err := conn.GetRules(runner.nft6.Nat, postroutingV6)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(postroutingV6Rules) != 0 {
t.Fatalf("len(postroutingV6Rules) = %d, want 0", len(postroutingV4Rules))
}
checkChainRules(t, conn, inputV6, 3)
checkChainRules(t, conn, forwardV6, 4)
checkChainRules(t, conn, postroutingV6, 0)

_, err = findCommonBaseRules(conn, forwardV6, "testTunn")
if err != nil {
Expand All @@ -740,13 +707,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
t.Fatalf("conn.ListChains() failed: %v", err)
}
for _, chain := range chains {
chainRules, err := conn.GetRules(chain.Table, chain)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(chainRules) != 0 {
t.Fatalf("len(chainRules) = %d, want 0", len(chainRules))
}
checkChainRules(t, conn, chain, 0)
}
}

Expand Down Expand Up @@ -790,36 +751,36 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
}

func TestNFTAddAndDelLoopbackRule(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
return
}

conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
runner.AddChains()
defer runner.DelChains()
runner.AddBase("testTunn")
defer runner.DelBase()

addr := netip.MustParseAddr("192.168.0.2")
addrV6 := netip.MustParseAddr("2001:db8::2")
runner.AddLoopbackRule(addr)
runner.AddLoopbackRule(addrV6)

inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4)
if err != nil {
t.Fatalf("getTsChains() failed: %v", err)
}

inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4)
inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(inputV4Rules) != 3 {
t.Fatalf("len(inputV4Rules) = %d, want 3", len(inputV4Rules))
t.Fatalf("getTsChains() failed: %v", err)
}
checkChainRules(t, conn, inputV4, 0)
checkChainRules(t, conn, inputV6, 0)

runner.AddBase("testTunn")
defer runner.DelBase()
checkChainRules(t, conn, inputV4, 3)
checkChainRules(t, conn, inputV6, 3)

addr := netip.MustParseAddr("192.168.0.2")
addrV6 := netip.MustParseAddr("2001:db8::2")
runner.AddLoopbackRule(addr)
runner.AddLoopbackRule(addrV6)

checkChainRules(t, conn, inputV4, 4)
checkChainRules(t, conn, inputV6, 4)

existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr)
if err != nil {
Expand All @@ -830,19 +791,6 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle)
}

inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6)
if err != nil {
t.Fatalf("getTsChains() failed: %v", err)
}

inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV4)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(inputV6Rules) != 1 {
t.Fatalf("len(inputV4Rules) = %d, want 1", len(inputV4Rules))
}

existingLoopBackRuleV6, err := findLoopBackRule(conn, nftables.TableFamilyIPv6, runner.nft6.Filter, inputV6, addrV6)
if err != nil {
t.Fatalf("findLoopBackRule() failed: %v", err)
Expand All @@ -855,21 +803,11 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
runner.DelLoopbackRule(addr)
runner.DelLoopbackRule(addrV6)

inputV4Rules, err = conn.GetRules(runner.nft4.Filter, inputV4)
if err != nil {
t.Fatalf("conn.GetRules() failed: %v", err)
}
if len(inputV4Rules) != 2 {
t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules))
}
checkChainRules(t, conn, inputV4, 3)
checkChainRules(t, conn, inputV6, 3)
}

func TestNFTAddAndDelHookRule(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip(t.Name(), " requires privileges to create a namespace in order to run")
return
}

conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner.AddChains()
Expand All @@ -880,72 +818,24 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
if err != nil {
t.Fatalf("failed to get forwardChain: %v", err)
}

forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(forwardChainRules) != 1 {
t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules))
}

inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT")
if err != nil {
t.Fatalf("failed to get inputChain: %v", err)
}

inputChainRules, err := conn.GetRules(inputChain.Table, inputChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(inputChainRules) != 1 {
t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules))
}

postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING")
if err != nil {
t.Fatalf("failed to get postroutingChain: %v", err)
}

postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(postroutingChainRules) != 1 {
t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
}
checkChainRules(t, conn, forwardChain, 1)
checkChainRules(t, conn, inputChain, 1)
checkChainRules(t, conn, postroutingChain, 1)

runner.DelHooks(t.Logf)

forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(forwardChainRules) != 0 {
t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules))
}

inputChainRules, err = conn.GetRules(inputChain.Table, inputChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(inputChainRules) != 0 {
t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules))
}

postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain)
if err != nil {
t.Fatalf("failed to get rules: %v", err)
}

if len(postroutingChainRules) != 0 {
t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
}
checkChainRules(t, conn, forwardChain, 0)
checkChainRules(t, conn, inputChain, 0)
checkChainRules(t, conn, postroutingChain, 0)
}

type testFWDetector struct {
Expand Down

0 comments on commit e53be95

Please sign in to comment.