This repository has been archived by the owner on Mar 16, 2024. It is now read-only.
/
message.go
96 lines (83 loc) · 2.38 KB
/
message.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package message
import (
"context"
v1 "github.com/acorn-io/assistant-runtime/pkg/apis/assistant.acorn.io/v1"
"github.com/acorn-io/baaah/pkg/conditions"
"github.com/acorn-io/baaah/pkg/router"
apierror "k8s.io/apimachinery/pkg/api/errors"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)
func setThreadName(ctx context.Context, c kclient.Client, msg *v1.Message) error {
if msg.Spec.ParentMessageName == "" {
msg.Status.ThreadName = ""
var threads v1.ThreadList
err := c.List(ctx, &threads, &kclient.ListOptions{
Namespace: msg.Namespace,
})
if err != nil {
return err
}
for _, thread := range threads.Items {
if thread.Spec.StartMessageName == msg.Name {
msg.Status.ThreadName = thread.Name
break
}
}
return nil
}
var parent v1.Message
if err := c.Get(ctx, router.Key(msg.Namespace, msg.Spec.ParentMessageName), &parent); err != nil {
return err
}
if parent.Status.NextMessageName != msg.Name {
if parent.Status.NextMessageName != "" {
var siblingMsg v1.Message
if err := c.Get(ctx, router.Key(msg.Namespace, parent.Status.NextMessageName), &siblingMsg); apierror.IsNotFound(err) {
} else if err != nil {
return err
} else {
if err := c.Delete(ctx, &siblingMsg); err != nil {
return err
}
}
}
parent.Status.NextMessageName = msg.Name
if err := c.Status().Update(ctx, &parent); err != nil {
return err
}
}
msg.Status.ThreadName = parent.Status.ThreadName
return nil
}
func Initialize(req router.Request, resp router.Response) error {
msg := req.Object.(*v1.Message)
if msg.Status.ThreadName != "" {
var t v1.Thread
if err := req.Get(&t, msg.Namespace, msg.Status.ThreadName); apierror.IsNotFound(err) {
return req.Client.Delete(req.Ctx, msg)
} else if err != nil {
return err
}
}
if err := setThreadName(req.Ctx, req.Client, msg); err != nil {
return err
}
if err := msg.Spec.Input.Valid(); err != nil {
return conditions.NewErrTerminal(err)
}
if msg.Spec.Input.InProgress {
msg.Status.InProgress = msg.Spec.Input.InProgress
}
if msg.Spec.Input.Completion {
msg.Status.Message.Role = v1.RoleTypeAssistant
return nil
}
msg.Status.Message.Content = msg.Spec.Input.Content
msg.Status.Message.ToolCall = msg.Spec.Input.ToolCall
if msg.Status.Message.ToolCall == nil {
msg.Status.Message.Role = v1.RoleTypeUser
} else {
msg.Status.Message.Role = v1.RoleTypeTool
}
return nil
}