Skip to content

Commit

Permalink
feat(all): stable working chat
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadBnei committed Feb 18, 2024
1 parent 57817ce commit 3cca3ca
Show file tree
Hide file tree
Showing 13 changed files with 648 additions and 471 deletions.
321 changes: 321 additions & 0 deletions service/chatMessages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
package service

import (
"errors"
"os"
"sort"
"time"

"github.com/MohammadBnei/go-ai-cli/tool"
"github.com/bwmarrin/snowflake"
"github.com/jinzhu/copier"
"github.com/pkoukk/tiktoken-go"
"github.com/samber/lo"
"github.com/sashabaranov/go-openai"
"github.com/spf13/viper"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"gopkg.in/yaml.v3"
)

type ROLES string
type TYPE string

const (
RoleUser ROLES = "user"
RoleSystem ROLES = "system"
RoleAssistant ROLES = "assistant"
RoleApp ROLES = "app"

TypeFile TYPE = "file"
TypeUser TYPE = "user"
)

type ChatMessages struct {
Id string
Description string
Messages []ChatMessage
TotalTokens int

node *snowflake.Node `json:"-"`
}

func NewChatMessages(id string) *ChatMessages {
node, _ := snowflake.NewNode(1)
return &ChatMessages{
Id: id,
Messages: []ChatMessage{},
TotalTokens: 0,

node: node,
}
}

func (c *ChatMessages) SetId(id string) *ChatMessages {
c.Id = id

return c
}
func (c *ChatMessages) SetDescription(description string) *ChatMessages {
c.Description = description

return c
}

func (c *ChatMessages) SaveToFile(filename string) error {
if filename == "" {
return errors.New("filename cannot be empty")
}

data, err := yaml.Marshal(c)
if err != nil {
return err
}

err = tool.SaveToFile(data, filename, false)
if err != nil {
return err
}

return nil
}

func (c *ChatMessages) LoadFromFile(filename string) error {
content, err := os.ReadFile(filename)
if err != nil {
return err
}

marshalledC := &ChatMessages{}

if err := yaml.Unmarshal(content, marshalledC); err != nil {
return err
}

c.Messages = marshalledC.Messages

c.RecountTokens()
c.Id = marshalledC.Id
c.Description = marshalledC.Description

return nil
}

func (c *ChatMessages) FindById(id int64) *ChatMessage {
_, index, ok := lo.FindIndexOf[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Id == snowflake.ParseInt64(int64(id))
})
if !ok {
return nil
}

return &c.Messages[index]
}

var ErrNotFound = errors.New("not found")

func (c *ChatMessages) FindMessageByContent(content string) (*ChatMessage, error) {
exists, ok := lo.Find[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Content == content
})

if !ok {
return nil, ErrNotFound
}

return &exists, nil
}

var ErrAlreadyExist = errors.New("already exists")

func (c *ChatMessages) AddMessage(content string, role ROLES) (*ChatMessage, error) {
if role == "" {
return nil, errors.New("role cannot be empty")
}

tokenCount, err := CountTokens(content)
if err != nil {
return nil, err
}

if exists, ok := lo.Find[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Content == content && item.Role == role && item.Role != RoleUser
}); ok && content != "" {
return &exists, ErrAlreadyExist
}

msg := ChatMessage{
Id: c.node.Generate(),
Role: role,
Content: content,
Date: time.Now(),
Type: TypeUser,
AssociatedMessageId: -1,
Meta: Meta{
ApiType: viper.GetString("API_TYPE"),
Model: viper.GetString("model"),
},
}

msg.Tokens = tokenCount

c.Messages = append(c.Messages, msg)

c.TotalTokens += tokenCount

sort.SliceStable(c.Messages, func(i, j int) bool {
return c.Messages[i].Date.Before(c.Messages[j].Date)
})

return &msg, nil
}

func (c *ChatMessages) AddMessageFromFile(filename string) (*ChatMessage, error) {
content, err := os.ReadFile(filename)
if err != nil {
return nil, err
}

return c.AddMessage(string(content), RoleUser)
}

