Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions app/api/v1/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
import smtplib
import aiosmtplib
from email.mime.text import MIMEText
from email.header import Header
import random
Expand All @@ -18,15 +18,15 @@

router = APIRouter()

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 使用 bcrypt 加密算法
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 使用 bcrypt 加密算法
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = settings.ALGORITHM
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES

# 配置 Redis 连接
redis_client = get_redis_client()

def create_access_token(data: dict, expires_delta: timedelta = None):
async def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
Expand All @@ -37,36 +37,35 @@ def create_access_token(data: dict, expires_delta: timedelta = None):
return encoded_jwt

@router.post("/register", response_model=dict)
def register(user: UserCreate, db: Session = Depends(get_db)):
existing_user = get_user_by_email(db, user.email)
if (redis_client.exists(f"email:{user.email}:code")):
code = redis_client.get(f"email:{user.email}:code").decode("utf-8")
if (user.code != code):
async def register(user: UserCreate, db: AsyncSession = Depends(get_db)):
existing_user = await get_user_by_email(db, user.email)
if redis_client.exists(f"email:{user.email}:code"):
code = redis_client.get(f"email:{user.email}:code")
if user.code != code:
raise HTTPException(status_code=400, detail="Invalid verification code")
else:
raise HTTPException(status_code=400, detail="Verification code expired or not sent")
if (existing_user):

if existing_user:
raise HTTPException(status_code=400, detail="Email already registered")
hashed_password = pwd_context.hash(user.password)
create_user(db, user.email, user.username, hashed_password)
await create_user(db, user.email, user.username, hashed_password)
return {"msg": "User registered successfully"}

