Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions frontend/src/services/aiService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ export const aiService = {
provider: config.provider,
endpoint: config.endpoint,
model: config.model,
api_key: config.apiKey,
max_tokens: config.maxTokens,
timeout: formatTimeout(config.timeout)
}
Expand All @@ -67,7 +66,7 @@ export const aiService = {
message: string
provider: string
error?: string
}>('test_connection', payload)
}>('test_connection', payload, { apiKey: config.apiKey })

return {
success: toBoolean(result.success),
Expand Down Expand Up @@ -130,12 +129,11 @@ export const aiService = {
include_explanation: request.includeExplanation,
provider: request.provider,
endpoint: request.endpoint,
api_key: request.apiKey,
max_tokens: request.maxTokens,
timeout: formatTimeout(request.timeout),
database_type: request.databaseDialect
})
})
}, { apiKey: request.apiKey })

console.log('📥 [aiService] Received backend result', {
hasContent: !!result.content,
Expand Down Expand Up @@ -214,12 +212,11 @@ export const aiService = {
provider: config.provider,
endpoint: config.endpoint,
model: config.model,
api_key: config.apiKey,
max_tokens: config.maxTokens,
timeout: formatTimeout(config.timeout),
database_type: config.databaseDialect
}
})
}, { apiKey: config.apiKey })
}
}

