Skip to content

Commit

Permalink
feat(option-menu): added an option menu as i ran out of key presses
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadBnei committed Feb 17, 2024
1 parent c622954 commit a4c9e10
Show file tree
Hide file tree
Showing 16 changed files with 452 additions and 96 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/TannerKvarfordt/hfapigo v1.3.1
github.com/atotto/clipboard v0.1.4
github.com/briandowns/spinner v1.23.0
github.com/bwmarrin/snowflake v0.3.0
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500
github.com/charmbracelet/bubbles v0.17.2-0.20240108170749-ec883029c8e6
github.com/charmbracelet/bubbletea v0.25.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ github.com/briandowns/spinner v1.23.0 h1:alDF2guRWqa/FOZZYWjlMIx2L6H0wyewPxo/CH4
github.com/briandowns/spinner v1.23.0/go.mod h1:rPG4gmXeN3wQV/TsAY4w8lPdIM6RX3yqeBQJSrbXjuE=
github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE=
github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0=
github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE=
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500 h1:6lhrsTEnloDPXyeZBvSYvQf8u86jbKehZPVDDlkgDl4=
github.com/c2h5oh/datasize v0.0.0-20231215233829-aa82cc1e6500/go.mod h1:S/7n9copUssQ56c7aAgHqftWO4LTf4xY6CGWt8Bc+3M=
github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA=
Expand Down
21 changes: 11 additions & 10 deletions service/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"errors"
"fmt"

"github.com/bwmarrin/snowflake"
"github.com/samber/lo"
)

type ContextHold struct {
UserChatId int
UserChatId snowflake.ID
Ctx context.Context
CancelFn func()
}
Expand All @@ -34,8 +35,8 @@ func (pc *PromptConfig) AddContext(ctx context.Context, cancelFn func()) {
pc.Contexts = append(pc.Contexts, ContextHold{Ctx: ctx, CancelFn: cancelFn})
}

func (pc *PromptConfig) AddContextWithId(ctx context.Context, cancelFn func(), id int) {
pc.Contexts = append(pc.Contexts, ContextHold{Ctx: ctx, CancelFn: cancelFn, UserChatId: id})
func (pc *PromptConfig) AddContextWithId(ctx context.Context, cancelFn func(), id int64) {
pc.Contexts = append(pc.Contexts, ContextHold{Ctx: ctx, CancelFn: cancelFn, UserChatId: snowflake.ParseInt64(id)})
}

func (pc *PromptConfig) DeleteContext(ctx context.Context) {
Expand All @@ -44,28 +45,28 @@ func (pc *PromptConfig) DeleteContext(ctx context.Context) {
})
}

func (pc *PromptConfig) FindContextWithId(id int) *ContextHold {
func (pc *PromptConfig) FindContextWithId(id int64) *ContextHold {
ctx, _ := lo.Find(pc.Contexts, func(item ContextHold) bool {
return item.UserChatId != id
return item.UserChatId != snowflake.ParseInt64(id)
})
return &ctx
}

func (pc *PromptConfig) DeleteContextById(id int) {
func (pc *PromptConfig) DeleteContextById(id int64) {
pc.Contexts = lo.Filter(pc.Contexts, func(item ContextHold, index int) bool {
return item.UserChatId != id
return item.UserChatId != snowflake.ParseInt64(id)
})
}

func (pc *PromptConfig) CloseContextById(id int) error {
ctx, _, ok := lo.FindLastIndexOf(pc.Contexts, func(item ContextHold) bool { return item.UserChatId == id })
func (pc *PromptConfig) CloseContextById(id int64) error {
ctx, _, ok := lo.FindLastIndexOf(pc.Contexts, func(item ContextHold) bool { return item.UserChatId == snowflake.ParseInt64(id) })
if !ok {
return fmt.Errorf("no context found with id %d, %s", id, pc.Contexts)
}
ctx.CancelFn()

pc.Contexts = lo.Filter(pc.Contexts, func(item ContextHold, index int) bool {
return item.UserChatId != id
return item.UserChatId != snowflake.ParseInt64(id)
})

return nil
Expand Down
28 changes: 17 additions & 11 deletions service/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/MohammadBnei/go-ai-cli/api"
"github.com/MohammadBnei/go-ai-cli/audio"
"github.com/MohammadBnei/go-ai-cli/tool"
"github.com/bwmarrin/snowflake"
"github.com/jinzhu/copier"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
Expand All @@ -35,13 +36,13 @@ const (
)

type ChatMessage struct {
Id int
Id snowflake.ID
Role ROLES `json:"role"`
Content string `json:"content"`
Tokens int `json:"tokens"`
Type TYPE

AssociatedMessageId int
AssociatedMessageId int64

ToolCall openai.ToolCall
Date time.Time
Expand All @@ -60,21 +61,26 @@ type ChatMessages struct {
Description string
Messages []ChatMessage
TotalTokens int

node *snowflake.Node
}

func NewChatMessages(id string) *ChatMessages {
node, _ := snowflake.NewNode(1)
return &ChatMessages{
Id: id,
Messages: []ChatMessage{},
TotalTokens: 0,

node: node,
}
}

func (c *ChatMessages) SetId(id string) *ChatMessages {
c.Id = id

return c
}

func (c *ChatMessages) SetDescription(description string) *ChatMessages {
c.Description = description

Expand Down Expand Up @@ -118,9 +124,9 @@ func (c *ChatMessages) LoadFromFile(filename string) error {
return nil
}

func (c *ChatMessages) FindById(id int) *ChatMessage {
func (c *ChatMessages) FindById(id int64) *ChatMessage {
_, index, ok := lo.FindIndexOf[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Id == id
return item.Id == snowflake.ParseInt64(int64(id))
})
if !ok {
return nil
Expand Down Expand Up @@ -162,7 +168,7 @@ func (c *ChatMessages) AddMessage(content string, role ROLES) (*ChatMessage, err
}

msg := ChatMessage{
Id: len(c.Messages),
Id: c.node.Generate(),
Role: role,
Content: content,
Date: time.Now(),
Expand Down Expand Up @@ -219,7 +225,7 @@ func (c *ChatMessage) PlayAudio(ctx context.Context) error {
return audio.PlaySound(ctx, c.Audio)
}

func (c *ChatMessages) SetAssociatedId(idUser, idAssistant int) error {
func (c *ChatMessages) SetAssociatedId(idUser, idAssistant int64) error {
msgUser := c.FindById(idUser)
if msgUser == nil {
return errors.New("user message not found")
Expand Down Expand Up @@ -250,7 +256,7 @@ func (c *ChatMessages) UpdateMessage(m ChatMessage) error {
return err
}

msg := c.FindById(m.Id)
msg := c.FindById(m.Id.Int64())
if msg == nil {
c.AddMessage(m.Content, m.Role)
return nil
Expand All @@ -265,9 +271,9 @@ func (c *ChatMessages) UpdateMessage(m ChatMessage) error {
return nil
}

func (c *ChatMessages) DeleteMessage(id int) error {
func (c *ChatMessages) DeleteMessage(id int64) error {
message, ok := lo.Find[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Id == id
return item.Id == snowflake.ParseInt64(id)
})
if !ok {
return errors.New("message not found")
Expand All @@ -276,7 +282,7 @@ func (c *ChatMessages) DeleteMessage(id int) error {
c.TotalTokens -= message.Tokens

c.Messages = lo.Filter[ChatMessage](c.Messages, func(item ChatMessage, _ int) bool {
return item.Id != id
return item.Id != snowflake.ParseInt64(id)
})

return nil
Expand Down
9 changes: 4 additions & 5 deletions ui/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func Chat(pc *service.PromptConfig) {
}

type currentChatIndexes struct {
user int
assistant int
user, assistant int64
}
type chatModel struct {
viewport viewport.Model
Expand Down Expand Up @@ -316,11 +315,11 @@ func (m chatModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.transitionModel.Title = msg.Title

case service.ChatMessage:
if msg.Id == m.currentChatIndices.user {
if msg.Id.Int64() == m.currentChatIndices.user {
m.userPrompt = msg.Content
}

if msg.Id == m.currentChatIndices.assistant {
if msg.Id.Int64() == m.currentChatIndices.assistant {
m.aiResponse = msg.Content
}

Expand Down Expand Up @@ -404,7 +403,7 @@ func (m chatModel) LoadingTitle() {
func (m chatModel) GetTitleView() string {
userPrompt := m.userPrompt
if m.currentChatIndices.user >= 0 {
_, index, _ := lo.FindIndexOf[service.ChatMessage](m.promptConfig.ChatMessages.Messages, func(c service.ChatMessage) bool { return c.Id == m.currentChatIndices.user })
_, index, _ := lo.FindIndexOf[service.ChatMessage](m.promptConfig.ChatMessages.Messages, func(c service.ChatMessage) bool { return c.Id.Int64() == m.currentChatIndices.user })
userPrompt = fmt.Sprintf("[%d] %s", index+1, userPrompt)
}
if userPrompt == "" {
Expand Down
24 changes: 14 additions & 10 deletions ui/chat/chatUpdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ func changeResponseUp(m chatModel) (chatModel, tea.Cmd) {
if len(m.promptConfig.ChatMessages.Messages) == 0 {
return m, nil
}
currentIndexes := lo.Filter[int]([]int{m.currentChatIndices.user, m.currentChatIndices.assistant}, func(i int, _ int) bool { return i >= 0 })
currentIndexes := lo.Filter([]int64{m.currentChatIndices.user, m.currentChatIndices.assistant}, func(i int64, _ int) bool { return i >= 0 })
minIndex := lo.Min(currentIndexes)
previous := minIndex - 1
if len(currentIndexes) == 0 {
previous = len(m.promptConfig.ChatMessages.Messages) - 1
previous = m.promptConfig.ChatMessages.Messages[len(m.promptConfig.ChatMessages.Messages)-1].Id.Int64()
}
c := m.promptConfig.ChatMessages.FindById(previous)
if c == nil {
Expand All @@ -88,7 +88,7 @@ func changeResponseDown(m chatModel) (chatModel, tea.Cmd) {
if len(m.promptConfig.ChatMessages.Messages) == 0 {
return m, nil
}
maxIndex := lo.Max([]int{m.currentChatIndices.assistant, m.currentChatIndices.user})
maxIndex := lo.Max([]int64{m.currentChatIndices.assistant, m.currentChatIndices.user})
next := maxIndex + 1
c := m.promptConfig.ChatMessages.FindById(next)
if c == nil {
Expand Down Expand Up @@ -139,10 +139,10 @@ func promptSend(m *chatModel) (tea.Model, tea.Cmd) {
return m, event.Error(err)
}

m.currentChatIndices.user = userMsg.Id
m.currentChatIndices.assistant = assistantMessage.Id
m.currentChatIndices.user = userMsg.Id.Int64()
m.currentChatIndices.assistant = assistantMessage.Id.Int64()

m.promptConfig.ChatMessages.SetAssociatedId(userMsg.Id, assistantMessage.Id)
m.promptConfig.ChatMessages.SetAssociatedId(userMsg.Id.Int64(), assistantMessage.Id.Int64())

go sendPrompt(m.promptConfig, *m.currentChatIndices)

Expand All @@ -159,15 +159,15 @@ func (m *chatModel) changeCurrentChatHelper(previous *service.ChatMessage) {
if previous.AssociatedMessageId >= 0 {
switch previous.Role {
case service.RoleUser:
m.currentChatIndices.user = previous.Id
m.currentChatIndices.user = previous.Id.Int64()
m.currentChatIndices.assistant = previous.AssociatedMessageId
case service.RoleAssistant:
m.currentChatIndices.assistant = previous.Id
m.currentChatIndices.assistant = previous.Id.Int64()
m.currentChatIndices.user = previous.AssociatedMessageId
}
} else {
m.currentChatIndices.assistant = -1
m.currentChatIndices.user = previous.Id
m.currentChatIndices.user = previous.Id.Int64()
}

if m.currentChatIndices.assistant >= 0 {
Expand Down Expand Up @@ -212,7 +212,7 @@ func sendPrompt(pc *service.PromptConfig, currentChatIds currentChatIndexes) err
return nil
}),
}

if v := viper.GetFloat64("temperature"); v >= 0 {
options = append(options, llms.WithTemperature(v))
}
Expand All @@ -224,6 +224,10 @@ func sendPrompt(pc *service.PromptConfig, currentChatIds currentChatIndexes) err
options = append(options, llms.WithTopP(v))
}

if pc.UpdateChan != nil {
pc.UpdateChan <- *pc.ChatMessages.FindById(currentChatIds.assistant)
}

_, err = generate(ctx, pc.ChatMessages.ToLangchainMessage(),
options...,
)
Expand Down
Loading

0 comments on commit a4c9e10

Please sign in to comment.