diff --git a/app/api/v1/endpoints/article.py b/app/api/v1/endpoints/article.py index 9cdb1dc..c1897d9 100644 --- a/app/api/v1/endpoints/article.py +++ b/app/api/v1/endpoints/article.py @@ -3,11 +3,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from typing import Optional import os +import io +from zipfile import ZipFile -from utils.get_db import get_db -from utils.auth import get_current_user -from curd.user import get_user_by_email -from curd.article import crud_upload_to_self_folder, crud_get_self_folders, crud_get_articles_in_folder, crud_self_create_folder, crud_self_article_to_recycle_bin, crud_self_folder_to_recycle_bin, crud_read_article +from app.utils.get_db import get_db +from app.utils.auth import get_current_user +from app.curd.article import crud_upload_to_self_folder, crud_get_self_folders, crud_get_articles_in_folder, crud_self_create_folder, crud_self_article_to_recycle_bin, crud_self_folder_to_recycle_bin, crud_read_article, crud_import_self_folder router = APIRouter() @@ -36,16 +37,8 @@ async def upload_to_self_folder(folder_id: int = Query(...), article: UploadFile @router.get("/getSelfFolders", response_model="dict") async def get_self_folders(page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): - # 获取用户邮箱 - user_email = user.get("email") - if not user_email: - raise HTTPException(status_code=401, detail="Unauthorized") - - # 由邮箱查得id - db_user = await get_user_by_email(db, user_email) - if not db_user: - raise HTTPException(status_code=404, detail="User not found") - user_id = db_user.id + # 获取用户id + user_id = user.get("id") # 数据库查询 folders = await crud_get_self_folders(user_id, page_number, page_size, db) @@ -66,16 +59,8 @@ async def self_create_folder(folder_name: str = Body(...), db: AsyncSession = De if folder_name == "": raise HTTPException(status_code=405, detail="Empty Folder Name") - # 获取用户邮箱 - user_email = user.get("email") - if not user_email: - raise HTTPException(status_code=401, detail="Unauthorized") - - # 由邮箱查得id - db_user = await get_user_by_email(db, user_email) - if not db_user: - raise HTTPException(status_code=404, detail="User not found") - user_id = db_user.id + # 获取用户id + user_id = user.get("id") # 数据库插入 await crud_self_create_folder(folder_name, user_id, db) @@ -83,18 +68,10 @@ async def self_create_folder(folder_name: str = Body(...), db: AsyncSession = De # 返回结果 return {"msg": "User Folder Created Successfully"} -@router.delete("/selfArticleToRecycleBin", resplonse_model="dict") +@router.delete("/selfArticleToRecycleBin", response_model="dict") async def self_article_to_recycle_bin(article_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): - # 获取用户邮箱 - user_email = user.get("email") - if not user_email: - raise HTTPException(status_code=401, detail="Unauthorized") - - # 由邮箱查得id - db_user = await get_user_by_email(db, user_email) - if not db_user: - raise HTTPException(status_code=404, detail="User not found") - user_id = db_user.id + # 获取用户id + user_id = user.get("id") # 数据库修改 await crud_self_article_to_recycle_bin(article_id, user_id, db) @@ -104,16 +81,8 @@ async def self_article_to_recycle_bin(article_id: int = Query(...), db: AsyncSes @router.delete("/selfFolderToRecycleBin", response_model="dict") async def self_folder_to_recycle_bin(folder_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): - # 获取用户邮箱 - user_email = user.get("email") - if not user_email: - raise HTTPException(status_code=401, detail="Unauthorized") - - # 由邮箱查得id - db_user = await get_user_by_email(db, user_email) - if not db_user: - raise HTTPException(status_code=404, detail="User not found") - user_id = db_user.id + # 获取用户id + user_id = user.get("id") # 数据库修改 await crud_self_folder_to_recycle_bin(folder_id, user_id, db) @@ -135,16 +104,8 @@ async def annotate_self_article(article_id: int = Query(...), article: UploadFil @router.get("/readArticle", response_class=FileResponse) async def read_article(article_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): - # 获取用户邮箱 - user_email = user.get("email") - if not user_email: - raise HTTPException(status_code=401, detail="Unauthorized") - - # 由邮箱查得id - db_user = await get_user_by_email(db, user_email) - if not db_user: - raise HTTPException(status_code=404, detail="User not found") - user_id = db_user.id + # 获取用户id + user_id = user.get("id") # 文件路径 file_path = f"articles/{article_id}.pdf" @@ -155,4 +116,30 @@ async def read_article(article_id: int = Query(...), db: AsyncSession = Depends( article_name = await crud_read_article(article_id, user_id, db) # 返回结果 - return FileResponse(path=file_path, filename=f"{article_name}.pdf", media_type='application/pdf') \ No newline at end of file + return FileResponse(path=file_path, filename=f"{article_name}.pdf", media_type='application/pdf') + +@router.post("/importSelfFolder", response_model="dict") +async def import_self_folder(folder_name: str = Query(...), zip: UploadFile = File(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)): + # 获取用户id + user_id = user.get("id") + + # 获取压缩包中的所有文献名(去掉.pdf) + zip_bytes = await zip.read() + zip_file = ZipFile(io.BytesIO(zip_bytes)) + article_names = [os.path.splitext(os.path.basename(name))[0] for name in zip_file.namelist() if name.endswith('.pdf')] + + # 记入数据库 + result = await crud_import_self_folder(folder_name, article_names, user_id, db) + + # 存储文献,暂时存储到本地 + os.makedirs("articles", exist_ok=True) + for i in range(0, len(result), 2): + article_id = result[i] + article_name = result[i + 1] + pdf_filename_in_zip = f"{article_name}.pdf" + with zip_file.open(pdf_filename_in_zip) as source_file: + target_path = os.path.join("articles", f"{article_id}.pdf") + with open(target_path, "wb") as out_file: + out_file.write(source_file.read()) + + return {"msg": "Succesfully import articles"} \ No newline at end of file diff --git a/app/curd/article.py b/app/curd/article.py index 987a456..74055be 100644 --- a/app/curd/article.py +++ b/app/curd/article.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy.future import select -from models.model import User, Group, Folder, Article, Note, Tag, user_group +from app.models.model import User, Group, Folder, Article, Note, Tag, user_group async def crud_upload_to_self_folder(name: str, folder_id: int, db: AsyncSession): new_article = Article(name=name, folder_id=folder_id) @@ -97,4 +97,24 @@ async def crud_read_article(article_id: int, user_id: int, db: AsyncSession): if not relation: raise HTTPException(status_code=403, detail="This is an article of a group which you don't belong to") - return article.name \ No newline at end of file + return article.name + +async def crud_import_self_folder(folder_name: str, article_names, user_id: int, db: AsyncSession): + result = [] + + # 新建文件夹 + new_folder = Folder(name=folder_name, user_id=user_id) + db.add(new_folder) + await db.commit() + await db.refresh(new_folder) + + # 新建文献 + for article_name in article_names: + new_article = Article(name=article_name, folder_id=new_folder.id) + db.add(new_article) + await db.commit() + await db.refresh(new_article) + result.append(new_article.id) + result.append(new_article.name) + + return result \ No newline at end of file diff --git a/app/utils/auth.py b/app/utils/auth.py index f17f09f..b79818e 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -1,26 +1,22 @@ from fastapi.security import OAuth2PasswordBearer -from jwt import PyJWTError, decode -from app.core.config import settings from fastapi import Depends, HTTPException -import asyncio +from jose import JWTError, jwt # 用 jose 替代 jwt +from app.core.config import settings # 配置 OAuth2PasswordBearer oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") async def get_current_user(token: str = Depends(oauth2_scheme)): try: - loop = asyncio.get_event_loop() - payload = await loop.run_in_executor( - None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM] - ) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) email: str = payload.get("sub") - user_id: int = payload.get("id") # 从 payload 中提取用户 ID + user_id: int = payload.get("id") if email is None or user_id is None: raise HTTPException( status_code=401, detail="Invalid authentication credentials" ) - return {"email": email, "id": user_id} # 返回用户 ID 和 email - except PyJWTError: + return {"email": email, "id": user_id} + except JWTError: raise HTTPException( status_code=401, detail="Invalid authentication credentials" ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 54537e3..71a2482 100644 Binary files a/requirements.txt and b/requirements.txt differ