From 3718933f96af79eb3cd3263260e81520fdfbf224 Mon Sep 17 00:00:00 2001 From: Fantasy lee <129943055+Fantasylee21@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:12:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[chore]:=20=E5=B0=86=E5=90=8E=E7=AB=AF?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=8D=A2=E4=B8=BA=E5=BC=82=E6=AD=A5=E7=A8=8B?= =?UTF-8?q?=E5=BA=8F=EF=BC=8C=E6=8F=90=E9=AB=98=E6=9C=8D=E5=8A=A1=E6=95=88?= =?UTF-8?q?=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/v1/endpoints/auth.py | 58 +++++++++++++++++++----------------- app/api/v1/endpoints/note.py | 18 +++++------ app/core/config.py | 3 +- app/curd/note.py | 53 +++++++++++++++++++------------- app/curd/user.py | 15 ++++++---- app/db/session.py | 14 ++++++++- app/utils/auth.py | 9 ++++-- app/utils/get_db.py | 11 +++---- 8 files changed, 105 insertions(+), 76 deletions(-) diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py index f14cb56..cf19bac 100644 --- a/app/api/v1/endpoints/auth.py +++ b/app/api/v1/endpoints/auth.py @@ -1,9 +1,9 @@ from fastapi import APIRouter, HTTPException, Depends -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from passlib.context import CryptContext from datetime import datetime, timedelta import jwt -import smtplib +import aiosmtplib from email.mime.text import MIMEText from email.header import Header import random @@ -18,7 +18,7 @@ router = APIRouter() -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 使用 bcrypt 加密算法 +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 使用 bcrypt 加密算法 SECRET_KEY = settings.SECRET_KEY ALGORITHM = settings.ALGORITHM ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES @@ -26,7 +26,7 @@ # 配置 Redis 连接 redis_client = get_redis_client() -def create_access_token(data: dict, expires_delta: timedelta = None): +async def create_access_token(data: dict, expires_delta: timedelta = None): to_encode = data.copy() if expires_delta: expire = datetime.now() + expires_delta @@ -37,36 +37,35 @@ def create_access_token(data: dict, expires_delta: timedelta = None): return encoded_jwt @router.post("/register", response_model=dict) -def register(user: UserCreate, db: Session = Depends(get_db)): - existing_user = get_user_by_email(db, user.email) - if (redis_client.exists(f"email:{user.email}:code")): - code = redis_client.get(f"email:{user.email}:code").decode("utf-8") - if (user.code != code): +async def register(user: UserCreate, db: AsyncSession = Depends(get_db)): + existing_user = await get_user_by_email(db, user.email) + if redis_client.exists(f"email:{user.email}:code"): + code = redis_client.get(f"email:{user.email}:code") + if user.code != code: raise HTTPException(status_code=400, detail="Invalid verification code") else: raise HTTPException(status_code=400, detail="Verification code expired or not sent") - - if (existing_user): + + if existing_user: raise HTTPException(status_code=400, detail="Email already registered") hashed_password = pwd_context.hash(user.password) - create_user(db, user.email, user.username, hashed_password) + await create_user(db, user.email, user.username, hashed_password) return {"msg": "User registered successfully"} @router.post("/login", response_model=dict) -def login(user: UserLogin, db: Session = Depends(get_db)): - db_user = get_user_by_email(db, user.email) +async def login(user: UserLogin, db: AsyncSession = Depends(get_db)): + db_user = await get_user_by_email(db, user.email) if not db_user or not pwd_context.verify(user.password, db_user.password): raise HTTPException(status_code=401, detail="Invalid email or password") access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( + access_token = await create_access_token( data={"sub": db_user.email}, expires_delta=access_token_expires ) - return {"access_token": access_token, "token_type": "bearer", } + return {"access_token": access_token, "token_type": "bearer"} # 发送验证码 @router.post("/send_code", response_model=dict) -def send_code(user_send_code : UserSendCode, db: Session = Depends(get_db)): - # 检查 Redis 中是否存在该邮箱的发送记录 +async def send_code(user_send_code: UserSendCode): if redis_client.exists(f"email:{user_send_code.email}:time"): raise HTTPException(status_code=429, detail="You can only request a verification code once every 5 minutes.") @@ -85,23 +84,26 @@ def send_code(user_send_code : UserSendCode, db: Session = Depends(get_db)): # 创建MIMEText对象时需要显式指定子类型和编码 message = MIMEText(_text=body, _subtype='plain', _charset='utf-8') - message["From"] = formataddr(("JieNote团队", "noreply@jienote.com")) + message["From"] = formataddr(("JieNote团队", "jienote_buaa@163.com")) message["To"] = user_send_code.email message["Subject"] = Header(subject, 'utf-8').encode() # 添加必要的内容传输编码头 message.add_header('Content-Transfer-Encoding', 'base64') try: - # 连接 SMTP 服务器并发送邮件 - with smtplib.SMTP_SSL(smtp_server, smtp_port) as server: - server.login(sender_email, sender_password) - server.sendmail(sender_email, [user_send_code.email], message.as_string()) - - # 将验证码和发送时间存储到 Redis,设置 5 分钟过期时间 - redis_client.setex(f"email:{user_send_code.email}:code", ACCESS_TOKEN_EXPIRE_MINUTES, code) - redis_client.setex(f"email:{user_send_code.email}:time", ACCESS_TOKEN_EXPIRE_MINUTES, int(time.time())) + await aiosmtplib.send( + message, + hostname=smtp_server, + port=smtp_port, + username=sender_email, + password=sender_password, + use_tls=True, + ) + + redis_client.setex(f"email:{user_send_code.email}:code", ACCESS_TOKEN_EXPIRE_MINUTES * 60, code) + redis_client.setex(f"email:{user_send_code.email}:time", ACCESS_TOKEN_EXPIRE_MINUTES * 60, int(time.time())) return {"msg": "Verification code sent"} - except smtplib.SMTPException as e: + except aiosmtplib.SMTPException as e: raise HTTPException(status_code=500, detail=f"Failed to send email: {str(e)}") \ No newline at end of file diff --git a/app/api/v1/endpoints/note.py b/app/api/v1/endpoints/note.py index f73975a..2b5a6b3 100644 --- a/app/api/v1/endpoints/note.py +++ b/app/api/v1/endpoints/note.py @@ -1,5 +1,5 @@ from fastapi import APIRouter, HTTPException, Depends -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.schemas.note import NoteCreate, NoteUpdate, NoteFind from app.utils.get_db import get_db from app.curd.note import create_note_in_db, delete_note_in_db, update_note_in_db, find_notes_in_db @@ -7,28 +7,28 @@ router = APIRouter() @router.post("", response_model=dict) -def create_note(note: NoteCreate, db: Session = Depends(get_db)): - new_note = create_note_in_db(note, db) +async def create_note(note: NoteCreate, db: AsyncSession = Depends(get_db)): + new_note = await create_note_in_db(note, db) return {"msg": "Note created successfully", "note_id": new_note.id} @router.delete("/{note_id}", response_model=dict) -def delete_note(note_id: int, db: Session = Depends(get_db)): - note = delete_note_in_db(note_id, db) +async def delete_note(note_id: int, db: AsyncSession = Depends(get_db)): + note = await delete_note_in_db(note_id, db) if not note: raise HTTPException(status_code=404, detail="Note not found") return {"msg": "Note deleted successfully"} @router.put("/{note_id}", response_model=dict) -def update_note(note_id: int, content: str, db: Session = Depends(get_db)): +async def update_note(note_id: int, content: str, db: AsyncSession = Depends(get_db)): note = NoteUpdate(id=note_id, content=content) - updated_note = update_note_in_db(note_id, note, db) + updated_note = await update_note_in_db(note_id, note, db) if not updated_note: raise HTTPException(status_code=404, detail="Note not found") return {"msg": "Note updated successfully", "note_id": updated_note.id} @router.get("", response_model=dict) -def get_notes(note_find: NoteFind = Depends(), db: Session = Depends(get_db)): - notes, total_count = find_notes_in_db(note_find, db) +async def get_notes(note_find: NoteFind = Depends(), db: AsyncSession = Depends(get_db)): + notes, total_count = await find_notes_in_db(note_find, db) return { "pagination": { "total_count": total_count, diff --git a/app/core/config.py b/app/core/config.py index 6163c9b..4365eeb 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -1,5 +1,4 @@ import os -from datetime import timedelta from dotenv import load_dotenv load_dotenv() @@ -7,7 +6,7 @@ class Settings: PROJECT_NAME: str = "JieNote Backend" # 项目名称 VERSION: str = "1.0.0" # 项目版本 - SQLALCHEMY_DATABASE_URL = "mysql+pymysql://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称 + SQLALCHEMY_DATABASE_URL = "mysql+asyncmy://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称 SECRET_KEY: str = os.getenv("SECRET_KEY", "default_secret_key") # JWT密钥 ALGORITHM: str = "HS256" # JWT算法 ACCESS_TOKEN_EXPIRE_MINUTES: int = 300 # token过期时间 diff --git a/app/curd/note.py b/app/curd/note.py index 71072a0..24cf806 100644 --- a/app/curd/note.py +++ b/app/curd/note.py @@ -1,41 +1,52 @@ -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import delete, func from app.models.model import Note from app.schemas.note import NoteCreate, NoteUpdate, NoteFind, NoteResponse -def create_note_in_db(note: NoteCreate, db: Session): +async def create_note_in_db(note: NoteCreate, db: AsyncSession): new_note = Note(content=note.content, article_id=note.article_id) db.add(new_note) - db.commit() - db.refresh(new_note) + await db.commit() + await db.refresh(new_note) return new_note -def delete_note_in_db(note_id: int, db: Session): - note = db.query(Note).filter(Note.id == note_id).first() +async def delete_note_in_db(note_id: int, db: AsyncSession): + stmt = select(Note).where(Note.id == note_id) + result = await db.execute(stmt) + note = result.scalar_one_or_none() if note: - db.delete(note) - db.commit() + delete_stmt = delete(Note).where(Note.id == note_id) + await db.execute(delete_stmt) + await db.commit() return note -def update_note_in_db(note_id: int, note: NoteUpdate, db: Session): - existing_note = db.query(Note).filter(Note.id == note_id).first() +async def update_note_in_db(note_id: int, note: NoteUpdate, db: AsyncSession): + stmt = select(Note).where(Note.id == note_id) + result = await db.execute(stmt) + existing_note = result.scalar_one_or_none() if existing_note: existing_note.content = note.content - db.commit() - db.refresh(existing_note) + await db.commit() + await db.refresh(existing_note) return existing_note -def find_notes_in_db(note_find: NoteFind, db: Session): - query = db.query(Note) +async def find_notes_in_db(note_find: NoteFind, db: AsyncSession): + stmt = select(Note) if note_find.id is not None: - query = query.filter(Note.id == note_find.id) + stmt = stmt.where(Note.id == note_find.id) elif note_find.article_id is not None: - query = query.filter(Note.article_id == note_find.article_id) + stmt = stmt.where(Note.article_id == note_find.article_id) + + total_count_stmt = select(func.count()).select_from(stmt) + total_count_result = await db.execute(total_count_stmt) + total_count = total_count_result.scalar() - totol_count = query.count() - # 添加分页逻辑 if note_find.page is not None and note_find.page_size is not None: offset = (note_find.page - 1) * note_find.page_size - query = query.offset(offset).limit(note_find.page_size) - notes = [NoteResponse.model_validate(note) for note in query.all()] - return notes, totol_count + stmt = stmt.offset(offset).limit(note_find.page_size) + + result = await db.execute(stmt) + notes = [NoteResponse.model_validate(note) for note in result.scalars().all()] + return notes, total_count diff --git a/app/curd/user.py b/app/curd/user.py index e0393ff..ca63bfb 100644 --- a/app/curd/user.py +++ b/app/curd/user.py @@ -1,12 +1,15 @@ -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select from app.models.model import User -def get_user_by_email(db: Session, email: str): - return db.query(User).filter(User.email == email).first() +async def get_user_by_email(db: AsyncSession, email: str): + stmt = select(User).where(User.email == email) + result = await db.execute(stmt) + return result.scalar_one_or_none() -def create_user(db: Session, email: str, username: str,hashed_password: str): +async def create_user(db: AsyncSession, email: str, username: str, hashed_password: str): new_user = User(email=email, username=username, password=hashed_password, avatar="app/static/avatar/default.jpg") db.add(new_user) - db.commit() - db.refresh(new_user) + await db.commit() + await db.refresh(new_user) return new_user \ No newline at end of file diff --git a/app/db/session.py b/app/db/session.py index 7c9648b..68701ee 100644 --- a/app/db/session.py +++ b/app/db/session.py @@ -1,7 +1,19 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from app.core.config import settings engine = create_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) #连接mysql -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) \ No newline at end of file +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建异步引擎 +async_engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + +# 创建异步会话 +async_session = sessionmaker( + bind=async_engine, + class_=AsyncSession, + autocommit=False, + autoflush=False +) \ No newline at end of file diff --git a/app/utils/auth.py b/app/utils/auth.py index 1bd6667..72528cf 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -2,12 +2,17 @@ from jwt import PyJWTError, decode from app.core.config import settings from fastapi import Depends, HTTPException +import asyncio + # 配置 OAuth2PasswordBearer oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") -def get_current_user(token: str = Depends(oauth2_scheme)): +async def get_current_user(token: str = Depends(oauth2_scheme)): try: - payload = decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + loop = asyncio.get_event_loop() + payload = await loop.run_in_executor( + None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM] + ) email: str = payload.get("sub") if email is None: raise HTTPException( diff --git a/app/utils/get_db.py b/app/utils/get_db.py index 2b23243..2165389 100644 --- a/app/utils/get_db.py +++ b/app/utils/get_db.py @@ -1,8 +1,5 @@ -from app.db.session import SessionLocal +from app.db.session import async_session -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() \ No newline at end of file +async def get_db(): + async with async_session() as db: + yield db \ No newline at end of file From f4e9138fa48fc55974150fdcc8144b8a519c2368 Mon Sep 17 00:00:00 2001 From: Fantasy lee <129943055+Fantasylee21@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:13:26 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[docs]:=20=E6=9B=B4=E6=96=B0requirements.tx?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 6ac98c4..078842d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ +aiosmtplib==4.0.0 alembic==1.15.2 annotated-types==0.7.0 anyio==4.9.0 +async-timeout==5.0.1 +asyncmy==0.2.10 bcrypt==4.3.0 cffi==1.17.1 click==8.1.8 @@ -10,6 +13,7 @@ dnspython==2.7.0 dotenv==0.9.9 email_validator==2.2.0 fastapi==0.115.12 +fastapi-pagination==0.12.34 greenlet==3.1.1 h11==0.14.0 idna==3.10