Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
415 changes: 415 additions & 0 deletions BATCH_JOB_STATUS.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions configs/config.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,9 @@ batch_processing:
# Async batch job poller configuration
poller_interval: 30s # Poll Gemini batch jobs every 30 seconds
max_concurrent_polls: 5 # Max concurrent batch job status checks
# Token counting configuration
token_counting:
enabled: true # Enable token counting
mode: "all" # Mode: "all", "sample", or "on_demand"
sample_percent: 10 # Percentage of chunks to sample (for "sample" mode)
max_tokens_per_chunk: 8192 # Maximum tokens per chunk (Gemini embedding model limit)
9 changes: 8 additions & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ batch_processing:
submission_max_backoff: 30m # Maximum backoff duration
max_submission_attempts: 10 # Max retry attempts per batch submission

# Token counting configuration
token_counting:
enabled: true # Enable token counting
mode: "all" # Mode: "all", "sample", or "on_demand"
sample_percent: 10 # Percentage of chunks to sample (for "sample" mode)
max_tokens_per_chunk: 8192 # Maximum tokens per chunk (Gemini embedding model limit)

log:
level: info
format: json
format: json
211 changes: 204 additions & 7 deletions internal/adapter/outbound/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,8 @@ func (c *Client) GenerateEmbedding(
startTime := time.Now()

// Input validation
if strings.TrimSpace(text) == "" {
return nil, &outbound.EmbeddingError{
Code: "empty_text",
Type: "validation",
Message: "text content cannot be empty",
Retryable: false,
}
if err := validateNonEmptyText(text); err != nil {
return nil, err
}

// Determine model to use
Expand Down Expand Up @@ -757,6 +752,20 @@ func max(a, b int) int {
return b
}

// validateNonEmptyText validates that text is not empty or whitespace-only.
// Returns an EmbeddingError if validation fails, nil otherwise.
func validateNonEmptyText(text string) error {
if strings.TrimSpace(text) == "" {
return &outbound.EmbeddingError{
Code: "invalid_input",
Type: "validation",
Message: "text cannot be empty",
Retryable: false,
}
}
return nil
}

// getGenaiClient returns the cached genai.Client instance for making API requests.
//
// This method uses a read lock (RLock) to allow concurrent access from multiple
Expand Down Expand Up @@ -839,6 +848,194 @@ func (c *Client) convertToFloat64Slice(values []float32) []float64 {
return result
}

// CountTokens counts the exact number of tokens in the given text using the Gemini API.
func (c *Client) CountTokens(ctx context.Context, text string, model string) (*outbound.TokenCountResult, error) {
startTime := time.Now()

// Input validation
if err := validateNonEmptyText(text); err != nil {
return nil, err
}

// Check context cancellation early
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

// Use default model if not specified
if model == "" {
model = c.config.Model
}

// Log request
slogger.Info(ctx, "Token counting request initiated", slogger.Fields{
"text_length": len(text),
"model": model,
"word_count": countWords(text),
})

// Get the cached GenAI client
genaiClient := c.getGenaiClient()

// Create content for token counting
content := genai.NewContentFromText(text, genai.RoleUser)

// Call the CountTokens API
result, err := genaiClient.Models.CountTokens(ctx, model, []*genai.Content{content}, nil)
if err != nil {
sdkErr := c.convertSDKError(err)
duration := time.Since(startTime)
slogger.Error(ctx, "Token counting request failed", slogger.Fields{
"error_type": sdkErr.Type,
"error_code": sdkErr.Code,
"message": sdkErr.Message,
"duration_ms": duration.Milliseconds(),
"model": model,
})
return nil, fmt.Errorf("failed to count tokens: %w", sdkErr)
}

// Build result
tokenResult := &outbound.TokenCountResult{
TotalTokens: int(result.TotalTokens),
Model: model,
CachedAt: nil,
}

// Log successful response
duration := time.Since(startTime)
slogger.Info(ctx, "Token counting request completed", slogger.Fields{
"token_count": tokenResult.TotalTokens,
"model": model,
"duration_ms": duration.Milliseconds(),
"text_length": len(text),
})

return tokenResult, nil
}

// CountTokensBatch counts tokens for multiple texts in a single batch request.
func (c *Client) CountTokensBatch(
ctx context.Context,
texts []string,
model string,
) ([]*outbound.TokenCountResult, error) {
startTime := time.Now()

// Input validation
if len(texts) == 0 {
return nil, &outbound.EmbeddingError{
Code: "invalid_input",
Type: "validation",
Message: "texts slice cannot be empty",
Retryable: false,
}
}

// Check context cancellation early
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

// Use default model if not specified
effectiveModel := model
if effectiveModel == "" {
effectiveModel = c.config.Model
}

// Log batch request
slogger.Info(ctx, "Batch token counting request initiated", slogger.Fields{
"batch_size": len(texts),
"model": effectiveModel,
})

// Process each text sequentially
results := make([]*outbound.TokenCountResult, 0, len(texts))
for i, text := range texts {
result, err := c.CountTokens(ctx, text, model)
if err != nil {
duration := time.Since(startTime)
slogger.Error(ctx, "Batch token counting failed", slogger.Fields{
"failed_at_index": i,
"batch_size": len(texts),
"duration_ms": duration.Milliseconds(),
"model": effectiveModel,
})
return nil, fmt.Errorf("failed to count tokens for text at index %d: %w", i, err)
}
results = append(results, result)
}

// Log successful batch completion
duration := time.Since(startTime)
totalTokens := 0
for _, result := range results {
totalTokens += result.TotalTokens
}
slogger.Info(ctx, "Batch token counting completed", slogger.Fields{
"batch_size": len(texts),
"total_tokens": totalTokens,
"duration_ms": duration.Milliseconds(),
"model": effectiveModel,
})

return results, nil
}

// CountTokensWithCallback counts tokens for each chunk and invokes the callback after each result.
// This enables progressive processing (e.g., saving chunks in batches) during token counting.
func (c *Client) CountTokensWithCallback(
ctx context.Context,
chunks []outbound.CodeChunk,
model string,
callback outbound.TokenCountCallback,
) error {
if len(chunks) == 0 {
return nil
}

// Check context cancellation early
select {
case <-ctx.Done():
return ctx.Err()
default:
}

effectiveModel := model
if effectiveModel == "" {
effectiveModel = c.config.Model
}

slogger.Info(ctx, "Progressive token counting started", slogger.Fields{
"chunk_count": len(chunks),
"model": effectiveModel,
})

for i := range chunks {
// Check for cancellation between chunks
select {
case <-ctx.Done():
return ctx.Err()
default:
}

result, err := c.CountTokens(ctx, chunks[i].Content, model)
if err != nil {
return fmt.Errorf("token counting failed at index %d: %w", i, err)
}

if err := callback(i, &chunks[i], result); err != nil {
return fmt.Errorf("callback failed at index %d: %w", i, err)
}
}

return nil
}

// Request/Response structures for JSON serialization/deserialization

// EmbeddingRequest represents the JSON structure for Gemini embedding requests.
Expand Down
Loading