diff --git a/Cargo.lock b/Cargo.lock index 5bddae79..0788cb26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,6 +60,17 @@ dependencies = [ "serde", ] +[[package]] +name = "async-lock" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -423,6 +434,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.12" @@ -621,6 +641,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1907,6 +1937,27 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -2227,6 +2278,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + [[package]] name = "potential_utf" version = "0.1.2" @@ -2627,6 +2684,15 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.0.7" @@ -2740,6 +2806,7 @@ dependencies = [ "chrono", "hyper 1.8.1", "mime_guess", + "moka", "rand 0.9.2", "regex", "reqwest 0.12.24", diff --git a/rustmail/Cargo.toml b/rustmail/Cargo.toml index bd3d5780..63d4ce4c 100644 --- a/rustmail/Cargo.toml +++ b/rustmail/Cargo.toml @@ -25,6 +25,7 @@ serde_json = "1.0.145" rand = "0.9.2" base64 = "0.22.1" subtle = "2.6.1" +moka = { version = "0.12", features = ["future"] } [dependencies.uuid] version = "1.18.1" diff --git a/rustmail/src/api/handler/bot/tickets.rs b/rustmail/src/api/handler/bot/tickets.rs index be88070d..2e1ec752 100644 --- a/rustmail/src/api/handler/bot/tickets.rs +++ b/rustmail/src/api/handler/bot/tickets.rs @@ -1,9 +1,11 @@ +use crate::prelude::api::*; use crate::prelude::types::*; use axum::{ Json, extract::{Query, State}, http::StatusCode, }; +use axum_extra::extract::CookieJar; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::Mutex; @@ -62,6 +64,7 @@ pub struct PaginatedThreadsResponse { pub async fn handle_tickets_bot( State(bot_state): State>>, + jar: CookieJar, Query(params): Query, ) -> (StatusCode, Json) { let db_pool = { @@ -79,6 +82,47 @@ pub async fn handle_tickets_bot( } }; + let session_cookie = jar.get("session_id"); + let user_id = if let Some(cookie) = session_cookie { + get_user_id_from_session(cookie.value(), &db_pool).await + } else { + return ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({ + "error": "No session found" + })), + ); + }; + + let (guild_id, bot_http) = { + let state_lock = bot_state.lock().await; + let guild_id = match &state_lock.config { + Some(config) => config.bot.get_staff_guild_id(), + None => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "Config not initialized" + })), + ); + } + }; + let bot_http = match &state_lock.bot_http { + Some(http) => http.clone(), + None => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "Bot HTTP not initialized" + })), + ); + } + }; + (guild_id, bot_http) + }; + + let is_admin = is_admin_or_owner(&user_id, guild_id, bot_http.clone()).await; + if let Some(id) = params.id { let thread = match sqlx::query!( r#" @@ -174,21 +218,59 @@ pub async fn handle_tickets_bot( user_id: thread.user_id, user_name: thread.user_name, channel_id: thread.channel_id, - created_at: thread.created_at + created_at: thread + .created_at .flatten() .and_then(|ts: String| ts.parse::().ok()) .unwrap_or_default(), new_message_number: thread.new_message_number.unwrap_or_default(), status: thread.status, user_left: thread.user_left, - closed_at: thread.closed_at.flatten().and_then(|ts: String| ts.parse::().ok()), + closed_at: thread + .closed_at + .flatten() + .and_then(|ts: String| ts.parse::().ok()), closed_by: thread.closed_by, category_id: thread.category_id, category_name: thread.category_name, - required_permissions: thread.required_permissions, + required_permissions: thread.required_permissions.clone(), messages, }; + if !is_admin { + if let Some(ref category_id) = complete.category_id { + if !category_id.is_empty() { + match get_user_permissions_in_category( + &user_id, + guild_id, + category_id, + bot_http.clone(), + ) + .await + { + Some(perms) => { + if !can_view_channel(perms) { + return ( + StatusCode::FORBIDDEN, + Json(serde_json::json!({ + "error": "You don't have permission to view this ticket" + })), + ); + } + } + None => { + return ( + StatusCode::FORBIDDEN, + Json(serde_json::json!({ + "error": "Failed to check permissions" + })), + ); + } + } + } + } + } + return (StatusCode::OK, Json(serde_json::json!(complete))); } @@ -218,11 +300,11 @@ pub async fn handle_tickets_bot( _ => "DESC", }; - let count_query = format!("SELECT COUNT(*) as count FROM threads WHERE {}", where_clause); - let total: i64 = match sqlx::query_scalar(&count_query) - .fetch_one(&db_pool) - .await - { + let count_query = format!( + "SELECT COUNT(*) as count FROM threads WHERE {}", + where_clause + ); + let _total: i64 = match sqlx::query_scalar(&count_query).fetch_one(&db_pool).await { Ok(count) => count, Err(err) => { eprintln!("Erreur SQL count: {:?}", err); @@ -235,8 +317,6 @@ pub async fn handle_tickets_bot( } }; - let total_pages = (total as f64 / page_size as f64).ceil() as i64; - let query_str = format!( r#" SELECT @@ -261,21 +341,24 @@ pub async fn handle_tickets_bot( where_clause, sort_column, sort_order, page_size, offset ); - let threads_query = match sqlx::query_as::<_, ( - String, - i64, - String, - String, - Option, - Option, - i64, - bool, - Option, - Option, - Option, - Option, - Option, - )>(&query_str) + let threads_query = match sqlx::query_as::< + _, + ( + String, + i64, + String, + String, + Option, + Option, + i64, + bool, + Option, + Option, + Option, + Option, + Option, + ), + >(&query_str) .fetch_all(&db_pool) .await { @@ -316,55 +399,68 @@ pub async fn handle_tickets_bot( placeholders ); - let mut messages_query = sqlx::query_as::<_, ( - i64, - String, - i64, - String, - bool, - Option, - Option, - Option, - String, - String, - )>(&messages_query_str); + let mut messages_query = sqlx::query_as::< + _, + ( + i64, + String, + i64, + String, + bool, + Option, + Option, + Option, + String, + String, + ), + >(&messages_query_str); for thread_id in &thread_ids { messages_query = messages_query.bind(thread_id); } - let all_messages = messages_query.fetch_all(&db_pool).await.unwrap_or_else(|err| { - eprintln!("Erreur SQL messages batch: {:?}", err); - Vec::new() - }); + let all_messages = messages_query + .fetch_all(&db_pool) + .await + .unwrap_or_else(|err| { + eprintln!("Erreur SQL messages batch: {:?}", err); + Vec::new() + }); let mut messages_by_thread: std::collections::HashMap> = std::collections::HashMap::new(); for msg in all_messages { - messages_by_thread.entry(msg.1.clone()).or_insert_with(Vec::new).push(ThreadMessage { - id: msg.0, - thread_id: msg.1.clone(), - user_id: msg.2, - user_name: msg.3, - is_anonymous: msg.4, - dm_message_id: msg.5, - inbox_message_id: msg.6, - message_number: msg.7, - created_at: msg.8, - content: msg.9, - }); + messages_by_thread + .entry(msg.1.clone()) + .or_insert_with(Vec::new) + .push(ThreadMessage { + id: msg.0, + thread_id: msg.1.clone(), + user_id: msg.2, + user_name: msg.3, + is_anonymous: msg.4, + dm_message_id: msg.5, + inbox_message_id: msg.6, + message_number: msg.7, + created_at: msg.8, + content: msg.9, + }); } for thread in threads_query { - let messages = messages_by_thread.get(&thread.0).cloned().unwrap_or_default(); + let messages = messages_by_thread + .get(&thread.0) + .cloned() + .unwrap_or_default(); threads.push(CompleteThread { id: thread.0.clone(), user_id: thread.1, user_name: thread.2, channel_id: thread.3, - created_at: thread.4 + created_at: thread + .4 .and_then(|ts: String| ts.parse::().ok()) .unwrap_or_default(), new_message_number: thread.5.unwrap_or_default(), @@ -379,12 +475,47 @@ pub async fn handle_tickets_bot( }); } + let filtered_threads = if is_admin { + threads + } else { + let mut filtered = Vec::new(); + for thread in threads { + let can_view = if let Some(ref category_id) = thread.category_id { + if category_id.is_empty() { + true + } else { + match get_user_permissions_in_category( + &user_id, + guild_id, + category_id, + bot_http.clone(), + ) + .await + { + Some(perms) => can_view_channel(perms), + None => false, + } + } + } else { + true + }; + + if can_view { + filtered.push(thread); + } + } + filtered + }; + + let filtered_total = filtered_threads.len() as i64; + let filtered_total_pages = (filtered_total as f64 / page_size as f64).ceil() as i64; + let response = PaginatedThreadsResponse { - threads, - total, + threads: filtered_threads, + total: filtered_total, page, page_size, - total_pages, + total_pages: filtered_total_pages, }; (StatusCode::OK, Json(serde_json::json!(response))) diff --git a/rustmail/src/api/utils/mod.rs b/rustmail/src/api/utils/mod.rs index f4d433bd..2e0a70ad 100644 --- a/rustmail/src/api/utils/mod.rs +++ b/rustmail/src/api/utils/mod.rs @@ -1,5 +1,9 @@ pub mod get_user_id_from_session; +pub mod permissions_cache; pub mod ping_internal; +pub mod user_permissions; pub use get_user_id_from_session::*; +pub use permissions_cache::*; pub use ping_internal::*; +pub use user_permissions::*; diff --git a/rustmail/src/api/utils/permissions_cache.rs b/rustmail/src/api/utils/permissions_cache.rs new file mode 100644 index 00000000..caa9dd19 --- /dev/null +++ b/rustmail/src/api/utils/permissions_cache.rs @@ -0,0 +1,26 @@ +use moka::future::Cache; +use std::time::Duration; + +pub fn get_admin_cache() -> &'static Cache { + use std::sync::OnceLock; + static ADMIN_CACHE: OnceLock> = OnceLock::new(); + + ADMIN_CACHE.get_or_init(|| { + Cache::builder() + .max_capacity(1000) + .time_to_live(Duration::from_secs(300)) + .build() + }) +} + +pub fn get_permissions_cache() -> &'static Cache<(String, String), u64> { + use std::sync::OnceLock; + static PERMISSIONS_CACHE: OnceLock> = OnceLock::new(); + + PERMISSIONS_CACHE.get_or_init(|| { + Cache::builder() + .max_capacity(10000) + .time_to_live(Duration::from_secs(300)) + .build() + }) +} diff --git a/rustmail/src/api/utils/user_permissions.rs b/rustmail/src/api/utils/user_permissions.rs new file mode 100644 index 00000000..1bc04531 --- /dev/null +++ b/rustmail/src/api/utils/user_permissions.rs @@ -0,0 +1,137 @@ +use crate::prelude::api::*; +use serenity::all::{ + ChannelId, GuildId, Http, PermissionOverwriteType, Permissions, RoleId, UserId, +}; +use std::sync::Arc; + +pub async fn is_admin_or_owner(user_id: &str, guild_id: u64, bot_http: Arc) -> bool { + let cache = get_admin_cache(); + if let Some(is_admin) = cache.get(user_id).await { + return is_admin; + } + + let user_id_num = match user_id.parse::() { + Ok(id) => id, + Err(_) => return false, + }; + let guild_id_obj = GuildId::new(guild_id); + let user_id_obj = UserId::new(user_id_num); + + let guild = match guild_id_obj.to_partial_guild(bot_http.clone()).await { + Ok(g) => g, + Err(_) => return false, + }; + + if guild.owner_id == user_id_obj { + cache.insert(user_id.to_string(), true).await; + return true; + } + + let member = match guild_id_obj.member(bot_http.clone(), user_id_obj).await { + Ok(m) => m, + Err(_) => { + cache.insert(user_id.to_string(), false).await; + return false; + } + }; + + let is_admin = member.roles.iter().any(|role_id| { + guild + .roles + .get(role_id) + .map(|role| role.permissions.contains(Permissions::ADMINISTRATOR)) + .unwrap_or(false) + }); + + cache.insert(user_id.to_string(), is_admin).await; + is_admin +} + +pub async fn get_user_permissions_in_category( + user_id: &str, + guild_id: u64, + category_id: &str, + bot_http: Arc, +) -> Option { + let cache = get_permissions_cache(); + let cache_key = (user_id.to_string(), category_id.to_string()); + if let Some(perms) = cache.get(&cache_key).await { + return Some(perms); + } + + let user_id_num = user_id.parse::().ok()?; + let category_id_num = category_id.parse::().ok()?; + let guild_id_obj = GuildId::new(guild_id); + let user_id_obj = UserId::new(user_id_num); + let channel_id = ChannelId::new(category_id_num); + + let member = match guild_id_obj.member(bot_http.clone(), user_id_obj).await { + Ok(m) => m, + Err(_) => return None, + }; + + let guild_roles = match guild_id_obj.roles(bot_http.clone()).await { + Ok(roles) => roles, + Err(_) => return None, + }; + + let category = match channel_id.to_channel(bot_http.clone()).await { + Ok(channel) => match channel.guild() { + Some(guild_channel) => guild_channel, + None => return None, + }, + Err(_) => return None, + }; + + let everyone_role_id = RoleId::new(guild_id_obj.get()); + let mut permissions = guild_roles + .get(&everyone_role_id) + .map(|r| r.permissions.bits()) + .unwrap_or(0u64); + + for overwrite in &category.permission_overwrites { + if let PermissionOverwriteType::Role(role_id) = overwrite.kind { + if role_id == everyone_role_id { + let deny = overwrite.deny.bits(); + let allow = overwrite.allow.bits(); + permissions = (permissions & !deny) | allow; + break; + } + } + } + + let mut combined_allow = 0u64; + let mut combined_deny = 0u64; + + for role_id in &member.roles { + for overwrite in &category.permission_overwrites { + if let PermissionOverwriteType::Role(overwrite_role_id) = overwrite.kind { + if overwrite_role_id == *role_id { + combined_allow |= overwrite.allow.bits(); + combined_deny |= overwrite.deny.bits(); + } + } + } + } + + permissions = (permissions & !combined_deny) | combined_allow; + + for overwrite in &category.permission_overwrites { + if let PermissionOverwriteType::Member(member_id) = overwrite.kind { + if member_id == user_id_obj { + let deny = overwrite.deny.bits(); + let allow = overwrite.allow.bits(); + permissions = (permissions & !deny) | allow; + break; + } + } + } + + cache.insert(cache_key, permissions).await; + Some(permissions) +} + +pub fn can_view_channel(user_permissions: u64) -> bool { + const VIEW_CHANNEL: u64 = 1 << 10; + (user_permissions & VIEW_CHANNEL) == VIEW_CHANNEL +} diff --git a/rustmail/src/commands/close/common.rs b/rustmail/src/commands/close/common.rs index 5a346bd3..cc5eb2e9 100644 --- a/rustmail/src/commands/close/common.rs +++ b/rustmail/src/commands/close/common.rs @@ -1,12 +1,26 @@ use std::time::Duration; +pub fn format_duration(seconds: u64) -> String { + if seconds < 60 { + format!("{}s", seconds) + } else if seconds < 3600 { + format!("{}m", seconds / 60) + } else if seconds < 86400 { + format!("{}h{}m", seconds / 3600, (seconds % 3600) / 60) + } else { + format!("{}d{}h", seconds / 86400, (seconds % 86400) / 3600) + } +} + pub fn parse_duration_spec(spec: &str) -> Option { if spec.is_empty() { return None; } + let mut total: u64 = 0; let mut num: u64 = 0; let mut has_unit_segment = false; + for ch in spec.chars() { if ch.is_ascii_digit() { let digit = ch.to_digit(10)? as u64; @@ -24,13 +38,15 @@ pub fn parse_duration_spec(spec: &str) -> Option { has_unit_segment = true; } } + if num > 0 { if has_unit_segment { total = total.saturating_add(num); } else { - total = total.saturating_add(num * 60); + total = total.saturating_add(num); } } + if total == 0 { None } else { diff --git a/rustmail/src/commands/close/slash_command/close.rs b/rustmail/src/commands/close/slash_command/close.rs index 1e78bf23..11fd098c 100644 --- a/rustmail/src/commands/close/slash_command/close.rs +++ b/rustmail/src/commands/close/slash_command/close.rs @@ -4,17 +4,17 @@ use crate::prelude::db::*; use crate::prelude::errors::*; use crate::prelude::handlers::*; use crate::prelude::i18n::*; +use crate::prelude::modules::*; use crate::prelude::utils::*; use chrono::Utc; use serenity::FutureExt; use serenity::all::{ - CommandDataOptionValue, CommandInteraction, CommandOptionType, Context, CreateCommand, - CreateCommandOption, GuildId, ResolvedOption, UserId, + Channel, CommandDataOptionValue, CommandInteraction, CommandOptionType, Context, CreateCommand, + CreateCommandOption, GuildId, PermissionOverwriteType, ResolvedOption, RoleId, UserId, }; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::time::sleep; pub struct CloseCommand; @@ -191,16 +191,32 @@ impl RegistrableCommand for CloseCommand { } if let Some(delay) = duration { + if let Ok(Some(existing)) = get_scheduled_closure(&thread.id, db_pool).await { + let remaining = existing.close_at - Utc::now().timestamp(); + if remaining > 0 { + let old_human = format_duration(remaining as u64); + + let mut warn_params = HashMap::new(); + warn_params.insert("old_time".to_string(), old_human); + + let response = MessageBuilder::system_message(&ctx, &config) + .translated_content( + "close.replacing_existing_closure", + Some(&warn_params), + Some(command.user.id), + command.guild_id.map(|g| g.get()), + ) + .await + .to_channel(command.channel_id) + .build_interaction_message_followup() + .await; + + let _ = command.create_followup(&ctx.http, response).await; + } + } + let delay_secs = delay.as_secs(); - let human = if delay_secs < 60 { - format!("{}s", delay_secs) - } else if delay_secs < 3600 { - format!("{}m", delay_secs / 60) - } else if delay_secs < 86400 { - format!("{}h{}m", delay_secs / 3600, (delay_secs % 3600) / 60) - } else { - format!("{}d{}h", delay_secs / 86400, (delay_secs % 86400) / 3600) - }; + let human = format_duration(delay_secs); let mut params = HashMap::new(); params.insert("time".to_string(), human); @@ -217,7 +233,7 @@ impl RegistrableCommand for CloseCommand { .build_interaction_message_followup() .await; - let _ = command.create_followup(&ctx.http, response).await; + command.create_followup(&ctx.http, response).await } else { let response = MessageBuilder::system_message(&ctx, &config) .translated_content( @@ -231,18 +247,58 @@ impl RegistrableCommand for CloseCommand { .build_interaction_message_followup() .await; - let _ = command.create_followup(&ctx.http, response).await; + command.create_followup(&ctx.http, response).await }; + let closed_by = command.user.id.to_string(); + + let (category_id, category_name, required_permissions) = + match command.channel_id.to_channel(&ctx.http).await { + Ok(Channel::Guild(guild_channel)) => { + let guild_id = guild_channel.guild_id; + let parent_id = guild_channel.parent_id; + + let category_id = + parent_id.map(|id| id.to_string()).unwrap_or_default(); + + let category_name = if let Some(parent_id) = parent_id { + guild_id + .channels(&ctx.http) + .await + .ok() + .and_then(|channels| { + channels.get(&parent_id).map(|c| c.name.clone()) + }) + .unwrap_or_default() + } else { + String::new() + }; + + let guild = guild_id.to_partial_guild(&ctx.http).await.ok(); + let everyone_role_id = RoleId::new(guild_id.get()); + + let mut perms = guild + .and_then(|g| { + g.roles.get(&everyone_role_id).map(|r| r.permissions.bits()) + }) + .unwrap_or(0u64); + + for overwrite in &guild_channel.permission_overwrites { + if let PermissionOverwriteType::Role(_) = overwrite.kind { + let allow = overwrite.allow.bits(); + let deny = overwrite.deny.bits(); + perms = (perms & !deny) | allow; + } + } + + (category_id, category_name, perms) + } + _ => (String::new(), String::new(), 0u64), + }; + let thread_id = thread.id.clone(); let close_at = Utc::now().timestamp() + delay.as_secs() as i64; - let closed_by = command.user.id.to_string(); - let category_id = get_category_id_from_command(&ctx, &command).await; - let category_name = get_category_name_from_command(&ctx, &command).await; - let required_permissions = - get_required_permissions_channel_from_command(&ctx, &command).await; - if let Err(e) = upsert_scheduled_closure( &thread_id, close_at, @@ -257,103 +313,8 @@ impl RegistrableCommand for CloseCommand { { eprintln!("Failed to persist scheduled closure: {e:?}"); } - let channel_id = command.channel_id; - let config_clone = config.clone(); - let ctx_clone = ctx.clone(); - let user_id_clone = user_id; - let thread_id_for_task = thread_id.clone(); - - tokio::spawn(async move { - sleep(delay).await; - if let Some(pool) = config_clone.db_pool.as_ref() { - if let Ok(Some(record)) = - get_scheduled_closure(&thread_id_for_task, pool).await - { - if record.close_at <= Utc::now().timestamp() { - let _ = close_thread( - &thread_id_for_task, - &record.closed_by, - &record.category_id, - &record.category_name, - record.required_permissions.parse::().unwrap_or(0), - pool, - ) - .await; - let _ = delete_scheduled_closure(&thread_id_for_task, pool).await; - - let community_guild_id = - GuildId::new(config_clone.bot.get_community_guild_id()); - - let user_still_member = community_guild_id - .member(&ctx_clone.http, user_id_clone) - .await - .is_ok(); - - if !record.silent && user_still_member { - let _ = - MessageBuilder::system_message(&ctx_clone, &config_clone) - .content(&config_clone.bot.close_message) - .to_user(user_id_clone) - .send(true) - .await; - } - let _ = channel_id.delete(&ctx_clone.http).await; - } else { - let delay2 = - (record.close_at - Utc::now().timestamp()).max(1) as u64; - let config_clone2 = config_clone.clone(); - let ctx_clone2 = ctx_clone.clone(); - let thread_id_again = thread_id_for_task.clone(); - - tokio::spawn(async move { - sleep(Duration::from_secs(delay2)).await; - if let Some(pool2) = config_clone2.db_pool.as_ref() { - if let Ok(Some(r2)) = - get_scheduled_closure(&thread_id_again, pool2).await - { - if r2.close_at <= Utc::now().timestamp() { - let _ = close_thread( - &thread_id_again, - &r2.closed_by, - &r2.category_id, - &r2.category_name, - r2.required_permissions - .parse::() - .unwrap_or(0), - pool2, - ) - .await; - let _ = delete_scheduled_closure( - &thread_id_again, - pool2, - ) - .await; - let community_guild_id = GuildId::new( - config_clone2.bot.get_community_guild_id(), - ); - let user_still_member = community_guild_id - .member(&ctx_clone2.http, user_id_clone) - .await - .is_ok(); - if !r2.silent && user_still_member { - let _ = MessageBuilder::system_message( - &ctx_clone2, - &config_clone2, - ) - .content(&config_clone2.bot.close_message) - .to_user(user_id_clone) - .send(true) - .await; - } - let _ = channel_id.delete(&ctx_clone2.http).await; - } - } - } - }); - } - } - } - }); + + schedule_one(&ctx, &config, thread_id, close_at); return Ok(()); } diff --git a/rustmail/src/commands/close/text_command/close.rs b/rustmail/src/commands/close/text_command/close.rs index 7cd515f4..fefd9206 100644 --- a/rustmail/src/commands/close/text_command/close.rs +++ b/rustmail/src/commands/close/text_command/close.rs @@ -3,13 +3,13 @@ use crate::prelude::config::*; use crate::prelude::db::*; use crate::prelude::errors::*; use crate::prelude::handlers::*; +use crate::prelude::modules::*; use crate::prelude::utils::*; use chrono::Utc; use serenity::all::{Channel, Context, GuildId, Message, PermissionOverwriteType, RoleId, UserId}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::time::sleep; pub async fn close( ctx: Context, @@ -127,21 +127,35 @@ pub async fn close( } if let Some(delay) = duration { + if let Ok(Some(existing)) = get_scheduled_closure(&thread.id, db_pool).await { + let remaining = existing.close_at - Utc::now().timestamp(); + if remaining > 0 { + let old_human = format_duration(remaining as u64); + + let mut warn_params = HashMap::new(); + warn_params.insert("old_time".to_string(), old_human); + + let _ = MessageBuilder::system_message(&ctx, config) + .translated_content( + "close.replacing_existing_closure", + Some(&warn_params), + Some(msg.author.id), + msg.guild_id.map(|g| g.get()), + ) + .await + .to_channel(msg.channel_id) + .send(true) + .await; + } + } + let delay_secs = delay.as_secs(); - let human = if delay_secs < 60 { - format!("{}s", delay_secs) - } else if delay_secs < 3600 { - format!("{}m", delay_secs / 60) - } else if delay_secs < 86400 { - format!("{}h{}m", delay_secs / 3600, (delay_secs % 3600) / 60) - } else { - format!("{}d{}h", delay_secs / 86400, (delay_secs % 86400) / 3600) - }; + let human = format_duration(delay_secs); let mut params = HashMap::new(); params.insert("time".to_string(), human); let _ = if silent { - let _ = MessageBuilder::system_message(&ctx, config) + MessageBuilder::system_message(&ctx, config) .translated_content( "close.silent_closing", Some(¶ms), @@ -151,9 +165,9 @@ pub async fn close( .await .to_channel(msg.channel_id) .send(true) - .await; + .await } else { - let _ = MessageBuilder::system_message(&ctx, config) + MessageBuilder::system_message(&ctx, config) .translated_content( "close.closing", Some(¶ms), @@ -163,51 +177,53 @@ pub async fn close( .await .to_channel(msg.channel_id) .send(true) - .await; + .await }; let closed_by = msg.author.id.to_string(); - let category_id = match msg.channel_id.to_channel(&ctx.http).await { - Ok(channel) => match channel.category() { - Some(category) => category.id.to_string(), - None => String::new(), - }, - _ => String::new(), - }; - let category_name = match msg.channel_id.to_channel(&ctx.http).await { - Ok(channel) => match channel.category() { - Some(category) => category.name.clone(), - None => String::new(), - }, - _ => String::new(), - }; - let required_permissions = match msg.channel_id.to_channel(&ctx.http).await { - Ok(Channel::Guild(guild_channel)) => { - let guild_id = guild_channel.guild_id; - let guild = guild_id.to_partial_guild(&ctx.http).await.ok(); + let (category_id, category_name, required_permissions) = + match msg.channel_id.to_channel(&ctx.http).await { + Ok(Channel::Guild(guild_channel)) => { + let guild_id = guild_channel.guild_id; + let parent_id = guild_channel.parent_id; + + let category_id = parent_id.map(|id| id.to_string()).unwrap_or_default(); + + let category_name = if let Some(parent_id) = parent_id { + guild_id + .channels(&ctx.http) + .await + .ok() + .and_then(|channels| channels.get(&parent_id).map(|c| c.name.clone())) + .unwrap_or_default() + } else { + String::new() + }; - let everyone_role_id = RoleId::new(guild_id.get()); + let guild = guild_id.to_partial_guild(&ctx.http).await.ok(); + let everyone_role_id = RoleId::new(guild_id.get()); - let mut perms = guild - .and_then(|g| g.roles.get(&everyone_role_id).map(|r| r.permissions.bits())) - .unwrap_or(0u64); + let mut perms = guild + .and_then(|g| g.roles.get(&everyone_role_id).map(|r| r.permissions.bits())) + .unwrap_or(0u64); - for overwrite in &guild_channel.permission_overwrites { - if let PermissionOverwriteType::Role(_) = overwrite.kind { - let allow = overwrite.allow.bits(); - let deny = overwrite.deny.bits(); - perms = (perms & !deny) | allow; + for overwrite in &guild_channel.permission_overwrites { + if let PermissionOverwriteType::Role(_) = overwrite.kind { + let allow = overwrite.allow.bits(); + let deny = overwrite.deny.bits(); + perms = (perms & !deny) | allow; + } } - } - perms - } - _ => 0u64, - }; + (category_id, category_name, perms) + } + _ => (String::new(), String::new(), 0u64), + }; let thread_id = thread.id.clone(); let close_at = Utc::now().timestamp() + delay.as_secs() as i64; + if let Err(e) = upsert_scheduled_closure( &thread_id, close_at, @@ -222,94 +238,8 @@ pub async fn close( { eprintln!("Failed to persist scheduled closure: {e:?}"); } - let channel_id = msg.channel_id; - let config_clone = config.clone(); - let ctx_clone = ctx.clone(); - let user_id_clone = user_id; - let thread_id_for_task = thread_id.clone(); - - tokio::spawn(async move { - sleep(delay).await; - if let Some(pool) = config_clone.db_pool.as_ref() { - if let Ok(Some(record)) = get_scheduled_closure(&thread_id_for_task, pool).await { - if record.close_at <= Utc::now().timestamp() { - let _ = close_thread( - &thread_id_for_task, - &record.closed_by, - &record.category_id, - &category_name, - record.required_permissions.parse::().unwrap_or(0), - pool, - ) - .await; - let _ = delete_scheduled_closure(&thread_id_for_task, pool).await; - - let community_guild_id = - GuildId::new(config_clone.bot.get_community_guild_id()); - let user_still_member = community_guild_id - .member(&ctx_clone.http, user_id_clone) - .await - .is_ok(); - - if !record.silent && user_still_member { - let _ = MessageBuilder::system_message(&ctx_clone, &config_clone) - .content(&config_clone.bot.close_message) - .to_user(user_id_clone) - .send(true) - .await; - } - let _ = channel_id.delete(&ctx_clone.http).await; - } else { - let delay2 = (record.close_at - Utc::now().timestamp()).max(1) as u64; - let config_clone2 = config_clone.clone(); - let ctx_clone2 = ctx_clone.clone(); - let thread_id_again = thread_id_for_task.clone(); - - tokio::spawn(async move { - sleep(Duration::from_secs(delay2)).await; - if let Some(pool2) = config_clone2.db_pool.as_ref() { - if let Ok(Some(r2)) = - get_scheduled_closure(&thread_id_again, pool2).await - { - if r2.close_at <= Utc::now().timestamp() { - let _ = close_thread( - &thread_id_again, - &r2.closed_by, - &r2.category_id, - &r2.category_id, - r2.required_permissions.parse::().unwrap_or(0), - pool2, - ) - .await; - let _ = - delete_scheduled_closure(&thread_id_again, pool2).await; - let community_guild_id = GuildId::new( - config_clone2.bot.get_community_guild_id(), - ); - let user_still_member = community_guild_id - .member(&ctx_clone2.http, user_id_clone) - .await - .is_ok(); - if !r2.silent && user_still_member { - let _ = MessageBuilder::system_message( - &ctx_clone2, - &config_clone2, - ) - .content(&config_clone2.bot.close_message) - .to_user(user_id_clone) - .send(true) - .await; - } - let _ = channel_id.delete(&ctx_clone2.http).await; - } - } - } - }); - } - } - } - }); + schedule_one(&ctx, config, thread_id, close_at); return Ok(()); } diff --git a/rustmail/src/handlers/guild_messages_handler.rs b/rustmail/src/handlers/guild_messages_handler.rs index 55ddf01a..6ca38d68 100644 --- a/rustmail/src/handlers/guild_messages_handler.rs +++ b/rustmail/src/handlers/guild_messages_handler.rs @@ -149,6 +149,24 @@ async fn manage_incoming_message( .await; return Err(error); } + + if let Ok(thread) = fetch_thread(pool, &channel_id_str).await { + if let Ok(existed) = delete_scheduled_closure(&thread.id, pool).await { + if existed { + let _ = MessageBuilder::system_message(ctx, config) + .translated_content( + "close.auto_canceled_on_message", + None, + Some(msg.author.id), + None, + ) + .await + .to_channel(channel_id) + .send(true) + .await; + } + } + } } } else { create_channel(ctx, msg, config).await; diff --git a/rustmail/src/i18n/language/en.rs b/rustmail/src/i18n/language/en.rs index c915ccc1..749fd196 100644 --- a/rustmail/src/i18n/language/en.rs +++ b/rustmail/src/i18n/language/en.rs @@ -602,6 +602,14 @@ pub fn load_english_messages(dict: &mut ErrorDictionary) { "close.closure_canceled".to_string(), DictionaryMessage::new("Closure canceled."), ); + dict.messages.insert( + "close.auto_canceled_on_message".to_string(), + DictionaryMessage::new("Scheduled closure has been automatically canceled because a message was received."), + ); + dict.messages.insert( + "close.replacing_existing_closure".to_string(), + DictionaryMessage::new("⚠️ Warning: A closure was already scheduled in {old_time}. It will be replaced by the new one."), + ); dict.messages.insert( "close.no_scheduled_closures_to_cancel".to_string(), DictionaryMessage::new("No scheduled closures to cancel."), diff --git a/rustmail/src/i18n/language/fr.rs b/rustmail/src/i18n/language/fr.rs index 783d9894..fed8901d 100644 --- a/rustmail/src/i18n/language/fr.rs +++ b/rustmail/src/i18n/language/fr.rs @@ -630,6 +630,14 @@ pub fn load_french_messages(dict: &mut ErrorDictionary) { "close.closure_canceled".to_string(), DictionaryMessage::new("Fermeture programmée annulée."), ); + dict.messages.insert( + "close.auto_canceled_on_message".to_string(), + DictionaryMessage::new("La fermeture programmée a été automatiquement annulée car un message a été reçu."), + ); + dict.messages.insert( + "close.replacing_existing_closure".to_string(), + DictionaryMessage::new("⚠️ Attention : Une fermeture était déjà programmée dans {old_time}. Elle sera remplacée par la nouvelle."), + ); dict.messages.insert( "close.no_scheduled_closures_to_cancel".to_string(), DictionaryMessage::new("Aucune fermeture programmée à annuler."), diff --git a/rustmail/src/modules/scheduled_closures.rs b/rustmail/src/modules/scheduled_closures.rs index 07707260..27a7ac10 100644 --- a/rustmail/src/modules/scheduled_closures.rs +++ b/rustmail/src/modules/scheduled_closures.rs @@ -5,7 +5,7 @@ use chrono::Utc; use serenity::all::{ChannelId, Context, UserId}; use tokio::time::{Duration, sleep}; -fn schedule_one(ctx: &Context, config: &Config, thread_id: String, close_at: i64) { +pub fn schedule_one(ctx: &Context, config: &Config, thread_id: String, close_at: i64) { let now = Utc::now().timestamp(); let delay_secs = (close_at - now).max(0) as u64; let ctx_clone = ctx.clone(); @@ -61,6 +61,7 @@ pub async fn hydrate_scheduled_closures(ctx: &Context, config: &Config) { let Some(pool) = config.db_pool.as_ref() else { return; }; + let list = match get_all_scheduled_closures(pool).await { Ok(l) => l, Err(e) => { @@ -68,6 +69,7 @@ pub async fn hydrate_scheduled_closures(ctx: &Context, config: &Config) { return; } }; + for sc in list { if let Some(thread) = get_thread_by_id(&sc.thread_id, pool).await { if sc.close_at <= Utc::now().timestamp() {