diff --git a/tests/test_admin_commands.py b/tests/test_admin_commands.py index 91c5b9f..2636f7a 100644 --- a/tests/test_admin_commands.py +++ b/tests/test_admin_commands.py @@ -276,8 +276,8 @@ def test_cooldown_expires(self, workflow_state): assert remaining == 0.0 -class TestOnMessageListener: - """Test suite for on_message DM listener.""" +class TestAdminSessionExclusivity: + """Admin DM workflows are mutually exclusive — only one active session per user.""" @pytest.fixture def admin_cog(self, mock_bot, database): @@ -285,54 +285,95 @@ def admin_cog(self, mock_bot, database): return AdminCommands(mock_bot) @pytest.mark.asyncio - async def test_ignores_bot_messages(self, admin_cog): - """Bot messages are ignored to prevent infinite loops.""" - mock_message = MagicMock() - mock_message.author.bot = True - mock_message.guild = None + async def test_fixture_create_command_blocked_when_results_session_active( + self, admin_cog, mock_interaction_admin + ): + """fixture_create gives a single clear error — not 'Check your DMs' + error.""" + user_id = str(mock_interaction_admin.user.id) + mock_interaction_admin.channel_id = int(mock_interaction_admin.channel.id) + mock_interaction_admin.guild_id = mock_interaction_admin.guild.id + admin_cog.results_handler.start_session(user_id, 1, 111111, week_number=1) - result = await admin_cog.on_message(mock_message) - assert result is None + await admin_cog.fixture_create.callback(admin_cog, mock_interaction_admin) + + assert len(mock_interaction_admin.response_sent) == 1 + response = mock_interaction_admin.response_sent[0]["content"] + assert "results entry" in response.lower() + assert "Check your DMs" not in response + assert not admin_cog.fixture_handler.has_session(user_id) @pytest.mark.asyncio - async def test_ignores_guild_messages(self, admin_cog): - """Guild messages are ignored - admin workflows require DMs.""" - mock_message = MagicMock() - mock_message.guild = MagicMock() # Has guild + async def test_start_fixture_dm_blocked_when_results_session_active( + self, admin_cog, mock_interaction_admin + ): + """_start_fixture_dm fallback guard covers the view-triggered path.""" + user_id = str(mock_interaction_admin.user.id) + admin_cog.results_handler.start_session(user_id, 1, 111111, week_number=1) + + result = await admin_cog._start_fixture_dm( + mock_interaction_admin.user, + user_id, + channel_id=123456, + guild_id=111111, + ) - result = await admin_cog.on_message(mock_message) - assert result is None + assert result is False + assert not admin_cog.fixture_handler.has_session(user_id) + assert len(mock_interaction_admin.user.dm_sent) == 1 + assert "results entry" in mock_interaction_admin.user.dm_sent[0].lower() @pytest.mark.asyncio - async def test_handles_fixture_creation_dm(self, admin_cog): - """Fixture creation DMs route to the correct handler.""" - mock_message = MagicMock() - mock_message.guild = None - user_id = "123456" - mock_message.author.id = 123456 - mock_message.author.bot = False + async def test_results_enter_blocked_when_fixture_session_active( + self, admin_cog, mock_interaction_admin, sample_games + ): + """Starting results entry while fixture creation is in progress is rejected.""" + from datetime import UTC, datetime, timedelta + deadline = datetime.now(UTC) + timedelta(days=1) + await admin_cog.db.create_fixture(1, sample_games, deadline) + + user_id = str(mock_interaction_admin.user.id) + mock_interaction_admin.guild_id = mock_interaction_admin.guild.id admin_cog.fixture_handler.start_session(user_id, 123456, 111111) - admin_cog.fixture_handler.handle_dm = AsyncMock(return_value=True) - await admin_cog.on_message(mock_message) + await admin_cog.results_enter.callback(admin_cog, mock_interaction_admin, 1) + + assert not admin_cog.results_handler.has_session(user_id) + response = mock_interaction_admin.response_sent[-1] + assert "fixture creation" in response["content"].lower() + + @pytest.mark.asyncio + async def test_fixture_create_allowed_when_no_conflicting_session( + self, admin_cog, mock_interaction_admin + ): + """Fixture creation proceeds normally when no admin session is active.""" + user_id = str(mock_interaction_admin.user.id) + + result = await admin_cog._start_fixture_dm( + mock_interaction_admin.user, + user_id, + channel_id=123456, + guild_id=111111, + ) + assert result is True assert admin_cog.fixture_handler.has_session(user_id) @pytest.mark.asyncio - async def test_handles_results_entry_dm(self, admin_cog): - """Results entry DMs route to the correct handler.""" - mock_message = MagicMock() - mock_message.guild = None - user_id = "123456" - mock_message.author.id = 123456 - mock_message.author.bot = False + async def test_results_enter_allowed_when_no_conflicting_session( + self, admin_cog, mock_interaction_admin, sample_games + ): + """Results entry proceeds normally when no admin session is active.""" + from datetime import UTC, datetime, timedelta - admin_cog.results_handler.start_session(user_id, 1, 111111, week_number=1) - admin_cog.results_handler.handle_dm = AsyncMock(return_value=True) + deadline = datetime.now(UTC) + timedelta(days=1) + await admin_cog.db.create_fixture(1, sample_games, deadline) - await admin_cog.on_message(mock_message) + mock_interaction_admin.guild_id = mock_interaction_admin.guild.id + await admin_cog.results_enter.callback(admin_cog, mock_interaction_admin, 1) + + user_id = str(mock_interaction_admin.user.id) assert admin_cog.results_handler.has_session(user_id) diff --git a/tests/test_bot.py b/tests/test_bot.py index ce8e7c5..56bb163 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -48,9 +48,16 @@ class TestSetupHook: async def bot_instance(self): mock_tree = MagicMock() mock_tree.sync = AsyncMock(return_value=[]) + mock_admin_cog = MagicMock() + mock_admin_cog.fixture_handler = MagicMock() + mock_admin_cog.results_handler = MagicMock() + mock_user_cog = MagicMock() + mock_user_cog.prediction_handler = MagicMock() + mock_cogs = {"AdminCommands": mock_admin_cog, "UserCommands": mock_user_cog} with ( patch("typer_bot.bot.commands.Bot.__init__", return_value=None), patch.object(TyperBot, "tree", mock_tree), + patch.object(TyperBot, "cogs", mock_cogs), ): bot = TyperBot.__new__(TyperBot) bot.db = MagicMock() @@ -421,6 +428,72 @@ async def test_on_message_sets_trace_id(self, bot_instance): mock_set_trace.assert_called_once_with("msg-123456") +class TestOnMessageDMRouting: + """Test suite verifying DM messages are routed through DMRouter.""" + + @pytest.fixture + def bot_instance(self): + mock_router = MagicMock() + mock_router.route = AsyncMock(return_value=True) + with patch("typer_bot.bot.commands.Bot.__init__", return_value=None): + bot = TyperBot.__new__(TyperBot) + bot.thread_handler = MagicMock() + bot.thread_handler.on_message = AsyncMock(return_value=False) + bot.dm_router = mock_router + yield bot + + @pytest.mark.asyncio + async def test_dm_routes_through_dm_router(self, bot_instance): + """DMs are dispatched to the router, not to cog listeners.""" + mock_message = MagicMock() + mock_message.author.bot = False + mock_message.guild = None + mock_message.id = 1 + + await bot_instance.on_message(mock_message) + + bot_instance.dm_router.route.assert_awaited_once_with(mock_message) + + @pytest.mark.asyncio + async def test_guild_messages_skip_dm_router(self, bot_instance): + """Guild messages go through normal command processing, not the DM router.""" + mock_message = MagicMock() + mock_message.author.bot = False + mock_message.guild = MagicMock() + mock_message.id = 2 + + with patch("discord.ext.commands.Bot.on_message", new_callable=AsyncMock): + await bot_instance.on_message(mock_message) + + bot_instance.dm_router.route.assert_not_awaited() + + @pytest.mark.asyncio + async def test_none_router_logs_warning_and_drops_dm(self, bot_instance): + """DMs received before the router is ready are logged and dropped.""" + bot_instance.dm_router = None + mock_message = MagicMock() + mock_message.author.bot = False + mock_message.guild = None + mock_message.id = 3 + + with patch("typer_bot.bot.logger") as mock_logger: + await bot_instance.on_message(mock_message) + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_thread_handler_takes_priority_over_dm_router(self, bot_instance): + """Thread messages are consumed before reaching the DM router.""" + bot_instance.thread_handler.on_message = AsyncMock(return_value=True) + mock_message = MagicMock() + mock_message.author.bot = False + mock_message.guild = None + mock_message.id = 4 + + await bot_instance.on_message(mock_message) + + bot_instance.dm_router.route.assert_not_awaited() + + class TestOnInteraction: """Test suite for on_interaction event handler.""" diff --git a/tests/test_dm_prediction_handler.py b/tests/test_dm_prediction_handler.py index 41815c0..6193a36 100644 --- a/tests/test_dm_prediction_handler.py +++ b/tests/test_dm_prediction_handler.py @@ -31,20 +31,6 @@ async def test_ignores_guild_messages(self, prediction_handler, mock_message): assert not handled assert len(mock_message.author.dm_sent) == 0 - @pytest.mark.asyncio - async def test_ignores_dms_during_results_entry( - self, prediction_handler, mock_message, workflow_state - ): - mock_message.guild = None - user_id = str(mock_message.author.id) - session = workflow_state.start_results_session(user_id, 1, 123456) - session.created_at = datetime.now(UTC) - - handled = await prediction_handler.handle_dm(mock_message) - - assert not handled - assert len(mock_message.author.dm_sent) == 0 - @pytest.mark.asyncio async def test_rejects_message_too_long(self, prediction_handler, mock_message): mock_message.guild = None diff --git a/tests/test_dm_router.py b/tests/test_dm_router.py new file mode 100644 index 0000000..f12c09b --- /dev/null +++ b/tests/test_dm_router.py @@ -0,0 +1,125 @@ +"""Tests for DM routing precedence.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from typer_bot.services.dm_router import DMRouter + + +def _make_dm_message(user_id: str = "123456", bot: bool = False, in_guild: bool = False): + message = MagicMock() + message.author.id = int(user_id) + message.author.bot = bot + message.guild = MagicMock() if in_guild else None + return message + + +@pytest.fixture +def fixture_handler(): + h = MagicMock() + h.has_session.return_value = False + h.handle_dm = AsyncMock(return_value=True) + return h + + +@pytest.fixture +def results_handler(): + h = MagicMock() + h.has_session.return_value = False + h.handle_dm = AsyncMock(return_value=True) + return h + + +@pytest.fixture +def prediction_handler(): + h = MagicMock() + h.handle_dm = AsyncMock(return_value=True) + return h + + +@pytest.fixture +def router(fixture_handler, results_handler, prediction_handler): + return DMRouter(fixture_handler, results_handler, prediction_handler) + + +class TestRouterIgnoresNonDMs: + @pytest.mark.asyncio + async def test_ignores_bot_messages(self, router): + result = await router.route(_make_dm_message(bot=True)) + assert result is False + + @pytest.mark.asyncio + async def test_ignores_guild_messages(self, router): + result = await router.route(_make_dm_message(in_guild=True)) + assert result is False + + +class TestRoutingPrecedence: + @pytest.mark.asyncio + async def test_fixture_session_routes_to_fixture_handler( + self, router, fixture_handler, results_handler, prediction_handler + ): + fixture_handler.has_session.return_value = True + message = _make_dm_message() + + result = await router.route(message) + + assert result is True + fixture_handler.handle_dm.assert_awaited_once() + results_handler.handle_dm.assert_not_awaited() + prediction_handler.handle_dm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_results_session_routes_to_results_handler( + self, router, fixture_handler, results_handler, prediction_handler + ): + results_handler.has_session.return_value = True + message = _make_dm_message() + + result = await router.route(message) + + assert result is True + results_handler.handle_dm.assert_awaited_once() + fixture_handler.handle_dm.assert_not_awaited() + prediction_handler.handle_dm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_admin_session_falls_through_to_prediction( + self, router, fixture_handler, results_handler, prediction_handler + ): + message = _make_dm_message() + + result = await router.route(message) + + assert result is True + prediction_handler.handle_dm.assert_awaited_once_with(message) + fixture_handler.handle_dm.assert_not_awaited() + results_handler.handle_dm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_fixture_session_takes_precedence_over_results( + self, router, fixture_handler, results_handler + ): + """Fixture check runs first; results handler should never be reached.""" + fixture_handler.has_session.return_value = True + results_handler.has_session.return_value = True + message = _make_dm_message() + + await router.route(message) + + fixture_handler.handle_dm.assert_awaited_once() + results_handler.handle_dm.assert_not_awaited() + + @pytest.mark.asyncio + async def test_admin_session_takes_precedence_over_prediction( + self, router, results_handler, prediction_handler + ): + """Any active admin session blocks the prediction handler.""" + results_handler.has_session.return_value = True + message = _make_dm_message() + + await router.route(message) + + results_handler.handle_dm.assert_awaited_once() + prediction_handler.handle_dm.assert_not_awaited() diff --git a/tests/test_user_commands.py b/tests/test_user_commands.py index 0608b50..a04aaad 100644 --- a/tests/test_user_commands.py +++ b/tests/test_user_commands.py @@ -14,17 +14,6 @@ async def user_commands(mock_bot, database): return UserCommands(mock_bot) -class TestOnMessage: - @pytest.mark.asyncio - async def test_delegates_dm_messages_to_prediction_handler(self, user_commands, mock_message): - handler = AsyncMock(return_value=True) - user_commands.prediction_handler.handle_dm = handler - - await user_commands.on_message(mock_message) - - handler.assert_awaited_once_with(mock_message) - - class TestPredictCommand: @pytest.mark.asyncio async def test_no_fixture_shows_error(self, user_commands, mock_interaction): diff --git a/typer_bot/bot.py b/typer_bot/bot.py index bcfccd3..f66052d 100644 --- a/typer_bot/bot.py +++ b/typer_bot/bot.py @@ -13,6 +13,7 @@ from typer_bot.database import Database from typer_bot.handlers.thread_prediction_handler import ThreadPredictionHandler from typer_bot.services import WorkflowStateStore +from typer_bot.services.dm_router import DMRouter from typer_bot.utils import format_for_discord, now from typer_bot.utils.config import IS_PRODUCTION from typer_bot.utils.logger import set_log_context, set_trace_id @@ -40,6 +41,7 @@ def __init__(self): self.db = Database() self.workflow_state = WorkflowStateStore() self.thread_handler = ThreadPredictionHandler(self, self.db, self.workflow_state) + self.dm_router: DMRouter | None = None logger.info("Database instance created") async def on_interaction(self, interaction: discord.Interaction): @@ -69,7 +71,16 @@ async def on_message(self, message: discord.Message): if handled: return - await super().on_message(message) + if message.guild is None: + # DMs: explicit router owns precedence — no cog listener ordering dependency. + if self.dm_router is None: + logger.warning( + "DM received before router initialised, dropping: user=%s", user_id + ) + return + await self.dm_router.route(message) + else: + await super().on_message(message) finally: from typer_bot.utils.logger import clear_log_context, clear_trace_id @@ -110,6 +121,17 @@ async def setup_hook(self): logger.exception("Failed to load admin_commands") raise + admin_cog = self.cogs.get("AdminCommands") + user_cog = self.cogs.get("UserCommands") + if admin_cog is None or user_cog is None: + raise RuntimeError("Required cogs not loaded before DM router initialisation") + self.dm_router = DMRouter( + admin_cog.fixture_handler, # type: ignore[attr-defined] + admin_cog.results_handler, # type: ignore[attr-defined] + user_cog.prediction_handler, # type: ignore[attr-defined] + ) + logger.info("DM router initialised") + logger.info("Syncing slash commands...") try: synced = await self.tree.sync() diff --git a/typer_bot/commands/admin_commands.py b/typer_bot/commands/admin_commands.py index 580cce2..5823167 100644 --- a/typer_bot/commands/admin_commands.py +++ b/typer_bot/commands/admin_commands.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import logging import discord @@ -18,7 +19,7 @@ from typer_bot.handlers import FixtureCreationHandler, ResultsEntryHandler from typer_bot.services import AdminService, WorkflowStateStore from typer_bot.services.admin_service import FixtureScoreResult -from typer_bot.utils import format_fixture_results, format_standings, is_admin, is_admin_member, now +from typer_bot.utils import format_fixture_results, format_standings, is_admin, now from typer_bot.utils.config import BACKUP_DIR from typer_bot.utils.db_backup import cleanup_old_backups, create_backup @@ -124,6 +125,14 @@ async def _start_fixture_dm( guild_id: int, ) -> bool: """Returns True if the DM was sent successfully.""" + if self.workflow_state.has_results_session(user_id): + with contextlib.suppress(Exception): + await user.send( + "❌ You have an active results entry session. " + "Finish or cancel it before starting a new fixture." + ) + return False + self.fixture_handler.start_session(user_id, channel_id, guild_id) max_week = await self.db.get_max_week_number() predicted_week = max_week + 1 @@ -141,7 +150,7 @@ async def _start_fixture_dm( "```\n" "One game per line." ) - except discord.HTTPException as exc: + except Exception as exc: reason = "dm_forbidden" if isinstance(exc, discord.Forbidden) else "dm_error" self.fixture_handler.cancel_session(user_id, reason=reason) return False @@ -182,21 +191,6 @@ async def _post_calculation_to_channel( "Scores calculated but failed to post to channel.", ephemeral=True ) - @commands.Cog.listener() - async def on_message(self, message: discord.Message): - if message.author.bot or message.guild is not None: - return - - user_id = str(message.author.id) - - if self.fixture_handler.has_session(user_id): - await self.fixture_handler.handle_dm(message, user_id, is_admin_member) - return - - if self.results_handler.has_session(user_id): - await self.results_handler.handle_dm(message, user_id, is_admin_member) - return - admin = app_commands.Group(name="admin", description="Admin commands for managing fixtures") fixture = app_commands.Group(name="fixture", description="Manage fixtures", parent=admin) results = app_commands.Group(name="results", description="Manage results", parent=admin) @@ -221,6 +215,14 @@ async def fixture_create(self, interaction: discord.Interaction): ) return + if self.workflow_state.has_results_session(user_id): + await interaction.response.send_message( + "❌ You have an active results entry session. " + "Finish or cancel it before starting a new fixture.", + ephemeral=True, + ) + return + open_fixtures = await self.db.get_open_fixtures() if open_fixtures: open_weeks = self._format_open_weeks(open_fixtures) @@ -302,6 +304,15 @@ async def results_enter(self, interaction: discord.Interaction, week: int | None "Error: Invalid interaction context.", ephemeral=True ) return + + if self.workflow_state.has_fixture_session(user_id): + await interaction.response.send_message( + "❌ You have an active fixture creation session. " + "Finish or cancel it before entering results.", + ephemeral=True, + ) + return + self.results_handler.start_session( user_id, fixture["id"], interaction.guild_id, fixture["week_number"] ) @@ -329,8 +340,9 @@ async def results_enter(self, interaction: discord.Interaction, week: int | None ] ) await interaction.user.send("\n".join(lines)) - except discord.Forbidden: - self.results_handler.cancel_session(user_id, reason="dm_forbidden") + except Exception as exc: + reason = "dm_forbidden" if isinstance(exc, discord.Forbidden) else "dm_error" + self.results_handler.cancel_session(user_id, reason=reason) await interaction.followup.send( "I can't send you DMs. Please enable DMs from server members and try again.", ephemeral=True, diff --git a/typer_bot/commands/user_commands.py b/typer_bot/commands/user_commands.py index 1a39e47..9a62375 100644 --- a/typer_bot/commands/user_commands.py +++ b/typer_bot/commands/user_commands.py @@ -66,11 +66,6 @@ async def _send_chunked_ephemeral(self, interaction: discord.Interaction, conten for chunk in chunks[1:]: await interaction.followup.send(chunk, ephemeral=True) - @commands.Cog.listener() - async def on_message(self, message: discord.Message): - """Listen for DMs with predictions.""" - await self.prediction_handler.handle_dm(message) - @app_commands.command(name="predict", description="Submit your predictions for open fixtures") @app_commands.checks.cooldown(1, 1.0) async def predict(self, interaction: discord.Interaction): diff --git a/typer_bot/handlers/dm_prediction_handler.py b/typer_bot/handlers/dm_prediction_handler.py index e34c873..470419e 100644 --- a/typer_bot/handlers/dm_prediction_handler.py +++ b/typer_bot/handlers/dm_prediction_handler.py @@ -154,10 +154,6 @@ async def handle_dm(self, message: discord.Message) -> bool: user_id = str(message.author.id) - # Ignore admin result-entry DMs so they do not overwrite stored predictions as late. - if self.workflow_state.has_results_session(user_id): - return False - if len(message.content) > MAX_MESSAGE_LENGTH: await message.author.send(f"❌ Message too long! (max {MAX_MESSAGE_LENGTH} characters)") return True diff --git a/typer_bot/services/dm_router.py b/typer_bot/services/dm_router.py new file mode 100644 index 0000000..c533ef6 --- /dev/null +++ b/typer_bot/services/dm_router.py @@ -0,0 +1,50 @@ +"""Explicit DM routing coordinator. + +Routing precedence (highest → lowest): +1. Admin fixture creation session +2. Admin results entry session +3. User prediction session + +All three handlers are checked against the same WorkflowStateStore, so precedence +is determined by the order of checks here — not by listener registration order across cogs. +""" + +from __future__ import annotations + +import discord + +from typer_bot.handlers.dm_prediction_handler import DMPredictionHandler +from typer_bot.handlers.fixture_handler import FixtureCreationHandler +from typer_bot.handlers.results_handler import ResultsEntryHandler +from typer_bot.utils.permissions import is_admin_member + + +class DMRouter: + """Routes incoming DMs to the correct workflow handler.""" + + def __init__( + self, + fixture_handler: FixtureCreationHandler, + results_handler: ResultsEntryHandler, + prediction_handler: DMPredictionHandler, + ) -> None: + self._fixture_handler = fixture_handler + self._results_handler = results_handler + self._prediction_handler = prediction_handler + + async def route(self, message: discord.Message) -> bool: + """Route a DM to the correct handler. Returns True if consumed.""" + if message.author.bot or message.guild is not None: + return False + + user_id = str(message.author.id) + + if self._fixture_handler.has_session(user_id): + await self._fixture_handler.handle_dm(message, user_id, is_admin_member) + return True + + if self._results_handler.has_session(user_id): + await self._results_handler.handle_dm(message, user_id, is_admin_member) + return True + + return await self._prediction_handler.handle_dm(message)