From ac9be443f43e620fad57096d9628035db5f10dfb Mon Sep 17 00:00:00 2001 From: Ahmed Radwan Date: Mon, 1 Jan 2024 17:45:20 +0200 Subject: [PATCH] test renew access token API --- api/main_test.go | 6 +- api/token_test.go | 184 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 api/token_test.go diff --git a/api/main_test.go b/api/main_test.go index a2e073f..12bec16 100644 --- a/api/main_test.go +++ b/api/main_test.go @@ -12,10 +12,12 @@ import ( "github.com/stretchr/testify/require" ) +// newTestServer create a test server with a config suitable for testing func newTestServer(t *testing.T, store db.Store) *Server { config := util.Config{ - TokenSymmetricKey: util.RandomString(32), - AccessTokenDuration: time.Minute, + TokenSymmetricKey: util.RandomString(32), + AccessTokenDuration: time.Minute, + RefreshTokenDuration: time.Minute, } server, err := NewServer(config, store) require.NoError(t, err) diff --git a/api/token_test.go b/api/token_test.go new file mode 100644 index 0000000..b70c1b9 --- /dev/null +++ b/api/token_test.go @@ -0,0 +1,184 @@ +package api + +import ( + "bytes" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + mockdb "github.com/aradwann/eenergy/db/mock" + db "github.com/aradwann/eenergy/db/store" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func createTestSession(t *testing.T, user db.User, server *Server) (db.Session, string) { + refreshToken, refreshPayload, err := server.tokenMaker.CreateToken( + user.Username, + server.config.RefreshTokenDuration, + ) + require.NoError(t, err, "failed to create token") + + session := db.Session{ + ID: refreshPayload.ID, + Username: user.Username, + RefreshToken: refreshToken, + UserAgent: "unknown", + ClientIp: "unknown", + IsBlocked: false, + ExpiresAt: refreshPayload.ExpiredAt, + } + return session, refreshToken +} + +func TestRenewAccessTokenAPI(t *testing.T) { + user, _ := randomUser(t) + + testCases := []struct { + name string + buildStubs func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string + checkResponse func(recorder *httptest.ResponseRecorder) + }{ + { + name: "RenewAccessToken_OK", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + session, refreshToken := createTestSession(t, user, server) + + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(session, nil) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "RenewAccessToken_Unauthorized", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + // Create a session with a blocked flag + session, refreshToken := createTestSession(t, user, server) + session.IsBlocked = true + + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(session, nil) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "RenewAccessToken_ExpiredSession", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + // Create a session with an expired ExpiresAt + session, refreshToken := createTestSession(t, user, server) + session.ExpiresAt = time.Now().Add(-time.Hour) // Set to a past time + + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(session, nil) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "RenewAccessToken_SessionNotFound", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + // Create a session with an expired ExpiresAt + session, refreshToken := createTestSession(t, user, server) + + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(db.Session{}, sql.ErrNoRows) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusNotFound, recorder.Code) + }, + }, + { + name: "RenewAccessToken_MismatchToken", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + // Create a session with an expired ExpiresAt + session, refreshToken := createTestSession(t, user, server) + session.RefreshToken = "mismatchttttoken" + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(session, nil) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "RenewAccessToken_IncorrectSessionUser", + buildStubs: func(t *testing.T, store *mockdb.MockStore, user db.User, server *Server) string { + // Create a session with an expired ExpiresAt + session, refreshToken := createTestSession(t, user, server) + session.Username = "mismatchusername" + // Ensure GetSession is called with the expected arguments + store.EXPECT(). + GetSession(gomock.Any(), gomock.Eq(session.ID)). + Times(1). + Return(session, nil) + + return refreshToken + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + + store := mockdb.NewMockStore(ctrl) + + server := newTestServer(t, store) + refreshToken := tc.buildStubs(t, store, user, server) + body := gin.H{ + "refresh_token": refreshToken, + } + recorder := httptest.NewRecorder() + + // Marshal body data to JSON + data, err := json.Marshal(body) + require.NoError(t, err) + + url := "/tokens/renew_access" + request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) + require.NoError(t, err) + + server.router.ServeHTTP(recorder, request) + tc.checkResponse(recorder) + }) + } +}