diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 6fb34f0a27584d..b4f82aaf68bd28 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -1647,6 +1647,10 @@ int bpf_prog_test_run_sk_lookup(struct bpf_prog *prog, const union bpf_attr *kattr, union bpf_attr __user *uattr); bool bpf_prog_test_check_kfunc_call(u32 kfunc_id, struct module *owner); +bool bpf_prog_test_is_acquire_kfunc(u32 kfunc_id, struct module *owner); +bool bpf_prog_test_is_release_kfunc(u32 kfunc_id, struct module *owner); +enum bpf_return_type bpf_prog_test_get_kfunc_return_type(u32 kfunc_id, + struct module *owner); bool btf_ctx_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, struct bpf_insn_access_aux *info); @@ -1874,6 +1878,24 @@ static inline bool bpf_prog_test_check_kfunc_call(u32 kfunc_id, return false; } +static inline bool bpf_prog_test_is_acquire_kfunc(u32 kfunc_id, + struct module *owner) +{ + return false; +} + +static inline bool bpf_prog_test_is_release_kfunc(u32 kfunc_id, + struct module *owner) +{ + return false; +} + +static inline enum bpf_return_type +bpf_prog_test_get_kfunc_return_type(u32 kfunc_id, struct module *owner) +{ + return __BPF_RET_TYPE_MAX; +} + static inline void bpf_map_put(struct bpf_map *map) { } diff --git a/include/linux/btf.h b/include/linux/btf.h index 464f22bf7d5fc8..87019548e37cb2 100644 --- a/include/linux/btf.h +++ b/include/linux/btf.h @@ -321,5 +321,6 @@ static inline int bpf_btf_mod_struct_access(struct kfunc_btf_id_list *klist, extern struct kfunc_btf_id_list bpf_tcp_ca_kfunc_list; extern struct kfunc_btf_id_list prog_test_kfunc_list; +extern struct kfunc_btf_id_list xdp_kfunc_list; #endif diff --git a/kernel/bpf/btf.c b/kernel/bpf/btf.c index 8b3c15f4359d07..d0e5101b8c2d82 100644 --- a/kernel/bpf/btf.c +++ b/kernel/bpf/btf.c @@ -735,6 +735,7 @@ const struct btf_type *btf_type_by_id(const struct btf *btf, u32 type_id) return NULL; return btf->types[type_id]; } +EXPORT_SYMBOL_GPL(btf_type_by_id); /* * Regular int is not a bit field and it must be either @@ -6502,3 +6503,4 @@ int bpf_btf_mod_struct_access(struct kfunc_btf_id_list *klist, DEFINE_KFUNC_BTF_ID_LIST(bpf_tcp_ca_kfunc_list); DEFINE_KFUNC_BTF_ID_LIST(prog_test_kfunc_list); +DEFINE_KFUNC_BTF_ID_LIST(xdp_kfunc_list); diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index 46dd9575596724..a678ddc97e0fec 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -232,6 +232,28 @@ struct sock * noinline bpf_kfunc_call_test3(struct sock *sk) return sk; } +struct prog_test_ref_kfunc { + int a; + int b; +}; + +static struct prog_test_ref_kfunc prog_test_struct; + +noinline struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(char *ptr) +{ + /* randomly return NULL */ + if (get_jiffies_64() % 2) + return NULL; + prog_test_struct.a = 42; + prog_test_struct.b = 108; + return &prog_test_struct; +} + +noinline void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) +{ + return; +} + __diag_pop(); ALLOW_ERROR_INJECTION(bpf_modify_return_test, ERRNO); @@ -240,8 +262,14 @@ BTF_SET_START(test_sk_kfunc_ids) BTF_ID(func, bpf_kfunc_call_test1) BTF_ID(func, bpf_kfunc_call_test2) BTF_ID(func, bpf_kfunc_call_test3) +BTF_ID(func, bpf_kfunc_call_test_acquire) +BTF_ID(func, bpf_kfunc_call_test_release) BTF_SET_END(test_sk_kfunc_ids) +BTF_ID_LIST(test_sk_acq_rel) +BTF_ID(func, bpf_kfunc_call_test_acquire) +BTF_ID(func, bpf_kfunc_call_test_release) + bool bpf_prog_test_check_kfunc_call(u32 kfunc_id, struct module *owner) { if (btf_id_set_contains(&test_sk_kfunc_ids, kfunc_id)) @@ -249,6 +277,33 @@ bool bpf_prog_test_check_kfunc_call(u32 kfunc_id, struct module *owner) return bpf_check_mod_kfunc_call(&prog_test_kfunc_list, kfunc_id, owner); } +bool bpf_prog_test_is_acquire_kfunc(u32 kfunc_id, struct module *owner) +{ + if (!owner) /* bpf_kfunc_call_test_acquire */ + return kfunc_id == test_sk_acq_rel[0]; + return bpf_is_mod_acquire_kfunc(&prog_test_kfunc_list, kfunc_id, owner); +} + +bool bpf_prog_test_is_release_kfunc(u32 kfunc_id, struct module *owner) +{ + if (!owner) /* bpf_kfunc_call_test_release */ + return kfunc_id == test_sk_acq_rel[1]; + return bpf_is_mod_release_kfunc(&prog_test_kfunc_list, kfunc_id, owner); +} + +enum bpf_return_type bpf_prog_test_get_kfunc_return_type(u32 kfunc_id, + struct module *owner) +{ + if (!owner) { + if (kfunc_id == test_sk_acq_rel[0]) + return RET_PTR_TO_BTF_ID_OR_NULL; + else + return __BPF_RET_TYPE_MAX; + } + return bpf_get_mod_kfunc_return_type(&prog_test_kfunc_list, kfunc_id, + owner); +} + static void *bpf_test_init(const union bpf_attr *kattr, u32 size, u32 headroom, u32 tailroom) { diff --git a/net/core/filter.c b/net/core/filter.c index 8e8d3b49c29767..4e320de4472d1d 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -9948,6 +9948,12 @@ const struct bpf_prog_ops sk_filter_prog_ops = { .test_run = bpf_prog_test_run_skb, }; +static int xdp_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id); + const struct bpf_verifier_ops tc_cls_act_verifier_ops = { .get_func_proto = tc_cls_act_func_proto, .is_valid_access = tc_cls_act_is_valid_access, @@ -9955,17 +9961,67 @@ const struct bpf_verifier_ops tc_cls_act_verifier_ops = { .gen_prologue = tc_cls_act_prologue, .gen_ld_abs = bpf_gen_ld_abs, .check_kfunc_call = bpf_prog_test_check_kfunc_call, + .is_acquire_kfunc = bpf_prog_test_is_acquire_kfunc, + .is_release_kfunc = bpf_prog_test_is_release_kfunc, + .get_kfunc_return_type = bpf_prog_test_get_kfunc_return_type, + /* resuse the callback, there is nothing xdp specific in it */ + .btf_struct_access = xdp_btf_struct_access, }; const struct bpf_prog_ops tc_cls_act_prog_ops = { .test_run = bpf_prog_test_run_skb, }; +static bool xdp_is_acquire_kfunc(u32 kfunc_id, struct module *owner) +{ + return bpf_is_mod_acquire_kfunc(&xdp_kfunc_list, kfunc_id, owner); +} + +static bool xdp_is_release_kfunc(u32 kfunc_id, struct module *owner) +{ + return bpf_is_mod_release_kfunc(&xdp_kfunc_list, kfunc_id, owner); +} + +static enum bpf_return_type xdp_get_kfunc_return_type(u32 kfunc_id, + struct module *owner) +{ + return bpf_get_mod_kfunc_return_type(&xdp_kfunc_list, kfunc_id, owner); +} + +static int xdp_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id) +{ + int ret = __BPF_REG_TYPE_MAX; + struct module *mod; + + if (atype != BPF_READ) + return -EACCES; + + if (btf_is_module(btf)) { + mod = btf_try_get_module(btf); + if (!mod) + return -ENXIO; + ret = bpf_btf_mod_struct_access(&xdp_kfunc_list, mod, log, btf, t, off, size, + atype, next_btf_id); + module_put(mod); + } + if (ret == __BPF_REG_TYPE_MAX) + return btf_struct_access(log, btf, t, off, size, atype, next_btf_id); + return ret; +} + const struct bpf_verifier_ops xdp_verifier_ops = { .get_func_proto = xdp_func_proto, .is_valid_access = xdp_is_valid_access, .convert_ctx_access = xdp_convert_ctx_access, .gen_prologue = bpf_noop_prologue, + .is_acquire_kfunc = xdp_is_acquire_kfunc, + .is_release_kfunc = xdp_is_release_kfunc, + .get_kfunc_return_type = xdp_get_kfunc_return_type, + .btf_struct_access = xdp_btf_struct_access, }; const struct bpf_prog_ops xdp_prog_ops = { diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c index 202fa5eacd0f9b..7b4bfe79300244 100644 --- a/net/core/net_namespace.c +++ b/net/core/net_namespace.c @@ -299,6 +299,7 @@ struct net *get_net_ns_by_id(const struct net *net, int id) return peer; } +EXPORT_SYMBOL_GPL(get_net_ns_by_id); /* * setup_net runs the initializers for the network namespace object. diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c index 770a63103c7a42..69450b5c32f016 100644 --- a/net/netfilter/nf_conntrack_core.c +++ b/net/netfilter/nf_conntrack_core.c @@ -11,6 +11,9 @@ #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include #include #include #include @@ -2451,6 +2454,252 @@ static int kill_all(struct nf_conn *i, void *data) return net_eq(nf_ct_net(i), data); } +/* Unstable Kernel Helpers for XDP hook */ +static struct nf_conn *__bpf_nf_ct_lookup(struct net *net, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u8 protonum, + u64 netns_id, u64 flags) +{ + struct nf_conntrack_tuple_hash *hash; + struct nf_conntrack_tuple tuple; + + if (flags != IP_CT_DIR_ORIGINAL && flags != IP_CT_DIR_REPLY) + return ERR_PTR(-EINVAL); + + memset(&tuple, 0, sizeof(tuple)); + + switch (tuple_len) { + case sizeof(bpf_tuple->ipv4): + tuple.src.l3num = AF_INET; + tuple.src.u3.ip = bpf_tuple->ipv4.saddr; + tuple.src.u.tcp.port = bpf_tuple->ipv4.sport; + tuple.dst.u3.ip = bpf_tuple->ipv4.daddr; + tuple.dst.u.tcp.port = bpf_tuple->ipv4.dport; + break; + case sizeof(bpf_tuple->ipv6): + tuple.src.l3num = AF_INET6; + memcpy(tuple.src.u3.ip6, bpf_tuple->ipv6.saddr, sizeof(bpf_tuple->ipv6.saddr)); + tuple.src.u.tcp.port = bpf_tuple->ipv6.sport; + memcpy(tuple.dst.u3.ip6, bpf_tuple->ipv6.daddr, sizeof(bpf_tuple->ipv6.daddr)); + tuple.dst.u.tcp.port = bpf_tuple->ipv6.dport; + break; + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + tuple.dst.protonum = protonum; + tuple.dst.dir = flags; + + if ((s32)netns_id >= 0) { + if ((s32)netns_id > S32_MAX) + return ERR_PTR(-EINVAL); + net = get_net_ns_by_id(net, netns_id); + if (!net) + return ERR_PTR(-EINVAL); + } + + hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple); + if ((s32)netns_id >= 0) + put_net(net); + if (!hash) + return ERR_PTR(-ENOENT); + return nf_ct_tuplehash_to_ctrack(hash); +} + +static struct nf_conn *bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u8 protonum, + u64 netns_id, u64 *flags_err) +{ + struct xdp_buff *ctx = (struct xdp_buff *)xdp_ctx; + struct net *caller_net; + struct nf_conn *nfct; + + if (!flags_err) + return NULL; + if (!bpf_tuple) { + *flags_err = -EINVAL; + return NULL; + } + caller_net = dev_net(ctx->rxq->dev); + nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple_len, protonum, + netns_id, *flags_err); + if (IS_ERR(nfct)) { + *flags_err = PTR_ERR(nfct); + return NULL; + } + return nfct; +} + +static struct nf_conn *bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u8 protonum, + u64 netns_id, u64 *flags_err) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct net *caller_net; + struct nf_conn *nfct; + + if (!flags_err) + return NULL; + if (!bpf_tuple) { + *flags_err = -EINVAL; + return NULL; + } + caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk); + nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple_len, protonum, + netns_id, *flags_err); + if (IS_ERR(nfct)) { + *flags_err = PTR_ERR(nfct); + return NULL; + } + return nfct; +} + +struct nf_conn *bpf_xdp_ct_lookup_tcp(struct xdp_md *xdp_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u64 netns_id, + u64 *flags_err) +{ + return bpf_xdp_ct_lookup(xdp_ctx, bpf_tuple, tuple_len, IPPROTO_TCP, + netns_id, flags_err); +} + +struct nf_conn *bpf_xdp_ct_lookup_udp(struct xdp_md *xdp_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u64 netns_id, + u64 *flags_err) +{ + return bpf_xdp_ct_lookup(xdp_ctx, bpf_tuple, tuple_len, IPPROTO_UDP, + netns_id, flags_err); +} + +struct nf_conn *bpf_skb_ct_lookup_tcp(struct __sk_buff *skb_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u64 netns_id, + u64 *flags_err) +{ + return bpf_skb_ct_lookup(skb_ctx, bpf_tuple, tuple_len, IPPROTO_TCP, + netns_id, flags_err); +} + +struct nf_conn *bpf_skb_ct_lookup_udp(struct __sk_buff *skb_ctx, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u64 netns_id, + u64 *flags_err) +{ + return bpf_skb_ct_lookup(skb_ctx, bpf_tuple, tuple_len, IPPROTO_UDP, + netns_id, flags_err); +} + +void bpf_ct_release(struct nf_conn *nfct) +{ + if (!nfct) + return; + nf_ct_put(nfct); +} + +BTF_SET_START(nf_conntrack_xdp_ids) +BTF_ID(func, bpf_xdp_ct_lookup_tcp) +BTF_ID(func, bpf_xdp_ct_lookup_udp) +BTF_ID(func, bpf_ct_release) +BTF_SET_END(nf_conntrack_xdp_ids) + +BTF_SET_START(nf_conntrack_skb_ids) +BTF_ID(func, bpf_skb_ct_lookup_tcp) +BTF_ID(func, bpf_skb_ct_lookup_udp) +BTF_ID(func, bpf_ct_release) +BTF_SET_END(nf_conntrack_skb_ids) + +BTF_ID_LIST(nf_conntrack_ids) +BTF_ID(func, bpf_xdp_ct_lookup_tcp) +BTF_ID(func, bpf_xdp_ct_lookup_udp) +BTF_ID(func, bpf_skb_ct_lookup_tcp) +BTF_ID(func, bpf_skb_ct_lookup_udp) +BTF_ID(func, bpf_ct_release) +BTF_ID(struct, nf_conn) + +bool nf_is_acquire_kfunc(u32 kfunc_id) +{ + return kfunc_id == nf_conntrack_ids[0] || + kfunc_id == nf_conntrack_ids[1] || + kfunc_id == nf_conntrack_ids[2] || + kfunc_id == nf_conntrack_ids[3]; +} + +bool nf_is_release_kfunc(u32 kfunc_id) +{ + return kfunc_id == nf_conntrack_ids[4]; +} + +enum bpf_return_type nf_get_kfunc_return_type(u32 kfunc_id) +{ + if (kfunc_id == nf_conntrack_ids[0] || + kfunc_id == nf_conntrack_ids[1] || + kfunc_id == nf_conntrack_ids[2] || + kfunc_id == nf_conntrack_ids[3]) + return RET_PTR_TO_BTF_ID_OR_NULL; + return __BPF_RET_TYPE_MAX; +} + +static int nf_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id) +{ + const struct btf_type *nf_conn_type; + size_t end; + + nf_conn_type = btf_type_by_id(btf, nf_conntrack_ids[5]); + if (!nf_conn_type) + return -EACCES; + /* This won't work (not even with btf_struct_ids_match for off == 0), + * see below for the reason: + * https://lore.kernel.org/bpf/20211028014428.rsuq6rkfvqzq23tg@apollo.localdomain + */ + if (t != nf_conn_type) /* skip */ + return __BPF_REG_TYPE_MAX; + + if (atype != BPF_READ) + return -EACCES; + + switch (off) { + case offsetof(struct nf_conn, status): + end = offsetofend(struct nf_conn, status); + break; + /* TODO(v2): We should do it per field offset */ + case bpf_ctx_range(struct nf_conn, proto): + end = offsetofend(struct nf_conn, proto); + break; + default: + return -EACCES; + } + + if (off + size > end) + return -EACCES; + + return NOT_INIT; +} + +static struct kfunc_btf_id_set nf_ct_xdp_kfunc_set = { + .owner = THIS_MODULE, + .set = &nf_conntrack_xdp_ids, + .is_acquire_kfunc = nf_is_acquire_kfunc, + .is_release_kfunc = nf_is_release_kfunc, + .get_kfunc_return_type = nf_get_kfunc_return_type, + .btf_struct_access = nf_btf_struct_access, +}; + +static struct kfunc_btf_id_set nf_ct_skb_kfunc_set = { + .owner = THIS_MODULE, + .set = &nf_conntrack_skb_ids, + .is_acquire_kfunc = nf_is_acquire_kfunc, + .is_release_kfunc = nf_is_release_kfunc, + .get_kfunc_return_type = nf_get_kfunc_return_type, + .btf_struct_access = nf_btf_struct_access, +}; + void nf_conntrack_cleanup_start(void) { conntrack_gc_work.exiting = true; @@ -2459,6 +2708,9 @@ void nf_conntrack_cleanup_start(void) void nf_conntrack_cleanup_end(void) { + unregister_kfunc_btf_id_set(&xdp_kfunc_list, &nf_ct_xdp_kfunc_set); + unregister_kfunc_btf_id_set(&prog_test_kfunc_list, &nf_ct_skb_kfunc_set); + RCU_INIT_POINTER(nf_ct_hook, NULL); cancel_delayed_work_sync(&conntrack_gc_work.dwork); kvfree(nf_conntrack_hash); @@ -2745,6 +2997,9 @@ int nf_conntrack_init_start(void) conntrack_gc_work_init(&conntrack_gc_work); queue_delayed_work(system_power_efficient_wq, &conntrack_gc_work.dwork, HZ); + register_kfunc_btf_id_set(&prog_test_kfunc_list, &nf_ct_skb_kfunc_set); + register_kfunc_btf_id_set(&xdp_kfunc_list, &nf_ct_xdp_kfunc_set); + return 0; err_proto: