Skip to content

Commit

Permalink
tests(graphql): add project and prompt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnatarHe committed Aug 17, 2023
1 parent c35983a commit 769cc94
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 4 deletions.
32 changes: 32 additions & 0 deletions schema/project_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package schema

import (
"context"
"errors"
"testing"

"github.com/PromptPal/PromptPal/config"
Expand Down Expand Up @@ -81,6 +82,37 @@ func (s *projectTestSuite) TestGetProject() {
assert.Nil(s.T(), err)
pj := result
assert.Equal(s.T(), s.projectName, pj.Name())
assert.Equal(s.T(), "https://api.openai.com/v1", pj.OpenAIBaseURL())
assert.Equal(s.T(), "gpt-3.5-turbo", pj.OpenAIModel())
assert.Equal(s.T(), "SOME_RANDOM_TOKEN_HERE", pj.OpenAIToken())
assert.EqualValues(s.T(), 1, pj.OpenAITemperature())
assert.EqualValues(s.T(), 0.9, pj.OpenAITopP())
assert.EqualValues(s.T(), 0, pj.OpenAIMaxTokens())
assert.NotEmpty(s.T(), pj.CreatedAt())
assert.NotEmpty(s.T(), pj.UpdatedAt())

creator, err := pj.Creator(ctx)
assert.Nil(s.T(), err)

assert.EqualValues(s.T(), 1, creator.ID())

lps := pj.LatestPrompts(ctx)
cs, err := lps.Count(ctx)
assert.Nil(s.T(), err)
assert.EqualValues(s.T(), cs, 0)
lps2, err := lps.Edges(ctx)
assert.Nil(s.T(), err)
assert.Len(s.T(), lps2, 0)

_, err = q.Project(ctx, projectArgs{
ID: int32(887771),
})
assert.Error(s.T(), err)

ge, ok := err.(GraphQLHttpError)
assert.True(s.T(), ok)
assert.EqualValues(s.T(), "[500]: ent: project not found", ge.Error())
assert.EqualValues(s.T(), errors.New("[500]: ent: project not found"), ge.Unwrap())
}

func (s *projectTestSuite) TearDownSuite() {
Expand Down
8 changes: 4 additions & 4 deletions schema/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,21 @@ func (q QueryResolver) UpdatePrompt(ctx context.Context, args updatePromptArgs)
return
}

updator := tx.Prompt.UpdateOneID(int(args.ID)).
updater := tx.Prompt.UpdateOneID(int(args.ID)).
SetDescription(payload.Description).
SetTokenCount(int(payload.TokenCount)).
SetPrompts(payload.Prompts).
SetVariables(payload.Variables).
SetPublicLevel(payload.PublicLevel)

if args.Data.Enabled != nil {
updator = updator.SetEnabled(*args.Data.Enabled)
updater = updater.SetEnabled(*args.Data.Enabled)
}
if args.Data.Debug != nil {
updator = updator.SetNillableDebug(args.Data.Debug)
updater = updater.SetNillableDebug(args.Data.Debug)
}

updatedPrompt, err := updator.Save(ctx)
updatedPrompt, err := updater.Save(ctx)

if err != nil {
tx.Rollback()
Expand Down
56 changes: 56 additions & 0 deletions schema/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,62 @@ func (s *promptTestSuite) TestGetPrompt() {
assert.Nil(s.T(), err)
pt := result
assert.Equal(s.T(), "test-prompt", pt.Name())

assert.NotEmpty(s.T(), pt.CreatedAt())
assert.NotEmpty(s.T(), pt.UpdatedAt())
}

func (s *promptTestSuite) TestUpdatePrompt() {
q := QueryResolver{}
ctx := context.WithValue(context.Background(), service.GinGraphQLContextKey, service.GinGraphQLContextType{
UserID: 1,
})

truthy := true

result, err := q.UpdatePrompt(ctx, updatePromptArgs{
ID: int32(s.promptID),
Data: createPromptData{
ProjectID: int32(s.pjID),
Name: "test-prompt-podcast-AsyncTalk",
Description: "welcome to listen the podcast: `AsyncTalk`",
TokenCount: 9231,
Enabled: &truthy,
Debug: &truthy,
PublicLevel: prompt.PublicLevelPrivate,
Prompts: []dbSchema.PromptRow{
{
Prompt: "AsyncTalk podcast is a a good chinese podcast talk about frontend development {{ var88 }}",
Role: "system",
},
},
Variables: []dbSchema.PromptVariable{
{
Name: "var88",
Type: "string",
},
},
},
})
assert.Nil(s.T(), err)
assert.True(s.T(), result.Debug())
assert.True(s.T(), result.Enabled())
assert.EqualValues(s.T(), "test-prompt", result.Name())
assert.EqualValues(s.T(), 9231, result.TokenCount())
assert.EqualValues(s.T(), s.promptID, result.ID())

pts := result.Prompts()
assert.Len(s.T(), pts, 1)
pt := pts[0]
assert.Equal(s.T(), "system", pt.Role())
assert.Equal(s.T(), "AsyncTalk podcast is a a good chinese podcast talk about frontend development {{ var88 }}", pt.Prompt())

vars := result.Variables()
assert.Len(s.T(), vars, 1)
var1 := vars[0]
assert.Equal(s.T(), "var88", var1.Name())
assert.Equal(s.T(), "string", var1.Type())

}

func (s *promptTestSuite) TearDownSuite() {
Expand Down

0 comments on commit 769cc94

Please sign in to comment.