Expand All @@ -239,7 +236,7 @@ function formatTimeout(timeout: number | undefined): string {
* is designed for database queries and transforms the request format.
* The AI plugin expects: {type: 'ai', key: 'operation', sql: 'params_json'}
*/
async function callAPI<T>(key: string, data: any): Promise<T> {
async function callAPI<T>(key: string, data: any, options: { apiKey?: string } = {}): Promise<T> {
const requestBody = {
type: 'ai',
key,
Expand All @@ -254,12 +251,18 @@ async function callAPI<T>(key: string, data: any): Promise<T> {
})

try {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
'X-Store-Name': API_STORE
}

if (options.apiKey) {
headers['X-Auth'] = `Bearer ${options.apiKey}`
}

const response = await fetch(API_BASE, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Store-Name': API_STORE
},
headers,
body: JSON.stringify(requestBody)
})

Expand Down
33 changes: 33 additions & 0 deletions frontend/tests/services/aiService.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ describe('aiService', () => {
})
const payload = JSON.parse(body.sql)
expect(payload.config).toContain('timeout')
expect(payload.config).not.toContain('api_key')

return createFetchResponse({
data: [
Expand Down Expand Up @@ -68,6 +69,38 @@ describe('aiService', () => {
expect(response.meta).toEqual({ confidence: 0.9, model: 'demo' })
})

it('sends API key through authorization header only', async () => {
const apiKey = 'sk-secure'
fetchMock.mockImplementationOnce(async (_url: FetchArgs[0], options: FetchArgs[1]) => {
const headers = options?.headers as Record<string, string>
expect(headers['X-Auth']).toBe(`Bearer ${apiKey}`)

const body = JSON.parse(String(options?.body))
const payload = JSON.parse(body.sql)
expect(payload.config).not.toContain('api_key')

return createFetchResponse({
data: [
{ key: 'success', value: true },
{ key: 'content', value: 'sql:SELECT 1;' },
{ key: 'meta', value: '{}' }
]
})
})

await aiService.generateSQL({
provider: 'openai',
endpoint: 'https://api.openai.com',
apiKey,
model: 'gpt-5',
prompt: 'SELECT 1;',
timeout: 30,
maxTokens: 256,
includeExplanation: false,
databaseDialect: 'postgresql'
})
})

it('parses health check response when backend returns boolean healthy flag', async () => {
fetchMock.mockResolvedValueOnce(
createFetchResponse({
Expand Down
28 changes: 24 additions & 4 deletions pkg/ai/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type GenerateSQLRequest struct {
NaturalLanguage string `json:"natural_language"`
DatabaseType string `json:"database_type"`
Context map[string]string `json:"context,omitempty"`
RuntimeAPIKey string `json:"-"`
}

// GenerateSQLResponse represents an AI SQL generation response
Expand Down Expand Up @@ -113,7 +114,11 @@ func NewEngine(cfg config.AIConfig) (Engine, error) {

engine, err := newEngineFromManager(manager, cfg)
if err != nil {
_ = manager.Close()
if closeErr := manager.Close(); closeErr != nil {
logging.Logger.Warn("Failed to close AI manager after initialization error",
"provider", cfg.DefaultService,
"error", closeErr)
}
return nil, err
}
return engine, nil
Expand Down Expand Up @@ -163,7 +168,11 @@ func NewEngineWithManager(manager *Manager, cfg config.AIConfig) (Engine, error)

engine, err := newEngineFromManager(manager, cfg)
if err != nil {
_ = manager.Close()
if closeErr := manager.Close(); closeErr != nil {
logging.Logger.Warn("Failed to close AI manager after initialization error",
"provider", cfg.DefaultService,
"error", closeErr)
}
return nil, err
}
return engine, nil
Expand Down Expand Up @@ -206,6 +215,10 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G
MaxTokens: defaultMaxTokens,
}

if req.RuntimeAPIKey != "" {
options.APIKey = req.RuntimeAPIKey
}

// Add context if provided and extract preferred_model and runtime config
var runtimeConfig map[string]interface{}
if len(req.Context) > 0 {
Expand Down Expand Up @@ -236,7 +249,11 @@ func (e *aiEngine) GenerateSQL(ctx context.Context, req *GenerateSQLRequest) (*G
options.Provider = provider
}
if apiKey, ok := runtimeConfig["api_key"].(string); ok && apiKey != "" {
options.APIKey = apiKey
if options.APIKey == "" {
options.APIKey = apiKey
} else if options.APIKey != apiKey {
logging.Logger.Warn("Runtime config API key differs from secured metadata; using secured value")
}
}
if endpoint, ok := runtimeConfig["endpoint"].(string); ok && endpoint != "" {
options.Endpoint = endpoint
Expand Down Expand Up @@ -313,6 +330,9 @@ func (e *aiEngine) Close() {
e.generator.Close()
}
if e.manager != nil {
_ = e.manager.Close()
if err := e.manager.Close(); err != nil {
logging.Logger.Warn("Failed to close AI manager during engine shutdown",
"error", err)
}
}
}
83 changes: 68 additions & 15 deletions pkg/ai/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package ai

import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
Expand All @@ -38,10 +39,15 @@ type SQLGenerator struct {
sqlDialects map[string]SQLDialect
config config.AIConfig
capabilities *SQLCapabilities
runtimeClients map[string]interfaces.AIClient
runtimeClients map[string]*runtimeClientEntry
runtimeMu sync.RWMutex
}

type runtimeClientEntry struct {
client interfaces.AIClient
apiKeyFingerprint []byte
}

// Table represents a database table structure
type Table struct {
Name string `json:"name"`
Expand Down Expand Up @@ -142,7 +148,7 @@ func NewSQLGenerator(aiClient interfaces.AIClient, config config.AIConfig) (*SQL
aiClient: aiClient,
config: config,
sqlDialects: make(map[string]SQLDialect),
runtimeClients: make(map[string]interfaces.AIClient),
runtimeClients: make(map[string]*runtimeClientEntry),
}

// Initialize SQL dialects
Expand Down Expand Up @@ -630,23 +636,38 @@ func runtimeClientKey(options *GenerateOptions) string {
hasher.Write([]byte(options.Endpoint))
hasher.Write([]byte("|"))
hasher.Write([]byte(options.Model))
hasher.Write([]byte("|"))
hasher.Write([]byte(options.APIKey))
return hex.EncodeToString(hasher.Sum(nil))
}

func runtimeAPIKeyFingerprint(apiKey string) []byte {
if apiKey == "" {
return nil
}
sum := sha256.Sum256([]byte(apiKey))
fingerprint := make([]byte, len(sum))
copy(fingerprint, sum[:])
return fingerprint
}

func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (interfaces.AIClient, bool, error) {
key := runtimeClientKey(options)
fingerprint := runtimeAPIKeyFingerprint(options.APIKey)

g.runtimeMu.RLock()
if client, ok := g.runtimeClients[key]; ok {
g.runtimeMu.RUnlock()
return client, true, nil
if entry, ok := g.runtimeClients[key]; ok {
if bytes.Equal(entry.apiKeyFingerprint, fingerprint) {
client := entry.client
g.runtimeMu.RUnlock()
return client, true, nil
}
}
g.runtimeMu.RUnlock()

runtimeConfig := map[string]any{
"provider": options.Provider,
"api_key": options.APIKey,
}
if options.APIKey != "" {
runtimeConfig["api_key"] = options.APIKey
}
if options.Endpoint != "" {
runtimeConfig["base_url"] = options.Endpoint
Expand All @@ -664,23 +685,55 @@ func (g *SQLGenerator) getOrCreateRuntimeClient(options *GenerateOptions) (inter
}

g.runtimeMu.Lock()
if existing, ok := g.runtimeClients[key]; ok {
g.runtimeMu.Unlock()
_ = client.Close()
return existing, true, nil
var (
existingEntry *runtimeClientEntry
exists bool
)
if existingEntry, exists = g.runtimeClients[key]; exists {
if bytes.Equal(existingEntry.apiKeyFingerprint, fingerprint) {
g.runtimeMu.Unlock()
if err := client.Close(); err != nil {
logging.Logger.Warn("Failed to close redundant runtime client",
"provider", options.Provider,
"endpoint", options.Endpoint,
"error", err)
}
return existingEntry.client, true, nil
}
}

g.runtimeClients[key] = &runtimeClientEntry{
client: client,
apiKeyFingerprint: fingerprint,
}
g.runtimeClients[key] = client
g.runtimeMu.Unlock()

if exists && existingEntry != nil && existingEntry.client != nil {
if err := existingEntry.client.Close(); err != nil {
logging.Logger.Warn("Failed to close stale runtime client",
"provider", options.Provider,
"endpoint", options.Endpoint,
"error", err)
}
}

return client, false, nil
}

// Close releases all cached runtime clients held by the generator.
func (g *SQLGenerator) Close() {
g.runtimeMu.Lock()
defer g.runtimeMu.Unlock()
for key, client := range g.runtimeClients {
_ = client.Close()
for key, entry := range g.runtimeClients {
if entry == nil || entry.client == nil {
delete(g.runtimeClients, key)
continue
}
if err := entry.client.Close(); err != nil {
logging.Logger.Warn("Failed to close runtime client during generator shutdown",
"key", key,
"error", err)
}
delete(g.runtimeClients, key)
}
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package ai
import (
"testing"

"github.com/linuxsuren/atest-ext-ai/pkg/interfaces"
"github.com/stretchr/testify/require"
)

func TestRuntimeClientReuseAndClose(t *testing.T) {
generator := &SQLGenerator{
runtimeClients: make(map[string]interfaces.AIClient),
runtimeClients: make(map[string]*runtimeClientEntry),
}

options := &GenerateOptions{
Expand Down
Loading
Loading