Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ Attention!!!
- Must connect Redis before running the application. ‼️‼️‼️


## Token Authentication
- JWT (JSON Web Token) is used for authentication.
- Refresh tokens for 7 days and access tokens for 5min.

## Folder Structure
- `app/`: Contains the main application code.
- `tests/`: Contains test cases.
Expand Down
32 changes: 32 additions & 0 deletions alembic/versions/698cace619a6_添加约束.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""添加约束

Revision ID: 698cace619a6
Revises: f1242bbcad2d
Create Date: 2025-04-18 17:18:55.547867

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '698cace619a6'
down_revision: Union[str, None] = 'f1242bbcad2d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
48 changes: 43 additions & 5 deletions app/api/v1/endpoints/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from email.utils import formataddr

from app.schemas.auth import UserCreate, UserLogin, UserSendCode
from app.schemas.auth import UserCreate, UserLogin, UserSendCode, ReFreshToken
from app.core.config import settings
from app.curd.user import get_user_by_email, create_user
from app.utils.get_db import get_db
Expand All @@ -31,11 +31,21 @@ async def create_access_token(data: dict, expires_delta: timedelta = None):
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=15)
expire = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

async def create_refresh_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

@router.post("/register", response_model=dict)
async def register(user: UserCreate, db: AsyncSession = Depends(get_db)):
existing_user = await get_user_by_email(db, user.email)
Expand All @@ -57,11 +67,39 @@ async def login(user: UserLogin, db: AsyncSession = Depends(get_db)):
db_user = await get_user_by_email(db, user.email)
if not db_user or not pwd_context.verify(user.password, db_user.password):
raise HTTPException(status_code=401, detail="Invalid email or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
access_token = await create_access_token(
data={"sub": db_user.email}, expires_delta=access_token_expires
data={"sub": db_user.email, "id": db_user.id}, expires_delta=access_token_expires
)
refresh_token = await create_refresh_token(
data={"sub": db_user.email, "id": db_user.id}, expires_delta=refresh_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"user_id": db_user.id,
"email": db_user.email,
"username": db_user.username,
"avatar": db_user.avatar
}

@router.post("/refresh", response_model=dict)
async def refresh_token(refresh_token: ReFreshToken):
try:
payload = jwt.decode(refresh_token.refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
if payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid refresh token type")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = await create_access_token(
data={"sub": payload["sub"], "id": payload["id"]}, expires_delta=access_token_expires
)
return {"access_token": access_token}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Refresh token expired")
except Exception:
raise HTTPException(status_code=401, detail="Invalid refresh token")

# 发送验证码
@router.post("/send_code", response_model=dict)
Expand Down
75 changes: 75 additions & 0 deletions app/api/v1/endpoints/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from fastapi import APIRouter, HTTPException, Depends, UploadFile, Form, File
from app.schemas.user import UserUpdate, PasswordUpdate
from app.curd.user import update_user_in_db, get_user_by_email, update_user_password
from sqlalchemy.ext.asyncio import AsyncSession
from app.utils.get_db import get_db
from app.utils.auth import get_current_user
from passlib.context import CryptContext
import os
from uuid import uuid4

router = APIRouter()

# update current user
@router.put("", response_model=dict)
async def update_current_user(
username: str = Form(None),
avatar: UploadFile = File(None),
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user)
):
"""
Update the current user's information.
"""
db_user = await get_user_by_email(db, current_user["email"])
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
try:
avatar_url = None
if avatar:
avatar_file: UploadFile = avatar
file_extension = os.path.splitext(avatar_file.filename)[1]
unique_filename = f"{uuid4()}{file_extension}"
avatar_dir = os.path.join("app", "static", "avatar")
avatar_path = os.path.join(avatar_dir, unique_filename)

# 确保以二进制模式写入文件,避免编码问题
with open(avatar_path, "wb") as f:
f.write(await avatar_file.read())

# 生成 URL 路径
avatar_url = f"/app/static/avatar/{unique_filename}"

# 删除旧的头像文件
if db_user.avatar and db_user.avatar != "/app/static/avatar/default.png":
old_avatar_path = db_user.avatar.lstrip("/") # 去掉开头的斜杠
if os.path.exists(old_avatar_path):
os.remove(old_avatar_path)


update_user_response = UserUpdate(
username=username or db_user.username,
avatar=avatar_url if avatar_url else db_user.avatar
)
await update_user_in_db(db, update_user_response, db_user.id)
return {"msg": "User updated successfully"}

except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

