diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d6a8d10..e11b525 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -24,7 +24,9 @@ dev-dependencies = [ "pytest>=8.3.4", "coverage>=7.8.0", "ruff>=0.9.5", - "pre-commit>=4.2.0" + "pre-commit>=4.2.0", + "pytest-cov>=6.1.1", + "pytest-asyncio>=0.26.0", ] [tool.hatch.build.targets.wheel] diff --git a/backend/src/flashcards/services.py b/backend/src/flashcards/services.py index 2c217a8..c2d3f5f 100644 --- a/backend/src/flashcards/services.py +++ b/backend/src/flashcards/services.py @@ -314,23 +314,6 @@ def get_practice_cards( return practice_cards, count -def get_next_card( - session: Session, practice_session_id: uuid.UUID -) -> tuple[Card, PracticeCard] | None: - statement = ( - select(PracticeCard, Card) - .join(Card, PracticeCard.card_id == Card.id) - .where( - PracticeCard.session_id == practice_session_id, - PracticeCard.is_practiced.is_not(True), - ) - .limit(1) - ) - result = session.exec(statement).first() - if result: - return result[1], result[0] - - def get_practice_card( session: Session, practice_session_id: uuid.UUID, @@ -373,12 +356,6 @@ def record_practice_card_result( return practice_card -def get_session_statistics( - session: Session, practice_session_id: uuid.UUID -) -> PracticeSession: - return session.get(PracticeSession, practice_session_id) - - def get_card_by_id(session: Session, card_id: uuid.UUID) -> Card | None: statement = select(Card).where(Card.id == card_id) return session.exec(statement).first() diff --git a/backend/tests/flashcards/card/__init__.py b/backend/tests/flashcards/card/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/flashcards/card/test_api.py b/backend/tests/flashcards/card/test_api.py new file mode 100644 index 0000000..b4ea6cf --- /dev/null +++ b/backend/tests/flashcards/card/test_api.py @@ -0,0 +1,382 @@ +import uuid +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from src.core.config import settings +from src.flashcards.schemas import CardCreate, CardUpdate, CollectionCreate + + +@pytest.fixture +def test_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> dict[str, Any]: + """Create a testing collection""" + collection_data = CollectionCreate(name="Test Collection") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + return rsp.json() + + +@pytest.fixture +def test_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +) -> dict[str, Any]: + """Create a testing card""" + collection_id = test_collection["id"] + card_data = CardCreate(front="Test front", back="Test back") + rsp = client.post( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/", + json=card_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + return rsp.json() + + +@pytest.fixture +def test_multiple_cards( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +) -> list[dict[str, Any]]: + """Create testing cards""" + collection_id = test_collection["id"] + + cards = [] + for i in range(5): + card_data = CardCreate(front=f"Test front {i}", back=f"Test back {i}") + rsp = client.post( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/", + json=card_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + cards.append(rsp.json()) + return cards + + +def test_create_card_with_nonexistent_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_collection_id = uuid.uuid4() + card_data = CardCreate(front="Test front", back="Test back") + + rsp = client.post( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}/cards/", + json=card_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_read_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + """Read one card""" + + collection_id = test_collection["id"] + card_id = test_card["id"] + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert content + assert content["collection_id"] == collection_id + assert content["id"] == card_id + + +def test_read_nonexistent_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + non_existent_card_id = str(uuid.uuid4()) + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{non_existent_card_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_access( + client: TestClient, + superuser_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + collection_id = test_collection["id"] + card_id = test_card["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + +def test_read_cards( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + """Read cards""" + collection_id = test_collection["id"] + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert "data" in content + assert "count" in content + assert content["count"] >= 1 + + card_ids = [card["id"] for card in content["data"]] + assert test_card["id"] in card_ids + + +def test_read_cards_with_pagination( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_multiple_cards: list[dict[str, Any]], +): + collection_id = test_collection["id"] + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards?skip=2&limit=3", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert content["count"] >= len(test_multiple_cards) + assert len(content["data"]) <= 3 + + +def test_read_cards_with_nonexistent_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_collection_id = uuid.uuid4() + + rsp = client.get( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}/cards", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_update_card_success( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + collection_id = test_collection["id"] + card_id = test_card["id"] + update_data = CardUpdate(front="Front Update", back=test_card["back"]) + + rsp = client.put( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + json=update_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert collection_id == content["collection_id"] + assert card_id == content["id"] + assert content["front"] == update_data.front + # Make sure other data not change + assert content["back"] == test_card["back"] + + +def test_update_nonexistent_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + non_existent_card_id = str(uuid.uuid4()) + update_data = CardUpdate( + front="Nonexistent Card Front", back="Nonexistent Card Back" + ) + + rsp = client.put( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{non_existent_card_id}", + json=update_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_update( + client: TestClient, + normal_user_token_headers: dict[str, str], + superuser_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + collection_id = test_collection["id"] + card_id = test_card["id"] + update_data = CardUpdate(front="Cross Card Front", back="Cross Card Back") + rsp = client.put( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + json=update_data.model_dump(), + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + # Verity the data is still the same before updating + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=normal_user_token_headers, + ) + assert verify_rsp.status_code == 200 + content = verify_rsp.json() + + assert content["collection_id"] == test_collection["id"] + assert content["id"] == test_card["id"] + assert content["front"] == test_card["front"] + assert content["back"] == test_card["back"] + + +def test_delete_card_success( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + collection_id = test_collection["id"] + card_id = test_card["id"] + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 204 + assert rsp.content == b"" + + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=normal_user_token_headers, + ) + assert verify_rsp.status_code == 404 + content = verify_rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_delete_nonexistent_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + non_existent_card_id = str(uuid.uuid4()) + + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{non_existent_card_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_delete( + client: TestClient, + normal_user_token_headers: dict[str, str], + superuser_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_card: dict[str, Any], +): + collection_id = test_collection["id"] + card_id = test_card["id"] + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + # Verity the data still exists + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{card_id}", + headers=normal_user_token_headers, + ) + assert verify_rsp.status_code == 200 + + +def test_deleted_card_not_in_list( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], + test_multiple_cards: list[dict[str, Any]], +): + collection_id = test_collection["id"] + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/", + headers=normal_user_token_headers, + ) + list_before = rsp.json() + + delete_card = test_multiple_cards[0] + + # Check test_card in the card list + assert delete_card["id"] in [card["id"] for card in list_before["data"]] + + # Delete test card + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/{delete_card['id']}", + headers=normal_user_token_headers, + ) + assert rsp.status_code == 204 + + # Re-read the card list + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}/cards/", + headers=normal_user_token_headers, + ) + list_after = rsp.json() + assert list_after["count"] == len(test_multiple_cards) - 1 + assert delete_card["id"] not in [card["id"] for card in list_after["data"]] diff --git a/backend/tests/flashcards/card/test_services.py b/backend/tests/flashcards/card/test_services.py new file mode 100644 index 0000000..51f7725 --- /dev/null +++ b/backend/tests/flashcards/card/test_services.py @@ -0,0 +1,154 @@ +import uuid + +from sqlmodel import Session + +from src.flashcards.models import Card, Collection +from src.flashcards.schemas import CardCreate, CardUpdate +from src.flashcards.services import ( + create_card, + delete_card, + get_card, + get_card_by_id, + get_card_with_collection, + get_cards, + update_card, +) + + +def test_create_card(db: Session, test_collection: Collection): + card_in = CardCreate(front="front", back="back") + card = create_card(session=db, card_in=card_in, collection_id=test_collection.id) + + assert card is not None + assert card.id is not None + assert card.front == card_in.front + assert card.back == card_in.back + assert card.collection_id == test_collection.id + assert card.created_at is not None + assert card.updated_at is not None + + +def test_get_card(db: Session, test_card: Card): + db_card = get_card(session=db, card_id=test_card.id) + + assert db_card is not None + assert db_card.id == test_card.id + assert db_card.front == test_card.front + assert db_card.back == test_card.back + assert db_card.collection_id == test_card.collection_id + assert db_card.updated_at is not None + assert db_card.created_at is not None + + +def test_get_card_not_found(db: Session): + db_card = get_card(session=db, card_id=uuid.uuid4()) + + assert db_card is None + + +def test_get_card_by_id(db: Session, test_card: Card): + db_card = get_card_by_id(session=db, card_id=test_card.id) + + assert db_card is not None + assert db_card.id == test_card.id + + +def test_get_card_by_nonexistent_id(db: Session): + non_existent_card_id = uuid.uuid4() + db_card = get_card_by_id(session=db, card_id=non_existent_card_id) + + assert db_card is None + + +def test_get_card_with_collection( + db: Session, test_collection: Collection, test_card: Card +): + db_card = get_card_with_collection( + session=db, card_id=test_card.id, user_id=test_collection.user_id + ) + + assert db_card is not None + + +def test_get_card_with_wrong_collection(db: Session, test_card: Card): + db_card = get_card_with_collection( + session=db, + card_id=test_card.id, + user_id=uuid.uuid4(), + ) + + assert db_card is None + + +def test_get_cards( + db: Session, test_collection: Collection, test_multiple_cards: list[Card] +): + limit = 3 + db_cards, count = get_cards( + session=db, collection_id=test_collection.id, limit=limit + ) + + assert len(db_cards) == limit + assert count == len(test_multiple_cards) + # Verify the order + for i in range(len(db_cards) - 1): + assert db_cards[i].updated_at >= db_cards[i + 1].updated_at + + +def test_get_cards_skip( + db: Session, test_collection: Collection, test_multiple_cards: list[Card] +): + limit = 3 + skip = 2 + db_cards, count = get_cards( + session=db, collection_id=test_collection.id, skip=skip, limit=limit + ) + + assert len(db_cards) == limit + assert count == len(test_multiple_cards) + + +def test_get_cards_empty(db: Session, test_collection: Collection): + db_cards, count = get_cards(session=db, collection_id=test_collection.id) + + assert len(db_cards) == 0 + assert count == 0 + + +def test_update_card(db: Session, test_card: Card): + original_updated_at = test_card.updated_at + + import time + + time.sleep(0.01) + + card_in = CardUpdate(front="Update front", back="Update back") + updated_card = update_card(session=db, card=test_card, card_in=card_in) + + assert updated_card.front == card_in.front + assert updated_card.back == card_in.back + assert updated_card.updated_at > original_updated_at + + +def test_update_card_partial(db: Session, test_card: Card): + original_updated_at = test_card.updated_at + + import time + + time.sleep(0.01) + + card_in = CardUpdate(front="Update front") + updated_card = update_card(session=db, card=test_card, card_in=card_in) + + assert updated_card.front == card_in.front + assert updated_card.back == test_card.back + assert updated_card.updated_at > original_updated_at + + +def test_delete_card(db: Session, test_collection: Collection, test_card: Card): + delete_card(session=db, card=test_card) + + card = get_card_with_collection( + session=db, card_id=test_card.id, user_id=test_collection.user_id + ) + assert card is None diff --git a/backend/tests/flashcards/collection/__init__.py b/backend/tests/flashcards/collection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/flashcards/collection/test_api.py b/backend/tests/flashcards/collection/test_api.py new file mode 100644 index 0000000..22ab547 --- /dev/null +++ b/backend/tests/flashcards/collection/test_api.py @@ -0,0 +1,383 @@ +import uuid +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.ai_models.gemini.exceptions import AIGenerationError +from src.core.config import settings +from src.flashcards.schemas import Card, Collection, CollectionCreate, CollectionUpdate + + +@pytest.fixture +def test_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> dict[str, Any]: + """Create a testing collection""" + collection_data = CollectionCreate(name="Test Collection") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + return rsp.json() + + +@pytest.fixture +def test_multiple_collections( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> list[dict[str, Any]]: + """Create multiple testing collections""" + + collections = [] + for i in range(5): + collection_data = CollectionCreate(name=f"Test Collection {i}") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + collections.append(rsp.json()) + return collections + + +@pytest.fixture +def mock_collection() -> Collection: + collection_id = uuid.uuid4() + collection = Collection( + id=collection_id, + name="AI collection", + user_id=uuid.uuid4(), + cards=[ + Card( + id=uuid.uuid4(), + front="front1", + back="back1", + collection_id=collection_id, + ), + Card( + id=uuid.uuid4(), + front="front2", + back="back2", + collection_id=collection_id, + ), + ], + ) + return collection + + +def test_create_collection_with_prompt( + client: TestClient, + normal_user_token_headers: dict[str, str], + mock_collection: Collection, +): + collection_data = CollectionCreate( + name="AI collection", prompt="Create flashcards about pytest" + ) + + with patch( + "src.flashcards.services.generate_ai_collection", new_callable=AsyncMock + ) as mock_ai_generate: + mock_ai_generate.return_value = mock_collection + + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert content["name"] == collection_data.name + assert content["id"] == str(mock_collection.id) + + mock_ai_generate.assert_called_once() + + +def test_create_collection_with_ai_generation_error( + client: TestClient, normal_user_token_headers: dict[str, str] +): + collection_data = CollectionCreate( + name="Test AI Error", prompt="Create flashcards but fail with AI error" + ) + + with patch( + "src.flashcards.services.generate_ai_collection", new_callable=AsyncMock + ) as mock_ai_generate: + err_msg = "AI service is unavailable" + mock_ai_generate.side_effect = AIGenerationError(err_msg) + + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 500 + content = rsp.json() + assert "detail" in content + assert err_msg in content["detail"] + + +def test_read_collection( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + """Read one collection""" + + collection_id = test_collection["id"] + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert content + assert content["user_id"] + assert content["id"] == collection_id + + +def test_read_nonexistent_collection( + client: TestClient, + normal_user_token_headers: dict[str, str], +): + non_existent_collection_id = str(uuid.uuid4()) + rsp = client.get( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_access( + client: TestClient, + normal_user_token_headers: dict[str, str], + superuser_token_headers: dict[str, str], +): + collection_data = CollectionCreate(name="User Restricted Collection") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + assert rsp.status_code == 200 + content = rsp.json() + collection_id = content["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + +def test_read_collections( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + """Read collections""" + rsp = client.get( + f"{settings.API_V1_STR}/collections/", headers=normal_user_token_headers + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert "data" in content + assert "count" in content + assert content["count"] >= 1 + + collection_ids = [collection["id"] for collection in content["data"]] + assert test_collection["id"] in collection_ids + + +def test_read_collections_with_pagination( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_multiple_collections: list[dict[str, Any]], +): + rsp = client.get( + f"{settings.API_V1_STR}/collections?skip=2&limit=2", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert content["count"] >= len(test_multiple_collections) + assert len(content["data"]) <= 2 + + +def test_update_collection_success( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + update_data = CollectionUpdate(name="Update Collection") + + rsp = client.put( + f"{settings.API_V1_STR}/collections/{collection_id}", + json=update_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert collection_id == content["id"] + assert content["name"] == update_data.name + # Make sure other data not change + assert len(content["cards"]) == len(test_collection["cards"]) + assert content["user_id"] == test_collection["user_id"] + + +def test_update_nonexistent_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_collection_id = str(uuid.uuid4()) + update_data = CollectionUpdate(name="Nonexistent Collection") + + rsp = client.put( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}", + json=update_data.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_update( + client: TestClient, + normal_user_token_headers: dict[str, str], + superuser_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + update_data = CollectionUpdate(name="Cross User Collection Update") + rsp = client.put( + f"{settings.API_V1_STR}/collections/{collection_id}", + json=update_data.model_dump(), + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + # Verity the data is still the same before updating + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=normal_user_token_headers, + ) + content = verify_rsp.json() + + assert content["name"] == test_collection["name"] + + +def test_delete_collection_success( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 204 + assert rsp.content == b"" + + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=normal_user_token_headers, + ) + assert verify_rsp.status_code == 404 + content = verify_rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_delete_nonexistent_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_collection_id = str(uuid.uuid4()) + + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_different_user_delete( + client: TestClient, + normal_user_token_headers: dict[str, str], + superuser_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=superuser_token_headers, + ) + + assert rsp.status_code == 404 + + # Verity the data still exists + verify_rsp = client.get( + f"{settings.API_V1_STR}/collections/{collection_id}", + headers=normal_user_token_headers, + ) + assert verify_rsp.status_code == 200 + + +def test_deleted_collection_not_in_list( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_multiple_collections: list[dict[str, Any]], +): + rsp = client.get( + f"{settings.API_V1_STR}/collections/", headers=normal_user_token_headers + ) + list_before = rsp.json() + + delete_collection = test_multiple_collections[0] + + # Check test_collection in the collection list + assert delete_collection["id"] in [ + collection["id"] for collection in list_before["data"] + ] + + # Delete test collection + rsp = client.delete( + f"{settings.API_V1_STR}/collections/{delete_collection['id']}", + headers=normal_user_token_headers, + ) + assert rsp.status_code == 204 + + # Re-read the collection list + rsp = client.get( + f"{settings.API_V1_STR}/collections/", headers=normal_user_token_headers + ) + list_after = rsp.json() + assert list_after["count"] == len(test_multiple_collections) - 1 + assert delete_collection["id"] not in [ + collection["id"] for collection in list_after["data"] + ] diff --git a/backend/tests/flashcards/collection/test_services.py b/backend/tests/flashcards/collection/test_services.py new file mode 100644 index 0000000..a871220 --- /dev/null +++ b/backend/tests/flashcards/collection/test_services.py @@ -0,0 +1,149 @@ +import uuid +from typing import Any + +from sqlmodel import Session + +from src.flashcards.models import Collection +from src.flashcards.schemas import CollectionCreate, CollectionUpdate +from src.flashcards.services import ( + check_collection_access, + create_collection, + delete_collection, + get_collection, + get_collections, + update_collection, +) + + +def test_create_collection(db: Session, test_user: dict[str, Any]): + collection_in = CollectionCreate(name="Test Collection") + collection = create_collection( + session=db, + collection_in=collection_in, + user_id=test_user["id"], + ) + + assert collection.id is not None + assert collection.name == collection_in.name + assert collection.user_id == test_user["id"] + assert collection.created_at is not None + assert collection.updated_at is not None + + +def test_get_collection( + db: Session, test_user: dict[str, Any], test_collection: Collection +): + db_collection = get_collection( + session=db, id=test_collection.id, user_id=test_user["id"] + ) + assert db_collection is not None + assert db_collection.id == test_collection.id + assert db_collection.name == test_collection.name + assert db_collection.user_id == test_user["id"] + + +def test_get_collection_not_found(db: Session, test_user: dict[str, Any]): + db_collection = get_collection(session=db, id=uuid.uuid4(), user_id=test_user["id"]) + assert db_collection is None + + +def test_get_collection_with_wrong_user(db: Session, test_other_user: dict[str, Any]): + db_collection = get_collection( + session=db, id=uuid.uuid4(), user_id=test_other_user["id"] + ) + assert db_collection is None + + +def test_get_collections( + db: Session, test_user: dict[str, Any], test_multiple_collections: list[Collection] +): + limit = 2 + db_collections, count = get_collections( + session=db, user_id=test_user["id"], limit=limit + ) + + assert len(db_collections) == limit + assert count == len(test_multiple_collections) + + # Verify the order + for i in range(len(db_collections) - 1): + assert db_collections[i].updated_at >= db_collections[i + 1].updated_at + + +def test_get_collection_skip( + db: Session, test_user: dict[str, Any], test_multiple_collections: list[Collection] +): + skip = 2 + limit = 2 + + db_collections, count = get_collections( + session=db, user_id=test_user["id"], skip=skip, limit=limit + ) + + assert len(db_collections) == limit + assert count == len(test_multiple_collections) + + +def test_get_collection_empty(db: Session, test_user: dict[str, Any]): + db_collections, count = get_collections(session=db, user_id=test_user["id"]) + + assert len(db_collections) == 0 + assert count == 0 + + +def test_update_collection(db: Session, test_collection: Collection): + original_updated_at = test_collection.updated_at + + import time + + time.sleep(0.01) + + collection_in = CollectionUpdate(name="Updated Collection") + collection_update = update_collection( + session=db, collection=test_collection, collection_in=collection_in + ) + + assert collection_update is not None + assert collection_update.name == collection_in.name + assert collection_update.updated_at > original_updated_at + + +def test_delete_collection( + db: Session, test_user: dict[str, Any], test_collection: Collection +): + delete_collection(session=db, collection=test_collection) + + db_collection = get_collection( + session=db, id=test_collection.id, user_id=test_user["id"] + ) + + assert db_collection is None + + +def test_check_collection_access(db: Session, test_collection: Collection): + can_access = check_collection_access( + session=db, collection_id=test_collection.id, user_id=test_collection.user_id + ) + + assert can_access is True + + +def test_check_collection_access_with_nonexistent_collection( + db: Session, test_user: dict[str, Any] +): + non_existent_collection_id = uuid.uuid4() + can_access = check_collection_access( + session=db, collection_id=non_existent_collection_id, user_id=test_user["id"] + ) + + assert can_access is False + + +def test_check_collection_access_with_other_user( + db: Session, test_collection: Collection, test_other_user: dict[str, Any] +): + can_access = check_collection_access( + session=db, collection_id=test_collection.id, user_id=test_other_user["id"] + ) + + assert can_access is False diff --git a/backend/tests/flashcards/conftest.py b/backend/tests/flashcards/conftest.py new file mode 100644 index 0000000..e645e05 --- /dev/null +++ b/backend/tests/flashcards/conftest.py @@ -0,0 +1,105 @@ +from typing import Any + +import pytest +from sqlmodel import Session + +from src.flashcards.models import Card, Collection +from src.flashcards.schemas import CardCreate, CollectionCreate +from src.flashcards.services import create_card, create_collection +from src.users.schemas import UserCreate +from src.users.services import create_user +from tests.utils.utils import random_email, random_lower_string + + +@pytest.fixture +def test_user(db: Session) -> dict[str, Any]: + email = random_email() + password = random_lower_string() + full_name = random_lower_string() + + user_in = UserCreate(email=email, password=password, full_name=full_name) + user = create_user(session=db, user_create=user_in) + + return {"id": user.id, "email": user.email} + + +@pytest.fixture +def test_other_user(db: Session) -> dict[str, Any]: + email = random_email() + password = random_lower_string() + full_name = random_lower_string() + + user_in = UserCreate(email=email, password=password, full_name=full_name) + user = create_user(session=db, user_create=user_in) + + return {"id": user.id, "email": user.email} + + +@pytest.fixture +def test_collection(db: Session, test_user: dict[str, Any]) -> Collection: + collection_in = CollectionCreate(name="Test Collection") + collection = create_collection( + session=db, + collection_in=collection_in, + user_id=test_user["id"], + ) + + return collection + + +@pytest.fixture +def test_collection_with_multiple_cards( + db: Session, test_user: dict[str, Any] +) -> Collection: + collection_in = CollectionCreate(name="Test Collection") + collection = create_collection( + session=db, + collection_in=collection_in, + user_id=test_user["id"], + ) + + for i in range(5): + card_in = CardCreate(front=f"front {i}", back=f"back {i}") + create_card(session=db, card_in=card_in, collection_id=collection.id) + + return collection + + +@pytest.fixture +def test_multiple_collections( + db: Session, test_user: dict[str, Any] +) -> list[Collection]: + collections = [] + + for i in range(5): + collection_in = CollectionCreate(name=f"Test Collection {i}") + collection = create_collection( + session=db, + collection_in=collection_in, + user_id=test_user["id"], + ) + card_in = CardCreate(front=f"front {i}", back=f"back {i}") + create_card(session=db, card_in=card_in, collection_id=collection.id) + collections.append(collection) + + return collections + + +@pytest.fixture +def test_card(db: Session, test_collection: Collection) -> Card: + card_in = CardCreate(front="front", back="back") + card = create_card(session=db, card_in=card_in, collection_id=test_collection.id) + return card + + +@pytest.fixture +def test_multiple_cards(db: Session, test_collection: Collection) -> list[Card]: + cards = [] + for i in range(5): + card_in = CardCreate(front=f"front {i}", back=f"back {i}") + card = create_card( + session=db, card_in=card_in, collection_id=test_collection.id + ) + cards.append(card) + + return cards diff --git a/backend/tests/flashcards/practice_session/__init__.py b/backend/tests/flashcards/practice_session/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/flashcards/practice_session/test_api.py b/backend/tests/flashcards/practice_session/test_api.py new file mode 100644 index 0000000..e62411e --- /dev/null +++ b/backend/tests/flashcards/practice_session/test_api.py @@ -0,0 +1,509 @@ +import uuid +from datetime import datetime +from typing import Any +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +from src.core.config import settings +from src.flashcards.schemas import ( + CardCreate, + CollectionCreate, + PracticeCard, + PracticeCardResultPatch, + PracticeSession, +) + + +@pytest.fixture +def test_collection( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> dict[str, Any]: + """Create a testing collection""" + collection_data = CollectionCreate(name="Test Collection") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + assert rsp.status_code == 200 + collection = rsp.json() + + card_count = 3 + for i in range(card_count): + card_data = CardCreate(front=f"Test front {i}", back=f"Test back {i}") + rsp = client.post( + f"{settings.API_V1_STR}/collections/{collection['id']}/cards", + json=card_data.model_dump(), + headers=normal_user_token_headers, + ) + assert rsp.status_code == 200 + + return collection + + +@pytest.fixture +def test_practice_session( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +) -> dict[str, Any]: + """Create a testing practice session""" + collection_id = test_collection["id"] + + rsp = client.post( + f"{settings.API_V1_STR}/practice-sessions", + json={"collection_id": str(collection_id)}, + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + return rsp.json() + + +@pytest.fixture +def mock_completed_session() -> PracticeSession: + session_id = uuid.uuid4() + return PracticeSession( + id=session_id, + collection_id=uuid.uuid4(), + user_id=uuid.uuid4(), + is_completed=True, + total_cards=2, + cards_practiced=2, + correct_answers=2, + created_at=datetime.now(), + updated_at=datetime.now(), + practice_cards=[ + PracticeCard( + id=uuid.uuid4(), + card_id=uuid.uuid4(), + session_id=session_id, + is_correct=True, + is_practiced=True, + created_at=datetime.now(), + updated_at=datetime.now(), + ), + PracticeCard( + id=uuid.uuid4(), + card_id=uuid.uuid4(), + session_id=session_id, + is_correct=True, + is_practiced=True, + created_at=datetime.now(), + updated_at=datetime.now(), + ), + ], + ) + + +@pytest.fixture +def multiple_practice_sessions( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> dict[str, Any]: + practice_session_count = 3 + + # Create collections and sessions + for i in range(practice_session_count): + # Create collection first + collection_data = CollectionCreate(name=f"Test Collection {i}") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + collection = rsp.json() + collection_id = collection["id"] + + # Add card into collection + card_data = CardCreate(front=f"Test front {i}", back=f"Test back {i}") + rsp = client.post( + f"{settings.API_V1_STR}/collections/{collection_id}/cards", + json=card_data.model_dump(), + headers=normal_user_token_headers, + ) + + # Create practice session + rsp = client.post( + f"{settings.API_V1_STR}/practice-sessions", + json={"collection_id": str(collection_id)}, + headers=normal_user_token_headers, + ) + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions", headers=normal_user_token_headers + ) + sessions = rsp.json() + return sessions + + +def test_start_practice_session( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_collection: dict[str, Any], +): + collection_id = test_collection["id"] + + rsp = client.post( + f"{settings.API_V1_STR}/practice-sessions", + json={"collection_id": str(collection_id)}, + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + session = rsp.json() + assert session["collection_id"] == collection_id + assert session["is_completed"] is False + assert session["total_cards"] > 0 + assert session["cards_practiced"] == 0 + assert session["correct_answers"] == 0 + + +def test_start_practice_with_empty_collection( + client: TestClient, normal_user_token_headers +): + # Create collection first + collection_data = CollectionCreate(name="Empty Collection") + rsp = client.post( + f"{settings.API_V1_STR}/collections/", + json=collection_data.model_dump(), + headers=normal_user_token_headers, + ) + + collection = rsp.json() + collection_id = collection["id"] + rsp = client.post( + f"{settings.API_V1_STR}/practice-sessions", + json={"collection_id": str(collection_id)}, + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 400 + content = rsp.json() + assert "detail" in content + assert "empty collection" in content["detail"] + + +def test_start_practice_session_with_nonexistent_collection( + client: TestClient, + normal_user_token_headers: dict[str, str], +): + non_existent_collection_id = uuid.uuid4() + + rsp = client.post( + f"{settings.API_V1_STR}/practice-sessions", + json={"collection_id": str(non_existent_collection_id)}, + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_list_practice_sessions( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + """List practice session""" + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions", headers=normal_user_token_headers + ) + + assert rsp.status_code == 200 + content = rsp.json() + assert "data" in content + assert "count" in content + assert len(content["data"]) >= 1 + + session_ids = [session["id"] for session in content["data"]] + assert test_practice_session["id"] in session_ids + + +def test_list_practice_session_with_pagination( + client: TestClient, + normal_user_token_headers: dict[str, str], + multiple_practice_sessions: dict[str, Any], +): + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions?skip=0&limit=2", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + + assert len(content["data"]) <= 2 + assert content["count"] == multiple_practice_sessions["count"] + + +def test_get_practice_session_status( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + session_id = test_practice_session["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + + assert session_id == content["id"] + assert "total_cards" in content + assert "cards_practiced" in content + assert "correct_answers" in content + + +def test_get_nonexistent_practice_session( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_session_id = str(uuid.uuid4()) + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{non_existent_session_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_list_practice_cards( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + session_id = test_practice_session["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + + # Verify PracticeSessionList response + assert "data" in content + assert "count" in content + assert content["count"] > 0 + + # Verify PracticeCardResponse + card = content["data"][0] + assert "card" in card + assert "is_practiced" in card + assert "is_correct" in card + + assert card["is_practiced"] is False + assert card["is_correct"] is None + + +def test_list_practice_cards_with_nonexistent_session( + client: TestClient, + normal_user_token_headers: dict[str, str], +): + non_existent_session_id = str(uuid.uuid4()) + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{non_existent_session_id}/cards", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_list_practice_cards_with_status_filter( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + session_id = test_practice_session["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + content = rsp.json() + + cards = content["practice_cards"] + card_id = cards[0]["card_id"] + + # Mark card as practiced and correct + practice_result = PracticeCardResultPatch(is_correct=True) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards/{card_id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards?status=completed", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + complete_cards = rsp.json()["data"] + + assert len(complete_cards) == 1 + assert complete_cards[0]["is_practiced"] is True + assert complete_cards[0]["is_correct"] is True + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards?status=pending", + headers=normal_user_token_headers, + ) + + pending_cards = rsp.json()["data"] + + assert len(pending_cards) == 2 + for card in pending_cards: + assert card["is_practiced"] is False + + +def test_update_practice_card_result( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + session_id = test_practice_session["id"] + + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards", + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + cards = rsp.json()["data"] + card_id = cards[0]["card"]["id"] + + # Mark card0 as practiced and correct + practice_result = PracticeCardResultPatch(is_correct=True) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards/{card_id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + result = rsp.json() + + assert result["is_practiced"] is True + assert result["is_correct"] is True + + # Check the session stats are updated + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}", + headers=normal_user_token_headers, + ) + + session = rsp.json() + assert session["cards_practiced"] == 1 + assert session["correct_answers"] == 1 + + card_id = cards[1]["card"]["id"] + + # Mark card1 as practiced and correct + practice_result = PracticeCardResultPatch(is_correct=False) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards/{card_id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 200 + result = rsp.json() + + assert result["is_practiced"] is True + assert result["is_correct"] is False + + # Check the session stats are updated + rsp = client.get( + f"{settings.API_V1_STR}/practice-sessions/{session_id}", + headers=normal_user_token_headers, + ) + + session = rsp.json() + assert session["cards_practiced"] == 2 + assert session["correct_answers"] == 1 + + +def test_update_practice_card_with_nonexistent_session( + client: TestClient, normal_user_token_headers: dict[str, str] +): + non_existent_session_id = str(uuid.uuid4()) + card_id = str(uuid.uuid4()) + + practice_result = PracticeCardResultPatch(is_correct=True) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{non_existent_session_id}/cards/{card_id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_update_practice_card_with_nonexistent_card( + client: TestClient, + normal_user_token_headers: dict[str, str], + test_practice_session: dict[str, Any], +): + session_id = test_practice_session["id"] + non_existent_card_id = str(uuid.uuid4()) + + practice_result = PracticeCardResultPatch(is_correct=True) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{session_id}/cards/{non_existent_card_id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + assert rsp.status_code == 404 + content = rsp.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_update_practice_card_with_completed_session( + client: TestClient, + normal_user_token_headers: dict[str, str], + mock_completed_session: PracticeSession, +): + with patch("src.flashcards.services.get_practice_session") as mock_session: + mock_session.return_value = mock_completed_session + + practice_result = PracticeCardResultPatch(is_correct=True) + rsp = client.patch( + f"{settings.API_V1_STR}/practice-sessions/{mock_completed_session.id}/cards/{mock_completed_session.practice_cards[0].id}", + json=practice_result.model_dump(), + headers=normal_user_token_headers, + ) + + mock_session.assert_called_once() + + assert rsp.status_code == 400 + content = rsp.json() + + assert "detail" in content + assert "completed" in content["detail"] diff --git a/backend/tests/flashcards/practice_session/test_services.py b/backend/tests/flashcards/practice_session/test_services.py new file mode 100644 index 0000000..877f405 --- /dev/null +++ b/backend/tests/flashcards/practice_session/test_services.py @@ -0,0 +1,423 @@ +import json +import uuid +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlmodel import Session + +from src.ai_models.gemini.exceptions import AIGenerationError +from src.ai_models.gemini.provider import GeminiProvider +from src.flashcards.models import Card, Collection, PracticeSession +from src.flashcards.schemas import ( + AIFlashcard, + AIFlashcardCollection, +) +from src.flashcards.services import ( + _generate_ai_flashcards, + _get_collection_cards, + _save_ai_collection, + generate_ai_collection, + get_or_create_practice_session, + get_practice_card, + get_practice_cards, + get_practice_session, + get_practice_sessions, + record_practice_card_result, +) + + +def mock_get_provider() -> Generator[GeminiProvider, None, None]: + mock_provider = AsyncMock(spec=GeminiProvider) + yield mock_provider + + +@pytest.fixture +def test_practice_session( + db: Session, + test_collection_with_multiple_cards: Collection, +) -> PracticeSession: + session = get_or_create_practice_session( + session=db, + collection_id=test_collection_with_multiple_cards.id, + user_id=test_collection_with_multiple_cards.user_id, + ) + return session + + +@pytest.fixture +def test_multiple_practice_sessions( + db: Session, + test_multiple_collections: list[Collection], +) -> list[PracticeSession]: + sessions = [] + for i in range(3): + session = get_or_create_practice_session( + session=db, + collection_id=test_multiple_collections[i].id, + user_id=test_multiple_collections[i].user_id, + ) + sessions.append(session) + + return sessions + + +def test_get_or_create_practice_sessions( + db: Session, + test_collection_with_multiple_cards: Collection, + test_user: dict[str, Any], +): + session = get_or_create_practice_session( + session=db, + collection_id=test_collection_with_multiple_cards.id, + user_id=test_collection_with_multiple_cards.user_id, + ) + + assert session is not None + assert session.collection_id == test_collection_with_multiple_cards.id + assert session.user_id == test_user["id"] + assert session.is_completed is False + assert session.total_cards == len(test_collection_with_multiple_cards.cards) + assert session.cards_practiced == 0 + assert session.correct_answers == 0 + assert session.created_at is not None + assert session.updated_at is not None + + +def test_get_practice_session( + db: Session, + test_practice_session: PracticeSession, + test_user: dict[str, Any], +): + session = get_practice_session( + session=db, + session_id=test_practice_session.id, + user_id=test_practice_session.user_id, + ) + + assert session is not None + assert session.collection_id == test_practice_session.collection_id + assert session.user_id == test_user["id"] + assert session.is_completed is False + assert session.total_cards == test_practice_session.total_cards + assert session.cards_practiced == 0 + assert session.correct_answers == 0 + assert session.created_at is not None + assert session.updated_at is not None + + +def test_get_practice_sessions( + db: Session, + test_multiple_practice_sessions: list[PracticeSession], + test_user: dict[str, Any], +): + limit = 2 + db_sessions, count = get_practice_sessions( + session=db, user_id=test_user["id"], limit=limit + ) + + assert count == len(test_multiple_practice_sessions) + assert limit == len(db_sessions) + # Verify the order + for i in range(len(db_sessions) - 1): + assert db_sessions[i].updated_at >= db_sessions[i + 1].updated_at + + +def test_get_practice_sessions_with_skip( + db: Session, + test_multiple_practice_sessions: list[PracticeSession], + test_user: dict[str, Any], +): + skip = 2 + limit = 2 + db_sessions, count = get_practice_sessions( + session=db, user_id=test_user["id"], skip=skip, limit=limit + ) + + assert count == len(test_multiple_practice_sessions) + assert 1 == len(db_sessions) # Due to the total sessions is 3 + + +def test_get_collection_cards( + db: Session, + test_collection_with_multiple_cards: Collection, +): + cards = _get_collection_cards( + session=db, collection_id=test_collection_with_multiple_cards.id + ) + + assert len(cards) == len(test_collection_with_multiple_cards.cards) + + +def test_get_practice_card(db: Session, test_collection: Collection, test_card: Card): + session = get_or_create_practice_session( + session=db, collection_id=test_collection.id, user_id=test_collection.user_id + ) + + card = get_practice_card( + session=db, + practice_session_id=session.id, + card_id=test_card.id, + ) + + assert card is not None + assert card.card_id == test_card.id + assert card.is_correct is None + assert card.is_practiced is False + + +def test_get_nonexistent_practice_card( + db: Session, test_practice_session: PracticeSession +): + non_existent_card_id = uuid.uuid4() + card = get_practice_card( + session=db, + practice_session_id=test_practice_session.id, + card_id=non_existent_card_id, + ) + + assert card is None + + +def test_get_practice_cards(db: Session, test_practice_session: PracticeSession): + limit = 3 + cards, count = get_practice_cards( + session=db, practice_session_id=test_practice_session.id, limit=limit + ) + + assert limit == len(cards) + assert count == len(test_practice_session.practice_cards) + for card in cards: + assert card.is_practiced is False + assert card.is_correct is None + + +def test_get_practice_cards_with_status( + db: Session, test_practice_session: PracticeSession +): + cards, count = get_practice_cards( + session=db, practice_session_id=test_practice_session.id + ) + + record_practice_card_result(session=db, practice_card=cards[0], is_correct=True) + record_practice_card_result(session=db, practice_card=cards[1], is_correct=True) + + complete_count = 2 + cards, count = get_practice_cards( + session=db, practice_session_id=test_practice_session.id, status="completed" + ) + + assert count == complete_count + assert len(cards) == complete_count + for card in cards: + assert card.is_practiced is True + assert card.is_correct is True + + cards, count = get_practice_cards( + session=db, practice_session_id=test_practice_session.id, status="pending" + ) + + assert len(cards) == len(test_practice_session.practice_cards) - complete_count + for card in cards: + assert card.is_practiced is False + assert card.is_correct is None + + +def test_get_practice_card_with_asc_order( + db: Session, test_practice_session: PracticeSession +): + cards, _ = get_practice_cards( + session=db, practice_session_id=test_practice_session.id, order="asc" + ) + + # Verify the order + for i in range(len(cards) - 1): + assert cards[i].updated_at <= cards[i + 1].updated_at + + +def test_get_practice_card_with_desc_order( + db: Session, test_practice_session: PracticeSession +): + cards, _ = get_practice_cards( + session=db, practice_session_id=test_practice_session.id, order="desc" + ) + + # Verify the order + for i in range(len(cards) - 1): + assert cards[i].updated_at >= cards[i + 1].updated_at + + +def test_record_practice_card_result( + db: Session, test_practice_session: PracticeSession +): + before_card = test_practice_session.practice_cards[0] + original_updated_at = before_card.updated_at + + import time + + time.sleep(0.01) + + after_card = record_practice_card_result( + session=db, practice_card=before_card, is_correct=True + ) + + session = get_practice_session( + session=db, + session_id=test_practice_session.id, + user_id=test_practice_session.user_id, + ) + + assert session.correct_answers == 1 + assert session.cards_practiced == 1 + assert session.total_cards == len(test_practice_session.practice_cards) + assert session.is_completed is False + assert after_card.updated_at > original_updated_at + assert after_card.is_correct is True + assert after_card.is_practiced is True + + for card in test_practice_session.practice_cards[0:]: + record_practice_card_result(session=db, practice_card=card, is_correct=True) + + session = get_practice_session( + session=db, + session_id=test_practice_session.id, + user_id=test_practice_session.user_id, + ) + + # assert session.is_completed is True + assert session.cards_practiced == len(test_practice_session.practice_cards) + assert session.correct_answers == len(test_practice_session.practice_cards) + + +@pytest.mark.asyncio +async def test_generate_ai_collection(): + mock_session = MagicMock() + mock_provider = AsyncMock() + user_id = uuid.uuid4() + + test_cards = [AIFlashcard(front="Question 1", back="Answer 1")] + ai_collection = AIFlashcardCollection(name="AI Generated", cards=test_cards) + db_collection = Collection(name="AI Generated", user_id=user_id) + + with ( + patch( + "src.flashcards.services._generate_ai_flashcards", new_callable=AsyncMock + ) as mock_generate, + patch("src.flashcards.services._save_ai_collection") as mock_save, + ): + mock_generate.return_value = ai_collection + mock_save.return_value = db_collection + + result = await generate_ai_collection( + session=mock_session, + user_id=user_id, + prompt="Create pytest flashcards", + provider=mock_provider, + ) + + assert result == db_collection + + mock_generate.assert_called_once_with(mock_provider, "Create pytest flashcards") + mock_save.assert_called_once_with(mock_session, user_id, ai_collection) + + +def test_save_ai_collection(): + mock_session = MagicMock() + user_id = uuid.uuid4() + + def mock_refresh(obj): + if isinstance(obj, Collection) and obj.id is None: + obj.id = uuid.uuid4() + + mock_session.refresh.side_effect = mock_refresh + + test_cards = [ + AIFlashcard(front="Question 1", back="Answer 1"), + AIFlashcard(front="Question 2", back="Answer 2"), + ] + + test_collection = AIFlashcardCollection(name="Test Collection", cards=test_cards) + + result = _save_ai_collection(mock_session, user_id, test_collection) + + assert result.name == "Test Collection" + assert result.user_id == user_id + + assert mock_session.add.call_count == 3 # 1 collection + 2 cards + assert mock_session.commit.call_count == 2 + assert mock_session.refresh.call_count == 2 # Collection refreshed twice + + +@pytest.mark.asyncio +async def test_generate_ai_flashcards_success(): + mock_provider = AsyncMock() + mock_get_config = MagicMock() + + # Setup run_model return value + valid_json_response = json.dumps( + { + "collection": { + "name": "AI Collection", + "cards": [ + {"front": "Question 1", "back": "Answer 1"}, + {"front": "Question 2", "back": "Answer 2"}, + ], + } + } + ) + mock_provider.run_model.return_value = valid_json_response + + with patch( + "src.flashcards.services.get_flashcard_config", return_value=mock_get_config + ): + result = await _generate_ai_flashcards( + mock_provider, "Create flashcards about Pytest" + ) + + assert result.name == "AI Collection" + assert len(result.cards) == 2 + assert result.cards[0].front == "Question 1" + assert result.cards[0].back == "Answer 1" + + mock_provider.run_model.assert_called_once_with( + mock_get_config, "Create flashcards about Pytest" + ) + + +@pytest.mark.asyncio +async def test_generate_ai_flashcards_invalid_json(): + mock_provider = AsyncMock() + mock_provider.run_model.return_value = "Invalid JSON" + + with patch("src.flashcards.services.get_flashcard_config"): + with pytest.raises( + AIGenerationError, match="Failed to parse AI response as JSON" + ): + await _generate_ai_flashcards(mock_provider, "Create flashcards") + + +@pytest.mark.asyncio +async def test_generate_ai_flashcards_missing_collection(): + mock_provider = AsyncMock() + mock_provider.run_model.return_value = json.dumps({"other_field": "value"}) + + with patch("src.flashcards.services.get_flashcard_config"): + with pytest.raises( + AIGenerationError, match="AI response missing 'collection' field" + ): + await _generate_ai_flashcards(mock_provider, "Create flashcards") + + +@pytest.mark.asyncio +async def test_generate_ai_flashcards_empty_cards(): + mock_provider = AsyncMock() + mock_provider.run_model.return_value = json.dumps( + {"collection": {"name": "Empty Collection", "cards": []}} + ) + + with patch("src.flashcards.services.get_flashcard_config"): + with pytest.raises( + AIGenerationError, match="AI generated an empty collection with no cards" + ): + await _generate_ai_flashcards(mock_provider, "Create flashcards") diff --git a/backend/tests/stats/test_api.py b/backend/tests/stats/test_api.py index 6e2da99..ff9ea03 100644 --- a/backend/tests/stats/test_api.py +++ b/backend/tests/stats/test_api.py @@ -1,4 +1,9 @@ +import uuid +from unittest.mock import patch + import pytest +from fastapi.testclient import TestClient +from sqlmodel import Session from src.core.config import settings from src.flashcards.models import Collection @@ -81,3 +86,72 @@ def test_stats_endpoint_unauthorized(client, db, collection_with_sessions): ) assert response.status_code in (401, 403) + + +def test_get_collection_statistics_with_nonexistent_collection( + client: TestClient, db: Session +): + non_existent_collection_id = uuid.uuid4() + user = create_random_user(db) + headers = authentication_token_from_email(client=client, email=user.email, db=db) + response = client.get( + f"{settings.API_V1_STR}/collections/{non_existent_collection_id}/stats", + headers=headers, + ) + + assert response.status_code == 404 + content = response.json() + assert "detail" in content + assert "not found" in content["detail"] + + +def test_get_collection_statistics_with_value_error( + client: TestClient, db: Session, collection_with_sessions +): + with ( + patch("src.stats.api.check_collection_access") as mock_access, + patch("src.stats.api.get_collection_stats") as mock_stats, + ): + user = create_random_user(db) + headers = authentication_token_from_email( + client=client, email=user.email, db=db + ) + collection = collection_with_sessions(user.id, num_cards=5, num_sessions=10) + mock_access.return_value = True + mock_stats.side_effect = ValueError("Error testing") + + response = client.get( + f"{settings.API_V1_STR}/collections/{collection.id}/stats", + headers=headers, + ) + + assert response.status_code == 404 + content = response.json() + assert "detail" in content + assert "Error testing" in content["detail"] + + +def test_get_collection_statistics_with_exception( + client: TestClient, db: Session, collection_with_sessions +): + with ( + patch("src.stats.api.check_collection_access") as mock_access, + patch("src.stats.api.get_collection_stats") as mock_stats, + ): + user = create_random_user(db) + headers = authentication_token_from_email( + client=client, email=user.email, db=db + ) + collection = collection_with_sessions(user.id, num_cards=5, num_sessions=10) + mock_access.return_value = True + mock_stats.side_effect = Exception("Exception testing") + + response = client.get( + f"{settings.API_V1_STR}/collections/{collection.id}/stats", + headers=headers, + ) + + assert response.status_code == 500 + content = response.json() + assert "detail" in content + assert "Error retrieving collection statistics" in content["detail"] diff --git a/backend/tests/stats/test_services.py b/backend/tests/stats/test_services.py new file mode 100644 index 0000000..40eb442 --- /dev/null +++ b/backend/tests/stats/test_services.py @@ -0,0 +1,175 @@ +import uuid +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session + +from src.flashcards.models import Collection +from src.stats.schemas import ( + CardBasicStats, + CollectionBasicInfo, + CollectionStats, + PracticeSessionStats, +) +from src.stats.services import ( + _get_collection_basic_info, + _get_difficult_cards, + _get_recent_sessions, + get_collection_stats, +) +from src.users.schemas import UserCreate +from src.users.services import create_user +from tests.stats.utils import create_cards, create_practice_cards, create_sessions +from tests.utils.utils import random_email, random_lower_string + + +@pytest.fixture +def test_user(db: Session) -> dict[str, Any]: + email = random_email() + password = random_lower_string() + full_name = random_lower_string() + + user_in = UserCreate(email=email, password=password, full_name=full_name) + user = create_user(session=db, user_create=user_in) + + return {"id": user.id, "email": user.email} + + +@pytest.fixture +def collection_with_sessions(db, test_user: dict[str, Any]): + def _create(user_id, num_cards=5, num_sessions=10): + collection = Collection(name="Test Collection", user_id=user_id) + db.add(collection) + db.flush() + cards = create_cards(db, collection, num_cards) + sessions = create_sessions(db, user_id, collection, num_sessions, num_cards) + create_practice_cards(db, sessions, cards) + return collection + + collection_out = _create(user_id=test_user["id"]) + + return collection_out + + +def test_get_collection_basic_info_success( + db: Session, collection_with_sessions: Collection +): + info = _get_collection_basic_info( + session=db, collection_id=collection_with_sessions.id + ) + + assert info is not None + assert info.name == collection_with_sessions.name + assert info.total_cards == len(collection_with_sessions.cards) + assert info.total_practice_sessions == len( + collection_with_sessions.practice_sessions + ) + + +def test_get_collection_basic_info_with_nonexistent_collection(db: Session): + non_existent_collection_id = uuid.uuid4() + with pytest.raises(ValueError) as exc_info: + _get_collection_basic_info(session=db, collection_id=non_existent_collection_id) + + assert f"Collection with id {non_existent_collection_id} not found" in str( + exc_info.value + ) + + +def test_get_recent_sessions(db: Session, collection_with_sessions: Collection): + sessions = _get_recent_sessions( + session=db, collection_id=collection_with_sessions.id + ) + + assert isinstance(sessions, list) + assert len(sessions) == len(collection_with_sessions.practice_sessions) + + +def test_get_difficult_cards(): + mock_session = MagicMock(spec=Session) + mock_collection_id = uuid.uuid4() + + mock_difficult_cards = [ + (uuid.uuid4(), "Difficult Card 1", 7, 2), + (uuid.uuid4(), "Difficult Card 2", 6, 2), + (uuid.uuid4(), "Difficult Card 3", 5, 1), + ] + + mock_exec = mock_session.exec.return_value + mock_exec.all.return_value = mock_difficult_cards + + result = _get_difficult_cards( + session=mock_session, collection_id=mock_collection_id + ) + + assert isinstance(result, list) + assert len(result) == len(mock_difficult_cards) + + for i, card_stat in enumerate(result): + assert card_stat.id == mock_difficult_cards[i][0] + assert card_stat.front == mock_difficult_cards[i][1] + assert card_stat.total_attempts == mock_difficult_cards[i][2] + assert card_stat.correct_answers == mock_difficult_cards[i][3] + + +def test_get_collection_stats(): + mock_session = MagicMock(spec=Session) + mock_collection_id = uuid.uuid4() + + with ( + patch("src.stats.services._get_collection_basic_info") as mock_get_info, + patch("src.stats.services._get_recent_sessions") as mock_get_sessions, + patch("src.stats.services._get_difficult_cards") as mock_get_cards, + ): + mock_get_info.return_value = CollectionBasicInfo( + name="Pytest collection", total_cards=15, total_practice_sessions=7 + ) + mock_get_sessions.return_value = [ + PracticeSessionStats( + id=uuid.uuid4(), + created_at=datetime.now(), + cards_practiced=7, + correct_answers=7, + total_cards=7, + is_completed=True, + ), + PracticeSessionStats( + id=uuid.uuid4(), + created_at=datetime.now(), + cards_practiced=7, + correct_answers=2, + total_cards=7, + is_completed=True, + ), + ] + mock_get_cards.return_value = [ + CardBasicStats( + id=uuid.uuid4(), + front="Difficult Card 1", + total_attempts=7, + correct_answers=2, + ), + CardBasicStats( + id=uuid.uuid4(), + front="Difficult Card 2", + total_attempts=6, + correct_answers=2, + ), + ] + + result = get_collection_stats(mock_session, mock_collection_id) + + assert isinstance(result, CollectionStats) + assert result.collection_info == mock_get_info.return_value + assert result.recent_sessions == mock_get_sessions.return_value + assert result.difficult_cards == mock_get_cards.return_value + + mock_get_info.assert_called_once_with(mock_session, mock_collection_id) + mock_get_sessions.assert_called_once_with( + session=mock_session, collection_id=mock_collection_id, limit=30 + ) + mock_get_cards.assert_called_once_with( + session=mock_session, collection_id=mock_collection_id + ) diff --git a/backend/uv.lock b/backend/uv.lock index c825f6a..9de5613 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -64,6 +64,8 @@ dev = [ { name = "coverage" }, { name = "pre-commit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, { name = "ruff" }, ] @@ -72,7 +74,7 @@ requires-dist = [ { name = "alembic", specifier = ">=1.14.1" }, { name = "bcrypt", specifier = "==4.0.1" }, { name = "email-validator", specifier = ">=2.2.0" }, - { name = "fastapi", extras = ["standard"], specifier = "<1.0.0,>=0.114.2" }, + { name = "fastapi", extras = ["standard"], specifier = ">=0.114.2,<1.0.0" }, { name = "fastapi-pagination", specifier = ">=0.12.34" }, { name = "google-genai", specifier = ">=1.5.0" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, @@ -89,6 +91,8 @@ dev = [ { name = "coverage", specifier = ">=7.8.0" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest-asyncio", specifier = ">=0.26.0" }, + { name = "pytest-cov", specifier = ">=6.1.1" }, { name = "ruff", specifier = ">=0.9.5" }, ] @@ -204,7 +208,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -280,6 +284,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/59/f1/4da7717f0063a222db253e7121bd6a56f6fb1ba439dcc36659088793347c/coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7", size = 203435 }, ] +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11'" }, +] + [[package]] name = "distlib" version = "0.3.9" @@ -337,12 +346,10 @@ wheels = [ [package.optional-dependencies] standard = [ { name = "email-validator" }, - { name = "fastapi-cli" }, { name = "fastapi-cli", extra = ["standard"] }, { name = "httpx" }, { name = "jinja2" }, { name = "python-multipart" }, - { name = "uvicorn" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -362,7 +369,6 @@ wheels = [ [package.optional-dependencies] standard = [ - { name = "uvicorn" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -962,6 +968,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/c4/453c52c659521066969523e87d85d54139bbd17b78f09532fb8eb8cdb58e/pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f", size = 54156 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 }, +] + +[[package]] +name = "pytest-cov" +version = "6.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/69/5f1e57f6c5a39f81411b550027bf72842c4567ff5fd572bed1edc9e4b5d9/pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a", size = 66857 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/d0/def53b4a790cfb21483016430ed828f64830dd981ebe1089971cd10cab25/pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde", size = 23841 }, +] + [[package]] name = "python-dotenv" version = "1.0.1"