func (c *ChatMessages) SetAssociatedId(idUser, idAssistant int64) error {
msgUser := c.FindById(idUser)
if msgUser == nil {
return errors.New("user message not found")
}

msgAssistant := c.FindById(idAssistant)
if msgAssistant == nil {
return errors.New("assistant message not found")
}

msgUser.AssociatedMessageId = idAssistant
msgAssistant.AssociatedMessageId = idUser

return nil
}

func (c *ChatMessages) UpdateMessage(m ChatMessage) error {
if m.Content == "" {
return errors.New("content cannot be empty")
}

if m.Role == "" {
return errors.New("role cannot be empty")
}

tokenCount, err := CountTokens(m.Content)
if err != nil {
return err
}

msg := c.FindById(m.Id.Int64())
if msg == nil {
c.AddMessage(m.Content, m.Role)
return nil
}

m.Tokens = tokenCount

copier.Copy(msg, m)

c.RecountTokens()

return nil
}

func (c *ChatMessages) DeleteMessage(id int64) error {
message, ok := lo.Find[ChatMessage](c.Messages, func(item ChatMessage) bool {
return item.Id == snowflake.ParseInt64(id)
})
if !ok {
return errors.New("message not found")
}

c.TotalTokens -= message.Tokens

c.Messages = lo.Filter[ChatMessage](c.Messages, func(item ChatMessage, _ int) bool {
return item.Id != snowflake.ParseInt64(id)
})

return nil
}

func (c *ChatMessages) ToLangchainMessage() []llms.MessageContent {
return lo.Map[ChatMessage, llms.MessageContent](c.FilterByOpenAIRoles(), func(item ChatMessage, index int) llms.MessageContent {
switch item.Role {
case RoleSystem:
return llms.TextParts(schema.ChatMessageTypeSystem, item.Content)
case RoleAssistant:
return llms.TextParts(schema.ChatMessageTypeAI, item.Content)
case RoleUser:
return llms.TextParts(schema.ChatMessageTypeGeneric, item.Content)
}
return llms.TextParts(schema.ChatMessageTypeGeneric, item.Content)
})
}

func (c *ChatMessages) ClearMessages() {
c.Messages = []ChatMessage{}
c.TotalTokens = 0
}

func (c *ChatMessages) LastMessage(role *ROLES) *ChatMessage {
messages := c.Messages
if role != nil {
messages = lo.Filter[ChatMessage](c.Messages, func(item ChatMessage, _ int) bool {
return item.Role == *role
})
}
if len(messages) == 0 {
return nil
}

return &messages[len(messages)-1]
}

func (c *ChatMessages) FilterMessages(role ROLES) (messages []ChatMessage, tokens int) {
messages = lo.Filter[ChatMessage](c.Messages, func(item ChatMessage, _ int) bool {
return item.Role == role
})

sort.Slice(messages, func(i, j int) bool {
return messages[i].Date.Before(messages[j].Date)
})

tokens = lo.Reduce[ChatMessage, int](messages, func(acc int, item ChatMessage, _ int) int {
tokenCount, _ := CountTokens(item.Content)
return acc + tokenCount
}, 0)

return
}

func (c *ChatMessages) FilterByOpenAIRoles() []ChatMessage {
return lo.Filter[ChatMessage](c.Messages, func(item ChatMessage, _ int) bool {
return lo.Contains[ROLES]([]ROLES{
openai.ChatMessageRoleUser,
openai.ChatMessageRoleAssistant,
openai.ChatMessageRoleSystem,
}, item.Role)
})
}

func (c *ChatMessages) RecountTokens() *ChatMessages {
c.TotalTokens = 0
for _, msg := range c.Messages {
msg.Tokens, _ = CountTokens(msg.Content)
c.TotalTokens += msg.Tokens
}
return c
}

func CountTokens(content string) (int, error) {
tkm, err := tiktoken.EncodingForModel(openai.GPT4)
if err != nil {
return 0, err
}

return len(tkm.Encode(content, nil, nil)), nil
}
Loading

0 comments on commit 3cca3ca

Please sign in to comment.