Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 77 additions & 42 deletions client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@ const (
ExitSetupFailed = 1
)

func cleanup(interfaceName string, ipt *iptables.IPTables, hostPeerIp string, dockerCIDRs []string, bridgeIp string, bridgeInterface string) {
func cleanup(interfaceName string, ipt *iptables.IPTables, hostPeerIp string, dockerCIDRs []string, dockerInterfaces []string, bridgeIp string, bridgeInterface string) {
if ipt != nil {
fmt.Println("Removing iptables NAT rules")
for _, cidr := range dockerCIDRs {
fmt.Printf("Removing NAT rule for CIDR: %s\n", cidr)
ipt.Delete("nat", "POSTROUTING", "-s", hostPeerIp, "-d", cidr, "-o", "docker+", "-j", "MASQUERADE")
for i, cidr := range dockerCIDRs {
if i < len(dockerInterfaces) {
fmt.Printf("Removing NAT rule for CIDR: %s on interface: %s\n", cidr, dockerInterfaces[i])
ipt.Delete("nat", "POSTROUTING", "-s", hostPeerIp, "-d", cidr, "-o", dockerInterfaces[i], "-j", "MASQUERADE")
}
}

fmt.Println("Removing iptables filter rules")
for _, cidr := range dockerCIDRs {
fmt.Printf("Removing filter rule for CIDR: %s\n", cidr)
ipt.Delete("filter", "DOCKER", "-s", hostPeerIp, "-d", cidr, "-o", "docker+", "-j", "ACCEPT")
for i, cidr := range dockerCIDRs {
if i < len(dockerInterfaces) {
fmt.Printf("Removing filter rule for CIDR: %s on interface: %s\n", cidr, dockerInterfaces[i])
ipt.Delete("filter", "DOCKER", "-s", hostPeerIp, "-d", cidr, "-o", dockerInterfaces[i], "-j", "ACCEPT")
}
}

if bridgeIp != "" {
fmt.Println("Removing bridge DOCKER-USER rules")
for _, cidr := range dockerCIDRs {
fmt.Printf("Removing DOCKER-USER rule for bridge IP %s to CIDR: %s\n", bridgeIp, cidr)
ipt.Delete("filter", "DOCKER-USER", "-s", bridgeIp, "-d", cidr, "-i", bridgeInterface, "-o", "docker+", "-j", "ACCEPT")
for i, cidr := range dockerCIDRs {
if i < len(dockerInterfaces) {
fmt.Printf("Removing DOCKER-USER rule for bridge IP %s to CIDR: %s on interface: %s\n", bridgeIp, cidr, dockerInterfaces[i])
ipt.Delete("filter", "DOCKER-USER", "-s", bridgeIp, "-d", cidr, "-i", bridgeInterface, "-o", dockerInterfaces[i], "-j", "ACCEPT")
}
}
}
}
Expand Down Expand Up @@ -105,6 +111,18 @@ func main() {
}
dockerCIDRs := strings.Split(dockerCIDRsString, ",")

dockerInterfacesString := os.Getenv("DOCKER_INTERFACES")
if dockerInterfacesString == "" {
fmt.Printf("DOCKER_INTERFACES is not set\n")
os.Exit(ExitSetupFailed)
}
dockerInterfaces := strings.Split(dockerInterfacesString, ",")

if len(dockerCIDRs) != len(dockerInterfaces) {
fmt.Printf("DOCKER_CIDRS and DOCKER_INTERFACES must have equal number of elements: got %d CIDRs and %d interfaces\n", len(dockerCIDRs), len(dockerInterfaces))
os.Exit(ExitSetupFailed)
}

enableDockerFilterString := os.Getenv("ENABLE_DOCKER_FILTER")
enableDockerFilter := strings.ToLower(enableDockerFilterString) == "true"

Expand All @@ -117,7 +135,7 @@ func main() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("Panic occurred: %v\n", r)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
}()
Expand Down Expand Up @@ -155,13 +173,13 @@ func main() {
vmIpNet, err := netlink.ParseIPNet(vmPeerIp + "/32")
if err != nil {
fmt.Printf("Could not parse VM peer IPNet: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
hostIpNet, err := netlink.ParseIPNet(hostPeerIp + "/32")
if err != nil {
fmt.Printf("Could not parse host peer IPNet: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

Expand All @@ -171,14 +189,14 @@ func main() {
err = netlink.AddrAdd(wireguard, &addr)
if err != nil {
fmt.Printf("Failed to assign IP to WireGuard interface: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

c, err := wgctrl.New()
if err != nil {
fmt.Printf("Failed to create wgctrl client: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

Expand All @@ -187,35 +205,35 @@ func main() {
vmPrivateKey, err := wgtypes.ParseKey(vmPrivateKeyString)
if err != nil {
fmt.Printf("Failed to parse VM private key: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

hostPublicKey, err := wgtypes.ParseKey(hostPublicKeyString)
if err != nil {
fmt.Printf("Failed to parse host public key: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

wildcardIpNet, err := netlink.ParseIPNet("0.0.0.0/0")
if err != nil {
fmt.Printf("Failed to parse wildcard IPNet: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

ips, err := net.LookupIP("host.docker.internal")
if err != nil || len(ips) == 0 {
fmt.Printf("Failed to lookup IP: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

persistentKeepaliveInterval, err := time.ParseDuration("25s")
if err != nil {
fmt.Printf("Failed to parse duration: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

Expand All @@ -237,41 +255,46 @@ func main() {
})
if err != nil {
fmt.Printf("Failed to configure wireguard device: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

err = netlink.LinkSetUp(wireguard)
if err != nil {
fmt.Printf("Failed to set wireguard link to up: %v\n", err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

ipt, err = iptables.New()
if err != nil {
fmt.Printf("Failed to create new iptables client: %v\n", err)
cleanup(interfaceName, nil, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
cleanup(interfaceName, nil, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}

fmt.Println("Adding specific iptables NAT rules for Docker networks")

// Add specific iptables NAT rules for each Docker network CIDR
// This restricts masquerading only to traffic destined for Docker networks
// instead of masquerading all traffic from hostPeerIp
for _, cidr := range dockerCIDRs {
fmt.Printf("Adding NAT rule for Docker CIDR: %s\n", cidr)
// and uses the specific interface for each network instead of docker+ wildcard
for i, cidr := range dockerCIDRs {
if i >= len(dockerInterfaces) {
fmt.Printf("Warning: No interface found for CIDR %s, skipping\n", cidr)
continue
}
interfaceItem := dockerInterfaces[i]
fmt.Printf("Adding NAT rule for Docker CIDR: %s on interface: %s\n", cidr, interfaceItem)
err = ipt.AppendUnique(
"nat", "POSTROUTING",
"-s", hostPeerIp,
"-d", cidr,
"-o", "docker+",
"-o", interfaceItem,
"-j", "MASQUERADE",
)
if err != nil {
fmt.Printf("Failed to add iptables nat rule for CIDR %s: %v\n", cidr, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
fmt.Printf("Failed to add iptables nat rule for CIDR %s on interface %s: %v\n", cidr, interfaceItem, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
}
Expand All @@ -281,26 +304,32 @@ func main() {

// Add specific iptables filter rules for each Docker network CIDR
// This allows traffic from hostPeerIp only to specific Docker networks
for _, cidr := range dockerCIDRs {
fmt.Printf("Adding filter rule for Docker CIDR: %s\n", cidr)
// and uses the specific interface for each network instead of docker+ wildcard
for i, cidr := range dockerCIDRs {
if i >= len(dockerInterfaces) {
fmt.Printf("Warning: No interface found for CIDR %s, skipping\n", cidr)
continue
}
interfaceItem := dockerInterfaces[i]
fmt.Printf("Adding filter rule for Docker CIDR: %s on interface: %s\n", cidr, interfaceItem)
err = ipt.DeleteIfExists("filter", "DOCKER",
"-s", hostPeerIp,
"-d", cidr,
"-o", "docker+",
"-o", interfaceItem,
"-j", "ACCEPT")
if err != nil {
fmt.Printf("Failed to delete iptables filter rule for CIDR %s: %v\n", cidr, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
fmt.Printf("Failed to delete iptables filter rule for CIDR %s on interface %s: %v\n", cidr, interfaceItem, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
err = ipt.Insert("filter", "DOCKER", 1,
"-s", hostPeerIp,
"-d", cidr,
"-o", "docker+",
"-o", interfaceItem,
"-j", "ACCEPT")
if err != nil {
fmt.Printf("Failed to insert iptables filter rule for CIDR %s: %v\n", cidr, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
fmt.Printf("Failed to insert iptables filter rule for CIDR %s on interface %s: %v\n", cidr, interfaceItem, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
}
Expand All @@ -310,17 +339,23 @@ func main() {
fmt.Printf("Adding bridge traffic DOCKER-USER rules for bridge IP: %s\n", bridgeIp)

// Add DOCKER-USER rule to accept bridge traffic from bridge IP to Docker networks
for _, cidr := range dockerCIDRs {
fmt.Printf("Adding DOCKER-USER rule for bridge IP %s to Docker CIDR: %s\n", bridgeIp, cidr)
// and uses the specific interface for each network instead of docker+ wildcard
for i, cidr := range dockerCIDRs {
if i >= len(dockerInterfaces) {
fmt.Printf("Warning: No interface found for CIDR %s, skipping\n", cidr)
continue
}
interfaceItem := dockerInterfaces[i]
fmt.Printf("Adding DOCKER-USER rule for bridge IP %s to Docker CIDR: %s on interface: %s\n", bridgeIp, cidr, interfaceItem)
err = ipt.AppendUnique("filter", "DOCKER-USER",
"-s", bridgeIp,
"-d", cidr,
"-i", bridgeInterface,
"-o", "docker+",
"-o", interfaceItem,
"-j", "ACCEPT")
if err != nil {
fmt.Printf("Failed to add DOCKER-USER rule for bridge IP %s to CIDR %s: %v\n", bridgeIp, cidr, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, bridgeIp, bridgeInterface)
fmt.Printf("Failed to add DOCKER-USER rule for bridge IP %s to CIDR %s on interface %s: %v\n", bridgeIp, cidr, interfaceItem, err)
cleanup(interfaceName, ipt, hostPeerIp, dockerCIDRs, dockerInterfaces, bridgeIp, bridgeInterface)
os.Exit(ExitSetupFailed)
}
}
Expand Down
17 changes: 13 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ func main() {
for {
logger.Verbosef("Setting up Wireguard on Docker Desktop VM\n")

dockerCIDRs := networkManager.GetDockerCIDRs(ctx, cli)
if len(dockerCIDRs) == 0 {
dockerNetworkInfos := networkManager.GetDockerNetworkInfo(ctx, cli)
if len(dockerNetworkInfos) == 0 {
logger.Verbosef("No Docker networks found, skipping VM setup\n")
time.Sleep(5 * time.Second)
continue
}

err = setupVm(ctx, cli, portNumWg, hostPeerIp, vmPeerIp, interfaceNameWg, dockerCIDRs, enableDockerFilter, bridgeIp, bridgeInterface, hostPrivateKey, vmPrivateKey)
err = setupVm(ctx, cli, portNumWg, hostPeerIp, vmPeerIp, interfaceNameWg, dockerNetworkInfos, enableDockerFilter, bridgeIp, bridgeInterface, hostPrivateKey, vmPrivateKey)
if err != nil {
logger.Errorf("Failed to setup VM: %v", err)
time.Sleep(5 * time.Second)
Expand Down Expand Up @@ -306,7 +306,7 @@ func setupVm(
hostPeerIp string,
vmPeerIp string,
interfaceName string,
dockerCIDRs []string,
dockerNetworkInfos []networkmanager.DockerNetworkInfo,
enableDockerFilter bool,
bridgeIp string,
bridgeInterface string,
Expand All @@ -327,12 +327,21 @@ func setupVm(
io.Copy(os.Stdout, pullStream)
}

// Convert network infos to environment variables
var dockerCIDRs []string
var dockerInterfaces []string
for _, info := range dockerNetworkInfos {
dockerCIDRs = append(dockerCIDRs, info.CIDR)
dockerInterfaces = append(dockerInterfaces, info.Interface)
}

env := []string{
"SERVER_PORT=" + strconv.Itoa(serverPort),
"HOST_PEER_IP=" + hostPeerIp,
"VM_PEER_IP=" + vmPeerIp,
"INTERFACE_NAME=" + interfaceName,
"DOCKER_CIDRS=" + strings.Join(dockerCIDRs, ","),
"DOCKER_INTERFACES=" + strings.Join(dockerInterfaces, ","),
"ENABLE_DOCKER_FILTER=" + strconv.FormatBool(enableDockerFilter),
"HOST_PUBLIC_KEY=" + hostPrivateKey.PublicKey().String(),
"VM_PRIVATE_KEY=" + vmPrivateKey.String(),
Expand Down
63 changes: 63 additions & 0 deletions networkmanager/networkmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ func (manager *NetworkManager) ProcessDockerNetworkDestroy(network network.Inspe
delete(manager.DockerNetworks, network.ID)
}

type DockerNetworkInfo struct {
CIDR string
Interface string
}

func (manager *NetworkManager) GetDockerCIDRs(ctx context.Context, cli *client.Client) []string {
var cidrs []string

Expand Down Expand Up @@ -150,3 +155,61 @@ func (manager *NetworkManager) GetDockerCIDRs(ctx context.Context, cli *client.C
}
return cidrs
}

func (manager *NetworkManager) GetDockerNetworkInfo(ctx context.Context, cli *client.Client) []DockerNetworkInfo {
var networkInfos []DockerNetworkInfo

networks, err := cli.NetworkList(ctx, network.ListOptions{})
if err != nil {
fmt.Printf("Failed to list Docker networks: %v\n", err)
return networkInfos
}

for _, dockerNet := range networks {
if dockerNet.Scope == "local" {
// Get detailed network info to access interface name
detailedNet, err := cli.NetworkInspect(ctx, dockerNet.ID, network.InspectOptions{})
if err != nil {
fmt.Printf("Failed to inspect Docker network %s: %v\n", dockerNet.ID, err)
continue
}

// Extract interface name from network options
interfaceName := ""
if detailedNet.Options != nil {
if iface, exists := detailedNet.Options["com.docker.network.bridge.name"]; exists {
interfaceName = iface
}
}

// If no explicit interface name, generate based on network ID (Docker's default naming)
if interfaceName == "" {
if len(detailedNet.ID) >= 12 {
interfaceName = "br-" + detailedNet.ID[:12]
} else {
interfaceName = "docker0" // fallback for default bridge
}
}

for _, config := range detailedNet.IPAM.Config {
if config.Subnet != "" {
// Parse the subnet to check if it's IPv4
_, ipNet, err := net.ParseCIDR(config.Subnet)
if err != nil {
fmt.Printf("Failed to parse CIDR %s: %v\n", config.Subnet, err)
continue
}

// Only include IPv4 CIDRs, exclude IPv6
if ipNet.IP.To4() != nil {
networkInfos = append(networkInfos, DockerNetworkInfo{
CIDR: config.Subnet,
Interface: interfaceName,
})
}
}
}
}
}
return networkInfos
}