Skip to content

Commit

Permalink
Add support for Message Polls (serenity-rs#2836)
Browse files Browse the repository at this point in the history
This PR adds support for Polls attached to messages. This has been
tested for deserialising, creating, getting answer voters, and ending a
poll.

The builder is designed differently, as there are many required fields,
so I used the typestate pattern to prevent the user from using the
builder until it is ready.
  • Loading branch information
GnomedDev authored Apr 25, 2024
1 parent 5668654 commit 5c7d8af
Show file tree
Hide file tree
Showing 13 changed files with 533 additions and 20 deletions.
10 changes: 10 additions & 0 deletions src/builder/create_message.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use super::create_poll::Ready;
#[cfg(feature = "http")]
use super::{check_overflow, Builder};
use super::{
CreateActionRow,
CreateAllowedMentions,
CreateAttachment,
CreateEmbed,
CreatePoll,
EditAttachments,
};
#[cfg(feature = "http")]
Expand Down Expand Up @@ -69,6 +71,8 @@ pub struct CreateMessage {
flags: Option<MessageFlags>,
pub(crate) attachments: EditAttachments,
enforce_nonce: bool,
#[serde(skip_serializing_if = "Option::is_none")]
poll: Option<CreatePoll<super::create_poll::Ready>>,

// The following fields are handled separately.
#[serde(skip)]
Expand Down Expand Up @@ -288,6 +292,12 @@ impl CreateMessage {
self.enforce_nonce = enforce_nonce;
self
}

/// Sets the [`Poll`] for this message.
pub fn poll(mut self, poll: CreatePoll<Ready>) -> Self {
self.poll = Some(poll);
self
}
}

#[cfg(feature = "http")]
Expand Down
172 changes: 172 additions & 0 deletions src/builder/create_poll.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use crate::model::channel::{PollLayoutType, PollMedia, PollMediaEmoji};

#[derive(serde::Serialize, Clone, Debug)]
pub struct NeedsQuestion;
#[derive(serde::Serialize, Clone, Debug)]
pub struct NeedsAnswers;
#[derive(serde::Serialize, Clone, Debug)]
pub struct NeedsDuration;
#[derive(serde::Serialize, Clone, Debug)]
pub struct Ready;

mod sealed {
use super::*;

pub trait Sealed {}

impl Sealed for NeedsQuestion {}
impl Sealed for NeedsAnswers {}
impl Sealed for NeedsDuration {}
impl Sealed for Ready {}
}

use sealed::*;

/// "Only text is supported."
#[derive(serde::Serialize, Clone, Debug)]
struct CreatePollMedia {
text: String,
}

#[derive(serde::Serialize, Clone, Debug)]
#[must_use = "Builders do nothing unless built"]
pub struct CreatePoll<Stage: Sealed> {
question: CreatePollMedia,
answers: Vec<CreatePollAnswer>,
duration: u8,
allow_multiselect: bool,
layout_type: Option<PollLayoutType>,

#[serde(skip)]
_stage: Stage,
}

impl Default for CreatePoll<NeedsQuestion> {
/// See the documentation of [`Self::new`].
fn default() -> Self {
// Producing dummy values is okay as we must transition through all `Stage`s before firing,
// which fills in the values with real values.
Self {
question: CreatePollMedia {
text: String::default(),
},
answers: Vec::default(),
duration: u8::default(),
allow_multiselect: false,
layout_type: None,

_stage: NeedsQuestion,
}
}
}

impl CreatePoll<NeedsQuestion> {
/// Creates a builder for creating a Poll.
///
/// This must be transitioned through in order, to provide all required fields.
///
/// ```rust
/// use serenity::builder::{CreateMessage, CreatePoll, CreatePollAnswer};
///
/// let poll = CreatePoll::new()
/// .question("Cats or Dogs?")
/// .answers(vec![
/// CreatePollAnswer::new().emoji("🐱".to_string()).text("Cats!"),
/// CreatePollAnswer::new().emoji("🐶".to_string()).text("Dogs!"),
/// CreatePollAnswer::new().text("Neither..."),
/// ])
/// .duration(std::time::Duration::from_secs(60 * 60 * 24 * 7));
///
/// let message = CreateMessage::new().poll(poll);
/// ```
pub fn new() -> Self {
Self::default()
}

/// Sets the question to be polled.
pub fn question(self, text: impl Into<String>) -> CreatePoll<NeedsAnswers> {
CreatePoll {
question: CreatePollMedia {
text: text.into(),
},
answers: self.answers,
duration: self.duration,
allow_multiselect: self.allow_multiselect,
layout_type: self.layout_type,
_stage: NeedsAnswers,
}
}
}

impl CreatePoll<NeedsAnswers> {
/// Sets the answers that can be picked from.
pub fn answers(self, answers: Vec<CreatePollAnswer>) -> CreatePoll<NeedsDuration> {
CreatePoll {
question: self.question,
answers,
duration: self.duration,
allow_multiselect: self.allow_multiselect,
layout_type: self.layout_type,
_stage: NeedsDuration,
}
}
}

impl CreatePoll<NeedsDuration> {
/// Sets the duration for the Poll to run for.
///
/// This must be less than a week, and will be rounded to hours towards zero.
pub fn duration(self, duration: std::time::Duration) -> CreatePoll<Ready> {
let hours = duration.as_secs() / 3600;

CreatePoll {
question: self.question,
answers: self.answers,
duration: hours.try_into().unwrap_or(168),
allow_multiselect: self.allow_multiselect,
layout_type: self.layout_type,
_stage: Ready,
}
}
}

impl<Stage: Sealed> CreatePoll<Stage> {
/// Sets the layout type for the Poll to take.
///
/// This is currently only ever [`PollLayoutType::Default`], and is optional.
pub fn layout_type(mut self, layout_type: PollLayoutType) -> Self {
self.layout_type = Some(layout_type);
self
}

/// Allows users to select multiple answers for the Poll.
pub fn allow_multiselect(mut self) -> Self {
self.allow_multiselect = true;
self
}
}

#[derive(serde::Serialize, Clone, Debug, Default)]
#[must_use = "Builders do nothing unless built"]
pub struct CreatePollAnswer {
poll_media: PollMedia,
}

impl CreatePollAnswer {
/// Creates a builder for a Poll answer.
///
/// [`Self::text`] or [`Self::emoji`] must be provided.
pub fn new() -> Self {
Self::default()
}

pub fn text(mut self, text: impl Into<String>) -> Self {
self.poll_media.text = Some(text.into());
self
}

pub fn emoji(mut self, emoji: impl Into<PollMediaEmoji>) -> Self {
self.poll_media.emoji = Some(emoji.into());
self
}
}
2 changes: 2 additions & 0 deletions src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mod create_interaction_response;
mod create_interaction_response_followup;
mod create_invite;
mod create_message;
pub mod create_poll;
mod create_scheduled_event;
mod create_stage_instance;
mod create_sticker;
Expand Down Expand Up @@ -91,6 +92,7 @@ pub use create_interaction_response::*;
pub use create_interaction_response_followup::*;
pub use create_invite::*;
pub use create_message::*;
pub use create_poll::{CreatePoll, CreatePollAnswer};
pub use create_scheduled_event::*;
pub use create_stage_instance::*;
pub use create_sticker::*;
Expand Down
6 changes: 6 additions & 0 deletions src/client/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ fn update_cache_with_event(
Event::EntitlementDelete(event) => FullEvent::EntitlementDelete {
entitlement: event.entitlement,
},
Event::MessagePollVoteAdd(event) => FullEvent::MessagePollVoteAdd {
event,
},
Event::MessagePollVoteRemove(event) => FullEvent::MessagePollVoteRemove {
event,
},
};

Some((event, extra_event))
Expand Down
8 changes: 8 additions & 0 deletions src/client/event_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,14 @@ event_handler! {
/// be set.
EntitlementDelete { entitlement: Entitlement } => async fn entitlement_delete(&self, ctx: Context);

/// Dispatched when a user votes on a message poll.
///
/// This will be dispatched multiple times if multiple answers are selected.
MessagePollVoteAdd { event: MessagePollVoteAddEvent } => async fn poll_vote_add(&self, ctx: Context);

/// Dispatched when a user removes a previous vote on a poll.
MessagePollVoteRemove { event: MessagePollVoteRemoveEvent } => async fn poll_vote_remove(&self, ctx: Context);

/// Dispatched when an HTTP rate limit is hit
Ratelimit { data: RatelimitInfo } => async fn ratelimit(&self);
}
Expand Down
60 changes: 60 additions & 0 deletions src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,66 @@ impl Http {
.await
}

/// Get a list of users that voted for this specific answer.
pub async fn get_poll_answer_voters(
&self,
channel_id: ChannelId,
message_id: MessageId,
answer_id: AnswerId,
after: Option<UserId>,
limit: Option<u8>,
) -> Result<Vec<User>> {
#[derive(serde::Deserialize)]
struct VotersResponse {
users: Vec<User>,
}

let mut params = Vec::with_capacity(2);
if let Some(after) = after {
params.push(("after", after.to_string()));
}

if let Some(limit) = limit {
params.push(("limit", limit.to_string()));
}

let resp: VotersResponse = self
.fire(Request {
body: None,
multipart: None,
headers: None,
method: LightMethod::Get,
route: Route::ChannelPollGetAnswerVoters {
channel_id,
message_id,
answer_id,
},
params: Some(params),
})
.await?;

Ok(resp.users)
}

pub async fn expire_poll(
&self,
channel_id: ChannelId,
message_id: MessageId,
) -> Result<Message> {
self.fire(Request {
body: None,
multipart: None,
headers: None,
method: LightMethod::Post,
route: Route::ChannelPollExpire {
channel_id,
message_id,
},
params: None,
})
.await
}

/// Gets information about the current application.
///
/// **Note**: Only applications may use this endpoint.
Expand Down
8 changes: 8 additions & 0 deletions src/http/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ routes! ('a, {
api!("/channels/{}/users/@me/threads/archived/private", channel_id),
Some(RatelimitingKind::PathAndId(channel_id.into()));

ChannelPollGetAnswerVoters { channel_id: ChannelId, message_id: MessageId, answer_id: AnswerId },
api!("/channels/{}/polls/{}/answers/{}", channel_id, message_id, answer_id),
Some(RatelimitingKind::PathAndId(channel_id.into()));

ChannelPollExpire { channel_id: ChannelId, message_id: MessageId },
api!("/channels/{}/polls/{}/expire", channel_id, message_id),
Some(RatelimitingKind::PathAndId(channel_id.into()));

Gateway,
api!("/gateway"),
Some(RatelimitingKind::Path);
Expand Down
25 changes: 25 additions & 0 deletions src/model/channel/channel_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,31 @@ impl ChannelId {
) -> Result<ThreadsData> {
http.as_ref().get_channel_joined_archived_private_threads(self, before, limit).await
}

/// Get a list of users that voted for this specific answer.
///
/// # Errors
///
/// If the message does not have a poll.
pub async fn get_poll_answer_voters(
self,
http: impl AsRef<Http>,
message_id: MessageId,
answer_id: AnswerId,
after: Option<UserId>,
limit: Option<u8>,
) -> Result<Vec<User>> {
http.as_ref().get_poll_answer_voters(self, message_id, answer_id, after, limit).await
}

/// Ends the [`Poll`] on a given [`MessageId`], if there is one.
///
/// # Errors
///
/// If the message does not have a poll, or if the poll was not created by the current user.
pub async fn end_poll(self, http: impl AsRef<Http>, message_id: MessageId) -> Result<Message> {
http.as_ref().expire_poll(self, message_id).await
}
}

#[cfg(feature = "model")]
Expand Down
Loading

0 comments on commit 5c7d8af

Please sign in to comment.