From 769cc94bdd361ef56e05f01da6faba27d4facdc8 Mon Sep 17 00:00:00 2001 From: AnnatarHe Date: Thu, 17 Aug 2023 22:55:57 +0800 Subject: [PATCH] tests(graphql): add project and prompt tests --- schema/project_test.go | 32 ++++++++++++++++++++++++ schema/prompt.go | 8 +++--- schema/prompt_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/schema/project_test.go b/schema/project_test.go index 34cf4b5..f0d81f1 100644 --- a/schema/project_test.go +++ b/schema/project_test.go @@ -2,6 +2,7 @@ package schema import ( "context" + "errors" "testing" "github.com/PromptPal/PromptPal/config" @@ -81,6 +82,37 @@ func (s *projectTestSuite) TestGetProject() { assert.Nil(s.T(), err) pj := result assert.Equal(s.T(), s.projectName, pj.Name()) + assert.Equal(s.T(), "https://api.openai.com/v1", pj.OpenAIBaseURL()) + assert.Equal(s.T(), "gpt-3.5-turbo", pj.OpenAIModel()) + assert.Equal(s.T(), "SOME_RANDOM_TOKEN_HERE", pj.OpenAIToken()) + assert.EqualValues(s.T(), 1, pj.OpenAITemperature()) + assert.EqualValues(s.T(), 0.9, pj.OpenAITopP()) + assert.EqualValues(s.T(), 0, pj.OpenAIMaxTokens()) + assert.NotEmpty(s.T(), pj.CreatedAt()) + assert.NotEmpty(s.T(), pj.UpdatedAt()) + + creator, err := pj.Creator(ctx) + assert.Nil(s.T(), err) + + assert.EqualValues(s.T(), 1, creator.ID()) + + lps := pj.LatestPrompts(ctx) + cs, err := lps.Count(ctx) + assert.Nil(s.T(), err) + assert.EqualValues(s.T(), cs, 0) + lps2, err := lps.Edges(ctx) + assert.Nil(s.T(), err) + assert.Len(s.T(), lps2, 0) + + _, err = q.Project(ctx, projectArgs{ + ID: int32(887771), + }) + assert.Error(s.T(), err) + + ge, ok := err.(GraphQLHttpError) + assert.True(s.T(), ok) + assert.EqualValues(s.T(), "[500]: ent: project not found", ge.Error()) + assert.EqualValues(s.T(), errors.New("[500]: ent: project not found"), ge.Unwrap()) } func (s *projectTestSuite) TearDownSuite() { diff --git a/schema/prompt.go b/schema/prompt.go index ddd8b19..3058bfc 100644 --- a/schema/prompt.go +++ b/schema/prompt.go @@ -115,7 +115,7 @@ func (q QueryResolver) UpdatePrompt(ctx context.Context, args updatePromptArgs) return } - updator := tx.Prompt.UpdateOneID(int(args.ID)). + updater := tx.Prompt.UpdateOneID(int(args.ID)). SetDescription(payload.Description). SetTokenCount(int(payload.TokenCount)). SetPrompts(payload.Prompts). @@ -123,13 +123,13 @@ func (q QueryResolver) UpdatePrompt(ctx context.Context, args updatePromptArgs) SetPublicLevel(payload.PublicLevel) if args.Data.Enabled != nil { - updator = updator.SetEnabled(*args.Data.Enabled) + updater = updater.SetEnabled(*args.Data.Enabled) } if args.Data.Debug != nil { - updator = updator.SetNillableDebug(args.Data.Debug) + updater = updater.SetNillableDebug(args.Data.Debug) } - updatedPrompt, err := updator.Save(ctx) + updatedPrompt, err := updater.Save(ctx) if err != nil { tx.Rollback() diff --git a/schema/prompt_test.go b/schema/prompt_test.go index 555dcb3..3797edd 100644 --- a/schema/prompt_test.go +++ b/schema/prompt_test.go @@ -122,6 +122,62 @@ func (s *promptTestSuite) TestGetPrompt() { assert.Nil(s.T(), err) pt := result assert.Equal(s.T(), "test-prompt", pt.Name()) + + assert.NotEmpty(s.T(), pt.CreatedAt()) + assert.NotEmpty(s.T(), pt.UpdatedAt()) +} + +func (s *promptTestSuite) TestUpdatePrompt() { + q := QueryResolver{} + ctx := context.WithValue(context.Background(), service.GinGraphQLContextKey, service.GinGraphQLContextType{ + UserID: 1, + }) + + truthy := true + + result, err := q.UpdatePrompt(ctx, updatePromptArgs{ + ID: int32(s.promptID), + Data: createPromptData{ + ProjectID: int32(s.pjID), + Name: "test-prompt-podcast-AsyncTalk", + Description: "welcome to listen the podcast: `AsyncTalk`", + TokenCount: 9231, + Enabled: &truthy, + Debug: &truthy, + PublicLevel: prompt.PublicLevelPrivate, + Prompts: []dbSchema.PromptRow{ + { + Prompt: "AsyncTalk podcast is a a good chinese podcast talk about frontend development {{ var88 }}", + Role: "system", + }, + }, + Variables: []dbSchema.PromptVariable{ + { + Name: "var88", + Type: "string", + }, + }, + }, + }) + assert.Nil(s.T(), err) + assert.True(s.T(), result.Debug()) + assert.True(s.T(), result.Enabled()) + assert.EqualValues(s.T(), "test-prompt", result.Name()) + assert.EqualValues(s.T(), 9231, result.TokenCount()) + assert.EqualValues(s.T(), s.promptID, result.ID()) + + pts := result.Prompts() + assert.Len(s.T(), pts, 1) + pt := pts[0] + assert.Equal(s.T(), "system", pt.Role()) + assert.Equal(s.T(), "AsyncTalk podcast is a a good chinese podcast talk about frontend development {{ var88 }}", pt.Prompt()) + + vars := result.Variables() + assert.Len(s.T(), vars, 1) + var1 := vars[0] + assert.Equal(s.T(), "var88", var1.Name()) + assert.Equal(s.T(), "string", var1.Type()) + } func (s *promptTestSuite) TearDownSuite() {