@router.post("/password", response_model=dict)
async def change_password(
password_update: PasswordUpdate,
db: AsyncSession = Depends(get_db),
current_user: dict = Depends(get_current_user)
):
db_user = await get_user_by_email(db, current_user["email"])
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
if not pwd_context.verify(password_update.old_password, db_user.password):
raise HTTPException(status_code=400, detail="Old password is incorrect")

await update_user_password(db, db_user.id, pwd_context.hash(password_update.new_password))
return {"msg": "Password changed successfully"}
3 changes: 2 additions & 1 deletion app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class Settings:
SQLALCHEMY_DATABASE_URL = "mysql+asyncmy://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称
SECRET_KEY: str = os.getenv("SECRET_KEY", "default_secret_key") # JWT密钥
ALGORITHM: str = "HS256" # JWT算法
ACCESS_TOKEN_EXPIRE_MINUTES: int = 300 # token过期时间
ACCESS_TOKEN_EXPIRE_MINUTES: int = 5 # token过期时间
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 刷新token过期时间7天
SMTP_SERVER: str = "smtp.163.com" # SMTP服务器
SMTP_PORT: int = 465 # SMTP端口
SENDER_EMAIL : str = "jienote_buaa@163.com"
Expand Down
27 changes: 25 additions & 2 deletions app/curd/user.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.models.model import User
from app.schemas.user import UserUpdate

async def get_user_by_email(db: AsyncSession, email: str):
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
return result.scalar_one_or_none()

async def create_user(db: AsyncSession, email: str, username: str, hashed_password: str):
new_user = User(email=email, username=username, password=hashed_password, avatar="app/static/avatar/default.jpg")
new_user = User(email=email, username=username, password=hashed_password, avatar="/app/static/avatar/default.png")
db.add(new_user)
await db.commit()
await db.refresh(new_user)
return new_user
return new_user

async def update_user_in_db(db: AsyncSession, user_update: UserUpdate, user_id: int):
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if user:
if user_update.username:
user.username = user_update.username
user.avatar = user_update.avatar
await db.commit()
await db.refresh(user)
return user

async def update_user_password(db: AsyncSession, user_id: int, hashed_password: str):
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if user:
user.password = hashed_password
await db.commit()
await db.refresh(user)
return user
4 changes: 3 additions & 1 deletion app/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class Folder(Base):
articles = relationship('Article', back_populates='folder')

__table_args__ = (
UniqueConstraint('user_id', 'group_id', name='uq_user_group_folder'),
# 不能同时为空
UniqueConstraint('user_id', 'group_id', name='uq_user_group_folder'), # SQL中认为null 和 null 不相等
CheckConstraint('user_id IS NOT NULL OR group_id IS NOT NULL', name='check_user_or_group'),
)

class Article(Base):
Expand Down
7 changes: 6 additions & 1 deletion app/routers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
from app.utils.auth import get_current_user
from app.api.v1.endpoints.auth import router as auth_router
from app.api.v1.endpoints.note import router as note_router
from app.api.v1.endpoints.user import router as user_router

def include_auth_router(app):
app.include_router(auth_router, prefix="/public", tags=["auth"])

def include_note_router(app):
app.include_router(note_router, prefix="/notes", tags=["note"], dependencies=[Depends(get_current_user)])

def include_user_router(app):
app.include_router(user_router, prefix="/user", tags=["user"], dependencies=[Depends(get_current_user)])

def include_routers(app):
include_auth_router(app)
include_note_router(app)
include_note_router(app)
include_user_router(app)
5 changes: 4 additions & 1 deletion app/schemas/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ class UserLogin(BaseModel):
password: str

class UserSendCode(BaseModel):
email: EmailStr
email: EmailStr

class ReFreshToken(BaseModel):
refresh_token: str
10 changes: 10 additions & 0 deletions app/schemas/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic import BaseModel

class UserUpdate(BaseModel):
username: str | None = None
avatar: str | None = None

class PasswordUpdate(BaseModel):
old_password: str
new_password: str

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
5 changes: 3 additions & 2 deletions app/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM]
)
email: str = payload.get("sub")
if email is None:
user_id: int = payload.get("id") # 从 payload 中提取用户 ID
if email is None or user_id is None:
raise HTTPException(
status_code=401, detail="Invalid authentication credentials"
)
return {"email": email}
return {"email": email, "id": user_id} # 返回用户 ID 和 email
except PyJWTError:
raise HTTPException(
status_code=401, detail="Invalid authentication credentials"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ PyJWT==2.10.1
PyMySQL==1.1.1
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-multipart==0.0.20
pytz==2025.2
redis==5.2.1
six==1.17.0
Expand Down