diff --git a/README.md b/README.md index 947fc58..f3395b5 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,10 @@ Attention!!! - Must connect Redis before running the application. ‼️‼️‼️ +## Token Authentication +- JWT (JSON Web Token) is used for authentication. +- Refresh tokens for 7 days and access tokens for 5min. + ## Folder Structure - `app/`: Contains the main application code. - `tests/`: Contains test cases. diff --git "a/alembic/versions/698cace619a6_\346\267\273\345\212\240\347\272\246\346\235\237.py" "b/alembic/versions/698cace619a6_\346\267\273\345\212\240\347\272\246\346\235\237.py" new file mode 100644 index 0000000..8b2d594 --- /dev/null +++ "b/alembic/versions/698cace619a6_\346\267\273\345\212\240\347\272\246\346\235\237.py" @@ -0,0 +1,32 @@ +"""添加约束 + +Revision ID: 698cace619a6 +Revises: f1242bbcad2d +Create Date: 2025-04-18 17:18:55.547867 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '698cace619a6' +down_revision: Union[str, None] = 'f1242bbcad2d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/app/api/v1/endpoints/auth.py b/app/api/v1/endpoints/auth.py index cf19bac..7060608 100644 --- a/app/api/v1/endpoints/auth.py +++ b/app/api/v1/endpoints/auth.py @@ -10,7 +10,7 @@ import time from email.utils import formataddr -from app.schemas.auth import UserCreate, UserLogin, UserSendCode +from app.schemas.auth import UserCreate, UserLogin, UserSendCode, ReFreshToken from app.core.config import settings from app.curd.user import get_user_by_email, create_user from app.utils.get_db import get_db @@ -31,11 +31,21 @@ async def create_access_token(data: dict, expires_delta: timedelta = None): if expires_delta: expire = datetime.now() + expires_delta else: - expire = datetime.now() + timedelta(minutes=15) + expire = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt +async def create_refresh_token(data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.now() + expires_delta + else: + expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + to_encode.update({"exp": expire, "type": "refresh"}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + @router.post("/register", response_model=dict) async def register(user: UserCreate, db: AsyncSession = Depends(get_db)): existing_user = await get_user_by_email(db, user.email) @@ -57,11 +67,39 @@ async def login(user: UserLogin, db: AsyncSession = Depends(get_db)): db_user = await get_user_by_email(db, user.email) if not db_user or not pwd_context.verify(user.password, db_user.password): raise HTTPException(status_code=401, detail="Invalid email or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_token_expires = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) access_token = await create_access_token( - data={"sub": db_user.email}, expires_delta=access_token_expires + data={"sub": db_user.email, "id": db_user.id}, expires_delta=access_token_expires + ) + refresh_token = await create_refresh_token( + data={"sub": db_user.email, "id": db_user.id}, expires_delta=refresh_token_expires ) - return {"access_token": access_token, "token_type": "bearer"} + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + "user_id": db_user.id, + "email": db_user.email, + "username": db_user.username, + "avatar": db_user.avatar + } + +@router.post("/refresh", response_model=dict) +async def refresh_token(refresh_token: ReFreshToken): + try: + payload = jwt.decode(refresh_token.refresh_token, SECRET_KEY, algorithms=[ALGORITHM]) + if payload.get("type") != "refresh": + raise HTTPException(status_code=401, detail="Invalid refresh token type") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = await create_access_token( + data={"sub": payload["sub"], "id": payload["id"]}, expires_delta=access_token_expires + ) + return {"access_token": access_token} + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=401, detail="Refresh token expired") + except Exception: + raise HTTPException(status_code=401, detail="Invalid refresh token") # 发送验证码 @router.post("/send_code", response_model=dict) diff --git a/app/api/v1/endpoints/user.py b/app/api/v1/endpoints/user.py new file mode 100644 index 0000000..d47e584 --- /dev/null +++ b/app/api/v1/endpoints/user.py @@ -0,0 +1,75 @@ +from fastapi import APIRouter, HTTPException, Depends, UploadFile, Form, File +from app.schemas.user import UserUpdate, PasswordUpdate +from app.curd.user import update_user_in_db, get_user_by_email, update_user_password +from sqlalchemy.ext.asyncio import AsyncSession +from app.utils.get_db import get_db +from app.utils.auth import get_current_user +from passlib.context import CryptContext +import os +from uuid import uuid4 + +router = APIRouter() + +# update current user +@router.put("", response_model=dict) +async def update_current_user( + username: str = Form(None), + avatar: UploadFile = File(None), + db: AsyncSession = Depends(get_db), + current_user: dict = Depends(get_current_user) +): + """ + Update the current user's information. + """ + db_user = await get_user_by_email(db, current_user["email"]) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + try: + avatar_url = None + if avatar: + avatar_file: UploadFile = avatar + file_extension = os.path.splitext(avatar_file.filename)[1] + unique_filename = f"{uuid4()}{file_extension}" + avatar_dir = os.path.join("app", "static", "avatar") + avatar_path = os.path.join(avatar_dir, unique_filename) + + # 确保以二进制模式写入文件,避免编码问题 + with open(avatar_path, "wb") as f: + f.write(await avatar_file.read()) + + # 生成 URL 路径 + avatar_url = f"/app/static/avatar/{unique_filename}" + + # 删除旧的头像文件 + if db_user.avatar and db_user.avatar != "/app/static/avatar/default.png": + old_avatar_path = db_user.avatar.lstrip("/") # 去掉开头的斜杠 + if os.path.exists(old_avatar_path): + os.remove(old_avatar_path) + + + update_user_response = UserUpdate( + username=username or db_user.username, + avatar=avatar_url if avatar_url else db_user.avatar + ) + await update_user_in_db(db, update_user_response, db_user.id) + return {"msg": "User updated successfully"} + + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +@router.post("/password", response_model=dict) +async def change_password( + password_update: PasswordUpdate, + db: AsyncSession = Depends(get_db), + current_user: dict = Depends(get_current_user) +): + db_user = await get_user_by_email(db, current_user["email"]) + if not db_user: + raise HTTPException(status_code=404, detail="User not found") + if not pwd_context.verify(password_update.old_password, db_user.password): + raise HTTPException(status_code=400, detail="Old password is incorrect") + + await update_user_password(db, db_user.id, pwd_context.hash(password_update.new_password)) + return {"msg": "Password changed successfully"} \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index 4365eeb..d89dbe7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -9,7 +9,8 @@ class Settings: SQLALCHEMY_DATABASE_URL = "mysql+asyncmy://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称 SECRET_KEY: str = os.getenv("SECRET_KEY", "default_secret_key") # JWT密钥 ALGORITHM: str = "HS256" # JWT算法 - ACCESS_TOKEN_EXPIRE_MINUTES: int = 300 # token过期时间 + ACCESS_TOKEN_EXPIRE_MINUTES: int = 5 # token过期时间 + REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 刷新token过期时间7天 SMTP_SERVER: str = "smtp.163.com" # SMTP服务器 SMTP_PORT: int = 465 # SMTP端口 SENDER_EMAIL : str = "jienote_buaa@163.com" diff --git a/app/curd/user.py b/app/curd/user.py index ca63bfb..7a332af 100644 --- a/app/curd/user.py +++ b/app/curd/user.py @@ -1,6 +1,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.models.model import User +from app.schemas.user import UserUpdate async def get_user_by_email(db: AsyncSession, email: str): stmt = select(User).where(User.email == email) @@ -8,8 +9,30 @@ async def get_user_by_email(db: AsyncSession, email: str): return result.scalar_one_or_none() async def create_user(db: AsyncSession, email: str, username: str, hashed_password: str): - new_user = User(email=email, username=username, password=hashed_password, avatar="app/static/avatar/default.jpg") + new_user = User(email=email, username=username, password=hashed_password, avatar="/app/static/avatar/default.png") db.add(new_user) await db.commit() await db.refresh(new_user) - return new_user \ No newline at end of file + return new_user + +async def update_user_in_db(db: AsyncSession, user_update: UserUpdate, user_id: int): + stmt = select(User).where(User.id == user_id) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + if user: + if user_update.username: + user.username = user_update.username + user.avatar = user_update.avatar + await db.commit() + await db.refresh(user) + return user + +async def update_user_password(db: AsyncSession, user_id: int, hashed_password: str): + stmt = select(User).where(User.id == user_id) + result = await db.execute(stmt) + user = result.scalar_one_or_none() + if user: + user.password = hashed_password + await db.commit() + await db.refresh(user) + return user diff --git a/app/models/model.py b/app/models/model.py index ef8cf18..b40e8bc 100644 --- a/app/models/model.py +++ b/app/models/model.py @@ -50,7 +50,9 @@ class Folder(Base): articles = relationship('Article', back_populates='folder') __table_args__ = ( - UniqueConstraint('user_id', 'group_id', name='uq_user_group_folder'), + # 不能同时为空 + UniqueConstraint('user_id', 'group_id', name='uq_user_group_folder'), # SQL中认为null 和 null 不相等 + CheckConstraint('user_id IS NOT NULL OR group_id IS NOT NULL', name='check_user_or_group'), ) class Article(Base): diff --git a/app/routers/router.py b/app/routers/router.py index 5d0118c..829a3b8 100644 --- a/app/routers/router.py +++ b/app/routers/router.py @@ -2,6 +2,7 @@ from app.utils.auth import get_current_user from app.api.v1.endpoints.auth import router as auth_router from app.api.v1.endpoints.note import router as note_router +from app.api.v1.endpoints.user import router as user_router def include_auth_router(app): app.include_router(auth_router, prefix="/public", tags=["auth"]) @@ -9,6 +10,10 @@ def include_auth_router(app): def include_note_router(app): app.include_router(note_router, prefix="/notes", tags=["note"], dependencies=[Depends(get_current_user)]) +def include_user_router(app): + app.include_router(user_router, prefix="/user", tags=["user"], dependencies=[Depends(get_current_user)]) + def include_routers(app): include_auth_router(app) - include_note_router(app) \ No newline at end of file + include_note_router(app) + include_user_router(app) \ No newline at end of file diff --git a/app/schemas/auth.py b/app/schemas/auth.py index 9fa4bfc..59569d2 100644 --- a/app/schemas/auth.py +++ b/app/schemas/auth.py @@ -12,4 +12,7 @@ class UserLogin(BaseModel): password: str class UserSendCode(BaseModel): - email: EmailStr \ No newline at end of file + email: EmailStr + +class ReFreshToken(BaseModel): + refresh_token: str \ No newline at end of file diff --git a/app/schemas/user.py b/app/schemas/user.py new file mode 100644 index 0000000..49c965e --- /dev/null +++ b/app/schemas/user.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +class UserUpdate(BaseModel): + username: str | None = None + avatar: str | None = None + +class PasswordUpdate(BaseModel): + old_password: str + new_password: str + diff --git a/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png b/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png new file mode 100644 index 0000000..639c438 Binary files /dev/null and b/app/static/avatar/06914b55-d613-4f7d-854d-4442c1d7782e.png differ diff --git a/app/static/avatar/default.jpg b/app/static/avatar/default.png similarity index 100% rename from app/static/avatar/default.jpg rename to app/static/avatar/default.png diff --git a/app/utils/auth.py b/app/utils/auth.py index 72528cf..f17f09f 100644 --- a/app/utils/auth.py +++ b/app/utils/auth.py @@ -14,11 +14,12 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM] ) email: str = payload.get("sub") - if email is None: + user_id: int = payload.get("id") # 从 payload 中提取用户 ID + if email is None or user_id is None: raise HTTPException( status_code=401, detail="Invalid authentication credentials" ) - return {"email": email} + return {"email": email, "id": user_id} # 返回用户 ID 和 email except PyJWTError: raise HTTPException( status_code=401, detail="Invalid authentication credentials" diff --git a/requirements.txt b/requirements.txt index 078842d..e843fe8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,6 +30,7 @@ PyJWT==2.10.1 PyMySQL==1.1.1 python-dateutil==2.9.0.post0 python-dotenv==1.1.0 +python-multipart==0.0.20 pytz==2025.2 redis==5.2.1 six==1.17.0