Skip to content
9 changes: 9 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
Detail ImageURLDetail `json:"detail,omitempty"`
}

// ChatMessageFile is a placeholder for file parts in chat messages.
type ChatMessageFile struct {
FileID string `json:"file_id,omitempty"`
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
}

type ChatMessagePartType string

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
ChatMessagePartTypeFile ChatMessagePartType = "file"
)

type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
File *ChatMessageFile `json:"file,omitempty"`
Comment on lines +84 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming these all checkout to what OpenAI expects, LGTM. Alas, I'm failing to find where this API is documented in their chat completion API documentation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @matthewmorgan . I wasn't using their docs site correctly.

}

type ChatCompletionMessage struct {
Expand Down
114 changes: 112 additions & 2 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,14 @@ func TestMultipartChatCompletions(t *testing.T) {
Detail: openai.ImageURLDetailLow,
},
},
{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-123",
FileName: "test.txt",
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
},
},
},
},
},
Expand All @@ -807,7 +815,8 @@ func TestMultipartChatCompletions(t *testing.T) {
func TestMultipartChatMessageSerialization(t *testing.T) {
jsonText := `[{"role":"system","content":"system-message"},` +
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`

var msgs []openai.ChatCompletionMessage
err := json.Unmarshal([]byte(jsonText), &msgs)
Expand All @@ -820,7 +829,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
t.Errorf("invalid user message: %v", msgs[0])
}
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
t.Errorf("invalid user message")
}
parts := msgs[1].MultiContent
Expand All @@ -830,6 +839,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
t.Errorf("invalid image_url part")
}
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" ||
parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
t.Errorf("invalid file part: %v", parts[2])
}

s, err := json.Marshal(msgs)
if err != nil {
Expand Down Expand Up @@ -876,6 +889,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
}
}

func TestChatMessageFile(t *testing.T) {
// Test file part with FileID
filePart := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-abc123",
},
}

// Test serialization
data, err := json.Marshal(filePart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
result := strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization
var parsedPart openai.ChatMessagePart
err = json.Unmarshal(data, &parsedPart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedPart.Type != openai.ChatMessagePartTypeFile {
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
}
if parsedPart.File == nil {
t.Fatal("Expected File to be non-nil")
}
if parsedPart.File.FileID != "file-abc123" {
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
}

// Test file part with all fields
filePartComplete := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-xyz789",
FileName: "document.pdf",
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
},
}

data, err = json.Marshal(filePartComplete)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
result = strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization of complete file part
var parsedCompleteFile openai.ChatMessagePart
err = json.Unmarshal(data, &parsedCompleteFile)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedCompleteFile.File.FileID != "file-xyz789" {
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
}
if parsedCompleteFile.File.FileName != "document.pdf" {
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
}
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
}
}

func TestChatMessagePartTypeConstants(t *testing.T) {
// Test that the new file constant is properly defined
if openai.ChatMessagePartTypeFile != "file" {
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
}

// Test all part type constants
expectedTypes := map[openai.ChatMessagePartType]string{
openai.ChatMessagePartTypeText: "text",
openai.ChatMessagePartTypeImageURL: "image_url",
openai.ChatMessagePartTypeFile: "file",
}

for constant, expected := range expectedTypes {
if string(constant) != expected {
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
}
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down
7 changes: 6 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package openai
import (
"net/http"
"regexp"
"strings"
)

const (
Expand Down Expand Up @@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
APIType: APITypeAzure,
APIVersion: "2023-05-15",
AzureModelMapperFunc: func(model string) string {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
// only 3.5 models have the "." stripped in their names
if strings.Contains(model, "3.5") {
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
}
return strings.ReplaceAll(model, ":", "")
Comment on lines +74 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is the latest from upstream main. No worries from me about this.

},

HTTPClient: &http.Client{},
Expand Down
4 changes: 4 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
Model: "gpt-3.5-turbo-0301",
Expect: "gpt-35-turbo-0301",
},
{
Model: "gpt-4.1",
Expect: "gpt-4.1",
},
{
Model: "text-embedding-ada-002",
Expect: "text-embedding-ada-002",
Expand Down
Loading