-
Notifications
You must be signed in to change notification settings - Fork 0
/
session.go
459 lines (409 loc) · 15.9 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
// Package exporter provides tools for extracting and converting chat session data
// from JSON files into various formats, such as CSV and JSON datasets.
//
// This package facilitates tasks like data visualization, reporting, and machine learning data preparation.
//
// The exporter package defines types to represent chat sessions, messages, and associated metadata.
//
// It includes functions to:
//
// - Read chat session data from JSON files
// - Convert sessions to CSV with different formatting options
// - Create separate CSV files for sessions and messages
// - Extract sessions to a JSON format for Hugging Face datasets
//
// The package also handles fields in the source JSON that may be represented as either
// strings or integers by using the custom StringOrInt type.
//
// Additionally, it now supports context-aware operations, allowing for better control
// over long-running processes and the ability to cancel them if needed.
//
// Code:
//
// func (soi *StringOrInt) UnmarshalJSON(data []byte) error {
// // Try unmarshalling into a string
// var s string
// if err := json.Unmarshal(data, &s); err != nil {
// // If there is an error, try unmarshalling into an int
// var i int64
// if err := json.Unmarshal(data, &i); err != nil {
// return err // Return the error if it is not a string or int
// }
// // Convert int to string and assign it to the custom type
// *soi = StringOrInt(strconv.FormatInt(i, 10))
// return nil
// }
// // If no error, assign the string value to the custom type
// *soi = StringOrInt(s)
// return nil
// }
//
// Usage examples:
//
// To read chat sessions from a JSON file and convert them to a CSV format with context support:
//
// ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
// defer cancel()
//
// store, err := exporter.ReadJSONFromFile("path/to/chat-sessions.json")
// if err != nil {
// log.Fatal(err)
// }
// err = exporter.ConvertSessionsToCSV(ctx, store.ChatNextWebStore.Sessions, exporter.FormatOptionInline, "output.csv")
// if err != nil {
// log.Fatal(err)
// }
//
// To create separate CSV files for sessions and messages:
//
// err = exporter.CreateSeparateCSVFiles(store.ChatNextWebStore.Sessions, "sessions.csv", "messages.csv")
// if err != nil {
// log.Fatal(err)
// }
//
// To extract chat sessions to a JSON dataset:
//
// datasetJSON, err := exporter.ExtractToDataset(store.ChatNextWebStore.Sessions)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Println(datasetJSON)
package exporter
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
)
const (
// FormatOptionInline specifies the format where messages are displayed inline.
FormatOptionInline = iota + 1
// FormatOptionPerLine specifies the format where each message is on a separate line.
FormatOptionPerLine
// FormatOptionJSON specifies the format where messages are encoded as JSON.
FormatOptionJSON
// OutputFormatSeparateCSVFiles specifies the option to create separate CSV files for sessions and messages.
OutputFormatSeparateCSVFiles
)
// StringOrInt is a custom type to handle JSON values that can be either strings or integers (Magic Golang 🎩 🪄).
//
// It implements the Unmarshaler interface to handle this mixed type when unmarshaling JSON data.
type StringOrInt string
// UnmarshalJSON is a custom unmarshaler for StringOrInt that tries to unmarshal JSON data
// as a string, and if that fails, as an integer, which is then converted to a string.
func (soi *StringOrInt) UnmarshalJSON(data []byte) error {
// Try unmarshalling into a string
var s string
if err := json.Unmarshal(data, &s); err != nil {
// If there is an error, try unmarshalling into an int
var i int64
if err := json.Unmarshal(data, &i); err != nil {
return err // Return the error if it is not a string or int
}
// Convert int to string and assign it to the custom type
*soi = StringOrInt(strconv.FormatInt(i, 10))
return nil
}
// If no error, assign the string value to the custom type
*soi = StringOrInt(s)
return nil
}
// Message represents a single message within a chat session, including metadata
// like the ID, date, role of the sender, and the content of the message itself.
type Message struct {
ID string `json:"id"`
Date string `json:"date"`
Role string `json:"role"`
Content string `json:"content"`
}
// Stat represents statistics for a chat session, such as the count of tokens,
// words, and characters.
type Stat struct {
TokenCount int `json:"tokenCount"`
WordCount int `json:"wordCount"`
CharCount int `json:"charCount"`
}
// Mask represents an anonymization mask for a participant in a chat session,
// including the participant's ID, avatar link, name, language, and creation timestamp.
type Mask struct {
ID StringOrInt `json:"id"` // Use the custom type for ID
Avatar string `json:"avatar"`
Name string `json:"name"`
Lang string `json:"lang"`
CreatedAt int64 `json:"createdAt"` // Assuming it's a Unix timestamp
}
// Session represents a single chat session, including session metadata,
// statistics, messages, and the mask for the participant.
type Session struct {
ID string `json:"id"`
Topic string `json:"topic"`
MemoryPrompt string `json:"memoryPrompt"`
Stat Stat `json:"stat"`
LastUpdate int64 `json:"lastUpdate"` // Changed to int64 assuming it's a Unix timestamp
LastSummarizeIndex int `json:"lastSummarizeIndex"`
Mask Mask `json:"mask"`
Messages []Message `json:"messages"`
}
// Store encapsulates a collection of chat sessions.
type Store struct {
Sessions []Session `json:"sessions"`
}
// ChatNextWebStore is a wrapper for Store that aligns with the expected JSON structure
// for a chat-next-web-store object.
type ChatNextWebStore struct {
ChatNextWebStore Store `json:"chat-next-web-store"`
}
// ReadJSONFromFile reads a JSON file from the given file path and unmarshals it into a ChatNextWebStore struct.
//
// It returns an error if the file cannot be opened, the JSON
// is invalid, or the JSON format does not match the expected ChatNextWebStore format.
func ReadJSONFromFile(filePath string) (ChatNextWebStore, error) {
// Variable `store` is of type ChatNextWebStore. It is used to store the unmarshaled JSON data.
var store ChatNextWebStore
// Variable `file` is of type *os.File. It holds the pointer to the opened JSON file.
// Variable `err` is of type error. It is used to capture any errors that occur during the file opening and JSON decoding process.
file, err := os.Open(filePath)
if err != nil {
// If an error occurs while opening the file, the function returns the empty `store` and the error.
return store, err
}
// Defer the closing of the file until the function exits.
// This ensures that the file is closed properly to free resources and avoid leaks.
defer file.Close()
// Variable `decoder` is of type *json.Decoder. It is used to decode the JSON file into the `store` struct.
decoder := json.NewDecoder(file)
err = decoder.Decode(&store)
if err != nil {
// If an error occurs during decoding, the function returns the empty `store` and the error.
return store, err
}
// Check if the `Sessions` field in `store.ChatNextWebStore` is nil, which indicates the JSON was not in the expected format.
if store.ChatNextWebStore.Sessions == nil {
// If the JSON format is incorrect, the function returns the empty `store` and a format error.
return store, fmt.Errorf("JSON does not match the expected format chat-next-web-store")
}
// If no error occurs, the function returns the populated `store` and a nil error.
return store, nil
}
// ConvertSessionsToCSV writes a slice of Session objects into a CSV file with support for context cancellation.
//
// It delegates the writing of sessions to format-specific functions based on the formatOption provided.
//
// The outputFilePath parameter specifies the path to the output CSV file.
//
// It returns an error if the context is cancelled, the format option is invalid, or writing to the CSV fails.
func ConvertSessionsToCSV(ctx context.Context, sessions []Session, formatOption int, outputFilePath string) error {
outputFile, err := os.Create(outputFilePath)
if err != nil {
return fmt.Errorf("failed to create output CSV file: %w", err)
}
defer outputFile.Close()
csvWriter := csv.NewWriter(outputFile)
defer csvWriter.Flush()
headers, err := getCSVHeaders(formatOption)
if err != nil {
return err
}
if err := WriteHeaders(csvWriter, headers); err != nil {
return err
}
writeFunc, err := getWriteFunction(formatOption)
if err != nil {
return err
}
for _, session := range sessions {
if err := checkContextCancellation(ctx); err != nil {
return err
}
if err := writeFunc(csvWriter, session); err != nil {
return err
}
}
return nil
}
// getCSVHeaders returns the headers for the CSV file based on the formatOption.
// It returns an error if the formatOption is not recognized.
func getCSVHeaders(formatOption int) ([]string, error) {
switch formatOption {
case FormatOptionInline:
return []string{"id", "topic", "memoryPrompt", "messages"}, nil
case FormatOptionPerLine:
return []string{"session_id", "message_id", "date", "role", "content", "memoryPrompt"}, nil
case FormatOptionJSON:
return []string{"id", "topic", "memoryPrompt", "messages"}, nil
default:
return nil, fmt.Errorf("invalid format option")
}
}
// getWriteFunction returns a function that corresponds to the CSV writing strategy for the given formatOption.
// The returned function takes a csv.Writer and a Session object to write the session data according to the format.
// It returns an error if the formatOption is not recognized.
func getWriteFunction(formatOption int) (func(*csv.Writer, Session) error, error) {
switch formatOption {
case FormatOptionInline:
return writeInlineFormat, nil
case FormatOptionPerLine:
return writePerLineFormat, nil
case FormatOptionJSON:
return writeJSONFormat, nil
default:
return nil, fmt.Errorf("invalid format option")
}
}
// writeInlineFormat writes session data in an inline format to the provided csv.Writer.
// Messages are concatenated into a single string with a delimiter.
// It returns an error if writing to the CSV fails.
func writeInlineFormat(csvWriter *csv.Writer, session Session) error {
var messageContents []string
for _, message := range session.Messages {
messageContents = append(messageContents, fmt.Sprintf("[%s, %s] \"%s\"", message.Role, message.Date, message.Content))
}
sessionData := []string{session.ID, session.Topic, session.MemoryPrompt, strings.Join(messageContents, "; ")}
return csvWriter.Write(sessionData)
}
// writePerLineFormat writes each message of a session on a new line in the provided csv.Writer.
// It returns an error if writing to the CSV fails.
func writePerLineFormat(csvWriter *csv.Writer, session Session) error {
for _, message := range session.Messages {
sessionData := []string{session.ID, message.ID, message.Date, message.Role, message.Content, session.MemoryPrompt}
if err := csvWriter.Write(sessionData); err != nil {
return err
}
}
return nil
}
// writeJSONFormat writes session data with messages as a JSON string to the provided csv.Writer.
// It returns an error if marshaling messages to JSON or writing to the CSV fails.
func writeJSONFormat(csvWriter *csv.Writer, session Session) error {
messagesJSON, err := json.Marshal(session.Messages)
if err != nil {
return err
}
sessionData := []string{session.ID, session.Topic, session.MemoryPrompt, string(messagesJSON)}
return csvWriter.Write(sessionData)
}
// checkContextCancellation checks if the context has been cancelled.
// It returns a non-nil error if the context is cancelled; otherwise, it returns nil.
func checkContextCancellation(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return nil
}
}
// WriteHeaders writes the provided headers to the csv.Writer.
func WriteHeaders(csvWriter *csv.Writer, headers []string) error {
if err := csvWriter.Write(headers); err != nil {
return fmt.Errorf("failed to write headers: %w", err)
}
return nil
}
// WriteSessionData writes session data to the provided csv.Writer.
func WriteSessionData(csvWriter *csv.Writer, sessions []Session) error {
for _, session := range sessions {
sessionData := []string{
session.ID, session.Topic, session.MemoryPrompt,
}
if err := csvWriter.Write(sessionData); err != nil {
return fmt.Errorf("failed to write session data: %w", err)
}
}
return nil
}
// WriteMessageData writes message data to the provided csv.Writer.
func WriteMessageData(csvWriter *csv.Writer, sessions []Session) error {
for _, session := range sessions {
for _, message := range session.Messages {
messageData := []string{
session.ID, message.ID, message.Date, message.Role, message.Content, session.MemoryPrompt,
}
if err := csvWriter.Write(messageData); err != nil {
return fmt.Errorf("failed to write message data: %w", err)
}
}
}
return nil
}
// initializeCSVFile creates and initializes a CSV file with the given name and headers.
func initializeCSVFile(fileName string, headers []string) (*os.File, *csv.Writer, error) {
file, err := os.Create(fileName)
if err != nil {
return nil, nil, fmt.Errorf("failed to create file %s: %w", fileName, err)
}
csvWriter := csv.NewWriter(file)
if err := WriteHeaders(csvWriter, headers); err != nil {
file.Close() // ignore error; we're already handling an error
return nil, nil, err
}
return file, csvWriter, nil
}
// closeCSVWriter closes the csv.Writer and the underlying file, and checks for errors.
func closeCSVWriter(csvWriter *csv.Writer, file *os.File) error {
csvWriter.Flush()
if err := csvWriter.Error(); err != nil {
file.Close() // ignore error; we're already handling an error
return fmt.Errorf("failed to flush data: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
return nil
}
// CreateSeparateCSVFiles creates two separate CSV files for sessions and messages from a slice of Session objects.
//
// It takes the file names as parameters and returns an error if the files cannot be created or if writing the data fails.
//
// Errors from closing files or flushing data to the CSV writers are captured and will be returned after all operations are attempted.
//
// Error messages are logged to the console.
func CreateSeparateCSVFiles(sessions []Session, sessionsFileName string, messagesFileName string) (err error) {
// Create and initialize the sessions CSV file.
var sessionsFile *os.File
var sessionsWriter *csv.Writer
sessionsFile, sessionsWriter, err = initializeCSVFile(sessionsFileName, []string{"id", "topic", "memoryPrompt"})
if err != nil {
return err
}
defer func() {
if cerr := closeCSVWriter(sessionsWriter, sessionsFile); cerr != nil {
err = cerr
}
}()
// Write session data.
if err = WriteSessionData(sessionsWriter, sessions); err != nil {
return err
}
// Create and initialize the messages CSV file.
var messagesFile *os.File
var messagesWriter *csv.Writer
messagesFile, messagesWriter, err = initializeCSVFile(messagesFileName, []string{"session_id", "message_id", "date", "role", "content", "memoryPrompt"})
if err != nil {
return err
}
defer func() {
if cerr := closeCSVWriter(messagesWriter, messagesFile); cerr != nil {
err = cerr
}
}()
// Write message data.
if err = WriteMessageData(messagesWriter, sessions); err != nil {
return err
}
return nil
}
// ExtractToDataset converts a slice of Session objects into a JSON formatted string suitable for use as a dataset in machine learning applications.
//
// It returns an error if marshaling the sessions into JSON format fails.
func ExtractToDataset(sessions []Session) (string, error) {
dataset := make(map[string][]Session)
dataset["dataset"] = sessions
jsonData, err := json.MarshalIndent(dataset, "", " ")
if err != nil {
return "", err
}
return string(jsonData), nil
}