@router.post("/login", response_model=dict)
def login(user: UserLogin, db: Session = Depends(get_db)):
db_user = get_user_by_email(db, user.email)
async def login(user: UserLogin, db: AsyncSession = Depends(get_db)):
db_user = await get_user_by_email(db, user.email)
if not db_user or not pwd_context.verify(user.password, db_user.password):
raise HTTPException(status_code=401, detail="Invalid email or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
access_token = await create_access_token(
data={"sub": db_user.email}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer", }
return {"access_token": access_token, "token_type": "bearer"}

# 发送验证码
@router.post("/send_code", response_model=dict)
def send_code(user_send_code : UserSendCode, db: Session = Depends(get_db)):
# 检查 Redis 中是否存在该邮箱的发送记录
async def send_code(user_send_code: UserSendCode):
if redis_client.exists(f"email:{user_send_code.email}:time"):
raise HTTPException(status_code=429, detail="You can only request a verification code once every 5 minutes.")

Expand All @@ -85,23 +84,26 @@ def send_code(user_send_code : UserSendCode, db: Session = Depends(get_db)):

# 创建MIMEText对象时需要显式指定子类型和编码
message = MIMEText(_text=body, _subtype='plain', _charset='utf-8')
message["From"] = formataddr(("JieNote团队", "noreply@jienote.com"))
message["From"] = formataddr(("JieNote团队", "jienote_buaa@163.com"))
message["To"] = user_send_code.email
message["Subject"] = Header(subject, 'utf-8').encode()
# 添加必要的内容传输编码头
message.add_header('Content-Transfer-Encoding', 'base64')

try:
# 连接 SMTP 服务器并发送邮件
with smtplib.SMTP_SSL(smtp_server, smtp_port) as server:
server.login(sender_email, sender_password)
server.sendmail(sender_email, [user_send_code.email], message.as_string())

# 将验证码和发送时间存储到 Redis,设置 5 分钟过期时间
redis_client.setex(f"email:{user_send_code.email}:code", ACCESS_TOKEN_EXPIRE_MINUTES, code)
redis_client.setex(f"email:{user_send_code.email}:time", ACCESS_TOKEN_EXPIRE_MINUTES, int(time.time()))
await aiosmtplib.send(
message,
hostname=smtp_server,
port=smtp_port,
username=sender_email,
password=sender_password,
use_tls=True,
)

redis_client.setex(f"email:{user_send_code.email}:code", ACCESS_TOKEN_EXPIRE_MINUTES * 60, code)
redis_client.setex(f"email:{user_send_code.email}:time", ACCESS_TOKEN_EXPIRE_MINUTES * 60, int(time.time()))

return {"msg": "Verification code sent"}

except smtplib.SMTPException as e:
except aiosmtplib.SMTPException as e:
raise HTTPException(status_code=500, detail=f"Failed to send email: {str(e)}")
18 changes: 9 additions & 9 deletions app/api/v1/endpoints/note.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas.note import NoteCreate, NoteUpdate, NoteFind
from app.utils.get_db import get_db
from app.curd.note import create_note_in_db, delete_note_in_db, update_note_in_db, find_notes_in_db

router = APIRouter()

@router.post("", response_model=dict)
def create_note(note: NoteCreate, db: Session = Depends(get_db)):
new_note = create_note_in_db(note, db)
async def create_note(note: NoteCreate, db: AsyncSession = Depends(get_db)):
new_note = await create_note_in_db(note, db)
return {"msg": "Note created successfully", "note_id": new_note.id}

@router.delete("/{note_id}", response_model=dict)
def delete_note(note_id: int, db: Session = Depends(get_db)):
note = delete_note_in_db(note_id, db)
async def delete_note(note_id: int, db: AsyncSession = Depends(get_db)):
note = await delete_note_in_db(note_id, db)
if not note:
raise HTTPException(status_code=404, detail="Note not found")
return {"msg": "Note deleted successfully"}

@router.put("/{note_id}", response_model=dict)
def update_note(note_id: int, content: str, db: Session = Depends(get_db)):
async def update_note(note_id: int, content: str, db: AsyncSession = Depends(get_db)):
note = NoteUpdate(id=note_id, content=content)
updated_note = update_note_in_db(note_id, note, db)
updated_note = await update_note_in_db(note_id, note, db)
if not updated_note:
raise HTTPException(status_code=404, detail="Note not found")
return {"msg": "Note updated successfully", "note_id": updated_note.id}

@router.get("", response_model=dict)
def get_notes(note_find: NoteFind = Depends(), db: Session = Depends(get_db)):
notes, total_count = find_notes_in_db(note_find, db)
async def get_notes(note_find: NoteFind = Depends(), db: AsyncSession = Depends(get_db)):
notes, total_count = await find_notes_in_db(note_find, db)
return {
"pagination": {
"total_count": total_count,
Expand Down
3 changes: 1 addition & 2 deletions app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
from datetime import timedelta
from dotenv import load_dotenv

load_dotenv()

class Settings:
PROJECT_NAME: str = "JieNote Backend" # 项目名称
VERSION: str = "1.0.0" # 项目版本
SQLALCHEMY_DATABASE_URL = "mysql+pymysql://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称
SQLALCHEMY_DATABASE_URL = "mysql+asyncmy://root:coders007@47.93.172.156:3306/JieNote" # 替换为实际的用户名、密码和数据库名称
SECRET_KEY: str = os.getenv("SECRET_KEY", "default_secret_key") # JWT密钥
ALGORITHM: str = "HS256" # JWT算法
ACCESS_TOKEN_EXPIRE_MINUTES: int = 300 # token过期时间
Expand Down
53 changes: 32 additions & 21 deletions app/curd/note.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,52 @@
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete, func
from app.models.model import Note
from app.schemas.note import NoteCreate, NoteUpdate, NoteFind, NoteResponse

def create_note_in_db(note: NoteCreate, db: Session):
async def create_note_in_db(note: NoteCreate, db: AsyncSession):
new_note = Note(content=note.content, article_id=note.article_id)
db.add(new_note)
db.commit()
db.refresh(new_note)
await db.commit()
await db.refresh(new_note)
return new_note

def delete_note_in_db(note_id: int, db: Session):
note = db.query(Note).filter(Note.id == note_id).first()
async def delete_note_in_db(note_id: int, db: AsyncSession):
stmt = select(Note).where(Note.id == note_id)
result = await db.execute(stmt)
note = result.scalar_one_or_none()
if note:
db.delete(note)
db.commit()
delete_stmt = delete(Note).where(Note.id == note_id)
await db.execute(delete_stmt)
await db.commit()
return note

def update_note_in_db(note_id: int, note: NoteUpdate, db: Session):
existing_note = db.query(Note).filter(Note.id == note_id).first()
async def update_note_in_db(note_id: int, note: NoteUpdate, db: AsyncSession):
stmt = select(Note).where(Note.id == note_id)
result = await db.execute(stmt)
existing_note = result.scalar_one_or_none()
if existing_note:
existing_note.content = note.content
db.commit()
db.refresh(existing_note)
await db.commit()
await db.refresh(existing_note)
return existing_note

def find_notes_in_db(note_find: NoteFind, db: Session):
query = db.query(Note)
async def find_notes_in_db(note_find: NoteFind, db: AsyncSession):
stmt = select(Note)

if note_find.id is not None:
query = query.filter(Note.id == note_find.id)
stmt = stmt.where(Note.id == note_find.id)
elif note_find.article_id is not None:
query = query.filter(Note.article_id == note_find.article_id)
stmt = stmt.where(Note.article_id == note_find.article_id)

total_count_stmt = select(func.count()).select_from(stmt)
total_count_result = await db.execute(total_count_stmt)
total_count = total_count_result.scalar()

totol_count = query.count()
# 添加分页逻辑
if note_find.page is not None and note_find.page_size is not None:
offset = (note_find.page - 1) * note_find.page_size
query = query.offset(offset).limit(note_find.page_size)
notes = [NoteResponse.model_validate(note) for note in query.all()]
return notes, totol_count
stmt = stmt.offset(offset).limit(note_find.page_size)

result = await db.execute(stmt)
notes = [NoteResponse.model_validate(note) for note in result.scalars().all()]
return notes, total_count
15 changes: 9 additions & 6 deletions app/curd/user.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.models.model import User

def get_user_by_email(db: Session, email: str):
return db.query(User).filter(User.email == email).first()
async def get_user_by_email(db: AsyncSession, email: str):
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
return result.scalar_one_or_none()

def create_user(db: Session, email: str, username: str,hashed_password: str):
async def create_user(db: AsyncSession, email: str, username: str, hashed_password: str):
new_user = User(email=email, username=username, password=hashed_password, avatar="app/static/avatar/default.jpg")
db.add(new_user)
db.commit()
db.refresh(new_user)
await db.commit()
await db.refresh(new_user)
return new_user
14 changes: 13 additions & 1 deletion app/db/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession

from app.core.config import settings

engine = create_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) #连接mysql
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 创建异步引擎
async_engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)

# 创建异步会话
async_session = sessionmaker(
bind=async_engine,
class_=AsyncSession,
autocommit=False,
autoflush=False
)
9 changes: 7 additions & 2 deletions app/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from jwt import PyJWTError, decode
from app.core.config import settings
from fastapi import Depends, HTTPException
import asyncio

# 配置 OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")

def get_current_user(token: str = Depends(oauth2_scheme)):
async def get_current_user(token: str = Depends(oauth2_scheme)):
try:
payload = decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
loop = asyncio.get_event_loop()
payload = await loop.run_in_executor(
None, decode, token, settings.SECRET_KEY, [settings.ALGORITHM]
)
email: str = payload.get("sub")
if email is None:
raise HTTPException(
Expand Down
11 changes: 4 additions & 7 deletions app/utils/get_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from app.db.session import SessionLocal
from app.db.session import async_session

def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_db():
async with async_session() as db:
yield db
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
aiosmtplib==4.0.0
alembic==1.15.2
annotated-types==0.7.0
anyio==4.9.0
async-timeout==5.0.1
asyncmy==0.2.10
bcrypt==4.3.0
cffi==1.17.1
click==8.1.8
Expand All @@ -10,6 +13,7 @@ dnspython==2.7.0
dotenv==0.9.9
email_validator==2.2.0
fastapi==0.115.12
fastapi-pagination==0.12.34
greenlet==3.1.1
h11==0.14.0
idna==3.10
Expand Down