diff --git a/app/package-lock.json b/app/package-lock.json index bb97485..f289321 100644 --- a/app/package-lock.json +++ b/app/package-lock.json @@ -22,9 +22,11 @@ "@types/react-dom": "^18.2.22", "axios": "^1.6.8", "react": "^18.2.0", + "react-cookie": "^7.1.4", "react-dom": "^18.2.0", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", + "react-use-cookie": "^1.5.0", "typed-axios-instance": "^3.3.1", "typescript": "^4.9.5", "usehooks-ts": "^3.0.2", @@ -4522,6 +4524,11 @@ "@types/node": "*" } }, + "node_modules/@types/cookie": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@types/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-4Kh9a6B2bQciAhf7FSuMRRkUWecJgJu9nPnx3yzpsfXX/c50REIqpHY4C82bXP90qrLtXtkDxTZosYO3UpOwlA==" + }, "node_modules/@types/eslint": { "version": "8.56.6", "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.6.tgz", @@ -4575,6 +4582,15 @@ "@types/node": "*" } }, + "node_modules/@types/hoist-non-react-statics": { + "version": "3.3.5", + "resolved": "https://registry.npmjs.org/@types/hoist-non-react-statics/-/hoist-non-react-statics-3.3.5.tgz", + "integrity": "sha512-SbcrWzkKBw2cdwRTwQAswfpB9g9LJWfjtUeW/jvNwbhC8cpmmNYVePa+ncbUe0rGTQ7G3Ff6mYUN2VMfLVr+Sg==", + "dependencies": { + "@types/react": "*", + "hoist-non-react-statics": "^3.3.0" + } + }, "node_modules/@types/html-minifier-terser": { "version": "6.1.0", "resolved": "https://registry.npmjs.org/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz", @@ -15265,6 +15281,19 @@ "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.13.11.tgz", "integrity": "sha512-kY1AZVr2Ra+t+piVaJ4gxaFaReZVH40AKNo7UCX6W+dEwBo/2oZJzqfuN1qLq1oL45o56cPaTXELwrTh8Fpggg==" }, + "node_modules/react-cookie": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/react-cookie/-/react-cookie-7.1.4.tgz", + "integrity": "sha512-wDxxa/HYaSXSMlyWJvJ5uZTzIVtQTPf1gMksFgwAz/2/W3lCtY8r4OChCXMPE7wax0PAdMY97UkNJedGv7KnDw==", + "dependencies": { + "@types/hoist-non-react-statics": "^3.3.5", + "hoist-non-react-statics": "^3.3.2", + "universal-cookie": "^7.0.0" + }, + "peerDependencies": { + "react": ">= 16.3.0" + } + }, "node_modules/react-dev-utils": { "version": "12.0.1", "resolved": "https://registry.npmjs.org/react-dev-utils/-/react-dev-utils-12.0.1.tgz", @@ -15529,6 +15558,18 @@ "react-dom": ">=16.6.0" } }, + "node_modules/react-use-cookie": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/react-use-cookie/-/react-use-cookie-1.5.0.tgz", + "integrity": "sha512-zPEAmAYbRLXzpi3VD3rjYHszTo8BonuiaiLH/jYixHr6qE+Yukm2lA6AsinX1uL7/9nFSVeKBLqI4oOZYdhghQ==", + "engines": { + "node": ">=8", + "npm": ">=5" + }, + "peerDependencies": { + "react": ">=16.8" + } + }, "node_modules/read-cache": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", @@ -17618,6 +17659,15 @@ "node": ">=8" } }, + "node_modules/universal-cookie": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/universal-cookie/-/universal-cookie-7.1.4.tgz", + "integrity": "sha512-Q+DVJsdykStWRMtXr2Pdj3EF98qZHUH/fXv/gwFz/unyToy1Ek1w5GsWt53Pf38tT8Gbcy5QNsj61Xe9TggP4g==", + "dependencies": { + "@types/cookie": "^0.6.0", + "cookie": "^0.6.0" + } + }, "node_modules/universalify": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", diff --git a/app/package.json b/app/package.json index fd85f52..5021513 100644 --- a/app/package.json +++ b/app/package.json @@ -20,6 +20,7 @@ "react-dom": "^18.2.0", "react-router-dom": "^6.22.3", "react-scripts": "5.0.1", + "react-use-cookie": "^1.5.0", "typed-axios-instance": "^3.3.1", "typescript": "^4.9.5", "usehooks-ts": "^3.0.2", diff --git a/app/src/features/tokens/hooks/useApiTokens.ts b/app/src/features/tokens/hooks/useApiTokens.ts index b323ae0..31d3ed4 100644 --- a/app/src/features/tokens/hooks/useApiTokens.ts +++ b/app/src/features/tokens/hooks/useApiTokens.ts @@ -5,7 +5,6 @@ import { useGithubAuth } from '../../toolbar/hooks/useGithubAuth'; import axios from 'axios'; import HTTP, { CreateTokenResponse, - DeleteTokenResponse, RawToken, TokensResponse, } from '../../../utils/http'; diff --git a/app/src/features/toolbar/hooks/useGithubAuth.ts b/app/src/features/toolbar/hooks/useGithubAuth.ts index 6ae0f1f..d9b414c 100644 --- a/app/src/features/toolbar/hooks/useGithubAuth.ts +++ b/app/src/features/toolbar/hooks/useGithubAuth.ts @@ -2,6 +2,7 @@ import { useCallback, useEffect, useState } from 'react'; import { useSearchParams } from 'react-router-dom'; import { useLocalSession } from '../../../utils/localStorage'; import { SERVER_URI } from '../../../constants'; +import useCookie from 'react-use-cookie'; import axios from 'axios'; import HTTP, { AuthenticatedUser, @@ -13,13 +14,15 @@ export function useGithubAuth(): [ AuthenticatedUser | null, () => Promise ] { + const [sessionId, setSessionId] = useCookie('session'); const [githubUser, setGithubUser] = useState(null); const { githubCode, saveGithubCode, clearGithubCode } = useLocalSession(); const [searchParams, setSearchParams] = useSearchParams(); const logout = useCallback(async () => { + await HTTP.post(`/logout`); + setSessionId(''); setGithubUser(null); - HTTP.post(`/logout`); }, [setGithubUser]); // If this was a redirect from Github, we have a code to log in with. @@ -43,21 +46,20 @@ export function useGithubAuth(): [ if (data.user) { setGithubUser(data.user); } - }); + }).catch(() => clearGithubCode()); }, [githubCode, setGithubUser, clearGithubCode]); - // Attempt to fetch the logged in user info. useEffect(() => { - if (!!githubUser) { + // Attempt to fetch the logged in user info if the session cookie is set and the user hasn't been fetched. + if (!!githubUser || !sessionId) { return; } HTTP.get(`/user`).then(({ data }) => { - if (data.user) { - setGithubUser(data.user); - } + setGithubUser(data.user); }); - }, [githubUser, setGithubUser]); + }, [githubUser, setGithubUser, sessionId]); return [githubUser, logout]; } + diff --git a/app/src/utils/http.ts b/app/src/utils/http.ts index 6f23244..16d3056 100644 --- a/app/src/utils/http.ts +++ b/app/src/utils/http.ts @@ -16,20 +16,14 @@ export interface LoginRequest { } export interface LoginResponse { - sessionId?: string; - user?: AuthenticatedUser; - error?: string; + sessionId: string; + user: AuthenticatedUser; } export interface UserResponse { - user?: AuthenticatedUser; - error?: string; + user: AuthenticatedUser; } -export interface GenericResponse { - error?: string; - } - export interface RawToken { id: string, name: string, @@ -45,10 +39,6 @@ export interface GenericResponse { error?: string; } - export interface DeleteTokenResponse { - error?: string; - } - export interface TokensResponse { tokens: RawToken[]; error?: string; @@ -69,7 +59,6 @@ type Routes = [ { route: '/logout'; method: 'POST'; - jsonResponse: GenericResponse; }, { route: '/new_token'; @@ -85,7 +74,6 @@ type Routes = [ { route: '/token/[id]'; method: 'DELETE'; - jsonResponse: DeleteTokenResponse; } ]; @@ -97,12 +85,10 @@ const HTTP: TypedAxios = axios.create({ // Intercept the response and log any errors. HTTP.interceptors.response.use(function (response) { // Any status code that lie within the range of 2xx cause this function to trigger - if (response.data.error) { - console.log(`[${response.config.method}] API error: `, response.data.error); - } return response; }, function (error) { // Any status codes that falls outside the range of 2xx cause this function to trigger + console.error('HTTP Error:', error); return Promise.reject(error); }); export default HTTP; diff --git a/src/api/api_token.rs b/src/api/api_token.rs index 991523c..ea35f41 100644 --- a/src/api/api_token.rs +++ b/src/api/api_token.rs @@ -32,23 +32,7 @@ pub struct CreateTokenRequest { #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct CreateTokenResponse { - pub token: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -/// The CreateToken request. -#[derive(Deserialize, Debug)] -pub struct DeleteTokenRequest { - pub id: String, -} - -/// The response to a CreateToken request. -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct DeleteTokenResponse { - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, + pub token: Token, } /// The response to a CreateToken request. @@ -56,6 +40,4 @@ pub struct DeleteTokenResponse { #[serde(rename_all = "camelCase")] pub struct TokensResponse { pub tokens: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, } diff --git a/src/api/auth.rs b/src/api/auth.rs index 784fe42..64a1204 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -3,11 +3,6 @@ use uuid::Uuid; use crate::models; -pub struct UserSessionId { - pub user_id: Uuid, - pub session_id: String, -} - #[derive(Serialize, Deserialize, Debug, Default, Clone)] #[serde(rename_all = "camelCase")] pub struct User { @@ -42,25 +37,13 @@ pub struct LoginRequest { #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct LoginResponse { - pub user: Option, - pub session_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -/// The response to a logout request. -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct LogoutResponse { - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, + pub user: User, + pub session_id: String, } /// The response to a user GET request. #[derive(Serialize)] #[serde(rename_all = "camelCase")] pub struct UserResponse { - pub user: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, + pub user: User, } \ No newline at end of file diff --git a/src/api/mod.rs b/src/api/mod.rs index 199a558..2ca1c1a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,28 +1,30 @@ pub mod auth; pub mod api_token; -use rocket::serde::{Deserialize, Serialize}; +use rocket::{http::Status, response::Responder, serde::{json::Json, Deserialize, Serialize}, Request}; use thiserror::Error; +/// A wrapper for API responses that can return errors. +pub type ApiResult = Result, ApiError>; + +/// An empty response. +#[derive(Serialize)] +pub struct EmptyResponse; + #[derive(Error, Debug)] pub enum ApiError { - #[error("Unauthorized request")] - Unauthorized, - #[error("Missing session cookie")] - MissingCookie, + #[error("Database error: {0}")] + Database(#[from] crate::db::error::DatabaseError), + #[error("GitHub error: {0}")] + Github(#[from] crate::github::GithubError), } -/// The publish request. -#[derive(Deserialize, Debug)] -pub struct PublishRequest { - pub name: String, - pub version: String, +impl<'r, 'o: 'r> Responder<'r, 'o> for ApiError { + fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'o> { + match self { + ApiError::Database(_) => Err(Status::InternalServerError), + ApiError::Github(_) => Err(Status::Unauthorized), + } + } } -/// The response to a publish request. -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PublishResponse { - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} diff --git a/src/db/api_token.rs b/src/db/api_token.rs index 1a6e0d6..22b1245 100644 --- a/src/db/api_token.rs +++ b/src/db/api_token.rs @@ -68,9 +68,6 @@ impl Database { .select(models::Token::as_returning()) .load(connection) .map_err(|_| DatabaseError::NotFound(user_id.to_string())) - // // TODO: fix return type - // eprintln!("res: {:?}", res); - // Err(DatabaseError::NotFound(user_id.to_string())) } } diff --git a/src/db/error.rs b/src/db/error.rs index 8dd73dc..7f2c3a5 100644 --- a/src/db/error.rs +++ b/src/db/error.rs @@ -8,8 +8,8 @@ pub enum DatabaseError { NotFound(String), #[error("Failed to save user: {0}")] InsertUserFailed(String), - #[error("Failed to save session for user; {0}")] + #[error("Failed to save session for user: {0}")] InsertSessionFailed(String), - #[error("Failed to save session for user; {0}")] + #[error("Failed to save token for user: {0}")] InsertTokenFailed(String), } diff --git a/src/db/mod.rs b/src/db/mod.rs index 510e098..ce8fbf4 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -1,4 +1,4 @@ -mod error; +pub mod error; mod user_session; mod api_token; @@ -19,6 +19,7 @@ pub struct Database { pub pool: DbPool, } + impl Default for Database { fn default() -> Self { Database::new() @@ -27,6 +28,7 @@ impl Default for Database { impl Database { pub fn new() -> Self { + // Create a connection pool let pool = Pool::builder() .build(ConnectionManager::::new(db_url())) diff --git a/src/db/user_session.rs b/src/db/user_session.rs index d4ca249..9b7ddaa 100644 --- a/src/db/user_session.rs +++ b/src/db/user_session.rs @@ -14,7 +14,7 @@ impl Database { &self, user: &api::auth::User, expires_in: u32, - ) -> Result { + ) -> Result { let connection = &mut self.connection(); // Insert or update a user @@ -51,7 +51,7 @@ impl Database { .get_result(connection) .map_err(|_| DatabaseError::InsertSessionFailed(user.github_login.clone()))?; - Ok(saved_session.id.to_string()) + Ok(saved_session) } /// Fetch a user given the user ID. @@ -64,6 +64,16 @@ impl Database { .map_err(|_| DatabaseError::NotFound(user_id.to_string())) } + /// Fetch a user given the user ID. + pub fn get_session(&self, session_id: Uuid) -> Result { + let connection = &mut self.connection(); + schema::sessions::table + .filter(schema::sessions::id.eq(session_id)) + .select(models::Session::as_returning()) + .first::(connection) + .map_err(|_| DatabaseError::NotFound(session_id.to_string())) + } + /// Fetch a user from the database for a given session ID. pub fn get_user_for_session(&self, session_id: String) -> Result { let session_uuid = string_to_uuid(session_id.clone())?; @@ -77,12 +87,11 @@ impl Database { } /// Delete a session given its ID. - pub fn delete_session(&self, session_id: String) -> Result<(), DatabaseError> { - let session_uuid = string_to_uuid(session_id.clone())?; + pub fn delete_session(&self, session_id: Uuid) -> Result<(), DatabaseError> { let connection = &mut self.connection(); - diesel::delete(schema::sessions::table.filter(schema::sessions::id.eq(session_uuid))) + diesel::delete(schema::sessions::table.filter(schema::sessions::id.eq(session_id))) .execute(connection) - .map_err(|_| DatabaseError::NotFound(session_id))?; + .map_err(|_| DatabaseError::NotFound(session_id.to_string()))?; Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index de38342..5a9e5c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ pub mod api; -pub mod cors; +pub mod middleware; pub mod db; pub mod github; pub mod models; pub mod schema; -pub mod util; \ No newline at end of file +pub mod util; +pub mod cors; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index b19bd16..c512768 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,219 +3,95 @@ #[macro_use] extern crate rocket; -use std::hash::Hash; - -use forc_pub::api::api_token::{CreateTokenRequest, CreateTokenResponse, DeleteTokenRequest, DeleteTokenResponse, TokensResponse, Token}; -use forc_pub::api::auth::LogoutResponse; +use forc_pub::api::api_token::{CreateTokenRequest, CreateTokenResponse, Token, TokensResponse}; use forc_pub::api::{ - auth::{LoginRequest, LoginResponse, UserResponse, User, UserSessionId}, - ApiError, PublishRequest, PublishResponse, + auth::{LoginRequest, LoginResponse, UserResponse}, + ApiError, ApiResult, EmptyResponse, }; use forc_pub::cors::Cors; +use forc_pub::db::error::DatabaseError; use forc_pub::db::Database; use forc_pub::github::handle_login; +use forc_pub::middleware::session_auth::{SessionAuth, SESSION_COOKIE_NAME}; +use forc_pub::util::sys_time_to_epoch; use rocket::http::{Cookie, CookieJar}; +use rocket::time::OffsetDateTime; use rocket::{serde::json::Json, State}; -use uuid::Uuid; + #[derive(Default)] -struct ServerState { +pub struct ServerState { pub db: Database, } -impl ServerState { - pub fn get_authenticated_session(&self, cookies: &CookieJar) -> Result { - match cookies.get("session").map(|c| c.value()) { - Some(session_id) => { - let user = self - .db - .get_user_for_session(session_id.to_string()) - .map_err(|_| ApiError::Unauthorized)?; - Ok(UserSessionId { user_id : user.id, session_id: session_id.to_string()}) - } - None => Err(ApiError::MissingCookie), - } - } -} - /// The endpoint to authenticate with GitHub. #[post("/login", data = "")] async fn login( - state: &State, + db: &State, cookies: &CookieJar<'_>, request: Json, -) -> Json { - match handle_login(request.code.clone()).await { - Ok((user, expires_in)) => match state.db.insert_user_session(&user, expires_in) { - Ok(session_id) => { - cookies.add(Cookie::build("session", session_id.clone()).expires(None).finish()); - - Json(LoginResponse { - user: Some(user), - session_id: Some(session_id), - error: None, - }) - } - Err(e) => Json(LoginResponse { - user: None, - session_id: None, - error: Some(e.to_string()), - }), - }, - Err(e) => Json(LoginResponse { - user: None, - session_id: None, - error: Some(e.to_string()), - }), - } +) -> ApiResult { + let (user, expires_in) = handle_login(request.code.clone()).await?; + let session = db.insert_user_session(&user, expires_in)?; + let session_id = session.id.to_string(); + cookies.add( + Cookie::build(SESSION_COOKIE_NAME, session_id.clone()) + .finish(), + ); + Ok(Json(LoginResponse { user, session_id })) } /// The endpoint to log out. #[post("/logout")] -async fn logout( - state: &State, - cookies: &CookieJar<'_>, -) -> Json { - - match state.get_authenticated_session(cookies) { - Ok(UserSessionId { session_id, ..}) => { - match state.db.delete_session(session_id.to_string()) { - Ok(_) => { - cookies.remove(Cookie::named("session")); - Json(LogoutResponse { error: None }) - } - Err(e) => Json(LogoutResponse { - error: Some(e.to_string()), - }), - } - } - Err(_) => Json(LogoutResponse { - error: None - }), - } - - +async fn logout(db: &State, auth: SessionAuth) -> ApiResult { + let session_id = auth.session_id; + let _ = db.delete_session(session_id)?; + Ok(Json(EmptyResponse)) } /// The endpoint to authenticate with GitHub. #[get("/user")] -async fn user( - state: &State, - cookies: &CookieJar<'_>, -) -> Json { - // cookies.add(Cookie::build("session", id.clone()).expires(None).finish()); - - // match state.db.get_user_for_session(id) { - // Ok(user) => Json(SessionResponse { - // user: Some(User::from(user)), - // error: None, - // }), - - // Err(error) => Json(SessionResponse { - // user: None, - // error: Some(error.to_string()), - // }), - // } - - match state.get_authenticated_session(cookies) { - Ok(UserSessionId {user_id, ..}) => match state.db.get_user(user_id) { - Ok(user) => Json(UserResponse { - user: Some(User::from(user)), - error: None, - }), - - Err(error) => Json(UserResponse { - user: None, - error: Some(error.to_string()), - }), - }, - - Err(error) => Json(UserResponse { - user: None, - error: Some(error.to_string()), - }), - - } +fn user(auth: SessionAuth) -> Json { + Json(UserResponse { + user: auth.user.into(), + }) } #[post("/new_token", data = "")] fn new_token( - state: &State, - cookies: &CookieJar<'_>, + db: &State, + auth: SessionAuth, request: Json, -) -> Json { - match state.get_authenticated_session(cookies) { - Ok(UserSessionId { user_id, ..}) => match state.db.new_token(user_id, request.name.clone()) { - Ok((token, plain_token)) => Json(CreateTokenResponse { - token: Some(Token { - // The only time we return the plain token is when it's created. - token: Some(plain_token), - ..token.into() - } - ), - error: None, - }), - Err(e) => Json(CreateTokenResponse { - token: None, - error: Some(e.to_string()), - }), +) -> ApiResult { + let user = auth.user; + let (token, plain_token) = db.new_token(user.id, request.name.clone())?; + Ok(Json(CreateTokenResponse { + token: Token { + // The only time we return the plain token is when it's created. + token: Some(plain_token), + ..token.into() }, - - Err(e) => Json(CreateTokenResponse { - token: None, - error: Some(e.to_string()), - }), - } + })) } #[delete("/token/")] fn delete_token( - state: &State, - cookies: &CookieJar<'_>, + db: &State, + auth: SessionAuth, id: String, -) -> Json { - match state.get_authenticated_session(cookies) { - Ok(UserSessionId { user_id, ..}) => match state.db.delete_token(user_id, id.clone()) { - Ok(_) => Json(DeleteTokenResponse { - error: None, - }), - Err(e) => Json(DeleteTokenResponse { - error: Some(e.to_string()), - }), - }, - - Err(e) => Json(DeleteTokenResponse { - error: Some(e.to_string()), - }), - } +) -> ApiResult { + let user_id = auth.user.id; + let _ = db.delete_token(user_id, id.clone())?; + Ok(Json(EmptyResponse)) } #[get("/tokens")] -fn tokens(state: &State, cookies: &CookieJar<'_>) -> Json { - match state.get_authenticated_session(cookies) { - Ok(UserSessionId { user_id, ..}) => match state.db.get_tokens_for_user(user_id) { - Ok(tokens) => Json(TokensResponse { - tokens: tokens.into_iter().map(|t| t.into()).collect(), - error: None, - }), - Err(e) => Json(TokensResponse { - tokens: vec![], - error: Some(e.to_string()), - }), - }, - - Err(e) => Json(TokensResponse { - tokens: vec![], - error: Some(e.to_string()), - }), - } -} - -/// The endpoint to publish a package version. -#[post("/publish", data = "")] -fn publish(request: Json) -> Json { - eprintln!("Received request: {:?}", request); - Json(PublishResponse { error: None }) +fn tokens(db: &State, auth: SessionAuth) -> ApiResult { + let user_id = auth.user.id; + let tokens = db.get_tokens_for_user(user_id)?; + Ok(Json(TokensResponse { + tokens: tokens.into_iter().map(|t| t.into()).collect(), + })) } /// Catches all OPTION requests in order to get the CORS related Fairing triggered. @@ -240,7 +116,7 @@ fn health() -> String { #[launch] fn rocket() -> _ { rocket::build() - .manage(ServerState::default()) + .manage(Database::default()) .attach(Cors) .mount( "/", @@ -251,7 +127,6 @@ fn rocket() -> _ { new_token, delete_token, tokens, - publish, all_options, health ], diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs new file mode 100644 index 0000000..066d9a7 --- /dev/null +++ b/src/middleware/mod.rs @@ -0,0 +1 @@ +pub mod session_auth; \ No newline at end of file diff --git a/src/middleware/session_auth.rs b/src/middleware/session_auth.rs new file mode 100644 index 0000000..350c89c --- /dev/null +++ b/src/middleware/session_auth.rs @@ -0,0 +1,57 @@ +use std::error; +use std::f64::consts::E; +use std::time::SystemTime; + +use crate::db::{string_to_uuid, Database}; +use crate::{api, models}; +use rocket::fairing::{Fairing, Info, Kind}; +use rocket::http::{Header, Status}; +use rocket::outcome::try_outcome; +use rocket::request::{FromRequest, Outcome}; +use rocket::response::status; +use rocket::{Data, Request, Response}; +use thiserror::Error; +use uuid::Uuid; + +pub const SESSION_COOKIE_NAME: &str = "session"; + +pub struct SessionAuth { + pub user: models::User, + pub session_id: Uuid, +} + +#[derive(Debug)] +pub enum SessionAuthError { + Missing, + Invalid, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SessionAuth { + type Error = SessionAuthError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + // TODO: use fairing? + // let db = try_outcome!(request.guard::().await); + + let db = request.rocket().state::().unwrap(); + if let Some(Some(session_id)) = request + .cookies() + .get(SESSION_COOKIE_NAME) + .map(|c| Uuid::parse_str(c.value()).ok()) + { + if let Ok(session) = db.get_session(session_id.clone()) { + if let Ok(user) = db.get_user_for_session(session_id.to_string()) { + if session.expires_at > SystemTime::now() { + return Outcome::Success(SessionAuth { + user: user.into(), + session_id, + }); + } + } + } + return Outcome::Failure((Status::Unauthorized, SessionAuthError::Invalid)); + } + return Outcome::Failure((Status::Unauthorized, SessionAuthError::Missing)); + } +} diff --git a/src/models.rs b/src/models.rs index 599c535..b1d656c 100644 --- a/src/models.rs +++ b/src/models.rs @@ -2,7 +2,7 @@ use diesel::{prelude::*, sql_types::Bytea}; use std::time::SystemTime; use uuid::Uuid; -#[derive(Queryable, Selectable, Debug)] +#[derive(Queryable, Selectable, Debug, Clone)] #[diesel(table_name = crate::schema::users)] #[diesel(check_for_backend(diesel::pg::Pg))] pub struct User { diff --git a/tests/db_integration.rs b/tests/db_integration.rs index 2b93d93..837a49e 100644 --- a/tests/db_integration.rs +++ b/tests/db_integration.rs @@ -49,17 +49,14 @@ fn test_multiple_user_sessions() { let user2 = mock_user_2(); let session1 = db.insert_user_session(&user1, 1000).expect("result is ok"); - Uuid::parse_str(session1.as_str()).expect("result is a valid UUID"); // Insert an existing user let session2 = db.insert_user_session(&user1, 1000).expect("result is ok"); - Uuid::parse_str(session2.as_str()).expect("result is a valid UUID"); // Insert another user let session3 = db.insert_user_session(&user2, 1000).expect("result is ok"); - Uuid::parse_str(session3.as_str()).expect("result is a valid UUID"); - let result = db.get_user_for_session(session1).expect("result is ok"); + let result = db.get_user_for_session(session1.id.to_string()).expect("result is ok"); assert_eq!(result.github_login, TEST_LOGIN_1); assert_eq!(result.full_name, TEST_FULL_NAME_1); assert_eq!(result.email.expect("is some"), TEST_EMAIL_1); @@ -67,10 +64,10 @@ fn test_multiple_user_sessions() { assert_eq!(result.github_url, TEST_URL_2); assert!(result.is_admin); - let result = db.get_user_for_session(session2).expect("result is ok"); + let result = db.get_user_for_session(session2.id.to_string()).expect("result is ok"); assert_eq!(result.github_login, TEST_LOGIN_1); - let result = db.get_user_for_session(session3).expect("result is ok"); + let result = db.get_user_for_session(session3.id.to_string()).expect("result is ok"); assert_eq!(result.github_login, TEST_LOGIN_2); clear_tables(&db);