diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index a7dd7d8c9af2d..af54a878ac277 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -603,7 +603,7 @@ static bool mptcp_check_data_fin(struct sock *sk) WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1); WRITE_ONCE(msk->rcv_data_fin, 0); - sk->sk_shutdown |= RCV_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ switch (sk->sk_state) { @@ -910,7 +910,7 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk) /* hopefully temporary hack: propagate shutdown status * to msk, when all subflows agree on it */ - sk->sk_shutdown |= RCV_SHUTDOWN; + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ sk->sk_data_ready(sk); @@ -2526,7 +2526,7 @@ static void mptcp_check_fastclose(struct mptcp_sock *msk) } inet_sk_state_store(sk, TCP_CLOSE); - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags); @@ -2958,7 +2958,7 @@ bool __mptcp_close(struct sock *sk, long timeout) bool do_cancel_work = false; int subflows_alive = 0; - sk->sk_shutdown = SHUTDOWN_MASK; + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) { mptcp_listen_inuse_dec(sk); @@ -3101,7 +3101,7 @@ static int mptcp_disconnect(struct sock *sk, int flags) mptcp_pm_data_reset(msk); mptcp_ca_reset(sk); - sk->sk_shutdown = 0; + WRITE_ONCE(sk->sk_shutdown, 0); sk_error_report(sk); return 0; } @@ -3806,9 +3806,6 @@ static __poll_t mptcp_check_writeable(struct mptcp_sock *msk) { struct sock *sk = (struct sock *)msk; - if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN)) - return EPOLLOUT | EPOLLWRNORM; - if (sk_stream_is_writeable(sk)) return EPOLLOUT | EPOLLWRNORM; @@ -3826,6 +3823,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct sock *sk = sock->sk; struct mptcp_sock *msk; __poll_t mask = 0; + u8 shutdown; int state; msk = mptcp_sk(sk); @@ -3842,17 +3840,22 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock, return inet_csk_listen_poll(ssock->sk); } + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) + mask |= EPOLLHUP; + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; + if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) { mask |= mptcp_check_readable(msk); - mask |= mptcp_check_writeable(msk); + if (shutdown & SEND_SHUTDOWN) + mask |= EPOLLOUT | EPOLLWRNORM; + else + mask |= mptcp_check_writeable(msk); } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) { /* cf tcp_poll() note about TFO */ mask |= EPOLLOUT | EPOLLWRNORM; } - if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) - mask |= EPOLLHUP; - if (sk->sk_shutdown & RCV_SHUTDOWN) - mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */ smp_rmb();