Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 43 additions & 56 deletions app/api/v1/endpoints/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
import os
import io
from zipfile import ZipFile

from utils.get_db import get_db
from utils.auth import get_current_user
from curd.user import get_user_by_email
from curd.article import crud_upload_to_self_folder, crud_get_self_folders, crud_get_articles_in_folder, crud_self_create_folder, crud_self_article_to_recycle_bin, crud_self_folder_to_recycle_bin, crud_read_article
from app.utils.get_db import get_db
from app.utils.auth import get_current_user
from app.curd.article import crud_upload_to_self_folder, crud_get_self_folders, crud_get_articles_in_folder, crud_self_create_folder, crud_self_article_to_recycle_bin, crud_self_folder_to_recycle_bin, crud_read_article, crud_import_self_folder

router = APIRouter()

Expand Down Expand Up @@ -36,16 +37,8 @@ async def upload_to_self_folder(folder_id: int = Query(...), article: UploadFile
@router.get("/getSelfFolders", response_model="dict")
async def get_self_folders(page_number: Optional[int] = Query(None, ge=1), page_size: Optional[int] = Query(None, ge=1),
db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
# 获取用户邮箱
user_email = user.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Unauthorized")

# 由邮箱查得id
db_user = await get_user_by_email(db, user_email)
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
user_id = db_user.id
# 获取用户id
user_id = user.get("id")

# 数据库查询
folders = await crud_get_self_folders(user_id, page_number, page_size, db)
Expand All @@ -66,35 +59,19 @@ async def self_create_folder(folder_name: str = Body(...), db: AsyncSession = De
if folder_name == "":
raise HTTPException(status_code=405, detail="Empty Folder Name")

# 获取用户邮箱
user_email = user.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Unauthorized")

# 由邮箱查得id
db_user = await get_user_by_email(db, user_email)
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
user_id = db_user.id
# 获取用户id
user_id = user.get("id")

# 数据库插入
await crud_self_create_folder(folder_name, user_id, db)

# 返回结果
return {"msg": "User Folder Created Successfully"}

@router.delete("/selfArticleToRecycleBin", resplonse_model="dict")
@router.delete("/selfArticleToRecycleBin", response_model="dict")
async def self_article_to_recycle_bin(article_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
# 获取用户邮箱
user_email = user.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Unauthorized")

# 由邮箱查得id
db_user = await get_user_by_email(db, user_email)
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
user_id = db_user.id
# 获取用户id
user_id = user.get("id")

# 数据库修改
await crud_self_article_to_recycle_bin(article_id, user_id, db)
Expand All @@ -104,16 +81,8 @@ async def self_article_to_recycle_bin(article_id: int = Query(...), db: AsyncSes

@router.delete("/selfFolderToRecycleBin", response_model="dict")
async def self_folder_to_recycle_bin(folder_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
# 获取用户邮箱
user_email = user.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Unauthorized")

# 由邮箱查得id
db_user = await get_user_by_email(db, user_email)
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
user_id = db_user.id
# 获取用户id
user_id = user.get("id")

# 数据库修改
await crud_self_folder_to_recycle_bin(folder_id, user_id, db)
Expand All @@ -135,16 +104,8 @@ async def annotate_self_article(article_id: int = Query(...), article: UploadFil

@router.get("/readArticle", response_class=FileResponse)
async def read_article(article_id: int = Query(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
# 获取用户邮箱
user_email = user.get("email")
if not user_email:
raise HTTPException(status_code=401, detail="Unauthorized")

# 由邮箱查得id
db_user = await get_user_by_email(db, user_email)
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
user_id = db_user.id
# 获取用户id
user_id = user.get("id")

# 文件路径
file_path = f"articles/{article_id}.pdf"
Expand All @@ -155,4 +116,30 @@ async def read_article(article_id: int = Query(...), db: AsyncSession = Depends(
article_name = await crud_read_article(article_id, user_id, db)

# 返回结果
return FileResponse(path=file_path, filename=f"{article_name}.pdf", media_type='application/pdf')
return FileResponse(path=file_path, filename=f"{article_name}.pdf", media_type='application/pdf')

@router.post("/importSelfFolder", response_model="dict")
async def import_self_folder(folder_name: str = Query(...), zip: UploadFile = File(...), db: AsyncSession = Depends(get_db), user: dict = Depends(get_current_user)):
# 获取用户id
user_id = user.get("id")

# 获取压缩包中的所有文献名(去掉.pdf)
zip_bytes = await zip.read()
zip_file = ZipFile(io.BytesIO(zip_bytes))
article_names = [os.path.splitext(os.path.basename(name))[0] for name in zip_file.namelist() if name.endswith('.pdf')]

# 记入数据库
result = await crud_import_self_folder(folder_name, article_names, user_id, db)

# 存储文献,暂时存储到本地
os.makedirs("articles", exist_ok=True)
for i in range(0, len(result), 2):
article_id = result[i]
article_name = result[i + 1]
pdf_filename_in_zip = f"{article_name}.pdf"
with zip_file.open(pdf_filename_in_zip) as source_file:
target_path = os.path.join("articles", f"{article_id}.pdf")
with open(target_path, "wb") as out_file:
out_file.write(source_file.read())

return {"msg": "Succesfully import articles"}
24 changes: 22 additions & 2 deletions app/curd/article.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.future import select
from models.model import User, Group, Folder, Article, Note, Tag, user_group
from app.models.model import User, Group, Folder, Article, Note, Tag, user_group

async def crud_upload_to_self_folder(name: str, folder_id: int, db: AsyncSession):
new_article = Article(name=name, folder_id=folder_id)
Expand Down Expand Up @@ -97,4 +97,24 @@ async def crud_read_article(article_id: int, user_id: int, db: AsyncSession):
if not relation:
raise HTTPException(status_code=403, detail="This is an article of a group which you don't belong to")

return article.name
return article.name

async def crud_import_self_folder(folder_name: str, article_names, user_id: int, db: AsyncSession):
result = []

# 新建文件夹
new_folder = Folder(name=folder_name, user_id=user_id)
db.add(new_folder)
await db.commit()
await db.refresh(new_folder)

# 新建文献
for article_name in article_names:
new_article = Article(name=article_name, folder_id=new_folder.id)
db.add(new_article)
await db.commit()
await db.refresh(new_article)
result.append(new_article.id)
result.append(new_article.name)

return result
16 changes: 6 additions & 10 deletions app/utils/auth.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from fastapi.security import OAuth2PasswordBearer
from jwt import PyJWTError, decode
from app.core.config import settings
from fastapi import Depends, HTTPException
import asyncio
from jose import JWTError, jwt # 用 jose 替代 jwt
from app.core.config import settings

# 配置 OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")

async def get_current_user(token: str = Depends(oauth2_scheme)):
try:
loop = asyncio.get_event_loop()
payload = await loop.run_in_executor(
None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM]
)
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
email: str = payload.get("sub")
user_id: int = payload.get("id") # 从 payload 中提取用户 ID
user_id: int = payload.get("id")
if email is None or user_id is None:
raise HTTPException(
status_code=401, detail="Invalid authentication credentials"
)
return {"email": email, "id": user_id} # 返回用户 ID 和 email
except PyJWTError:
return {"email": email, "id": user_id}
except JWTError:
raise HTTPException(
status_code=401, detail="Invalid authentication credentials"
)
Binary file modified requirements.txt
Binary file not shown.