From d3b692b780cd3e991f28cc69feeef45d33a13f85 Mon Sep 17 00:00:00 2001 From: Anthony Bible Date: Thu, 27 Nov 2025 14:14:25 -0700 Subject: [PATCH 1/5] feat: add token counting infrastructure with Gemini API integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add token_count column to code_chunks table - Implement CountTokens and CountTokensBatch in Gemini client - Add TokenCache for efficient token count caching - Define port layer interfaces for token counting 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- BATCH_JOB_STATUS.md | 415 ++++++++++++ internal/adapter/outbound/gemini/client.go | 161 ++++- .../adapter/outbound/gemini/token_cache.go | 153 +++++ .../outbound/gemini/token_count_test.go | 619 +++++++++++++++++ internal/port/outbound/embedding_service.go | 15 + internal/port/outbound/token_count_test.go | 629 ++++++++++++++++++ .../000010_add_token_count_to_chunks.down.sql | 10 + .../000010_add_token_count_to_chunks.up.sql | 21 + 8 files changed, 2016 insertions(+), 7 deletions(-) create mode 100644 BATCH_JOB_STATUS.md create mode 100644 internal/adapter/outbound/gemini/token_cache.go create mode 100644 internal/adapter/outbound/gemini/token_count_test.go create mode 100644 internal/port/outbound/token_count_test.go create mode 100644 migrations/000010_add_token_count_to_chunks.down.sql create mode 100644 migrations/000010_add_token_count_to_chunks.up.sql diff --git a/BATCH_JOB_STATUS.md b/BATCH_JOB_STATUS.md new file mode 100644 index 0000000..11bcc19 --- /dev/null +++ b/BATCH_JOB_STATUS.md @@ -0,0 +1,415 @@ +# Batch Job Status Reference + +This document explains how to check the status of both **repository indexing jobs** (database) and **batch embedding jobs** (Google Gemini API). + +## Quick Start + +```bash +# Check database indexing jobs (repository processing) +./scripts/check_jobs.sh db + +# List all batch embedding jobs (requires API key) +./scripts/check_jobs.sh batch-list + +# Check specific batch job +./scripts/check_jobs.sh batch +``` + +## Understanding Job Types + +### 1. Repository Indexing Jobs (Database) + +**What they are**: Jobs that process repositories (git clone, parse, chunk files) +**Storage**: PostgreSQL `codechunking.indexing_jobs` table +**Purpose**: Track the overall repository indexing workflow + +**Status values**: +- `pending` - Job queued +- `running` - Currently processing +- `completed` - Successfully finished +- `failed` - Encountered an error +- `cancelled` - Manually cancelled + +**Fields tracked**: +- Repository information +- Files processed count +- Chunks created count +- Start/completion timestamps +- Error messages (if failed) + +### 2. Batch Embedding Jobs (External API) + +**What they are**: Asynchronous embedding generation jobs via Google Gemini Batches API +**Storage**: Google Cloud (accessed via API) +**Purpose**: Generate embeddings for large numbers of code chunks efficiently + +**Status values**: +- `PENDING` - Job queued in Google's system +- `PROCESSING` - Google is generating embeddings +- `COMPLETED` - All embeddings generated +- `FAILED` - Job failed +- `CANCELLED` - Job was cancelled + +**Fields tracked**: +- Total count of items +- Processed/success/error counts +- Progress percentage +- Processing rate +- Output file URIs + +## Configuration Requirements for Production Batch Processing + +### Prerequisites for Using Google Gemini Batches API + +To use production batch processing (not test mode), you must configure: + +1. **API Key**: Set the `CODECHUNK_GEMINI_API_KEY` environment variable +2. **Batch Directories**: Configure input/output directories in your config file +3. **Disable Test Mode**: Set `use_test_embeddings: false` in batch processing config + +**Configuration Example** (`configs/config.dev.yaml`): +```yaml +gemini: + api_key: ${CODECHUNK_GEMINI_API_KEY} # Set via environment + batch: + enabled: true + input_dir: /tmp/batch_embeddings/input # Required for production + output_dir: /tmp/batch_embeddings/output # Required for production + poll_interval: 5s + max_wait_time: 30m + +batch_processing: + enabled: true + threshold_chunks: 10 # Repositories with >10 chunks use batch processing + use_test_embeddings: false # IMPORTANT: Set to false for production batch API + fallback_to_sequential: true +``` + +### Verifying Batch Processing is Active + +When batch processing is correctly configured, worker logs will show: + +``` +"Processing batch embedding results (PRODUCTION MODE)" +"chunk_count": 8269 +"Using batch embeddings API for production" +``` + +If you see this message instead, batch processing is **NOT** active: +``` +"Production batch processing not implemented - falling back to sequential" +``` + +This fallback occurs when: +- `use_test_embeddings: true` (test mode enabled) +- Missing `input_dir` or `output_dir` configuration +- Batch processing disabled (`enabled: false`) + +## Detailed Usage + +### Check Database Jobs + +```bash +# Using the wrapper script (recommended) +./scripts/check_jobs.sh db + +# Or use SQL directly +psql -U dev -d codechunking -f check_indexing_jobs.sql + +# Or via Docker +docker exec codechunking-postgres psql -U dev -d codechunking -c \ + "SELECT * FROM codechunking.indexing_jobs WHERE deleted_at IS NULL ORDER BY created_at DESC LIMIT 10;" +``` + +**Example output**: +``` + job_id | status | repository_name | files_processed | chunks_created +--------------------------------------+-----------+-----------------+-----------------+--------------- + a1b2c3d4-e5f6-7890-abcd-ef1234567890 | completed | my-repo | 150 | 3421 + b2c3d4e5-f6g7-8901-bcde-fg2345678901 | running | another-repo | 45 | 987 +``` + +### Check Batch Embedding Jobs + +**Prerequisites**: +```bash +export CODECHUNK_GEMINI_API_KEY=your-api-key-here +``` + +**List all batch jobs**: +```bash +./scripts/check_jobs.sh batch-list + +# Or filter by state +go run scripts/check_batch_jobs.go -list -state COMPLETED +``` + +**Check specific job**: +```bash +./scripts/check_jobs.sh batch projects/PROJECT_ID/locations/us-central1/batchJobs/12345 + +# Or using Go directly +go run scripts/check_batch_jobs.go -job-id "projects/PROJECT_ID/locations/us-central1/batchJobs/12345" +``` + +**Example output**: +``` +Job Details: +──────────────────────────────────────── +Job ID: projects/.../batchJobs/12345 +State: COMPLETED +Model: gemini-embedding-001 +Total Count: 5000 +Processed: 5000 +Success: 4998 +Errors: 2 +Created At: 2025-11-22T10:30:00Z +Updated At: 2025-11-22T10:45:23Z +Completed At: 2025-11-22T10:45:23Z +Duration: 15m23s + +Progress: + Percent: 100.00% + Remaining: 0 items + Rate: 5.45 items/sec + +Output File: /tmp/batch_embeddings/output/batch_output_20251122_104523.jsonl + +✓ Job completed successfully! +``` + +## SQL Queries + +### Common Database Queries + +**All active jobs**: +```sql +SELECT ij.id, ij.status, r.name, ij.files_processed, ij.chunks_created +FROM codechunking.indexing_jobs ij +LEFT JOIN codechunking.repositories r ON r.id = ij.repository_id +WHERE ij.deleted_at IS NULL +ORDER BY ij.created_at DESC; +``` + +**Failed jobs with errors**: +```sql +SELECT ij.id, r.name, ij.error_message, ij.started_at, ij.completed_at +FROM codechunking.indexing_jobs ij +JOIN codechunking.repositories r ON r.id = ij.repository_id +WHERE ij.status = 'failed' AND ij.deleted_at IS NULL +ORDER BY ij.created_at DESC; +``` + +**Jobs by status summary**: +```sql +SELECT status, COUNT(*) as count, + SUM(files_processed) as total_files, + SUM(chunks_created) as total_chunks +FROM codechunking.indexing_jobs +WHERE deleted_at IS NULL +GROUP BY status; +``` + +## Important Notes + +### Batch Embedding Jobs + +1. **Not stored in database**: Batch jobs are managed entirely by Google's API +2. **Require API key**: Must set `CODECHUNK_GEMINI_API_KEY` to check status +3. **Job ID format**: Full path like `projects/PROJECT_ID/locations/REGION/batchJobs/JOB_ID` +4. **Transient storage**: Input/output files are in `/tmp/batch_embeddings/` + +### Repository Indexing Jobs + +1. **Stored in PostgreSQL**: Persisted in the database +2. **No API key needed**: Direct database access +3. **Linked to repositories**: Foreign key to repositories table +4. **Soft deletions**: Check `deleted_at IS NULL` for active jobs + +## Troubleshooting + +### Database jobs not showing up + +```bash +# Ensure database is running +make dev + +# Check if migrations ran +make migrate-up + +# Verify table exists +docker exec codechunking-postgres psql -U dev -d codechunking -c "\dt codechunking.*" +``` + +### "Falling back to sequential" - Batch processing not working + +**Symptom**: Worker logs show "Production batch processing not implemented - falling back to sequential" + +**Causes and Solutions**: + +1. **Test mode is enabled** (most common): + ```yaml + # In configs/config.dev.yaml + batch_processing: + use_test_embeddings: true # ← Change this to false + ``` + **Fix**: Set `use_test_embeddings: false` for production batch processing + +2. **Missing batch directories**: + ```yaml + # In configs/config.dev.yaml or configs/config.yaml + gemini: + batch: + enabled: true + # Missing: input_dir and output_dir! + ``` + **Fix**: Add directory configuration: + ```yaml + gemini: + batch: + enabled: true + input_dir: /tmp/batch_embeddings/input + output_dir: /tmp/batch_embeddings/output + ``` + +3. **Batch processing disabled**: + ```yaml + gemini: + batch: + enabled: false # ← Should be true + ``` + **Fix**: Set `enabled: true` + +4. **Chunk count below threshold**: + - Default threshold is 10 chunks + - Repositories with ≤10 chunks use sequential processing automatically + - Check `batch_processing.threshold_chunks` in config + +**Verification**: After fixing, restart the worker and look for: +``` +"Processing batch embedding results (PRODUCTION MODE)" +``` + +### Batch jobs failing to list + +```bash +# Check API key is set +echo $CODECHUNK_GEMINI_API_KEY + +# Verify network connectivity +curl -H "Authorization: Bearer $CODECHUNK_GEMINI_API_KEY" \ + https://generativelanguage.googleapis.com/v1beta/models + +# Check batch processing is enabled in config +grep -A 5 "batch:" configs/config.dev.yaml +``` + +### Permission errors + +```bash +# Ensure scripts are executable +chmod +x scripts/check_jobs.sh + +# Check file permissions on batch directories +ls -la /tmp/batch_embeddings/ +``` + +### Google API File Download Errors (403 PERMISSION_DENIED) + +**Symptom**: Worker logs show: +``` +"Failed to get file metadata from Google Files API" +"Error 403, Message: You do not have permission to access the File..." +"error_detail": "This error typically occurs when Google's Batch API returns file IDs that exceed the Files API 40-character limit" +``` + +**Root Cause**: Google's Batch API returns file IDs that are 42 characters long (e.g., `batch-pwccfe96og36g8qngof6db1dsnywrs1hxhd3`), which exceeds the Files API's documented 40-character limit. This is an inconsistency in Google's API design. + +**How the System Handles This**: +1. The code attempts to download results using the full file ID from Google's Files API +2. If Files.Get() fails with a permission error (due to the ID length issue), the system logs a detailed error +3. The system then tries to fall back to local file paths if the file was already downloaded + +**Verification**: Check your worker logs for: +``` +"Attempting download with full file ID" +"file_id": "batch-pwccfe96og36g8qngof6db1dsnywrs1hxhd3" +"id_length": 42 +``` + +If you see `id_length > 40`, this is the Google API limitation issue. + +**Resolution**: +- **Automatic Fix Implemented**: The system now automatically tries multiple strategies: + 1. First attempt: Try full file ID with Files API + 2. Second attempt: If 40-char error detected, remove "batch-" prefix and retry + 3. Third attempt: Fall back to local file handling if available +- **Expected**: Google should fix this API inconsistency on their end +- **Manual Workaround** (if automatic fix fails): + 1. Ensure output directory exists: `/tmp/batch_embeddings/output/` + 2. Check if files are being created in the output directory manually by Google + 3. Verify your API key has proper permissions for both Batch API and Files API + +**What to Look For in Logs**: +When the fix is working, you'll see: +``` +"File ID exceeds 40-character limit, trying alternative strategies" +"Attempting download without 'batch-' prefix" +"Successfully retrieved File object with full ID" +``` + +If all strategies fail, you'll see: +``` +"All Files API strategies failed, attempting direct batch result download" +``` + +**Related Files**: +- Implementation: `internal/adapter/outbound/gemini/batch_embedding_client.go:671-760` +- Error handling includes multi-strategy retry logic with detailed logging at each step + +## Integration Points + +### Where Jobs Are Created + +**Repository Indexing Jobs**: +- Created when: Repository added via API (`POST /repositories`) +- Processed by: Worker service (`codechunking worker`) +- Queue: NATS JetStream (`codechunk.indexing.jobs`) + +**Batch Embedding Jobs**: +- Created when: Worker processes repositories with batch embeddings enabled +- Config: `configs/config.dev.yaml` → `gemini.batch.enabled: true` +- Triggered by: `BatchEmbeddingClient.CreateBatchEmbeddingJob()` + +### Configuration + +**Enable batch processing** (`configs/config.dev.yaml`): +```yaml +gemini: + batch: + enabled: true + poll_interval: 5s + max_wait_time: 30m +``` + +**Batch directories**: +```yaml +# Defaults if not specified: +input_dir: /tmp/batch_embeddings/input +output_dir: /tmp/batch_embeddings/output +``` + +## Files Reference + +| File | Purpose | +|------|---------| +| `scripts/check_jobs.sh` | Main wrapper script for checking jobs | +| `scripts/check_batch_jobs.go` | Go program to query Google Gemini Batches API | +| `check_indexing_jobs.sql` | SQL queries for database jobs | +| `internal/adapter/outbound/gemini/batch_embedding_client.go` | Batch job client implementation | + +## See Also + +- [Google Gemini Batches API Documentation](https://ai.google.dev/gemini-api/docs/batch) +- Project CLAUDE.md for development commands +- `make help` for available Makefile targets diff --git a/internal/adapter/outbound/gemini/client.go b/internal/adapter/outbound/gemini/client.go index 0613bcb..1d1de8c 100644 --- a/internal/adapter/outbound/gemini/client.go +++ b/internal/adapter/outbound/gemini/client.go @@ -306,13 +306,8 @@ func (c *Client) GenerateEmbedding( startTime := time.Now() // Input validation - if strings.TrimSpace(text) == "" { - return nil, &outbound.EmbeddingError{ - Code: "empty_text", - Type: "validation", - Message: "text content cannot be empty", - Retryable: false, - } + if err := validateNonEmptyText(text); err != nil { + return nil, err } // Determine model to use @@ -757,6 +752,20 @@ func max(a, b int) int { return b } +// validateNonEmptyText validates that text is not empty or whitespace-only. +// Returns an EmbeddingError if validation fails, nil otherwise. +func validateNonEmptyText(text string) error { + if strings.TrimSpace(text) == "" { + return &outbound.EmbeddingError{ + Code: "invalid_input", + Type: "validation", + Message: "text cannot be empty", + Retryable: false, + } + } + return nil +} + // getGenaiClient returns the cached genai.Client instance for making API requests. // // This method uses a read lock (RLock) to allow concurrent access from multiple @@ -839,6 +848,144 @@ func (c *Client) convertToFloat64Slice(values []float32) []float64 { return result } +// CountTokens counts the exact number of tokens in the given text using the Gemini API. +func (c *Client) CountTokens(ctx context.Context, text string, model string) (*outbound.TokenCountResult, error) { + startTime := time.Now() + + // Input validation + if err := validateNonEmptyText(text); err != nil { + return nil, err + } + + // Check context cancellation early + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Use default model if not specified + if model == "" { + model = c.config.Model + } + + // Log request + slogger.Info(ctx, "Token counting request initiated", slogger.Fields{ + "text_length": len(text), + "model": model, + "word_count": countWords(text), + }) + + // Get the cached GenAI client + genaiClient := c.getGenaiClient() + + // Create content for token counting + content := genai.NewContentFromText(text, genai.RoleUser) + + // Call the CountTokens API + result, err := genaiClient.Models.CountTokens(ctx, model, []*genai.Content{content}, nil) + if err != nil { + sdkErr := c.convertSDKError(err) + duration := time.Since(startTime) + slogger.Error(ctx, "Token counting request failed", slogger.Fields{ + "error_type": sdkErr.Type, + "error_code": sdkErr.Code, + "message": sdkErr.Message, + "duration_ms": duration.Milliseconds(), + "model": model, + }) + return nil, fmt.Errorf("failed to count tokens: %w", sdkErr) + } + + // Build result + tokenResult := &outbound.TokenCountResult{ + TotalTokens: int(result.TotalTokens), + Model: model, + CachedAt: nil, + } + + // Log successful response + duration := time.Since(startTime) + slogger.Info(ctx, "Token counting request completed", slogger.Fields{ + "token_count": tokenResult.TotalTokens, + "model": model, + "duration_ms": duration.Milliseconds(), + "text_length": len(text), + }) + + return tokenResult, nil +} + +// CountTokensBatch counts tokens for multiple texts in a single batch request. +func (c *Client) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + startTime := time.Now() + + // Input validation + if len(texts) == 0 { + return nil, &outbound.EmbeddingError{ + Code: "invalid_input", + Type: "validation", + Message: "texts slice cannot be empty", + Retryable: false, + } + } + + // Check context cancellation early + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Use default model if not specified + effectiveModel := model + if effectiveModel == "" { + effectiveModel = c.config.Model + } + + // Log batch request + slogger.Info(ctx, "Batch token counting request initiated", slogger.Fields{ + "batch_size": len(texts), + "model": effectiveModel, + }) + + // Process each text sequentially + results := make([]*outbound.TokenCountResult, 0, len(texts)) + for i, text := range texts { + result, err := c.CountTokens(ctx, text, model) + if err != nil { + duration := time.Since(startTime) + slogger.Error(ctx, "Batch token counting failed", slogger.Fields{ + "failed_at_index": i, + "batch_size": len(texts), + "duration_ms": duration.Milliseconds(), + "model": effectiveModel, + }) + return nil, fmt.Errorf("failed to count tokens for text at index %d: %w", i, err) + } + results = append(results, result) + } + + // Log successful batch completion + duration := time.Since(startTime) + totalTokens := 0 + for _, result := range results { + totalTokens += result.TotalTokens + } + slogger.Info(ctx, "Batch token counting completed", slogger.Fields{ + "batch_size": len(texts), + "total_tokens": totalTokens, + "duration_ms": duration.Milliseconds(), + "model": effectiveModel, + }) + + return results, nil +} + // Request/Response structures for JSON serialization/deserialization // EmbeddingRequest represents the JSON structure for Gemini embedding requests. diff --git a/internal/adapter/outbound/gemini/token_cache.go b/internal/adapter/outbound/gemini/token_cache.go new file mode 100644 index 0000000..ca0b31c --- /dev/null +++ b/internal/adapter/outbound/gemini/token_cache.go @@ -0,0 +1,153 @@ +package gemini + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" +) + +// TokenCacheEntry represents a cached token count with TTL. +type TokenCacheEntry struct { + TokenCount int + Model string + CachedAt time.Time +} + +// IsExpired checks if the cache entry has expired based on TTL. +func (e *TokenCacheEntry) IsExpired(ttl time.Duration) bool { + return time.Since(e.CachedAt) > ttl +} + +// TokenCache provides thread-safe LRU caching for token counts with TTL support. +type TokenCache struct { + mu sync.RWMutex + cache map[string]*TokenCacheEntry + maxSize int + ttl time.Duration + lruList []string // Simple LRU tracking using slice (newest at end) +} + +// TokenCacheConfig holds configuration for the token cache. +type TokenCacheConfig struct { + MaxSize int // Maximum number of entries (0 = unlimited) + TTL time.Duration // Time-to-live for cache entries (0 = no expiration) +} + +// NewTokenCache creates a new token cache with the given configuration. +func NewTokenCache(config TokenCacheConfig) *TokenCache { + maxSize := config.MaxSize + if maxSize <= 0 { + maxSize = 1000 // Default max size + } + + ttl := config.TTL + if ttl <= 0 { + ttl = 1 * time.Hour // Default TTL + } + + return &TokenCache{ + cache: make(map[string]*TokenCacheEntry), + maxSize: maxSize, + ttl: ttl, + lruList: make([]string, 0, maxSize), + } +} + +// Get retrieves a token count from the cache if it exists and hasn't expired. +// Returns the entry and true if found and valid, nil and false otherwise. +func (tc *TokenCache) Get(text, model string) (*TokenCacheEntry, bool) { + tc.mu.RLock() + defer tc.mu.RUnlock() + + key := tc.generateKey(text, model) + entry, exists := tc.cache[key] + if !exists { + return nil, false + } + + // Check if entry has expired + if entry.IsExpired(tc.ttl) { + return nil, false + } + + return entry, true +} + +// Set stores a token count in the cache with LRU eviction if necessary. +func (tc *TokenCache) Set(text, model string, tokenCount int) { + tc.mu.Lock() + defer tc.mu.Unlock() + + key := tc.generateKey(text, model) + + // Check if we need to evict the oldest entry + if len(tc.cache) >= tc.maxSize { + if _, exists := tc.cache[key]; !exists { + // Only evict if this is a new entry (not an update) + tc.evictOldest() + } + } + + // Add or update entry + entry := &TokenCacheEntry{ + TokenCount: tokenCount, + Model: model, + CachedAt: time.Now(), + } + tc.cache[key] = entry + + // Update LRU list + tc.updateLRU(key) +} + +// Clear removes all entries from the cache. +func (tc *TokenCache) Clear() { + tc.mu.Lock() + defer tc.mu.Unlock() + + tc.cache = make(map[string]*TokenCacheEntry) + tc.lruList = make([]string, 0, tc.maxSize) +} + +// Size returns the current number of entries in the cache. +func (tc *TokenCache) Size() int { + tc.mu.RLock() + defer tc.mu.RUnlock() + + return len(tc.cache) +} + +// generateKey creates a cache key from text and model using SHA-256 hash. +func (tc *TokenCache) generateKey(text, model string) string { + h := sha256.New() + h.Write([]byte(text)) + h.Write([]byte(model)) + return hex.EncodeToString(h.Sum(nil)) +} + +// evictOldest removes the oldest entry from the cache (LRU eviction). +func (tc *TokenCache) evictOldest() { + if len(tc.lruList) == 0 { + return + } + + // Remove oldest (first) entry + oldestKey := tc.lruList[0] + delete(tc.cache, oldestKey) + tc.lruList = tc.lruList[1:] +} + +// updateLRU moves the key to the end of the LRU list (most recently used). +func (tc *TokenCache) updateLRU(key string) { + // Remove key if it exists in the list + for i, k := range tc.lruList { + if k == key { + tc.lruList = append(tc.lruList[:i], tc.lruList[i+1:]...) + break + } + } + + // Add key to the end (most recently used) + tc.lruList = append(tc.lruList, key) +} diff --git a/internal/adapter/outbound/gemini/token_count_test.go b/internal/adapter/outbound/gemini/token_count_test.go new file mode 100644 index 0000000..29a2140 --- /dev/null +++ b/internal/adapter/outbound/gemini/token_count_test.go @@ -0,0 +1,619 @@ +package gemini_test + +import ( + "codechunking/internal/adapter/outbound/gemini" + "codechunking/internal/port/outbound" + "context" + "errors" + "strings" + "testing" + "time" +) + +// TestClient_CountTokens tests single text token counting functionality. +func TestClient_CountTokens(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupFunc func(t *testing.T) (*gemini.Client, context.Context) + text string + model string + wantErr bool + checkFunc func(t *testing.T, result *outbound.TokenCountResult, err error) + }{ + { + name: "valid_text_returns_token_count", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + text: "This is a test string for token counting", + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokens is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - result is not nil + // - result.TotalTokens > 0 + // - result.Model matches requested model + // - result.CachedAt is nil for fresh count + }, + }, + { + name: "empty_text_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + text: "", + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for empty text, got nil") + return + } + if result != nil { + t.Errorf("expected nil result for empty text, got %+v", result) + } + + // Should be a validation error + var embErr *outbound.EmbeddingError + if errors.As(err, &embErr) { + if !embErr.IsValidationError() { + t.Errorf("expected validation error, got type=%s", embErr.Type) + } + } + }, + }, + { + name: "whitespace_only_text_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + text: " \t\n ", + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for whitespace-only text, got nil") + return + } + if result != nil { + t.Errorf("expected nil result for whitespace-only text, got %+v", result) + } + }, + }, + { + name: "context_cancellation_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + return client, ctx + }, + text: "Test text", + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for cancelled context, got nil") + return + } + // When implemented, should return context.Canceled error + }, + }, + { + name: "empty_model_uses_default", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + text: "Test text", + model: "", // Empty model should use default + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokens is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - result.Model should be "gemini-embedding-001" (default) + }, + }, + { + name: "long_text_returns_higher_token_count", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + text: strings.Repeat("This is a longer text that should have many tokens. ", 100), + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokens is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - result.TotalTokens should be significantly higher (e.g., > 100) + }, + }, + { + name: "timeout_context_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + time.Sleep(10 * time.Millisecond) // Ensure timeout + return client, ctx + }, + text: "Test text", + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for timeout context, got nil") + return + } + // When implemented, should return context.DeadlineExceeded error + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, ctx := tt.setupFunc(t) + + result, err := client.CountTokens(ctx, tt.text, tt.model) + + if (err != nil) != tt.wantErr { + t.Errorf("CountTokens() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.checkFunc != nil { + tt.checkFunc(t, result, err) + } + }) + } +} + +// TestClient_CountTokensBatch tests batch token counting functionality. +func TestClient_CountTokensBatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setupFunc func(t *testing.T) (*gemini.Client, context.Context) + texts []string + model string + wantErr bool + checkFunc func(t *testing.T, results []*outbound.TokenCountResult, err error) + }{ + { + name: "valid_texts_returns_matching_results", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{ + "First text for batch counting", + "Second text for batch counting", + "Third text for batch counting", + }, + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokensBatch is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - results length matches input length (3) + // - each result has TotalTokens > 0 + // - each result has Model set + // - order is preserved + }, + }, + { + name: "empty_slice_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{}, + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for empty slice, got nil") + return + } + if results != nil { + t.Errorf("expected nil results for empty slice, got %+v", results) + } + + // Should be a validation error + var embErr *outbound.EmbeddingError + if errors.As(err, &embErr) { + if !embErr.IsValidationError() { + t.Errorf("expected validation error, got type=%s", embErr.Type) + } + } + }, + }, + { + name: "nil_slice_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: nil, + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for nil slice, got nil") + return + } + }, + }, + { + name: "single_text_in_batch", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{"Single text for batch processing"}, + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokensBatch is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - results length is 1 + // - result has valid token count + }, + }, + { + name: "mixed_empty_and_valid_texts", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{ + "Valid text", + "", + "Another valid text", + }, + model: "gemini-embedding-001", + wantErr: true, // Should handle gracefully or return error + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + // Implementation should decide how to handle mixed content: + // Option 1: Return error for entire batch + // Option 2: Return nil for invalid entries + // Option 3: Skip empty entries and return partial results + + // For now, just verify it handles the case + if err == nil && results != nil { + t.Log("Implementation handles mixed content - verify behavior") + } + }, + }, + { + name: "order_preserved_in_results", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{ + "A short text", + "B this is a much longer text with many more tokens to count properly", + "C medium length text", + }, + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokensBatch is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - results[0] corresponds to "A short text" + // - results[1] corresponds to "B this is..." (highest token count) + // - results[2] corresponds to "C medium..." (middle token count) + // - Order matches input order exactly + }, + }, + { + name: "context_cancellation_returns_error", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + return client, ctx + }, + texts: []string{ + "Text 1", + "Text 2", + }, + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + if err == nil { + t.Error("expected error for cancelled context, got nil") + return + } + // When implemented, should return context.Canceled error + }, + }, + { + name: "varying_text_lengths_produce_different_counts", + setupFunc: func(t *testing.T) (*gemini.Client, context.Context) { + client, err := gemini.NewClient(&gemini.ClientConfig{ + APIKey: "test-api-key-" + t.Name(), + Model: "gemini-embedding-001", + TaskType: "RETRIEVAL_DOCUMENT", + Timeout: 30 * time.Second, + Dimensions: 768, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + return client, context.Background() + }, + texts: []string{ + "Short", + "This is a much longer piece of text that should definitely have more tokens than the short one", + }, + model: "gemini-embedding-001", + wantErr: true, // RED PHASE: Not implemented yet + checkFunc: func(t *testing.T, results []*outbound.TokenCountResult, err error) { + // RED PHASE: Should fail because CountTokensBatch is not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + return + } + + // When implemented (GREEN PHASE), verify: + // - results[1].TotalTokens > results[0].TotalTokens + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, ctx := tt.setupFunc(t) + + results, err := client.CountTokensBatch(ctx, tt.texts, tt.model) + + if (err != nil) != tt.wantErr { + t.Errorf("CountTokensBatch() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.checkFunc != nil { + tt.checkFunc(t, results, err) + } + }) + } +} + +// TestTokenCountResult_FieldValidation validates TokenCountResult structure. +func TestTokenCountResult_FieldValidation(t *testing.T) { + t.Parallel() + + now := time.Now() + + tests := []struct { + name string + result *outbound.TokenCountResult + checkFunc func(t *testing.T, result *outbound.TokenCountResult) + }{ + { + name: "all_fields_set_correctly", + result: &outbound.TokenCountResult{ + TotalTokens: 42, + Model: "gemini-embedding-001", + CachedAt: &now, + }, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult) { + if result.TotalTokens != 42 { + t.Errorf("expected TotalTokens=42, got=%d", result.TotalTokens) + } + if result.Model != "gemini-embedding-001" { + t.Errorf("expected Model='gemini-embedding-001', got='%s'", result.Model) + } + if result.CachedAt == nil { + t.Error("expected CachedAt to be set, got nil") + } + if !result.CachedAt.Equal(now) { + t.Errorf("expected CachedAt=%v, got=%v", now, *result.CachedAt) + } + }, + }, + { + name: "cached_at_can_be_nil", + result: &outbound.TokenCountResult{ + TotalTokens: 100, + Model: "test-model", + CachedAt: nil, + }, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult) { + if result.CachedAt != nil { + t.Error("expected CachedAt to be nil, got value") + } + if result.TotalTokens != 100 { + t.Errorf("expected TotalTokens=100, got=%d", result.TotalTokens) + } + }, + }, + { + name: "zero_tokens_is_valid", + result: &outbound.TokenCountResult{ + TotalTokens: 0, + Model: "test-model", + CachedAt: nil, + }, + checkFunc: func(t *testing.T, result *outbound.TokenCountResult) { + if result.TotalTokens != 0 { + t.Errorf("expected TotalTokens=0, got=%d", result.TotalTokens) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.checkFunc != nil { + tt.checkFunc(t, tt.result) + } + }) + } +} diff --git a/internal/port/outbound/embedding_service.go b/internal/port/outbound/embedding_service.go index 5f87d53..62ce7aa 100644 --- a/internal/port/outbound/embedding_service.go +++ b/internal/port/outbound/embedding_service.go @@ -42,6 +42,14 @@ type EmbeddingService interface { // EstimateTokenCount estimates the number of tokens in the given text EstimateTokenCount(ctx context.Context, text string) (int, error) + + // CountTokens counts the exact number of tokens in the given text using the embedding model + // Returns a TokenCountResult containing the token count, model used, and optional cache timestamp + CountTokens(ctx context.Context, text string, model string) (*TokenCountResult, error) + + // CountTokensBatch counts tokens for multiple texts in a single batch request + // Returns a slice of TokenCountResult matching the input texts order + CountTokensBatch(ctx context.Context, texts []string, model string) ([]*TokenCountResult, error) } // BatchEmbeddingService defines the interface for file-based batch embedding operations. @@ -316,6 +324,13 @@ type BatchEmbeddingRequest struct { Metadata map[string]interface{} `json:"metadata,omitempty"` // Optional metadata } +// TokenCountResult represents the result of a token counting operation. +type TokenCountResult struct { + TotalTokens int `json:"total_tokens"` // Total number of tokens counted + Model string `json:"model"` // Model used for token counting + CachedAt *time.Time `json:"cached_at,omitempty"` // Optional timestamp if result was cached +} + // BatchEmbeddingResponse represents a single embedding response from a batch. type BatchEmbeddingResponse struct { RequestID string `json:"request_id"` // Request identifier matching the request diff --git a/internal/port/outbound/token_count_test.go b/internal/port/outbound/token_count_test.go new file mode 100644 index 0000000..4878c40 --- /dev/null +++ b/internal/port/outbound/token_count_test.go @@ -0,0 +1,629 @@ +package outbound + +import ( + "context" + "errors" + "testing" + "time" +) + +// testEmbeddingService provides a minimal stub implementation for testing CountTokens methods. +// This stub will initially return errors to ensure tests fail in the red phase. +type testEmbeddingService struct{} + +func newTestEmbeddingService() EmbeddingService { + return &testEmbeddingService{} +} + +// Stub implementations that return errors (for red phase) +func (s *testEmbeddingService) GenerateEmbedding( + _ context.Context, + _ string, + _ EmbeddingOptions, +) (*EmbeddingResult, error) { + return nil, errors.New("not implemented") +} + +func (s *testEmbeddingService) GenerateBatchEmbeddings( + _ context.Context, + _ []string, + _ EmbeddingOptions, +) ([]*EmbeddingResult, error) { + return nil, errors.New("not implemented") +} + +func (s *testEmbeddingService) GenerateCodeChunkEmbedding( + _ context.Context, + _ *CodeChunk, + _ EmbeddingOptions, +) (*CodeChunkEmbedding, error) { + return nil, errors.New("not implemented") +} + +func (s *testEmbeddingService) ValidateApiKey(_ context.Context) error { + return errors.New("not implemented") +} + +func (s *testEmbeddingService) GetModelInfo(_ context.Context) (*ModelInfo, error) { + return nil, errors.New("not implemented") +} + +func (s *testEmbeddingService) GetSupportedModels(_ context.Context) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (s *testEmbeddingService) EstimateTokenCount(_ context.Context, _ string) (int, error) { + return 0, errors.New("not implemented") +} + +// CountTokens stub - returns error to make tests fail +func (s *testEmbeddingService) CountTokens( + _ context.Context, + text string, + _ string, +) (*TokenCountResult, error) { + // Validate empty text + if text == "" { + return nil, &EmbeddingError{ + Code: "invalid_input", + Message: "text cannot be empty", + Type: "validation", + Retryable: false, + } + } + // Return error for red phase - implementation doesn't exist yet + return nil, errors.New("CountTokens not implemented") +} + +// CountTokensBatch stub - returns error to make tests fail +func (s *testEmbeddingService) CountTokensBatch( + _ context.Context, + texts []string, + _ string, +) ([]*TokenCountResult, error) { + // Validate empty slice + if len(texts) == 0 { + return nil, &EmbeddingError{ + Code: "invalid_input", + Message: "texts slice cannot be empty", + Type: "validation", + Retryable: false, + } + } + // Return error for red phase - implementation doesn't exist yet + return nil, errors.New("CountTokensBatch not implemented") +} + +// TestTokenCountResult_Structure validates the TokenCountResult type has correct fields. +func TestTokenCountResult_Structure(t *testing.T) { + t.Parallel() + + now := time.Now() + result := &TokenCountResult{ + TotalTokens: 100, + Model: "gemini-embedding-001", + CachedAt: &now, + } + + tests := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "has TotalTokens field", + testFunc: func(t *testing.T) { + if result.TotalTokens != 100 { + t.Errorf("expected TotalTokens=100, got=%d", result.TotalTokens) + } + }, + }, + { + name: "has Model field", + testFunc: func(t *testing.T) { + if result.Model != "gemini-embedding-001" { + t.Errorf("expected Model='gemini-embedding-001', got='%s'", result.Model) + } + }, + }, + { + name: "has optional CachedAt field", + testFunc: func(t *testing.T) { + if result.CachedAt == nil { + t.Error("expected CachedAt to be set, got nil") + } + if !result.CachedAt.Equal(now) { + t.Errorf("expected CachedAt=%v, got=%v", now, *result.CachedAt) + } + }, + }, + { + name: "CachedAt can be nil", + testFunc: func(t *testing.T) { + resultNilCache := &TokenCountResult{ + TotalTokens: 50, + Model: "test-model", + CachedAt: nil, + } + if resultNilCache.CachedAt != nil { + t.Error("expected CachedAt to be nil") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.testFunc(t) + }) + } +} + +// TestCountTokens_Interface validates the CountTokens method exists and has correct signature. +func TestCountTokens_Interface(t *testing.T) { + t.Parallel() + + ctx := context.Background() + svc := newTestEmbeddingService() + + tests := []struct { + name string + text string + model string + wantErr bool + checkFunc func(t *testing.T, result *TokenCountResult, err error) + }{ + { + name: "returns error for empty text", + text: "", + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + if err == nil { + t.Error("expected error for empty text, got nil") + } + if result != nil { + t.Errorf("expected nil result for empty text, got %+v", result) + } + // Should be a validation error + var embErr *EmbeddingError + if errors.As(err, &embErr) { + if !embErr.IsValidationError() { + t.Errorf("expected validation error, got type=%s", embErr.Type) + } + } + }, + }, + { + name: "accepts valid text", + text: "This is a test string for token counting", + model: "gemini-embedding-001", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + // In green phase, this should return a valid result + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + { + name: "accepts empty model (should use default)", + text: "Test text", + model: "", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + { + name: "accepts custom model name", + text: "Test text", + model: "custom-model-v1", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := svc.CountTokens(ctx, tt.text, tt.model) + + if (err != nil) != tt.wantErr { + t.Errorf("CountTokens() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.checkFunc != nil { + tt.checkFunc(t, result, err) + } + }) + } +} + +// TestCountTokens_ResultValidation validates expected TokenCountResult contents. +func TestCountTokens_ResultValidation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + svc := newTestEmbeddingService() + + tests := []struct { + name string + text string + model string + checkFunc func(t *testing.T, result *TokenCountResult, err error) + }{ + { + name: "result should have positive token count", + text: "Hello world", + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && result != nil { + if result.TotalTokens <= 0 { + t.Error("expected TotalTokens > 0") + } + } else { + // Expected in red phase + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "result should include model name", + text: "Test text", + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && result != nil { + if result.Model == "" { + t.Error("expected Model to be set") + } + if result.Model != "gemini-embedding-001" { + t.Errorf("expected Model='gemini-embedding-001', got='%s'", result.Model) + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "CachedAt should be nil for fresh count", + text: "Fresh count", + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && result != nil { + if result.CachedAt != nil { + t.Error("expected CachedAt to be nil for fresh count") + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "longer text should have more tokens", + text: "This is a much longer piece of text that should have significantly more tokens than a short string", + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, result *TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && result != nil { + // Rough heuristic: long text should have many tokens + if result.TotalTokens < 10 { + t.Errorf("expected longer text to have more tokens, got %d", result.TotalTokens) + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := svc.CountTokens(ctx, tt.text, tt.model) + if tt.checkFunc != nil { + tt.checkFunc(t, result, err) + } + }) + } +} + +// TestCountTokensBatch_Interface validates the CountTokensBatch method exists and has correct signature. +func TestCountTokensBatch_Interface(t *testing.T) { + t.Parallel() + + ctx := context.Background() + svc := newTestEmbeddingService() + + tests := []struct { + name string + texts []string + model string + wantErr bool + checkFunc func(t *testing.T, results []*TokenCountResult, err error) + }{ + { + name: "returns error for empty slice", + texts: []string{}, + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + if err == nil { + t.Error("expected error for empty slice, got nil") + } + if results != nil { + t.Errorf("expected nil results for empty slice, got %+v", results) + } + // Should be a validation error + var embErr *EmbeddingError + if errors.As(err, &embErr) { + if !embErr.IsValidationError() { + t.Errorf("expected validation error, got type=%s", embErr.Type) + } + } + }, + }, + { + name: "returns error for nil slice", + texts: nil, + model: "gemini-embedding-001", + wantErr: true, + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + if err == nil { + t.Error("expected error for nil slice, got nil") + } + }, + }, + { + name: "accepts single text", + texts: []string{ + "Single text for batch processing", + }, + model: "gemini-embedding-001", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + { + name: "accepts multiple texts", + texts: []string{ + "First text", + "Second text", + "Third text", + }, + model: "gemini-embedding-001", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + { + name: "handles mixed content", + texts: []string{ + "Short", + "This is a much longer text with many more tokens to count", + "Medium length text here", + }, + model: "gemini-embedding-001", + wantErr: true, // Red phase - not implemented yet + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // In red phase, we expect error because it's not implemented + if err == nil { + t.Error("RED PHASE: expected error (not implemented), got nil - implementation exists!") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results, err := svc.CountTokensBatch(ctx, tt.texts, tt.model) + + if (err != nil) != tt.wantErr { + t.Errorf("CountTokensBatch() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.checkFunc != nil { + tt.checkFunc(t, results, err) + } + }) + } +} + +// TestCountTokensBatch_ResultValidation validates expected batch results. +func TestCountTokensBatch_ResultValidation(t *testing.T) { + t.Parallel() + + ctx := context.Background() + svc := newTestEmbeddingService() + + tests := []struct { + name string + texts []string + model string + checkFunc func(t *testing.T, results []*TokenCountResult, err error) + }{ + { + name: "results length should match input length", + texts: []string{ + "First", + "Second", + "Third", + }, + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && results != nil { + if len(results) != 3 { + t.Errorf("expected 3 results, got %d", len(results)) + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "each result should have valid token count", + texts: []string{ + "Hello world", + "Testing token counts", + }, + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && results != nil { + for i, result := range results { + if result == nil { + t.Errorf("result[%d] is nil", i) + continue + } + if result.TotalTokens <= 0 { + t.Errorf("result[%d].TotalTokens = %d, want > 0", i, result.TotalTokens) + } + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "each result should have model set", + texts: []string{ + "Text one", + "Text two", + }, + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && results != nil { + for i, result := range results { + if result == nil { + t.Errorf("result[%d] is nil", i) + continue + } + if result.Model == "" { + t.Errorf("result[%d].Model is empty", i) + } + if result.Model != "gemini-embedding-001" { + t.Errorf("result[%d].Model = %s, want 'gemini-embedding-001'", i, result.Model) + } + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "results should maintain input order", + texts: []string{ + "A", + "B", + "C", + }, + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && results != nil { + // We can't verify content mapping without implementation + // but we can verify structure + if len(results) != 3 { + t.Errorf("expected 3 results in order, got %d", len(results)) + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + { + name: "varying length texts should have different token counts", + texts: []string{ + "Short", + "This is a much longer piece of text that should definitely have more tokens", + }, + model: "gemini-embedding-001", + checkFunc: func(t *testing.T, results []*TokenCountResult, err error) { + // Red phase: we expect this to fail + if err == nil && results != nil && len(results) == 2 { + if results[0].TotalTokens >= results[1].TotalTokens { + t.Errorf("expected longer text to have more tokens: short=%d, long=%d", + results[0].TotalTokens, results[1].TotalTokens) + } + } else { + t.Log("RED PHASE: Implementation not ready, test will pass when implemented") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results, err := svc.CountTokensBatch(ctx, tt.texts, tt.model) + if tt.checkFunc != nil { + tt.checkFunc(t, results, err) + } + }) + } +} + +// TestCountTokens_ContextCancellation ensures context cancellation is respected. +func TestCountTokens_ContextCancellation(t *testing.T) { + t.Parallel() + + svc := newTestEmbeddingService() + + // Create cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result, err := svc.CountTokens(ctx, "test text", "gemini-embedding-001") + + // In a real implementation, should return context.Canceled error + // In red phase, we just verify the method handles context + if err == nil && result != nil { + // If implementation exists, it should respect context cancellation + t.Log("Implementation exists - should verify context cancellation handling") + } else { + t.Log("RED PHASE: Implementation not ready") + } +} + +// TestCountTokensBatch_ContextCancellation ensures batch method respects context cancellation. +func TestCountTokensBatch_ContextCancellation(t *testing.T) { + t.Parallel() + + svc := newTestEmbeddingService() + + // Create cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + texts := []string{"text1", "text2", "text3"} + results, err := svc.CountTokensBatch(ctx, texts, "gemini-embedding-001") + + // In a real implementation, should return context.Canceled error + // In red phase, we just verify the method handles context + if err == nil && results != nil { + // If implementation exists, it should respect context cancellation + t.Log("Implementation exists - should verify context cancellation handling") + } else { + t.Log("RED PHASE: Implementation not ready") + } +} diff --git a/migrations/000010_add_token_count_to_chunks.down.sql b/migrations/000010_add_token_count_to_chunks.down.sql new file mode 100644 index 0000000..f636315 --- /dev/null +++ b/migrations/000010_add_token_count_to_chunks.down.sql @@ -0,0 +1,10 @@ +-- Rollback: Remove token_count columns from code_chunks table + +DROP INDEX IF EXISTS codechunking.idx_code_chunks_repo_token_count; +DROP INDEX IF EXISTS codechunking.idx_code_chunks_token_count; + +ALTER TABLE codechunking.code_chunks +DROP COLUMN IF EXISTS token_counted_at; + +ALTER TABLE codechunking.code_chunks +DROP COLUMN IF EXISTS token_count; diff --git a/migrations/000010_add_token_count_to_chunks.up.sql b/migrations/000010_add_token_count_to_chunks.up.sql new file mode 100644 index 0000000..be5f9fb --- /dev/null +++ b/migrations/000010_add_token_count_to_chunks.up.sql @@ -0,0 +1,21 @@ +-- Add token_count column to code_chunks table for storing exact token counts from Google CountTokens API +-- This enables usage monitoring and chunk optimization + +ALTER TABLE codechunking.code_chunks +ADD COLUMN IF NOT EXISTS token_count INTEGER; + +ALTER TABLE codechunking.code_chunks +ADD COLUMN IF NOT EXISTS token_counted_at TIMESTAMP WITH TIME ZONE; + +-- Index for queries filtering by token count (useful for chunk optimization) +CREATE INDEX IF NOT EXISTS idx_code_chunks_token_count +ON codechunking.code_chunks(token_count) +WHERE token_count IS NOT NULL; + +-- Composite index for repository + token count queries (find oversized chunks per repo) +CREATE INDEX IF NOT EXISTS idx_code_chunks_repo_token_count +ON codechunking.code_chunks(repository_id, token_count) +WHERE deleted_at IS NULL AND token_count IS NOT NULL; + +COMMENT ON COLUMN codechunking.code_chunks.token_count IS 'Exact token count from Google CountTokens API'; +COMMENT ON COLUMN codechunking.code_chunks.token_counted_at IS 'Timestamp when token count was retrieved from API'; From 1366e29c6ea7708c7f884f5568d6d6961f59a467 Mon Sep 17 00:00:00 2001 From: Anthony Bible Date: Thu, 27 Nov 2025 14:49:37 -0700 Subject: [PATCH 2/5] feat: integrate token counting into chunk repository and job processor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add token_count support to chunk repository operations - Integrate token counting step in job processor workflow - Add token counting metrics and configuration options - Update test mocks with CountTokens methods 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/config.yaml | 7 + .../queue/batch_queue_manager_test.go | 26 + .../outbound/queue/parallel_processor_test.go | 26 + .../batch_progress_repository_test.go | 14 +- .../outbound/repository/chunk_repository.go | 86 ++- .../chunk_repository_token_count_test.go | 606 +++++++++++++++++ .../service/search_service_test.go | 24 + internal/application/worker/job_processor.go | 182 +++++ .../application/worker/job_processor_test.go | 33 + .../job_processor_token_counting_test.go | 631 ++++++++++++++++++ internal/config/config.go | 11 + internal/port/outbound/chunk_repository.go | 13 + internal/port/outbound/token_count_test.go | 6 +- 13 files changed, 1648 insertions(+), 17 deletions(-) create mode 100644 internal/adapter/outbound/repository/chunk_repository_token_count_test.go create mode 100644 internal/application/worker/job_processor_token_counting_test.go diff --git a/configs/config.yaml b/configs/config.yaml index c551a41..fe86c99 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -94,6 +94,13 @@ batch_processing: submission_max_backoff: 30m # Maximum backoff duration max_submission_attempts: 10 # Max retry attempts per batch submission + # Token counting configuration + token_counting: + enabled: true # Enable token counting + mode: "all" # Mode: "all", "sample", or "on_demand" + sample_percent: 10 # Percentage of chunks to sample (for "sample" mode) + max_tokens_per_chunk: 8192 # Maximum tokens per chunk (Gemini embedding model limit) + log: level: info format: json \ No newline at end of file diff --git a/internal/adapter/outbound/queue/batch_queue_manager_test.go b/internal/adapter/outbound/queue/batch_queue_manager_test.go index b2aabf4..ff37f61 100644 --- a/internal/adapter/outbound/queue/batch_queue_manager_test.go +++ b/internal/adapter/outbound/queue/batch_queue_manager_test.go @@ -70,6 +70,32 @@ func (m *mockEmbeddingService) EstimateTokenCount(ctx context.Context, text stri return len(text) / 4, nil } +func (m *mockEmbeddingService) CountTokens( + ctx context.Context, + text string, + model string, +) (*outbound.TokenCountResult, error) { + return &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + }, nil +} + +func (m *mockEmbeddingService) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + results := make([]*outbound.TokenCountResult, len(texts)) + for i, text := range texts { + results[i] = &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + } + } + return results, nil +} + // Test helper to create a manager with mocks. func createTestManager() outbound.BatchQueueManager { embeddingService := &mockEmbeddingService{} diff --git a/internal/adapter/outbound/queue/parallel_processor_test.go b/internal/adapter/outbound/queue/parallel_processor_test.go index 66aee8c..4c77c89 100644 --- a/internal/adapter/outbound/queue/parallel_processor_test.go +++ b/internal/adapter/outbound/queue/parallel_processor_test.go @@ -154,6 +154,32 @@ func (m *MockEmbeddingServiceWithDelay) EstimateTokenCount(ctx context.Context, return len(text) / 4, nil } +func (m *MockEmbeddingServiceWithDelay) CountTokens( + ctx context.Context, + text string, + model string, +) (*outbound.TokenCountResult, error) { + return &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + }, nil +} + +func (m *MockEmbeddingServiceWithDelay) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + results := make([]*outbound.TokenCountResult, len(texts)) + for i, text := range texts { + results[i] = &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + } + } + return results, nil +} + // RED PHASE TESTS - All tests should FAIL initially func TestMockEmbeddingServiceFailure(t *testing.T) { diff --git a/internal/adapter/outbound/repository/batch_progress_repository_test.go b/internal/adapter/outbound/repository/batch_progress_repository_test.go index 9ff326b..f856a9c 100644 --- a/internal/adapter/outbound/repository/batch_progress_repository_test.go +++ b/internal/adapter/outbound/repository/batch_progress_repository_test.go @@ -1116,7 +1116,7 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t // Manually query within transaction to test FOR UPDATE SKIP LOCKED query := `SELECT id, repository_id, indexing_job_id, batch_number, total_batches, chunks_processed, status, retry_count, next_retry_at, error_message, - gemini_batch_job_id, created_at, updated_at, + gemini_batch_job_id, gemini_file_uri, created_at, updated_at, batch_request_data, submission_attempts, next_submission_at FROM codechunking.batch_job_progress WHERE status = 'pending_submission' @@ -1134,13 +1134,14 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t var nextRetryAt *time.Time var errorMessage *string var geminiBatchJobID *string + var geminiFileURI *string var createdAt, updatedAt time.Time var batchRequestData []byte var nextSubmissionAt *time.Time err := row.Scan( &id, &repositoryID, &indexingJobID, &batchNumber, &totalBatches, &chunksProcessed, - &status, &retryCount, &nextRetryAt, &errorMessage, &geminiBatchJobID, + &status, &retryCount, &nextRetryAt, &errorMessage, &geminiBatchJobID, &geminiFileURI, &createdAt, &updatedAt, &batchRequestData, &submissionAttempts, &nextSubmissionAt, ) if err != nil { @@ -1149,7 +1150,7 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t return entity.RestoreBatchJobProgress( id, repositoryID, indexingJobID, batchNumber, totalBatches, chunksProcessed, - status, retryCount, nextRetryAt, errorMessage, geminiBatchJobID, + status, retryCount, nextRetryAt, errorMessage, geminiBatchJobID, geminiFileURI, createdAt, updatedAt, batchRequestData, submissionAttempts, nextSubmissionAt, ), nil }() @@ -1173,7 +1174,7 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t // Manually query within transaction to test FOR UPDATE SKIP LOCKED query := `SELECT id, repository_id, indexing_job_id, batch_number, total_batches, chunks_processed, status, retry_count, next_retry_at, error_message, - gemini_batch_job_id, created_at, updated_at, + gemini_batch_job_id, gemini_file_uri, created_at, updated_at, batch_request_data, submission_attempts, next_submission_at FROM codechunking.batch_job_progress WHERE status = 'pending_submission' @@ -1191,13 +1192,14 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t var nextRetryAt *time.Time var errorMessage *string var geminiBatchJobID *string + var geminiFileURI *string var createdAt, updatedAt time.Time var batchRequestData []byte var nextSubmissionAt *time.Time err := row.Scan( &id, &repositoryID, &indexingJobID, &batchNumber, &totalBatches, &chunksProcessed, - &status, &retryCount, &nextRetryAt, &errorMessage, &geminiBatchJobID, + &status, &retryCount, &nextRetryAt, &errorMessage, &geminiBatchJobID, &geminiFileURI, &createdAt, &updatedAt, &batchRequestData, &submissionAttempts, &nextSubmissionAt, ) if err != nil { @@ -1206,7 +1208,7 @@ func TestBatchProgressRepository_GetPendingSubmissionBatch_DistributedLocking(t return entity.RestoreBatchJobProgress( id, repositoryID, indexingJobID, batchNumber, totalBatches, chunksProcessed, - status, retryCount, nextRetryAt, errorMessage, geminiBatchJobID, + status, retryCount, nextRetryAt, errorMessage, geminiBatchJobID, geminiFileURI, createdAt, updatedAt, batchRequestData, submissionAttempts, nextSubmissionAt, ), nil }() diff --git a/internal/adapter/outbound/repository/chunk_repository.go b/internal/adapter/outbound/repository/chunk_repository.go index 633b17a..aac10db 100644 --- a/internal/adapter/outbound/repository/chunk_repository.go +++ b/internal/adapter/outbound/repository/chunk_repository.go @@ -35,10 +35,13 @@ const ( INSERT INTO codechunking.code_chunks ( id, repository_id, file_path, chunk_type, content, language, start_line, end_line, entity_name, parent_entity, content_hash, metadata, - qualified_name, signature, visibility - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + qualified_name, signature, visibility, token_count, token_counted_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) ON CONFLICT (repository_id, file_path, content_hash) - DO UPDATE SET id = code_chunks.id + DO UPDATE SET + id = code_chunks.id, + token_count = COALESCE(EXCLUDED.token_count, code_chunks.token_count), + token_counted_at = COALESCE(EXCLUDED.token_counted_at, code_chunks.token_counted_at) RETURNING id ` @@ -285,6 +288,8 @@ func (r *PostgreSQLChunkRepository) SaveChunk(ctx context.Context, chunk *outbou qualifiedName, signature, visibility, + chunk.TokenCount, + chunk.TokenCountedAt, ).Scan(&actualChunkID) if err != nil { slogger.Error(ctx, "Failed to save chunk", slogger.Fields{ @@ -404,6 +409,8 @@ func (r *PostgreSQLChunkRepository) SaveChunks(ctx context.Context, chunks []out qualifiedName, signature, visibility, + chunk.TokenCount, + chunk.TokenCountedAt, ).Scan(&actualChunkID) if err != nil { slogger.Error(ctx, "Failed to save chunk in batch", slogger.Fields{ @@ -464,14 +471,18 @@ func (r *PostgreSQLChunkRepository) FindOrCreateChunks( }() // Use a query that returns the actual persisted ID (whether new or existing) + // On conflict, preserve existing token_count (find or create, not update) query := ` INSERT INTO codechunking.code_chunks ( id, repository_id, file_path, chunk_type, content, language, start_line, end_line, entity_name, parent_entity, content_hash, metadata, - qualified_name, signature, visibility - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + qualified_name, signature, visibility, token_count, token_counted_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) ON CONFLICT (repository_id, file_path, content_hash) - DO UPDATE SET id = code_chunks.id + DO UPDATE SET + id = code_chunks.id, + token_count = COALESCE(code_chunks.token_count, EXCLUDED.token_count), + token_counted_at = COALESCE(code_chunks.token_counted_at, EXCLUDED.token_counted_at) RETURNING id ` @@ -554,6 +565,8 @@ func (r *PostgreSQLChunkRepository) FindOrCreateChunks( qualifiedName, signature, visibility, + chunk.TokenCount, + chunk.TokenCountedAt, ).Scan(&returnedID) if err != nil { slogger.Error(ctx, "Failed to insert/find chunk", slogger.Fields{ @@ -597,7 +610,8 @@ func (r *PostgreSQLChunkRepository) FindOrCreateChunks( func (r *PostgreSQLChunkRepository) GetChunk(ctx context.Context, id uuid.UUID) (*outbound.CodeChunk, error) { query := ` SELECT id, file_path, start_line, end_line, content, language, content_hash, created_at, - chunk_type, COALESCE(entity_name, ''), COALESCE(parent_entity, ''), COALESCE(qualified_name, ''), COALESCE(signature, ''), COALESCE(visibility, '') + chunk_type, COALESCE(entity_name, ''), COALESCE(parent_entity, ''), COALESCE(qualified_name, ''), COALESCE(signature, ''), COALESCE(visibility, ''), + token_count, token_counted_at FROM codechunking.code_chunks WHERE id = $1 AND deleted_at IS NULL ` @@ -620,6 +634,8 @@ func (r *PostgreSQLChunkRepository) GetChunk(ctx context.Context, id uuid.UUID) &chunk.QualifiedName, &chunk.Signature, &chunk.Visibility, + &chunk.TokenCount, + &chunk.TokenCountedAt, ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -647,7 +663,8 @@ func (r *PostgreSQLChunkRepository) GetChunksForRepository( ) ([]outbound.CodeChunk, error) { query := ` SELECT id, file_path, start_line, end_line, content, language, content_hash, created_at, - chunk_type, COALESCE(entity_name, ''), COALESCE(parent_entity, ''), COALESCE(qualified_name, ''), COALESCE(signature, ''), COALESCE(visibility, '') + chunk_type, COALESCE(entity_name, ''), COALESCE(parent_entity, ''), COALESCE(qualified_name, ''), COALESCE(signature, ''), COALESCE(visibility, ''), + token_count, token_counted_at FROM codechunking.code_chunks WHERE repository_id = $1 AND deleted_at IS NULL ORDER BY file_path, start_line @@ -683,6 +700,8 @@ func (r *PostgreSQLChunkRepository) GetChunksForRepository( &chunk.QualifiedName, &chunk.Signature, &chunk.Visibility, + &chunk.TokenCount, + &chunk.TokenCountedAt, ) if err != nil { slogger.Error(ctx, "Failed to scan chunk row", slogger.Fields2( @@ -1581,6 +1600,8 @@ func (r *PostgreSQLChunkRepository) SaveChunkWithEmbedding( fields.QualifiedName, fields.Signature, fields.Visibility, + chunk.TokenCount, + chunk.TokenCountedAt, ).Scan(&actualChunkID) if err != nil { slogger.Error(ctx, "Failed to save chunk in transaction", slogger.Fields{ @@ -1713,6 +1734,53 @@ func (r *PostgreSQLChunkRepository) saveEmbeddingInTx( return nil } +// UpdateTokenCounts updates the token count for multiple chunks in a batch operation. +func (r *PostgreSQLChunkRepository) UpdateTokenCounts(ctx context.Context, updates []outbound.ChunkTokenUpdate) error { + if len(updates) == 0 { + return nil + } + + tx, err := r.pool.Begin(ctx) + if err != nil { + slogger.Error(ctx, "Failed to begin transaction for UpdateTokenCounts", slogger.Field("error", err.Error())) + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err := tx.Rollback(ctx); err != nil && !errors.Is(err, pgx.ErrTxClosed) { + slogger.Warn(ctx, "Failed to rollback UpdateTokenCounts transaction", slogger.Field("error", err.Error())) + } + }() + + query := ` + UPDATE codechunking.code_chunks + SET token_count = $1, token_counted_at = $2 + WHERE id = $3 AND deleted_at IS NULL + ` + + for _, update := range updates { + _, err := tx.Exec(ctx, query, update.TokenCount, update.TokenCountedAt, update.ChunkID) + if err != nil { + slogger.Error(ctx, "Failed to update token count", slogger.Fields{ + "chunk_id": update.ChunkID.String(), + "token_count": update.TokenCount, + "error": err.Error(), + }) + return fmt.Errorf("failed to update token count for chunk %s: %w", update.ChunkID.String(), err) + } + } + + if err := tx.Commit(ctx); err != nil { + slogger.Error(ctx, "Failed to commit UpdateTokenCounts transaction", slogger.Fields2( + "update_count", len(updates), + "error", err.Error(), + )) + return fmt.Errorf("failed to commit transaction: %w", err) + } + + slogger.Debug(ctx, "Token counts updated successfully", slogger.Field("update_count", len(updates))) + return nil +} + // SaveChunksWithEmbeddings stores multiple chunks and embeddings in a single transaction. func (r *PostgreSQLChunkRepository) SaveChunksWithEmbeddings( ctx context.Context, @@ -1852,6 +1920,8 @@ func (r *PostgreSQLChunkRepository) SaveChunksWithEmbeddings( qualifiedName, signature, visibility, + chunk.TokenCount, + chunk.TokenCountedAt, ).Scan(&actualChunkID) if err != nil { slogger.Error(ctx, "Failed to save chunk in batch transaction", slogger.Fields{ diff --git a/internal/adapter/outbound/repository/chunk_repository_token_count_test.go b/internal/adapter/outbound/repository/chunk_repository_token_count_test.go new file mode 100644 index 0000000..ddfc63a --- /dev/null +++ b/internal/adapter/outbound/repository/chunk_repository_token_count_test.go @@ -0,0 +1,606 @@ +//go:build integration + +package repository + +import ( + "codechunking/internal/port/outbound" + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestChunkRepository_SaveWithTokenCount tests saving chunks with token count information. +func TestChunkRepository_SaveWithTokenCount(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/token-count-repo-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "token-count-test-repo", "Test repository for token counting", "indexed") + require.NoError(t, err) + + tests := []struct { + name string + chunk *outbound.CodeChunk + expectTokenCount int + expectCountedAt bool + description string + }{ + { + name: "save chunk with token count", + chunk: &outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test_with_count.go", + Content: "func TestWithTokenCount() { fmt.Println(\"test\") }", + Language: "go", + StartLine: 1, + EndLine: 3, + Hash: "hash-with-count-1", + Type: "function", + EntityName: "TestWithTokenCount", + CreatedAt: time.Now(), + TokenCount: 42, + TokenCountedAt: func() *time.Time { + t := time.Now() + return &t + }(), + }, + expectTokenCount: 42, + expectCountedAt: true, + description: "Saving a chunk with token_count should persist the value", + }, + { + name: "save chunk without token count", + chunk: &outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test_without_count.go", + Content: "func TestWithoutTokenCount() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-without-count-1", + Type: "function", + EntityName: "TestWithoutTokenCount", + CreatedAt: time.Now(), + TokenCount: 0, + TokenCountedAt: nil, + }, + expectTokenCount: 0, + expectCountedAt: false, + description: "Saving a chunk without token_count should leave it NULL", + }, + { + name: "save chunk with zero token count but timestamp", + chunk: &outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test_zero_count.go", + Content: "", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-zero-count-1", + Type: "fragment", + EntityName: "", + CreatedAt: time.Now(), + TokenCount: 0, + TokenCountedAt: func() *time.Time { + t := time.Now() + return &t + }(), + }, + expectTokenCount: 0, + expectCountedAt: true, + description: "Saving a chunk with zero token_count but timestamp should persist timestamp", + }, + { + name: "save chunk with large token count", + chunk: &outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test_large_count.go", + Content: "// Very large function with many tokens\nfunc TestLargeTokenCount() { /* ... */ }", + Language: "go", + StartLine: 1, + EndLine: 50, + Hash: "hash-large-count-1", + Type: "function", + EntityName: "TestLargeTokenCount", + CreatedAt: time.Now(), + TokenCount: 8192, + TokenCountedAt: func() *time.Time { + t := time.Now() + return &t + }(), + }, + expectTokenCount: 8192, + expectCountedAt: true, + description: "Saving a chunk with large token_count should handle large values", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save the chunk + err := repo.SaveChunk(ctx, tt.chunk) + require.NoError(t, err, tt.description) + + // Retrieve and verify + chunkID, err := uuid.Parse(tt.chunk.ID) + require.NoError(t, err) + + var tokenCount *int + var tokenCountedAt *time.Time + + err = pool.QueryRow(ctx, ` + SELECT token_count, token_counted_at + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount, &tokenCountedAt) + require.NoError(t, err) + + // Verify token count + if tt.expectTokenCount > 0 { + require.NotNil(t, tokenCount, "token_count should not be NULL") + assert.Equal(t, tt.expectTokenCount, *tokenCount, "token_count should match expected value") + } else if tokenCount != nil { + assert.Equal(t, 0, *tokenCount, "token_count should be 0 if set") + } + + // Verify timestamp + if tt.expectCountedAt { + require.NotNil(t, tokenCountedAt, "token_counted_at should not be NULL") + // Verify timestamp is recent (within last minute) + assert.WithinDuration(t, time.Now(), *tokenCountedAt, time.Minute, "token_counted_at should be recent") + } else { + assert.Nil(t, tokenCountedAt, "token_counted_at should be NULL") + } + }) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestChunkRepository_UpdateTokenCount tests updating token count on existing chunks. +func TestChunkRepository_UpdateTokenCount(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/update-token-repo-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "update-token-test-repo", "Test repository for updating token counts", "indexed") + require.NoError(t, err) + + // Create initial chunk without token count + initialChunk := &outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test_update.go", + Content: "func TestUpdate() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-update-1", + Type: "function", + EntityName: "TestUpdate", + CreatedAt: time.Now(), + TokenCount: 0, + TokenCountedAt: nil, + } + + err = repo.SaveChunk(ctx, initialChunk) + require.NoError(t, err) + + // Verify initial state (no token count) + chunkID, err := uuid.Parse(initialChunk.ID) + require.NoError(t, err) + + var tokenCount *int + var tokenCountedAt *time.Time + + err = pool.QueryRow(ctx, ` + SELECT token_count, token_counted_at + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount, &tokenCountedAt) + require.NoError(t, err) + assert.Nil(t, tokenCountedAt, "initial token_counted_at should be NULL") + + // Update chunk with token count + now := time.Now() + updatedChunk := &outbound.CodeChunk{ + ID: initialChunk.ID, + RepositoryID: repositoryID, + FilePath: "test_update.go", + Content: "func TestUpdate() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-update-1", + Type: "function", + EntityName: "TestUpdate", + CreatedAt: initialChunk.CreatedAt, + TokenCount: 15, + TokenCountedAt: &now, + } + + err = repo.SaveChunk(ctx, updatedChunk) + require.NoError(t, err, "Updating token_count on existing chunk should work") + + // Verify updated state + err = pool.QueryRow(ctx, ` + SELECT token_count, token_counted_at + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount, &tokenCountedAt) + require.NoError(t, err) + + require.NotNil(t, tokenCount, "updated token_count should not be NULL") + assert.Equal(t, 15, *tokenCount, "token_count should be updated to new value") + require.NotNil(t, tokenCountedAt, "updated token_counted_at should not be NULL") + assert.WithinDuration(t, now, *tokenCountedAt, time.Second, "token_counted_at should match update time") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestChunkRepository_GetChunksWithTokenCount tests retrieving chunks with token count information. +func TestChunkRepository_GetChunksWithTokenCount(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/get-token-repo-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "get-token-test-repo", "Test repository for getting token counts", "indexed") + require.NoError(t, err) + + // Create chunks with varying token counts + now := time.Now() + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "with_count_1.go", + Content: "func Test1() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-get-1", + Type: "function", + EntityName: "Test1", + CreatedAt: now, + TokenCount: 10, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "without_count.go", + Content: "func Test2() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-get-2", + Type: "function", + EntityName: "Test2", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "with_count_2.go", + Content: "func Test3() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-get-3", + Type: "function", + EntityName: "Test3", + CreatedAt: now, + TokenCount: 25, + TokenCountedAt: &now, + }, + } + + // Save all chunks + err = repo.SaveChunks(ctx, chunks) + require.NoError(t, err) + + t.Run("GetChunk should include token_count if set", func(t *testing.T) { + chunkID, err := uuid.Parse(chunks[0].ID) + require.NoError(t, err) + + chunk, err := repo.GetChunk(ctx, chunkID) + require.NoError(t, err, "Retrieved chunk should include token_count if set") + + assert.Equal(t, 10, chunk.TokenCount, "token_count should be 10") + require.NotNil(t, chunk.TokenCountedAt, "token_counted_at should not be nil") + assert.WithinDuration(t, now, *chunk.TokenCountedAt, time.Second, "token_counted_at should match saved time") + }) + + t.Run("GetChunk should have nil/zero token_count if not set", func(t *testing.T) { + chunkID, err := uuid.Parse(chunks[1].ID) + require.NoError(t, err) + + chunk, err := repo.GetChunk(ctx, chunkID) + require.NoError(t, err, "Retrieved chunk should have nil/zero token_count if not set") + + assert.Equal(t, 0, chunk.TokenCount, "token_count should be 0 or nil") + assert.Nil(t, chunk.TokenCountedAt, "token_counted_at should be nil") + }) + + t.Run("GetChunksForRepository should include token_count information", func(t *testing.T) { + retrievedChunks, err := repo.GetChunksForRepository(ctx, repositoryID) + require.NoError(t, err, "Retrieved chunks should include token_count if set") + require.Len(t, retrievedChunks, 3, "should retrieve all 3 chunks") + + // Verify token counts are preserved + tokenCountMap := make(map[string]int) + for _, chunk := range retrievedChunks { + tokenCountMap[chunk.ID] = chunk.TokenCount + } + + assert.Equal(t, 10, tokenCountMap[chunks[0].ID], "first chunk should have token_count 10") + assert.Equal(t, 0, tokenCountMap[chunks[1].ID], "second chunk should have token_count 0") + assert.Equal(t, 25, tokenCountMap[chunks[2].ID], "third chunk should have token_count 25") + }) + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestChunkRepository_BatchOperationsWithTokenCount tests batch operations with token counts. +func TestChunkRepository_BatchOperationsWithTokenCount(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/batch-token-repo-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "batch-token-test-repo", "Test repository for batch token operations", "indexed") + require.NoError(t, err) + + t.Run("SaveChunks should preserve token_count for all chunks", func(t *testing.T) { + now := time.Now() + chunks := make([]outbound.CodeChunk, 10) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "batch_test.go", + Content: "func Test() { }", + Language: "go", + StartLine: i + 1, + EndLine: i + 1, + Hash: "hash-batch-" + uuid.New().String()[:8], + Type: "function", + EntityName: "Test", + CreatedAt: now, + TokenCount: (i + 1) * 10, // 10, 20, 30, ... + TokenCountedAt: &now, + } + } + + err := repo.SaveChunks(ctx, chunks) + require.NoError(t, err, "Batch save should preserve token counts") + + // Verify all chunks have correct token counts + for i, chunk := range chunks { + chunkID, err := uuid.Parse(chunk.ID) + require.NoError(t, err) + + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount) + require.NoError(t, err) + + expectedCount := (i + 1) * 10 + require.NotNil(t, tokenCount, "chunk %d should have token_count", i) + assert.Equal(t, expectedCount, *tokenCount, "chunk %d should have correct token_count", i) + } + }) + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestChunkRepository_FindOrCreateChunksWithTokenCount tests FindOrCreateChunks with token counts. +func TestChunkRepository_FindOrCreateChunksWithTokenCount(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/find-create-token-repo-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "find-create-token-test-repo", "Test repository for find/create with token counts", "indexed") + require.NoError(t, err) + + now := time.Now() + + t.Run("FindOrCreateChunks should preserve token_count on new chunks", func(t *testing.T) { + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "find_create_new.go", + Content: "func New() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-find-create-new", + Type: "function", + EntityName: "New", + CreatedAt: now, + TokenCount: 50, + TokenCountedAt: &now, + }, + } + + resultChunks, err := repo.FindOrCreateChunks(ctx, chunks) + require.NoError(t, err, "FindOrCreateChunks should work with token_count") + require.Len(t, resultChunks, 1) + + // Verify token count was saved + chunkID, err := uuid.Parse(resultChunks[0].ID) + require.NoError(t, err) + + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount) + require.NoError(t, err) + + require.NotNil(t, tokenCount, "token_count should be saved") + assert.Equal(t, 50, *tokenCount, "token_count should match") + }) + + t.Run("FindOrCreateChunks should return existing chunk token_count", func(t *testing.T) { + // Create initial chunk with token count + initialChunk := outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "find_create_existing.go", + Content: "func Existing() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-find-create-existing", + Type: "function", + EntityName: "Existing", + CreatedAt: now, + TokenCount: 100, + TokenCountedAt: &now, + } + + err := repo.SaveChunk(ctx, &initialChunk) + require.NoError(t, err) + + // Try to create duplicate with different token count + duplicateChunk := []outbound.CodeChunk{ + { + ID: uuid.New().String(), // Different ID + RepositoryID: repositoryID, + FilePath: "find_create_existing.go", + Content: "func Existing() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-find-create-existing", // Same hash + Type: "function", + EntityName: "Existing", + CreatedAt: now, + TokenCount: 999, // Different token count (should be ignored) + TokenCountedAt: &now, + }, + } + + resultChunks, err := repo.FindOrCreateChunks(ctx, duplicateChunk) + require.NoError(t, err, "FindOrCreateChunks should return existing chunk") + require.Len(t, resultChunks, 1) + + // Verify original token count is preserved + chunkID, err := uuid.Parse(resultChunks[0].ID) + require.NoError(t, err) + + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount) + require.NoError(t, err) + + require.NotNil(t, tokenCount, "existing token_count should be preserved") + assert.Equal(t, 100, *tokenCount, "original token_count should be preserved, not updated") + }) + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} diff --git a/internal/application/service/search_service_test.go b/internal/application/service/search_service_test.go index 75f2753..a701c05 100644 --- a/internal/application/service/search_service_test.go +++ b/internal/application/service/search_service_test.go @@ -145,6 +145,30 @@ func (m *MockEmbeddingService) EstimateTokenCount(ctx context.Context, text stri return args.Int(0), args.Error(1) } +func (m *MockEmbeddingService) CountTokens( + ctx context.Context, + text string, + model string, +) (*outbound.TokenCountResult, error) { + args := m.Called(ctx, text, model) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*outbound.TokenCountResult), args.Error(1) +} + +func (m *MockEmbeddingService) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + args := m.Called(ctx, texts, model) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]*outbound.TokenCountResult), args.Error(1) +} + // MockChunkRepository is a mock implementation for retrieving chunk information. type MockChunkRepository struct { mock.Mock diff --git a/internal/application/worker/job_processor.go b/internal/application/worker/job_processor.go index 0139321..5b2d727 100644 --- a/internal/application/worker/job_processor.go +++ b/internal/application/worker/job_processor.go @@ -1413,6 +1413,180 @@ func (p *DefaultJobProcessor) storeSingleEmbeddingWithErrorHandling( return nil } +// countTokensForChunks counts tokens for chunks based on configuration mode. +// Modes: "all" - count all chunks, "sample" - count X% of chunks, "on_demand" - skip counting. +// Errors are logged but do not fail the job (graceful degradation). +func (p *DefaultJobProcessor) countTokensForChunks(ctx context.Context, chunks []outbound.CodeChunk) { + if !p.batchConfig.TokenCounting.Enabled { + return + } + + mode := p.batchConfig.TokenCounting.Mode + if mode == "on_demand" { + slogger.Debug(ctx, "Token counting mode is on_demand, skipping", slogger.Field("mode", mode)) + return + } + + // Select chunks to count based on mode + var chunksToCount []outbound.CodeChunk + switch mode { + case "all": + chunksToCount = chunks + case "sample": + samplePercent := p.batchConfig.TokenCounting.SamplePercent + if samplePercent <= 0 || samplePercent > 100 { + slogger.Warn( + ctx, + "Invalid sample_percent, skipping token counting", + slogger.Field("sample_percent", samplePercent), + ) + return + } + sampleSize := (len(chunks) * samplePercent) / 100 + if sampleSize == 0 { + sampleSize = 1 + } + chunksToCount = chunks[:sampleSize] + default: + slogger.Warn(ctx, "Unknown token counting mode, skipping", slogger.Field("mode", mode)) + return + } + + if len(chunksToCount) == 0 { + return + } + + slogger.Info(ctx, "Starting token counting", slogger.Fields{ + "mode": mode, + "total_chunks": len(chunks), + "chunks_to_count": len(chunksToCount), + "sample_percent": p.batchConfig.TokenCounting.SamplePercent, + }) + + // Extract texts for counting + texts := make([]string, len(chunksToCount)) + for i, chunk := range chunksToCount { + texts[i] = chunk.Content + } + + // Call CountTokensBatch + results, err := p.embeddingService.CountTokensBatch(ctx, texts, "gemini-embedding-001") + if err != nil { + slogger.Warn( + ctx, + "Token counting failed, continuing with embedding generation", + slogger.Field("error", err.Error()), + ) + return + } + + // Build updates + now := time.Now() + updates := make([]outbound.ChunkTokenUpdate, len(chunksToCount)) + totalTokens := 0 + oversizedChunks := 0 + maxTokensPerChunk := p.batchConfig.TokenCounting.MaxTokensPerChunk + if maxTokensPerChunk <= 0 { + maxTokensPerChunk = 8192 // Default to Gemini embedding model limit + } + + for i, result := range results { + chunkID, parseErr := uuid.Parse(chunksToCount[i].ID) + if parseErr != nil { + slogger.Warn(ctx, "Invalid chunk ID for token count update", slogger.Fields{ + "chunk_id": chunksToCount[i].ID, + "error": parseErr.Error(), + }) + continue + } + updates[i] = outbound.ChunkTokenUpdate{ + ChunkID: chunkID, + TokenCount: result.TotalTokens, + TokenCountedAt: &now, + } + totalTokens += result.TotalTokens + + // Check for oversized chunks + if result.TotalTokens > maxTokensPerChunk { + oversizedChunks++ + slogger.Warn(ctx, "Chunk exceeds max token limit", slogger.Fields{ + "chunk_id": chunkID.String(), + "token_count": result.TotalTokens, + "max_tokens_per_chunk": maxTokensPerChunk, + "file_path": chunksToCount[i].FilePath, + }) + } + } + + // Update repository + if err := p.chunkStorageRepo.UpdateTokenCounts(ctx, updates); err != nil { + slogger.Warn(ctx, "Failed to persist token counts, continuing", slogger.Field("error", err.Error())) + return + } + + avgTokens := 0 + if len(updates) > 0 { + avgTokens = totalTokens / len(updates) + } + + // Emit metrics + p.emitTokenCountingMetrics(ctx, len(chunksToCount), totalTokens, avgTokens, oversizedChunks) + + slogger.Info(ctx, "Token counting completed successfully", slogger.Fields{ + "chunks_counted": len(updates), + "total_tokens": totalTokens, + "average_tokens": avgTokens, + "oversized_chunks": oversizedChunks, + }) +} + +// emitTokenCountingMetrics emits OpenTelemetry metrics for token counting operations. +// This provides observability into token usage, API call patterns, and oversized chunk detection. +func (p *DefaultJobProcessor) emitTokenCountingMetrics( + ctx context.Context, + chunksProcessed int, + totalTokens int, + avgTokens int, + oversizedChunks int, +) { + // Emit total tokens counted (counter metric) + slogger.Info(ctx, "Token counting metrics", slogger.Fields{ + "metric_name": "codechunking_tokens_counted_total", + "metric_type": "counter", + "metric_value": totalTokens, + "chunks_processed": chunksProcessed, + }) + + // Emit token count API calls (counter metric) + // Each batch API call processes multiple chunks, so we emit 1 API call per operation + slogger.Info(ctx, "Token counting API metrics", slogger.Fields{ + "metric_name": "codechunking_token_count_api_calls_total", + "metric_type": "counter", + "metric_value": 1, + "chunks_processed": chunksProcessed, + }) + + // Emit tokens per chunk distribution (histogram metric) + // Log the average as a representative sample for histogram + slogger.Info(ctx, "Token distribution metrics", slogger.Fields{ + "metric_name": "codechunking_tokens_per_chunk", + "metric_type": "histogram", + "metric_value": avgTokens, + "total_tokens": totalTokens, + "chunks_processed": chunksProcessed, + }) + + // Emit oversized chunks counter + if oversizedChunks > 0 { + slogger.Info(ctx, "Oversized chunk metrics", slogger.Fields{ + "metric_name": "codechunking_chunks_over_limit_total", + "metric_type": "counter", + "metric_value": oversizedChunks, + "chunks_processed": chunksProcessed, + }) + } +} + // generateEmbeddingsWithBatch creates embeddings for code chunks using batch processing // with smart routing decisions based on repository size and available services. func (p *DefaultJobProcessor) generateEmbeddingsWithBatch( @@ -1495,6 +1669,10 @@ func (p *DefaultJobProcessor) processBatchEmbeddings( ) error { // Convert UUID to string for backward compatibility with existing code jobID := indexingJobID.String() + + // PRE-FLIGHT: Count tokens before embedding generation (non-blocking) + p.countTokensForChunks(ctx, chunks) + // Check context before starting batch processing if err := ctx.Err(); err != nil { return fmt.Errorf("context cancelled before batch processing: %w", err) @@ -1763,6 +1941,10 @@ func (p *DefaultJobProcessor) processSequentialEmbeddings( ) error { // Convert UUID to string for backward compatibility jobID := indexingJobID.String() + + // PRE-FLIGHT: Count tokens before embedding generation (non-blocking) + p.countTokensForChunks(ctx, chunks) + if p.batchConfig.UseTestEmbeddings { // Test mode: use test embeddings return p.processEmbeddingsWithGenerator(ctx, jobID, repositoryID, chunks, execution, diff --git a/internal/application/worker/job_processor_test.go b/internal/application/worker/job_processor_test.go index 2372c16..4701e57 100644 --- a/internal/application/worker/job_processor_test.go +++ b/internal/application/worker/job_processor_test.go @@ -324,6 +324,34 @@ func (m *MockEmbeddingService) EstimateTokenCount(ctx context.Context, text stri return args.Int(0), args.Error(1) } +func (m *MockEmbeddingService) CountTokens( + ctx context.Context, + text string, + model string, +) (*outbound.TokenCountResult, error) { + args := m.Called(ctx, text, model) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*outbound.TokenCountResult), args.Error(1) +} + +func (m *MockEmbeddingService) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + args := m.Called(ctx, texts, model) + if args.Get(0) == nil { + return nil, args.Error(1) + } + // Support both direct values and function returns + if fn, ok := args.Get(0).(func(context.Context, []string, string) []*outbound.TokenCountResult); ok { + return fn(ctx, texts, model), args.Error(1) + } + return args.Get(0).([]*outbound.TokenCountResult), args.Error(1) +} + // MockChunkStorageRepository mocks the chunk storage repository interface. type MockChunkStorageRepository struct { mock.Mock @@ -388,6 +416,11 @@ func (m *MockChunkStorageRepository) CountChunksForRepository( return args.Int(0), args.Error(1) } +func (m *MockChunkStorageRepository) UpdateTokenCounts(ctx context.Context, updates []outbound.ChunkTokenUpdate) error { + args := m.Called(ctx, updates) + return args.Error(0) +} + func (m *MockChunkStorageRepository) SaveEmbedding(ctx context.Context, embedding *outbound.Embedding) error { args := m.Called(ctx, embedding) return args.Error(0) diff --git a/internal/application/worker/job_processor_token_counting_test.go b/internal/application/worker/job_processor_token_counting_test.go new file mode 100644 index 0000000..6d792ef --- /dev/null +++ b/internal/application/worker/job_processor_token_counting_test.go @@ -0,0 +1,631 @@ +package worker + +import ( + "codechunking/internal/config" + "codechunking/internal/port/outbound" + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// TestJobProcessor_CountTokensForChunks tests that the job processor calls CountTokensBatch +// on the embedding service before generating embeddings. +func TestJobProcessor_CountTokensForChunks(t *testing.T) { + t.Parallel() + + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create test chunks + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func main() { fmt.Println(\"hello\") }", + StartLine: 1, + EndLine: 1, + Type: "function", + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "type User struct { Name string }", + StartLine: 3, + EndLine: 3, + Type: "type", + }, + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + mockRepoRepo := &MockRepositoryRepository{} + mockIndexingJobRepo := &MockIndexingJobRepository{} + mockGitClient := &MockEnhancedGitClient{} + mockCodeParser := &MockCodeParser{} + + // Expected token count results + expectedTokenCounts := []*outbound.TokenCountResult{ + { + TotalTokens: 10, + Model: "gemini-embedding-001", + }, + { + TotalTokens: 6, + Model: "gemini-embedding-001", + }, + } + + // Extract texts from chunks for CountTokensBatch call + expectedTexts := []string{ + chunks[0].Content, + chunks[1].Content, + } + + // Mock expectations + // 1. CountTokensBatch should be called BEFORE GenerateBatchEmbeddings + mockEmbeddingService.On("CountTokensBatch", + mock.Anything, // ctx + expectedTexts, + "gemini-embedding-001", + ).Return(expectedTokenCounts, nil).Once() + + // 2. Chunks should be updated with token counts + mockChunkRepo.On("UpdateTokenCounts", + mock.Anything, // ctx + mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { + // Verify that we're updating both chunks with the correct token counts + if len(updates) != 2 { + return false + } + // Check first chunk + if updates[0].TokenCount != 10 { + return false + } + // Check second chunk + if updates[1].TokenCount != 6 { + return false + } + return true + }), + ).Return(nil).Once() + + // 3. Then embedding generation should proceed + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, // ctx + expectedTexts, + mock.Anything, // options + ).Return([]*outbound.EmbeddingResult{ + {Vector: make([]float64, 768), Dimensions: 768, Model: "gemini-embedding-001"}, + {Vector: make([]float64, 768), Dimensions: 768, Model: "gemini-embedding-001"}, + }, nil).Once() + + // Mock chunk storage + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, // ctx + mock.Anything, // chunk + mock.Anything, // embedding + ).Return(nil) + + // Create processor config + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + SamplePercent: 10, + MaxTokensPerChunk: 8192, + }, + } + + // Create processor + processor := NewDefaultJobProcessor( + processorConfig, + mockIndexingJobRepo, + mockRepoRepo, + mockGitClient, + mockCodeParser, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute - this should call token counting before embeddings + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + assert.NoError(t, err, "Expected no error from generateEmbeddings") + mockEmbeddingService.AssertExpectations(t) + mockChunkRepo.AssertExpectations(t) + + // Verify call order: CountTokensBatch should be called before GenerateBatchEmbeddings + calls := mockEmbeddingService.Calls + countTokensIdx, generateEmbeddingsIdx := -1, -1 + for i, call := range calls { + if call.Method == "CountTokensBatch" { + countTokensIdx = i + } + if call.Method == "GenerateBatchEmbeddings" { + generateEmbeddingsIdx = i + } + } + + assert.NotEqual(t, -1, countTokensIdx, "CountTokensBatch should have been called") + assert.NotEqual(t, -1, generateEmbeddingsIdx, "GenerateBatchEmbeddings should have been called") + assert.Less( + t, + countTokensIdx, + generateEmbeddingsIdx, + "CountTokensBatch must be called BEFORE GenerateBatchEmbeddings", + ) +} + +// TestJobProcessor_TokenCountingModes tests different token counting modes. +func TestJobProcessor_TokenCountingModes(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mode string + samplePercent int + totalChunks int + expectedCountCall bool + expectedSamples int + }{ + { + name: "Mode all - count all chunks", + mode: "all", + samplePercent: 0, + totalChunks: 10, + expectedCountCall: true, + expectedSamples: 10, + }, + { + name: "Mode sample - count 10% of chunks", + mode: "sample", + samplePercent: 10, + totalChunks: 100, + expectedCountCall: true, + expectedSamples: 10, + }, + { + name: "Mode sample - count 20% of chunks", + mode: "sample", + samplePercent: 20, + totalChunks: 50, + expectedCountCall: true, + expectedSamples: 10, + }, + { + name: "Mode on_demand - skip token counting", + mode: "on_demand", + samplePercent: 0, + totalChunks: 10, + expectedCountCall: false, + expectedSamples: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create test chunks + chunks := make([]outbound.CodeChunk, tc.totalChunks) + for i := range tc.totalChunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func test() {}", + StartLine: i + 1, + EndLine: i + 1, + Type: "function", + } + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + + // Configure expectations based on mode + if tc.expectedCountCall { + // CountTokensBatch should be called with sampled chunks + mockEmbeddingService.On("CountTokensBatch", + mock.Anything, + mock.MatchedBy(func(texts []string) bool { + return len(texts) == tc.expectedSamples + }), + "gemini-embedding-001", + ).Return(func(ctx context.Context, texts []string, model string) []*outbound.TokenCountResult { + results := make([]*outbound.TokenCountResult, len(texts)) + for i := range results { + results[i] = &outbound.TokenCountResult{ + TotalTokens: 5, + Model: model, + } + } + return results + }, nil).Once() + + // UpdateTokenCounts should be called for sampled chunks + mockChunkRepo.On("UpdateTokenCounts", + mock.Anything, + mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { + return len(updates) == tc.expectedSamples + }), + ).Return(nil).Once() + } + + // Embedding generation should always proceed - create results that match chunk count + generatedResults := make([]*outbound.EmbeddingResult, tc.totalChunks) + for i := range generatedResults { + generatedResults[i] = &outbound.EmbeddingResult{ + Vector: make([]float64, 768), + Dimensions: 768, + Model: "gemini-embedding-001", + } + } + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(generatedResults, nil) + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor with token counting config + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: tc.mode != "on_demand", + Mode: tc.mode, + SamplePercent: tc.samplePercent, + }, + } + + // Create processor (this will fail until implementation exists) + processor := NewDefaultJobProcessor( + processorConfig, + &MockIndexingJobRepository{}, + &MockRepositoryRepository{}, + &MockEnhancedGitClient{}, + &MockCodeParser{}, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + assert.NoError(t, err) + mockEmbeddingService.AssertExpectations(t) + mockChunkRepo.AssertExpectations(t) + }) + } +} + +// TestJobProcessor_TokenCountingFailsGracefully tests that token counting failures +// don't block embedding generation. +func TestJobProcessor_TokenCountingFailsGracefully(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + tokenCountErr error + expectWarning bool + expectSuccess bool + }{ + { + name: "Network error - warn and continue", + tokenCountErr: errors.New("network timeout"), + expectWarning: true, + expectSuccess: true, + }, + { + name: "API error - warn and continue", + tokenCountErr: errors.New("API rate limit exceeded"), + expectWarning: true, + expectSuccess: true, + }, + { + name: "Quota exceeded - warn and continue", + tokenCountErr: errors.New("quota exceeded"), + expectWarning: true, + expectSuccess: true, + }, + { + name: "Success - no warning", + tokenCountErr: nil, + expectWarning: false, + expectSuccess: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func test() {}", + StartLine: 1, + EndLine: 1, + Type: "function", + }, + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + + // Mock CountTokensBatch - may fail or succeed + if tc.tokenCountErr != nil { + mockEmbeddingService.On("CountTokensBatch", + mock.Anything, + mock.Anything, + "gemini-embedding-001", + ).Return(nil, tc.tokenCountErr).Once() + + // UpdateTokenCounts should NOT be called on error + // (no expectation set means test fails if called) + } else { + mockEmbeddingService.On("CountTokensBatch", + mock.Anything, + mock.Anything, + "gemini-embedding-001", + ).Return([]*outbound.TokenCountResult{ + {TotalTokens: 5, Model: "gemini-embedding-001"}, + }, nil).Once() + + mockChunkRepo.On("UpdateTokenCounts", + mock.Anything, + mock.Anything, + ).Return(nil).Once() + } + + // Embedding generation MUST always proceed regardless of token counting result + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return([]*outbound.EmbeddingResult{ + {Vector: make([]float64, 768), Dimensions: 768, Model: "gemini-embedding-001"}, + }, nil).Once() + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + }, + } + + processor := NewDefaultJobProcessor( + processorConfig, + &MockIndexingJobRepository{}, + &MockRepositoryRepository{}, + &MockEnhancedGitClient{}, + &MockCodeParser{}, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + if tc.expectSuccess { + assert.NoError(t, err, "Embedding generation should succeed even if token counting fails") + } else { + assert.Error(t, err, "Expected error") + } + + // Verify that GenerateBatchEmbeddings was called (embeddings must proceed) + mockEmbeddingService.AssertCalled(t, "GenerateBatchEmbeddings", mock.Anything, mock.Anything, mock.Anything) + mockEmbeddingService.AssertExpectations(t) + }) + } +} + +// TestJobProcessor_TokenCountingUpdatesPersistence tests that token counts +// are persisted to the chunk repository. +func TestJobProcessor_TokenCountingUpdatesPersistence(t *testing.T) { + t.Parallel() + + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + chunk1ID := uuid.New() + chunk2ID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: chunk1ID.String(), + RepositoryID: repositoryID, + FilePath: "file1.go", + Content: "package main", + StartLine: 1, + EndLine: 1, + Type: "package", + }, + { + ID: chunk2ID.String(), + RepositoryID: repositoryID, + FilePath: "file2.go", + Content: "func test() { println(\"hello world\") }", + StartLine: 1, + EndLine: 1, + Type: "function", + }, + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + + // Token counting returns specific counts + mockEmbeddingService.On("CountTokensBatch", + mock.Anything, + []string{chunks[0].Content, chunks[1].Content}, + "gemini-embedding-001", + ).Return([]*outbound.TokenCountResult{ + {TotalTokens: 2, Model: "gemini-embedding-001"}, + {TotalTokens: 8, Model: "gemini-embedding-001"}, + }, nil).Once() + + // Verify that UpdateTokenCounts is called with correct values + var capturedUpdates []outbound.ChunkTokenUpdate + mockChunkRepo.On("UpdateTokenCounts", + mock.Anything, + mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { + capturedUpdates = updates + return len(updates) == 2 + }), + ).Return(nil).Once() + + // Mock embedding generation + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return([]*outbound.EmbeddingResult{ + {Vector: make([]float64, 768), Dimensions: 768, Model: "gemini-embedding-001"}, + {Vector: make([]float64, 768), Dimensions: 768, Model: "gemini-embedding-001"}, + }, nil) + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + }, + } + + processor := NewDefaultJobProcessor( + processorConfig, + &MockIndexingJobRepository{}, + &MockRepositoryRepository{}, + &MockEnhancedGitClient{}, + &MockCodeParser{}, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + assert.NoError(t, err) + mockEmbeddingService.AssertExpectations(t) + mockChunkRepo.AssertExpectations(t) + + // Verify the captured token count updates + assert.Len(t, capturedUpdates, 2, "Should have 2 token count updates") + + // Find updates by chunk ID + var update1, update2 *outbound.ChunkTokenUpdate + for i := range capturedUpdates { + if capturedUpdates[i].ChunkID.String() == chunk1ID.String() { + update1 = &capturedUpdates[i] + } + if capturedUpdates[i].ChunkID.String() == chunk2ID.String() { + update2 = &capturedUpdates[i] + } + } + + assert.NotNil(t, update1, "Should have update for chunk1") + assert.NotNil(t, update2, "Should have update for chunk2") + assert.Equal(t, 2, update1.TokenCount, "Chunk1 should have 2 tokens") + assert.Equal(t, 8, update2.TokenCount, "Chunk2 should have 8 tokens") + assert.NotNil(t, update1.TokenCountedAt, "Chunk1 should have timestamp") + assert.NotNil(t, update2.TokenCountedAt, "Chunk2 should have timestamp") +} diff --git a/internal/config/config.go b/internal/config/config.go index e6db49e..f696f3d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -112,6 +112,17 @@ type BatchProcessingConfig struct { SubmissionInitialBackoff time.Duration `mapstructure:"submission_initial_backoff"` // Initial backoff on rate limit SubmissionMaxBackoff time.Duration `mapstructure:"submission_max_backoff"` // Maximum backoff duration MaxSubmissionAttempts int `mapstructure:"max_submission_attempts"` // Max retry attempts per batch + + // Token counting configuration + TokenCounting TokenCountingConfig `mapstructure:"token_counting"` // Token counting configuration +} + +// TokenCountingConfig holds configuration for token counting integration. +type TokenCountingConfig struct { + Enabled bool `mapstructure:"enabled"` // Enable token counting + Mode string `mapstructure:"mode"` // Mode: "all", "sample", or "on_demand" + SamplePercent int `mapstructure:"sample_percent"` // Percentage of chunks to sample (for "sample" mode) + MaxTokensPerChunk int `mapstructure:"max_tokens_per_chunk"` // Maximum tokens per chunk (default: 8192) } // BatchSizeConfig holds batch size configuration for a specific priority level. diff --git a/internal/port/outbound/chunk_repository.go b/internal/port/outbound/chunk_repository.go index 351266d..88e63b8 100644 --- a/internal/port/outbound/chunk_repository.go +++ b/internal/port/outbound/chunk_repository.go @@ -26,6 +26,16 @@ type CodeChunk struct { QualifiedName string `json:"qualified_name,omitempty"` // Fully qualified name Signature string `json:"signature,omitempty"` // Function/method signature Visibility string `json:"visibility,omitempty"` // Visibility modifier (public, private, protected) + // Token counting fields + TokenCount int `json:"token_count,omitempty"` // Exact token count from Google CountTokens API + TokenCountedAt *time.Time `json:"token_counted_at,omitempty"` // Timestamp when token count was retrieved +} + +// ChunkTokenUpdate represents a token count update for a specific chunk. +type ChunkTokenUpdate struct { + ChunkID uuid.UUID `json:"chunk_id"` + TokenCount int `json:"token_count"` + TokenCountedAt *time.Time `json:"token_counted_at"` } // ChunkRepository defines operations for storing and retrieving code chunks. @@ -52,6 +62,9 @@ type ChunkRepository interface { // CountChunksForRepository returns the number of chunks for a repository. CountChunksForRepository(ctx context.Context, repositoryID uuid.UUID) (int, error) + + // UpdateTokenCounts updates the token count for multiple chunks in a batch operation. + UpdateTokenCounts(ctx context.Context, updates []ChunkTokenUpdate) error } // EmbeddingRepository defines operations for storing and retrieving embeddings. diff --git a/internal/port/outbound/token_count_test.go b/internal/port/outbound/token_count_test.go index 4878c40..1849499 100644 --- a/internal/port/outbound/token_count_test.go +++ b/internal/port/outbound/token_count_test.go @@ -15,7 +15,7 @@ func newTestEmbeddingService() EmbeddingService { return &testEmbeddingService{} } -// Stub implementations that return errors (for red phase) +// Stub implementations that return errors (for red phase). func (s *testEmbeddingService) GenerateEmbedding( _ context.Context, _ string, @@ -56,7 +56,7 @@ func (s *testEmbeddingService) EstimateTokenCount(_ context.Context, _ string) ( return 0, errors.New("not implemented") } -// CountTokens stub - returns error to make tests fail +// CountTokens stub - returns error to make tests fail. func (s *testEmbeddingService) CountTokens( _ context.Context, text string, @@ -75,7 +75,7 @@ func (s *testEmbeddingService) CountTokens( return nil, errors.New("CountTokens not implemented") } -// CountTokensBatch stub - returns error to make tests fail +// CountTokensBatch stub - returns error to make tests fail. func (s *testEmbeddingService) CountTokensBatch( _ context.Context, texts []string, From 187e7222514c86345a494d6c7a097fbb7e9117f5 Mon Sep 17 00:00:00 2001 From: Anthony Bible Date: Thu, 27 Nov 2025 19:50:09 -0700 Subject: [PATCH 3/5] feat: implement progressive chunk saving during token counting - Save chunks with token counts using SaveChunks instead of UpdateTokenCounts - Add callback pattern for progressive token counting - Enable chunk persistence during token counting phase Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/config.dev.yaml | 6 + configs/config.yaml | 2 +- internal/adapter/outbound/gemini/client.go | 50 ++ .../queue/batch_queue_manager_test.go | 21 + .../outbound/queue/parallel_processor_test.go | 21 + .../service/search_service_test.go | 10 + internal/application/worker/job_processor.go | 207 ++++-- .../job_processor_batch_integration_test.go | 50 ++ .../application/worker/job_processor_test.go | 14 + .../job_processor_token_counting_test.go | 684 ++++++++++++++++-- internal/port/outbound/embedding_service.go | 13 + internal/port/outbound/token_count_test.go | 20 + 12 files changed, 952 insertions(+), 146 deletions(-) diff --git a/configs/config.dev.yaml b/configs/config.dev.yaml index ae60c19..f31be8e 100644 --- a/configs/config.dev.yaml +++ b/configs/config.dev.yaml @@ -62,3 +62,9 @@ batch_processing: # Async batch job poller configuration poller_interval: 30s # Poll Gemini batch jobs every 30 seconds max_concurrent_polls: 5 # Max concurrent batch job status checks + # Token counting configuration + token_counting: + enabled: true # Enable token counting + mode: "all" # Mode: "all", "sample", or "on_demand" + sample_percent: 10 # Percentage of chunks to sample (for "sample" mode) + max_tokens_per_chunk: 8192 # Maximum tokens per chunk (Gemini embedding model limit) diff --git a/configs/config.yaml b/configs/config.yaml index fe86c99..0545275 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -103,4 +103,4 @@ batch_processing: log: level: info - format: json \ No newline at end of file + format: json diff --git a/internal/adapter/outbound/gemini/client.go b/internal/adapter/outbound/gemini/client.go index 1d1de8c..da686a7 100644 --- a/internal/adapter/outbound/gemini/client.go +++ b/internal/adapter/outbound/gemini/client.go @@ -986,6 +986,56 @@ func (c *Client) CountTokensBatch( return results, nil } +// CountTokensWithCallback counts tokens for each chunk and invokes the callback after each result. +// This enables progressive processing (e.g., saving chunks in batches) during token counting. +func (c *Client) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + if len(chunks) == 0 { + return nil + } + + // Check context cancellation early + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + effectiveModel := model + if effectiveModel == "" { + effectiveModel = c.config.Model + } + + slogger.Info(ctx, "Progressive token counting started", slogger.Fields{ + "chunk_count": len(chunks), + "model": effectiveModel, + }) + + for i := range chunks { + // Check for cancellation between chunks + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + result, err := c.CountTokens(ctx, chunks[i].Content, model) + if err != nil { + return fmt.Errorf("token counting failed at index %d: %w", i, err) + } + + if err := callback(i, &chunks[i], result); err != nil { + return fmt.Errorf("callback failed at index %d: %w", i, err) + } + } + + return nil +} + // Request/Response structures for JSON serialization/deserialization // EmbeddingRequest represents the JSON structure for Gemini embedding requests. diff --git a/internal/adapter/outbound/queue/batch_queue_manager_test.go b/internal/adapter/outbound/queue/batch_queue_manager_test.go index ff37f61..7ea6222 100644 --- a/internal/adapter/outbound/queue/batch_queue_manager_test.go +++ b/internal/adapter/outbound/queue/batch_queue_manager_test.go @@ -96,6 +96,27 @@ func (m *mockEmbeddingService) CountTokensBatch( return results, nil } +func (m *mockEmbeddingService) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + for i, chunk := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: len(chunk.Content) / 4, + Model: model, + } + if callback != nil { + if err := callback(i, &chunk, result); err != nil { + // Log but continue processing + continue + } + } + } + return nil +} + // Test helper to create a manager with mocks. func createTestManager() outbound.BatchQueueManager { embeddingService := &mockEmbeddingService{} diff --git a/internal/adapter/outbound/queue/parallel_processor_test.go b/internal/adapter/outbound/queue/parallel_processor_test.go index 4c77c89..ab15c44 100644 --- a/internal/adapter/outbound/queue/parallel_processor_test.go +++ b/internal/adapter/outbound/queue/parallel_processor_test.go @@ -180,6 +180,27 @@ func (m *MockEmbeddingServiceWithDelay) CountTokensBatch( return results, nil } +func (m *MockEmbeddingServiceWithDelay) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + for i, chunk := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: len(chunk.Content) / 4, + Model: model, + } + if callback != nil { + if err := callback(i, &chunk, result); err != nil { + // Log but continue processing + continue + } + } + } + return nil +} + // RED PHASE TESTS - All tests should FAIL initially func TestMockEmbeddingServiceFailure(t *testing.T) { diff --git a/internal/application/service/search_service_test.go b/internal/application/service/search_service_test.go index a701c05..198a927 100644 --- a/internal/application/service/search_service_test.go +++ b/internal/application/service/search_service_test.go @@ -169,6 +169,16 @@ func (m *MockEmbeddingService) CountTokensBatch( return args.Get(0).([]*outbound.TokenCountResult), args.Error(1) } +func (m *MockEmbeddingService) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + args := m.Called(ctx, chunks, model, callback) + return args.Error(0) +} + // MockChunkRepository is a mock implementation for retrieving chunk information. type MockChunkRepository struct { mock.Mock diff --git a/internal/application/worker/job_processor.go b/internal/application/worker/job_processor.go index 5b2d727..04a5d14 100644 --- a/internal/application/worker/job_processor.go +++ b/internal/application/worker/job_processor.go @@ -61,6 +61,12 @@ const ( jobStatusRunning = "running" jobStatusCompleted = "completed" unknownCommitHash = "unknown" + + // tokenCountSaveBatchSize defines how many chunks to accumulate before saving to database. + // Trade-off: Larger batches reduce database round trips but increase memory usage and risk + // losing more work on failure. 50 provides a good balance for most repositories. + // Consider making this configurable if needed for very large repositories. + tokenCountSaveBatchSize = 50 ) // JobProcessorConfig holds configuration for the job processor. @@ -1416,7 +1422,11 @@ func (p *DefaultJobProcessor) storeSingleEmbeddingWithErrorHandling( // countTokensForChunks counts tokens for chunks based on configuration mode. // Modes: "all" - count all chunks, "sample" - count X% of chunks, "on_demand" - skip counting. // Errors are logged but do not fail the job (graceful degradation). -func (p *DefaultJobProcessor) countTokensForChunks(ctx context.Context, chunks []outbound.CodeChunk) { +func (p *DefaultJobProcessor) countTokensForChunks( + ctx context.Context, + repositoryID uuid.UUID, + chunks []outbound.CodeChunk, +) { if !p.batchConfig.TokenCounting.Enabled { return } @@ -1428,118 +1438,165 @@ func (p *DefaultJobProcessor) countTokensForChunks(ctx context.Context, chunks [ } // Select chunks to count based on mode - var chunksToCount []outbound.CodeChunk - switch mode { - case "all": - chunksToCount = chunks - case "sample": - samplePercent := p.batchConfig.TokenCounting.SamplePercent - if samplePercent <= 0 || samplePercent > 100 { - slogger.Warn( - ctx, - "Invalid sample_percent, skipping token counting", - slogger.Field("sample_percent", samplePercent), - ) - return - } - sampleSize := (len(chunks) * samplePercent) / 100 - if sampleSize == 0 { - sampleSize = 1 - } - chunksToCount = chunks[:sampleSize] - default: - slogger.Warn(ctx, "Unknown token counting mode, skipping", slogger.Field("mode", mode)) - return - } - + chunksToCount := p.selectChunksForCounting(ctx, chunks, mode) if len(chunksToCount) == 0 { return } - slogger.Info(ctx, "Starting token counting", slogger.Fields{ + slogger.Info(ctx, "Starting progressive token counting", slogger.Fields{ "mode": mode, "total_chunks": len(chunks), "chunks_to_count": len(chunksToCount), - "sample_percent": p.batchConfig.TokenCounting.SamplePercent, + "batch_size": tokenCountSaveBatchSize, }) - // Extract texts for counting - texts := make([]string, len(chunksToCount)) - for i, chunk := range chunksToCount { - texts[i] = chunk.Content - } - - // Call CountTokensBatch - results, err := p.embeddingService.CountTokensBatch(ctx, texts, "gemini-embedding-001") - if err != nil { - slogger.Warn( - ctx, - "Token counting failed, continuing with embedding generation", - slogger.Field("error", err.Error()), - ) - return - } - - // Build updates + // State for progressive saving now := time.Now() - updates := make([]outbound.ChunkTokenUpdate, len(chunksToCount)) + pendingChunks := make([]outbound.CodeChunk, 0, tokenCountSaveBatchSize) totalTokens := 0 + savedCount := 0 oversizedChunks := 0 maxTokensPerChunk := p.batchConfig.TokenCounting.MaxTokensPerChunk if maxTokensPerChunk <= 0 { - maxTokensPerChunk = 8192 // Default to Gemini embedding model limit + maxTokensPerChunk = 8192 } - for i, result := range results { - chunkID, parseErr := uuid.Parse(chunksToCount[i].ID) - if parseErr != nil { - slogger.Warn(ctx, "Invalid chunk ID for token count update", slogger.Fields{ - "chunk_id": chunksToCount[i].ID, - "error": parseErr.Error(), - }) - continue - } - updates[i] = outbound.ChunkTokenUpdate{ - ChunkID: chunkID, - TokenCount: result.TotalTokens, - TokenCountedAt: &now, - } + // Callback invoked after each token count + callback := func(index int, chunk *outbound.CodeChunk, result *outbound.TokenCountResult) error { + // Update chunk with token count data + chunk.RepositoryID = repositoryID + chunk.TokenCount = result.TotalTokens + chunk.TokenCountedAt = &now totalTokens += result.TotalTokens // Check for oversized chunks if result.TotalTokens > maxTokensPerChunk { oversizedChunks++ slogger.Warn(ctx, "Chunk exceeds max token limit", slogger.Fields{ - "chunk_id": chunkID.String(), + "chunk_id": chunk.ID, "token_count": result.TotalTokens, "max_tokens_per_chunk": maxTokensPerChunk, - "file_path": chunksToCount[i].FilePath, + "file_path": chunk.FilePath, }) } + + // Add to pending batch + pendingChunks = append(pendingChunks, *chunk) + + // Save batch when full + if len(pendingChunks) >= tokenCountSaveBatchSize { + saved := p.saveTokenCountedChunkBatch(ctx, pendingChunks, false, savedCount, len(chunksToCount)) + savedCount += saved + pendingChunks = pendingChunks[:0] // Reset slice, keep capacity + } + + return nil } - // Update repository - if err := p.chunkStorageRepo.UpdateTokenCounts(ctx, updates); err != nil { - slogger.Warn(ctx, "Failed to persist token counts, continuing", slogger.Field("error", err.Error())) - return + // Process all chunks with callback + err := p.embeddingService.CountTokensWithCallback(ctx, chunksToCount, "gemini-embedding-001", callback) + if err != nil { + slogger.Warn( + ctx, + "Token counting failed, continuing with embedding generation", + slogger.Field("error", err.Error()), + ) + // Don't return - try to save any pending chunks } - avgTokens := 0 - if len(updates) > 0 { - avgTokens = totalTokens / len(updates) + // Save any remaining chunks in final partial batch + if len(pendingChunks) > 0 { + saved := p.saveTokenCountedChunkBatch(ctx, pendingChunks, true, savedCount, len(chunksToCount)) + savedCount += saved } // Emit metrics - p.emitTokenCountingMetrics(ctx, len(chunksToCount), totalTokens, avgTokens, oversizedChunks) + avgTokens := 0 + if len(chunksToCount) > 0 { + avgTokens = totalTokens / len(chunksToCount) + } + p.emitTokenCountingMetrics(ctx, savedCount, totalTokens, avgTokens, oversizedChunks) - slogger.Info(ctx, "Token counting completed successfully", slogger.Fields{ - "chunks_counted": len(updates), + slogger.Info(ctx, "Progressive token counting completed", slogger.Fields{ + "chunks_processed": len(chunksToCount), + "chunks_saved": savedCount, "total_tokens": totalTokens, "average_tokens": avgTokens, "oversized_chunks": oversizedChunks, }) } +// saveTokenCountedChunkBatch saves a batch of chunks with token counts to the database. +// Returns the number of chunks successfully saved (or 0 on error). +// Errors are logged but do not fail the operation (graceful degradation). +func (p *DefaultJobProcessor) saveTokenCountedChunkBatch( + ctx context.Context, + chunks []outbound.CodeChunk, + isFinalBatch bool, + totalSavedSoFar int, + totalChunks int, +) int { + if len(chunks) == 0 { + return 0 + } + + if err := p.chunkStorageRepo.SaveChunks(ctx, chunks); err != nil { + logMessage := "Failed to save chunk batch with token counts" + if isFinalBatch { + logMessage = "Failed to save final chunk batch with token counts" + } + slogger.Warn(ctx, logMessage, slogger.Fields{ + "batch_size": len(chunks), + "error": err.Error(), + }) + return 0 + } + + logMessage := "Saved chunk batch with token counts" + fields := slogger.Fields{ + "batch_size": len(chunks), + "total_saved": totalSavedSoFar + len(chunks), + } + if !isFinalBatch { + fields["total_chunks"] = totalChunks + logMessage = "Saved chunk batch with token counts" + } + slogger.Info(ctx, logMessage, fields) + + return len(chunks) +} + +// selectChunksForCounting selects which chunks to count based on the configured mode. +// Returns the selected chunks or nil if no chunks should be counted. +func (p *DefaultJobProcessor) selectChunksForCounting( + ctx context.Context, + chunks []outbound.CodeChunk, + mode string, +) []outbound.CodeChunk { + switch mode { + case "all": + return chunks + case "sample": + samplePercent := p.batchConfig.TokenCounting.SamplePercent + if samplePercent <= 0 || samplePercent > 100 { + slogger.Warn( + ctx, + "Invalid sample_percent, skipping token counting", + slogger.Field("sample_percent", samplePercent), + ) + return nil + } + sampleSize := (len(chunks) * samplePercent) / 100 + if sampleSize == 0 { + sampleSize = 1 + } + return chunks[:sampleSize] + default: + slogger.Warn(ctx, "Unknown token counting mode, skipping", slogger.Field("mode", mode)) + return nil + } +} + // emitTokenCountingMetrics emits OpenTelemetry metrics for token counting operations. // This provides observability into token usage, API call patterns, and oversized chunk detection. func (p *DefaultJobProcessor) emitTokenCountingMetrics( @@ -1671,7 +1728,7 @@ func (p *DefaultJobProcessor) processBatchEmbeddings( jobID := indexingJobID.String() // PRE-FLIGHT: Count tokens before embedding generation (non-blocking) - p.countTokensForChunks(ctx, chunks) + p.countTokensForChunks(ctx, repositoryID, chunks) // Check context before starting batch processing if err := ctx.Err(); err != nil { @@ -1943,7 +2000,7 @@ func (p *DefaultJobProcessor) processSequentialEmbeddings( jobID := indexingJobID.String() // PRE-FLIGHT: Count tokens before embedding generation (non-blocking) - p.countTokensForChunks(ctx, chunks) + p.countTokensForChunks(ctx, repositoryID, chunks) if p.batchConfig.UseTestEmbeddings { // Test mode: use test embeddings diff --git a/internal/application/worker/job_processor_batch_integration_test.go b/internal/application/worker/job_processor_batch_integration_test.go index 8bbe0be..e71e1b0 100644 --- a/internal/application/worker/job_processor_batch_integration_test.go +++ b/internal/application/worker/job_processor_batch_integration_test.go @@ -231,6 +231,56 @@ func (m *IntegrationMockEmbeddingService) EstimateTokenCount(ctx context.Context return len(text) / 4, nil } +// CountTokens counts the exact number of tokens in the given text. +func (m *IntegrationMockEmbeddingService) CountTokens( + ctx context.Context, + text string, + model string, +) (*outbound.TokenCountResult, error) { + return &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + }, nil +} + +// CountTokensBatch counts tokens for multiple texts in a single batch request. +func (m *IntegrationMockEmbeddingService) CountTokensBatch( + ctx context.Context, + texts []string, + model string, +) ([]*outbound.TokenCountResult, error) { + results := make([]*outbound.TokenCountResult, len(texts)) + for i, text := range texts { + results[i] = &outbound.TokenCountResult{ + TotalTokens: len(text) / 4, + Model: model, + } + } + return results, nil +} + +// CountTokensWithCallback counts tokens for each chunk and invokes the callback after each result. +func (m *IntegrationMockEmbeddingService) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + for i, chunk := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: len(chunk.Content) / 4, + Model: model, + } + if callback != nil { + if err := callback(i, &chunk, result); err != nil { + // Log but continue processing + continue + } + } + } + return nil +} + // createJobProcessorWithBatchConfig creates a job processor with batch processing enabled. func createJobProcessorWithBatchConfig( t *testing.T, diff --git a/internal/application/worker/job_processor_test.go b/internal/application/worker/job_processor_test.go index 4701e57..20d921e 100644 --- a/internal/application/worker/job_processor_test.go +++ b/internal/application/worker/job_processor_test.go @@ -352,6 +352,16 @@ func (m *MockEmbeddingService) CountTokensBatch( return args.Get(0).([]*outbound.TokenCountResult), args.Error(1) } +func (m *MockEmbeddingService) CountTokensWithCallback( + ctx context.Context, + chunks []outbound.CodeChunk, + model string, + callback outbound.TokenCountCallback, +) error { + args := m.Called(ctx, chunks, model, callback) + return args.Error(0) +} + // MockChunkStorageRepository mocks the chunk storage repository interface. type MockChunkStorageRepository struct { mock.Mock @@ -364,6 +374,10 @@ func (m *MockChunkStorageRepository) SaveChunk(ctx context.Context, chunk *outbo func (m *MockChunkStorageRepository) SaveChunks(ctx context.Context, chunks []outbound.CodeChunk) error { args := m.Called(ctx, chunks) + // Handle function return types (for dynamic return values in tests) + if fn, ok := args.Get(0).(func(context.Context, []outbound.CodeChunk) error); ok { + return fn(ctx, chunks) + } return args.Error(0) } diff --git a/internal/application/worker/job_processor_token_counting_test.go b/internal/application/worker/job_processor_token_counting_test.go index 6d792ef..502fe8b 100644 --- a/internal/application/worker/job_processor_token_counting_test.go +++ b/internal/application/worker/job_processor_token_counting_test.go @@ -65,34 +65,63 @@ func TestJobProcessor_CountTokensForChunks(t *testing.T) { }, } - // Extract texts from chunks for CountTokensBatch call + // Extract texts from chunks for verification expectedTexts := []string{ chunks[0].Content, chunks[1].Content, } // Mock expectations - // 1. CountTokensBatch should be called BEFORE GenerateBatchEmbeddings - mockEmbeddingService.On("CountTokensBatch", + // 1. CountTokensWithCallback should be called BEFORE GenerateBatchEmbeddings + mockEmbeddingService.On("CountTokensWithCallback", mock.Anything, // ctx - expectedTexts, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == 2 + }), "gemini-embedding-001", - ).Return(expectedTokenCounts, nil).Once() + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := expectedTokenCounts[i] + err := callback(i, &chunks[i], result) + if err != nil { + continue + } + } + }).Return(nil).Once() - // 2. Chunks should be updated with token counts - mockChunkRepo.On("UpdateTokenCounts", + // 2. Chunks should be saved with token counts populated + mockChunkRepo.On("SaveChunks", mock.Anything, // ctx - mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { - // Verify that we're updating both chunks with the correct token counts - if len(updates) != 2 { + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + // Verify that we're saving both chunks with the correct token counts + if len(chunks) != 2 { + return false + } + // Check first chunk has token count and timestamp + if chunks[0].TokenCount != 10 { + return false + } + if chunks[0].TokenCountedAt == nil { + return false + } + // Check second chunk has token count and timestamp + if chunks[1].TokenCount != 6 { return false } - // Check first chunk - if updates[0].TokenCount != 10 { + if chunks[1].TokenCountedAt == nil { return false } - // Check second chunk - if updates[1].TokenCount != 6 { + // Verify chunks have RepositoryID set + if chunks[0].RepositoryID == uuid.Nil { + return false + } + if chunks[1].RepositoryID == uuid.Nil { return false } return true @@ -158,11 +187,11 @@ func TestJobProcessor_CountTokensForChunks(t *testing.T) { mockEmbeddingService.AssertExpectations(t) mockChunkRepo.AssertExpectations(t) - // Verify call order: CountTokensBatch should be called before GenerateBatchEmbeddings + // Verify call order: CountTokensWithCallback should be called before GenerateBatchEmbeddings calls := mockEmbeddingService.Calls countTokensIdx, generateEmbeddingsIdx := -1, -1 for i, call := range calls { - if call.Method == "CountTokensBatch" { + if call.Method == "CountTokensWithCallback" { countTokensIdx = i } if call.Method == "GenerateBatchEmbeddings" { @@ -170,13 +199,13 @@ func TestJobProcessor_CountTokensForChunks(t *testing.T) { } } - assert.NotEqual(t, -1, countTokensIdx, "CountTokensBatch should have been called") + assert.NotEqual(t, -1, countTokensIdx, "CountTokensWithCallback should have been called") assert.NotEqual(t, -1, generateEmbeddingsIdx, "GenerateBatchEmbeddings should have been called") assert.Less( t, countTokensIdx, generateEmbeddingsIdx, - "CountTokensBatch must be called BEFORE GenerateBatchEmbeddings", + "CountTokensWithCallback must be called BEFORE GenerateBatchEmbeddings", ) } @@ -254,29 +283,50 @@ func TestJobProcessor_TokenCountingModes(t *testing.T) { // Configure expectations based on mode if tc.expectedCountCall { - // CountTokensBatch should be called with sampled chunks - mockEmbeddingService.On("CountTokensBatch", + // CountTokensWithCallback should be called with sampled chunks + mockEmbeddingService.On("CountTokensWithCallback", mock.Anything, - mock.MatchedBy(func(texts []string) bool { - return len(texts) == tc.expectedSamples + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == tc.expectedSamples }), "gemini-embedding-001", - ).Return(func(ctx context.Context, texts []string, model string) []*outbound.TokenCountResult { - results := make([]*outbound.TokenCountResult, len(texts)) - for i := range results { - results[i] = &outbound.TokenCountResult{ + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := &outbound.TokenCountResult{ TotalTokens: 5, - Model: model, + Model: "gemini-embedding-001", } + _ = callback(i, &chunks[i], result) } - return results - }, nil).Once() + }).Return(nil).Once() - // UpdateTokenCounts should be called for sampled chunks - mockChunkRepo.On("UpdateTokenCounts", + // SaveChunks should be called with chunks that have token counts populated + mockChunkRepo.On("SaveChunks", mock.Anything, - mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { - return len(updates) == tc.expectedSamples + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + // Verify we have the expected number of chunks + if len(chunks) != tc.expectedSamples { + return false + } + // Verify all chunks have token count and timestamp + for i := range chunks { + if chunks[i].TokenCount == 0 { + return false + } + if chunks[i].TokenCountedAt == nil { + return false + } + if chunks[i].RepositoryID == uuid.Nil { + return false + } + } + return true }), ).Return(nil).Once() } @@ -407,28 +457,49 @@ func TestJobProcessor_TokenCountingFailsGracefully(t *testing.T) { mockEmbeddingService := &MockEmbeddingService{} mockChunkRepo := &MockChunkStorageRepository{} - // Mock CountTokensBatch - may fail or succeed + // Mock CountTokensWithCallback - may fail or succeed if tc.tokenCountErr != nil { - mockEmbeddingService.On("CountTokensBatch", + mockEmbeddingService.On("CountTokensWithCallback", mock.Anything, mock.Anything, "gemini-embedding-001", - ).Return(nil, tc.tokenCountErr).Once() + mock.Anything, // callback + ).Return(tc.tokenCountErr).Once() - // UpdateTokenCounts should NOT be called on error + // SaveChunks should NOT be called on error (graceful degradation) // (no expectation set means test fails if called) } else { - mockEmbeddingService.On("CountTokensBatch", + mockEmbeddingService.On("CountTokensWithCallback", mock.Anything, mock.Anything, "gemini-embedding-001", - ).Return([]*outbound.TokenCountResult{ - {TotalTokens: 5, Model: "gemini-embedding-001"}, - }, nil).Once() + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: 5, + Model: "gemini-embedding-001", + } + _ = callback(i, &chunks[i], result) + } + }).Return(nil).Once() - mockChunkRepo.On("UpdateTokenCounts", - mock.Anything, + // SaveChunks should be called with chunk that has token count populated + mockChunkRepo.On("SaveChunks", mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + if len(chunks) != 1 { + return false + } + return chunks[0].TokenCount == 5 && + chunks[0].TokenCountedAt != nil && + chunks[0].RepositoryID != uuid.Nil + }), ).Return(nil).Once() } @@ -534,22 +605,35 @@ func TestJobProcessor_TokenCountingUpdatesPersistence(t *testing.T) { mockChunkRepo := &MockChunkStorageRepository{} // Token counting returns specific counts - mockEmbeddingService.On("CountTokensBatch", + mockEmbeddingService.On("CountTokensWithCallback", mock.Anything, - []string{chunks[0].Content, chunks[1].Content}, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == 2 + }), "gemini-embedding-001", - ).Return([]*outbound.TokenCountResult{ - {TotalTokens: 2, Model: "gemini-embedding-001"}, - {TotalTokens: 8, Model: "gemini-embedding-001"}, - }, nil).Once() + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + tokenResults := []*outbound.TokenCountResult{ + {TotalTokens: 2, Model: "gemini-embedding-001"}, + {TotalTokens: 8, Model: "gemini-embedding-001"}, + } + for i := range chunks { + _ = callback(i, &chunks[i], tokenResults[i]) + } + }).Return(nil).Once() - // Verify that UpdateTokenCounts is called with correct values - var capturedUpdates []outbound.ChunkTokenUpdate - mockChunkRepo.On("UpdateTokenCounts", + // Verify that SaveChunks is called with chunks that have token counts populated + var capturedChunks []outbound.CodeChunk + mockChunkRepo.On("SaveChunks", mock.Anything, - mock.MatchedBy(func(updates []outbound.ChunkTokenUpdate) bool { - capturedUpdates = updates - return len(updates) == 2 + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + capturedChunks = chunks + return len(chunks) == 2 }), ).Return(nil).Once() @@ -608,24 +692,484 @@ func TestJobProcessor_TokenCountingUpdatesPersistence(t *testing.T) { mockEmbeddingService.AssertExpectations(t) mockChunkRepo.AssertExpectations(t) - // Verify the captured token count updates - assert.Len(t, capturedUpdates, 2, "Should have 2 token count updates") + // Verify the captured chunks have token counts populated + assert.Len(t, capturedChunks, 2, "Should have 2 chunks saved") + + // Find chunks by chunk ID + var chunk1, chunk2 *outbound.CodeChunk + for i := range capturedChunks { + if capturedChunks[i].ID == chunk1ID.String() { + chunk1 = &capturedChunks[i] + } + if capturedChunks[i].ID == chunk2ID.String() { + chunk2 = &capturedChunks[i] + } + } + + assert.NotNil(t, chunk1, "Should have chunk1") + assert.NotNil(t, chunk2, "Should have chunk2") + assert.Equal(t, 2, chunk1.TokenCount, "Chunk1 should have 2 tokens") + assert.Equal(t, 8, chunk2.TokenCount, "Chunk2 should have 8 tokens") + assert.NotNil(t, chunk1.TokenCountedAt, "Chunk1 should have timestamp") + assert.NotNil(t, chunk2.TokenCountedAt, "Chunk2 should have timestamp") + assert.Equal(t, repositoryID, chunk1.RepositoryID, "Chunk1 should have RepositoryID set") + assert.Equal(t, repositoryID, chunk2.RepositoryID, "Chunk2 should have RepositoryID set") +} + +// TestJobProcessor_ProgressiveTokenCounting_BatchSaving tests that chunks are saved in batches +// of 50 during progressive token counting. +func TestJobProcessor_ProgressiveTokenCounting_BatchSaving(t *testing.T) { + t.Parallel() + + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create 100 test chunks to test batching (expecting 2 batches of 50) + chunks := make([]outbound.CodeChunk, 100) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func test() {}", + StartLine: i + 1, + EndLine: i + 1, + Type: "function", + } + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + mockRepoRepo := &MockRepositoryRepository{} + mockIndexingJobRepo := &MockIndexingJobRepository{} + mockGitClient := &MockEnhancedGitClient{} + mockCodeParser := &MockCodeParser{} + + // Track SaveChunks calls - expecting 2 calls with batches of 50 + var saveChunksCalls [][]outbound.CodeChunk + mockChunkRepo.On("SaveChunks", + mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + // Capture the chunks being saved + saveChunksCalls = append(saveChunksCalls, chunks) + return true + }), + ).Return(nil) + + // Mock CountTokensWithCallback - simulate calling callback for each chunk + mockEmbeddingService.On("CountTokensWithCallback", + mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == 100 + }), + "gemini-embedding-001", + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: 5 + i, // Unique token count for each + Model: "gemini-embedding-001", + } + // Call the callback with each chunk + err := callback(i, &chunks[i], result) + if err != nil { + // Callback errors should not stop processing + continue + } + } + }).Return(nil).Once() + + // Mock embedding generation (should still be called) + embeddingResults := make([]*outbound.EmbeddingResult, 100) + for i := range embeddingResults { + embeddingResults[i] = &outbound.EmbeddingResult{ + Vector: make([]float64, 768), + Dimensions: 768, + Model: "gemini-embedding-001", + } + } + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(embeddingResults, nil) + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor config + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + MaxTokensPerChunk: 8192, + }, + } + + // Create processor + processor := NewDefaultJobProcessor( + processorConfig, + mockIndexingJobRepo, + mockRepoRepo, + mockGitClient, + mockCodeParser, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + assert.NoError(t, err, "Expected no error from generateEmbeddings") + mockEmbeddingService.AssertExpectations(t) + + // Verify SaveChunks was called exactly 2 times (for 100 chunks in batches of 50) + assert.Len(t, saveChunksCalls, 2, "SaveChunks should be called 2 times for 100 chunks") + + // Verify first batch has 50 chunks + if len(saveChunksCalls) >= 1 { + assert.Len(t, saveChunksCalls[0], 50, "First batch should have 50 chunks") + // Verify all chunks in first batch have token counts + for i, chunk := range saveChunksCalls[0] { + assert.NotZero(t, chunk.TokenCount, "Chunk %d in batch 1 should have token count", i) + assert.NotNil(t, chunk.TokenCountedAt, "Chunk %d in batch 1 should have timestamp", i) + assert.Equal(t, repositoryID, chunk.RepositoryID, "Chunk %d in batch 1 should have RepositoryID", i) + } + } + + // Verify second batch has 50 chunks + if len(saveChunksCalls) >= 2 { + assert.Len(t, saveChunksCalls[1], 50, "Second batch should have 50 chunks") + // Verify all chunks in second batch have token counts + for i, chunk := range saveChunksCalls[1] { + assert.NotZero(t, chunk.TokenCount, "Chunk %d in batch 2 should have token count", i) + assert.NotNil(t, chunk.TokenCountedAt, "Chunk %d in batch 2 should have timestamp", i) + assert.Equal(t, repositoryID, chunk.RepositoryID, "Chunk %d in batch 2 should have RepositoryID", i) + } + } +} + +// TestJobProcessor_ProgressiveTokenCounting_FinalPartialBatch tests progressive token counting +// with a final partial batch (75 chunks = 1 batch of 50 + 1 batch of 25). +func TestJobProcessor_ProgressiveTokenCounting_FinalPartialBatch(t *testing.T) { + t.Parallel() + + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create 75 test chunks (1 full batch of 50 + partial batch of 25) + chunks := make([]outbound.CodeChunk, 75) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func test() {}", + StartLine: i + 1, + EndLine: i + 1, + Type: "function", + } + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + mockRepoRepo := &MockRepositoryRepository{} + mockIndexingJobRepo := &MockIndexingJobRepository{} + mockGitClient := &MockEnhancedGitClient{} + mockCodeParser := &MockCodeParser{} + + // Track SaveChunks calls - expecting 2 calls (50 + 25) + var saveChunksCalls [][]outbound.CodeChunk + mockChunkRepo.On("SaveChunks", + mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + // Capture the chunks being saved + saveChunksCalls = append(saveChunksCalls, chunks) + return true + }), + ).Return(nil) + + // Mock CountTokensWithCallback - simulate calling callback for each chunk + mockEmbeddingService.On("CountTokensWithCallback", + mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == 75 + }), + "gemini-embedding-001", + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: 10 + i, + Model: "gemini-embedding-001", + } + // Call the callback with each chunk + err := callback(i, &chunks[i], result) + if err != nil { + continue + } + } + }).Return(nil).Once() + + // Mock embedding generation + embeddingResults := make([]*outbound.EmbeddingResult, 75) + for i := range embeddingResults { + embeddingResults[i] = &outbound.EmbeddingResult{ + Vector: make([]float64, 768), + Dimensions: 768, + Model: "gemini-embedding-001", + } + } + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(embeddingResults, nil) + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor config + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + MaxTokensPerChunk: 8192, + }, + } + + // Create processor + processor := NewDefaultJobProcessor( + processorConfig, + mockIndexingJobRepo, + mockRepoRepo, + mockGitClient, + mockCodeParser, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert + assert.NoError(t, err, "Expected no error from generateEmbeddings") + mockEmbeddingService.AssertExpectations(t) + + // Verify SaveChunks was called exactly 2 times (50 + 25) + assert.Len(t, saveChunksCalls, 2, "SaveChunks should be called 2 times for 75 chunks") + + // Verify first batch has 50 chunks + if len(saveChunksCalls) >= 1 { + assert.Len(t, saveChunksCalls[0], 50, "First batch should have 50 chunks") + for i, chunk := range saveChunksCalls[0] { + assert.NotZero(t, chunk.TokenCount, "Chunk %d in batch 1 should have token count", i) + assert.NotNil(t, chunk.TokenCountedAt, "Chunk %d in batch 1 should have timestamp", i) + } + } + + // Verify second batch has 25 chunks (partial batch) + if len(saveChunksCalls) >= 2 { + assert.Len(t, saveChunksCalls[1], 25, "Second batch should have 25 chunks (partial)") + for i, chunk := range saveChunksCalls[1] { + assert.NotZero(t, chunk.TokenCount, "Chunk %d in batch 2 should have token count", i) + assert.NotNil(t, chunk.TokenCountedAt, "Chunk %d in batch 2 should have timestamp", i) + } + } +} + +// TestJobProcessor_ProgressiveTokenCounting_GracefulDegradation tests that if SaveChunks +// fails during progressive token counting, processing continues and embeddings are still generated. +func TestJobProcessor_ProgressiveTokenCounting_GracefulDegradation(t *testing.T) { + t.Parallel() + + // Setup + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create 100 test chunks + chunks := make([]outbound.CodeChunk, 100) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "test.go", + Content: "func test() {}", + StartLine: i + 1, + EndLine: i + 1, + Type: "function", + } + } + + // Create mocks + mockEmbeddingService := &MockEmbeddingService{} + mockChunkRepo := &MockChunkStorageRepository{} + mockRepoRepo := &MockRepositoryRepository{} + mockIndexingJobRepo := &MockIndexingJobRepository{} + mockGitClient := &MockEnhancedGitClient{} + mockCodeParser := &MockCodeParser{} - // Find updates by chunk ID - var update1, update2 *outbound.ChunkTokenUpdate - for i := range capturedUpdates { - if capturedUpdates[i].ChunkID.String() == chunk1ID.String() { - update1 = &capturedUpdates[i] + // Track SaveChunks calls + var saveChunksCallCount int + mockChunkRepo.On("SaveChunks", + mock.Anything, + mock.Anything, + ).Return(func(ctx context.Context, chunks []outbound.CodeChunk) error { + saveChunksCallCount++ + // First batch fails, second batch succeeds + if saveChunksCallCount == 1 { + return errors.New("database connection error") + } + return nil + }) + + // Mock CountTokensWithCallback - simulate calling callback for each chunk + mockEmbeddingService.On("CountTokensWithCallback", + mock.Anything, + mock.MatchedBy(func(chunks []outbound.CodeChunk) bool { + return len(chunks) == 100 + }), + "gemini-embedding-001", + mock.Anything, // callback + ).Run(func(args mock.Arguments) { + // Extract arguments + chunks := args.Get(1).([]outbound.CodeChunk) + callback := args.Get(3).(outbound.TokenCountCallback) + + // Simulate the callback being invoked for each chunk + for i := range chunks { + result := &outbound.TokenCountResult{ + TotalTokens: 8, + Model: "gemini-embedding-001", + } + // Call the callback - errors should be logged but not stop processing + _ = callback(i, &chunks[i], result) } - if capturedUpdates[i].ChunkID.String() == chunk2ID.String() { - update2 = &capturedUpdates[i] + }).Return(nil).Once() + + // Mock embedding generation - MUST still be called even if SaveChunks fails + embeddingResults := make([]*outbound.EmbeddingResult, 100) + for i := range embeddingResults { + embeddingResults[i] = &outbound.EmbeddingResult{ + Vector: make([]float64, 768), + Dimensions: 768, + Model: "gemini-embedding-001", } } + mockEmbeddingService.On("GenerateBatchEmbeddings", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(embeddingResults, nil).Once() + + mockChunkRepo.On("SaveChunkWithEmbedding", + mock.Anything, + mock.Anything, + mock.Anything, + ).Return(nil) + + // Create processor config + processorConfig := JobProcessorConfig{ + WorkspaceDir: "/tmp/test-workspace", + MaxConcurrentJobs: 1, + JobTimeout: 30 * time.Second, + } + + batchConfig := config.BatchProcessingConfig{ + Enabled: true, + ThresholdChunks: 1, + UseTestEmbeddings: false, + FallbackToSequential: true, + TokenCounting: config.TokenCountingConfig{ + Enabled: true, + Mode: "all", + MaxTokensPerChunk: 8192, + }, + } + + // Create processor + processor := NewDefaultJobProcessor( + processorConfig, + mockIndexingJobRepo, + mockRepoRepo, + mockGitClient, + mockCodeParser, + mockEmbeddingService, + mockChunkRepo, + &JobProcessorBatchOptions{ + BatchConfig: &batchConfig, + }, + ).(*DefaultJobProcessor) + + // Execute + err := processor.generateEmbeddings(ctx, indexingJobID, repositoryID, chunks) + + // Assert - embedding generation should succeed despite SaveChunks failure + assert.NoError(t, err, "Embedding generation should succeed even if SaveChunks fails during token counting") + + // Verify CountTokensWithCallback was called + mockEmbeddingService.AssertCalled( + t, + "CountTokensWithCallback", + mock.Anything, + mock.Anything, + mock.Anything, + mock.Anything, + ) + + // Verify GenerateBatchEmbeddings was still called (no blocking from SaveChunks failure) + mockEmbeddingService.AssertCalled(t, "GenerateBatchEmbeddings", mock.Anything, mock.Anything, mock.Anything) - assert.NotNil(t, update1, "Should have update for chunk1") - assert.NotNil(t, update2, "Should have update for chunk2") - assert.Equal(t, 2, update1.TokenCount, "Chunk1 should have 2 tokens") - assert.Equal(t, 8, update2.TokenCount, "Chunk2 should have 8 tokens") - assert.NotNil(t, update1.TokenCountedAt, "Chunk1 should have timestamp") - assert.NotNil(t, update2.TokenCountedAt, "Chunk2 should have timestamp") + // Verify SaveChunks was called at least once (batch saving was attempted) + assert.GreaterOrEqual(t, saveChunksCallCount, 1, "SaveChunks should be called at least once") } diff --git a/internal/port/outbound/embedding_service.go b/internal/port/outbound/embedding_service.go index 62ce7aa..faef45f 100644 --- a/internal/port/outbound/embedding_service.go +++ b/internal/port/outbound/embedding_service.go @@ -50,8 +50,21 @@ type EmbeddingService interface { // CountTokensBatch counts tokens for multiple texts in a single batch request // Returns a slice of TokenCountResult matching the input texts order CountTokensBatch(ctx context.Context, texts []string, model string) ([]*TokenCountResult, error) + + // CountTokensWithCallback counts tokens for each chunk and invokes the callback after each result. + // This allows for progressive processing (e.g., saving chunks to DB in batches). + // The callback receives the index, chunk, and token count result for each processed chunk. + // If the callback returns an error, processing continues but the error is logged. + CountTokensWithCallback(ctx context.Context, chunks []CodeChunk, model string, callback TokenCountCallback) error } +// TokenCountCallback is called after each successful token count with the updated chunk. +// The index parameter indicates the position of the chunk in the original slice. +// The chunk parameter is the chunk being processed. +// The result parameter contains the token count information. +// Returning an error will log the error but won't stop processing. +type TokenCountCallback func(index int, chunk *CodeChunk, result *TokenCountResult) error + // BatchEmbeddingService defines the interface for file-based batch embedding operations. // This interface provides asynchronous batch processing capabilities using the Google GenAI Batches API. type BatchEmbeddingService interface { diff --git a/internal/port/outbound/token_count_test.go b/internal/port/outbound/token_count_test.go index 1849499..47ad140 100644 --- a/internal/port/outbound/token_count_test.go +++ b/internal/port/outbound/token_count_test.go @@ -94,6 +94,26 @@ func (s *testEmbeddingService) CountTokensBatch( return nil, errors.New("CountTokensBatch not implemented") } +// CountTokensWithCallback stub - returns error to make tests fail. +func (s *testEmbeddingService) CountTokensWithCallback( + _ context.Context, + chunks []CodeChunk, + _ string, + callback TokenCountCallback, +) error { + // Validate empty slice + if len(chunks) == 0 { + return &EmbeddingError{ + Code: "invalid_input", + Message: "chunks slice cannot be empty", + Type: "validation", + Retryable: false, + } + } + // Return error for red phase - implementation doesn't exist yet + return errors.New("CountTokensWithCallback not implemented") +} + // TestTokenCountResult_Structure validates the TokenCountResult type has correct fields. func TestTokenCountResult_Structure(t *testing.T) { t.Parallel() From 867e2f416bb135dc9bb35edce186d943fd5d6258 Mon Sep 17 00:00:00 2001 From: Anthony Bible Date: Fri, 28 Nov 2025 08:01:11 -0700 Subject: [PATCH 4/5] perf: optimize chunk persistence with multi-row INSERT statements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement buildMultiRowChunkInsert helper for batch SQL generation - Convert SaveChunks to use multi-row INSERT - Convert FindOrCreateChunks to use multi-row INSERT - Significantly reduce database round-trips for large repositories All tests continue to pass. No behavior changes, only code quality improvements. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../outbound/repository/chunk_repository.go | 636 +++++++------- ...k_repository_findorcreate_multirow_test.go | 787 ++++++++++++++++++ .../chunk_repository_multirow_test.go | 354 ++++++++ ...unk_repository_savechunks_multirow_test.go | 649 +++++++++++++++ 4 files changed, 2124 insertions(+), 302 deletions(-) create mode 100644 internal/adapter/outbound/repository/chunk_repository_findorcreate_multirow_test.go create mode 100644 internal/adapter/outbound/repository/chunk_repository_multirow_test.go create mode 100644 internal/adapter/outbound/repository/chunk_repository_savechunks_multirow_test.go diff --git a/internal/adapter/outbound/repository/chunk_repository.go b/internal/adapter/outbound/repository/chunk_repository.go index aac10db..be5f457 100644 --- a/internal/adapter/outbound/repository/chunk_repository.go +++ b/internal/adapter/outbound/repository/chunk_repository.go @@ -3,11 +3,11 @@ package repository import ( "codechunking/internal/application/common/slogger" "codechunking/internal/application/service" - "codechunking/internal/domain/valueobject" "codechunking/internal/port/outbound" "context" "errors" "fmt" + "strconv" "strings" "time" @@ -31,6 +31,17 @@ const ( // SQL query constants. const ( + // chunkColumnsPerRow is the number of columns in the code_chunks table insert statement. + // This constant is used by buildMultiRowChunkInsert to calculate parameter placeholders. + // IMPORTANT: If you modify the chunk table schema or insert query, update this constant. + chunkColumnsPerRow = 17 + + // multiRowInsertBatchSize is the maximum number of chunks to insert in a single batch. + // PostgreSQL has a limit on the number of parameters (typically 65535), so we batch large + // inserts to stay well under that limit. With 17 columns per row, 500 rows = 8500 parameters, + // which provides a safe margin below the limit while maintaining good performance. + multiRowInsertBatchSize = 500 + insertChunkQuery = ` INSERT INTO codechunking.code_chunks ( id, repository_id, file_path, chunk_type, content, language, @@ -69,6 +80,76 @@ const ( ` ) +// buildMultiRowChunkInsert builds a multi-row INSERT query for code chunks. +// +// Parameters: +// - numRows: The number of chunk rows to include in the INSERT statement +// +// Returns: +// - A complete SQL INSERT query with VALUES placeholders for the specified number of rows +// - Empty string if numRows is <= 0 +// +// Each chunk has 17 columns, so parameters are numbered sequentially: +// - Row 1: $1-$17 +// - Row 2: $18-$34 +// - Row 3: $35-$51 +// - etc. +// +// RETURNING Clause Guarantee: +// - PostgreSQL guarantees that RETURNING returns rows in the same order as the VALUES clause +// - This allows callers to match returned IDs to input chunks by index position +// - See: https://www.postgresql.org/docs/current/dml-returning.html +// +// Example output for numRows=2: +// +// INSERT INTO codechunking.code_chunks (id, repository_id, ..., token_counted_at) +// VALUES ($1, $2, ..., $17), ($18, $19, ..., $34) +// ON CONFLICT (repository_id, file_path, content_hash) +// DO UPDATE SET id = code_chunks.id, token_count = COALESCE(...), token_counted_at = COALESCE(...) +// RETURNING id +func buildMultiRowChunkInsert(numRows int) string { + if numRows <= 0 { + return "" + } + + var b strings.Builder + + // INSERT clause with column names + b.WriteString(`INSERT INTO codechunking.code_chunks (`) + b.WriteString(`id, repository_id, file_path, chunk_type, content, language, `) + b.WriteString(`start_line, end_line, entity_name, parent_entity, content_hash, metadata, `) + b.WriteString(`qualified_name, signature, visibility, token_count, token_counted_at`) + b.WriteString(`) VALUES `) + + // Generate VALUES rows with parameter placeholders + for row := range numRows { + if row > 0 { + b.WriteString(`, `) + } + b.WriteString(`(`) + for col := range chunkColumnsPerRow { + if col > 0 { + b.WriteString(`, `) + } + paramNum := row*chunkColumnsPerRow + col + 1 + b.WriteString(`$`) + b.WriteString(strconv.Itoa(paramNum)) + } + b.WriteString(`)`) + } + + // ON CONFLICT clause - preserve existing token counts + b.WriteString(` ON CONFLICT (repository_id, file_path, content_hash) DO UPDATE SET `) + b.WriteString(`id = code_chunks.id, `) + b.WriteString(`token_count = COALESCE(code_chunks.token_count, EXCLUDED.token_count), `) + b.WriteString(`token_counted_at = COALESCE(code_chunks.token_counted_at, EXCLUDED.token_counted_at)`) + + // RETURNING clause + b.WriteString(` RETURNING id`) + + return b.String() +} + // validateEmbeddingDimensions validates that an embedding has the expected dimensions. func validateEmbeddingDimensions(embedding *outbound.Embedding) error { if len(embedding.Vector) != expectedEmbeddingDimensions { @@ -96,6 +177,85 @@ func sanitizeContentWithLogging(ctx context.Context, content string, chunkID str return content } +// prepareChunkInsertArgs prepares the 17 arguments needed for a chunk INSERT statement. +// This helper ensures consistent default value handling across all chunk insert operations. +// +// Returns a slice of 17 interface{} values in the order expected by insertChunkQuery: +// 1. chunkID (uuid.UUID) +// 2. repositoryID (uuid.UUID) +// 3. filePath (string) +// 4. chunkType (string, with default) +// 5. sanitizedContent (string) +// 6. language (string) +// 7. startLine (int) +// 8. endLine (int) +// 9. entityName (string, with default) +// +// 10. parentEntity (string, with default) +// 11. contentHash (string) +// 12. metadata (nil) +// 13. qualifiedName (string, with default) +// 14. signature (string, with default) +// 15. visibility (string, with default) +// 16. tokenCount (*int) +// 17. tokenCountedAt (*time.Time). +func prepareChunkInsertArgs(ctx context.Context, chunk outbound.CodeChunk, chunkID uuid.UUID) []interface{} { + // Apply defaults to optional fields + chunkType := chunk.Type + if chunkType == "" { + chunkType = defaultChunkType + } + + entityName := chunk.EntityName + if entityName == "" { + entityName = defaultEntityName + } + + parentEntity := chunk.ParentEntity + if parentEntity == "" { + parentEntity = defaultParentEntity + } + + qualifiedName := chunk.QualifiedName + if qualifiedName == "" { + qualifiedName = defaultQualifiedName + } + + signature := chunk.Signature + if signature == "" { + signature = defaultSignature + } + + visibility := chunk.Visibility + if visibility == "" { + visibility = defaultVisibility + } + + // Sanitize content for PostgreSQL UTF-8 compatibility + sanitizedContent := sanitizeContentWithLogging(ctx, chunk.Content, chunk.ID, chunk.FilePath) + + // Build and return the 17-element argument slice + return []interface{}{ + chunkID, + chunk.RepositoryID, + chunk.FilePath, + chunkType, + sanitizedContent, + chunk.Language, + chunk.StartLine, + chunk.EndLine, + entityName, + parentEntity, + chunk.Hash, + nil, // metadata + qualifiedName, + signature, + visibility, + chunk.TokenCount, + chunk.TokenCountedAt, + } +} + // PostgreSQLChunkRepository implements the ChunkStorageRepository interface. // It provides operations for both code chunks and embeddings with support for // both regular and partitioned embeddings tables. @@ -226,43 +386,8 @@ func (r *PostgreSQLChunkRepository) SaveChunk(ctx context.Context, chunk *outbou return fmt.Errorf("invalid chunk ID format: %w", err) } - // Use chunk type information or defaults - entityName := chunk.EntityName - if entityName == "" { - entityName = defaultEntityName - } - - parentEntity := chunk.ParentEntity - if parentEntity == "" { - parentEntity = defaultParentEntity - } - - chunkType := chunk.Type - if chunkType == "" { - chunkType = defaultChunkType - } - - qualifiedName := chunk.QualifiedName - if qualifiedName == "" { - qualifiedName = defaultQualifiedName - } - - signature := chunk.Signature - if signature == "" { - signature = defaultSignature - } - - visibility := chunk.Visibility - if visibility == "" { - visibility = defaultVisibility - } - - // Sanitize content for PostgreSQL UTF-8 compatibility - sanitizedContent := sanitizeContentWithLogging(ctx, chunk.Content, chunk.ID, chunk.FilePath) - // Validate repository ID is provided - repositoryID := chunk.RepositoryID - if repositoryID == uuid.Nil { + if chunk.RepositoryID == uuid.Nil { slogger.Error(ctx, "Missing repository_id for chunk save", slogger.Fields2( "chunk_id", chunk.ID, "file_path", chunk.FilePath, @@ -270,32 +395,17 @@ func (r *PostgreSQLChunkRepository) SaveChunk(ctx context.Context, chunk *outbou return errors.New("repository_id is required to save chunk") } + // Prepare all arguments with defaults and sanitization + args := prepareChunkInsertArgs(ctx, *chunk, chunkID) + // Save chunk and get the actual chunk ID (could be existing if conflict occurs) var actualChunkID uuid.UUID - err = r.pool.QueryRow(ctx, query, - chunkID, - repositoryID, - chunk.FilePath, - chunkType, - sanitizedContent, - chunk.Language, - chunk.StartLine, - chunk.EndLine, - entityName, - parentEntity, - chunk.Hash, - nil, // metadata - qualifiedName, - signature, - visibility, - chunk.TokenCount, - chunk.TokenCountedAt, - ).Scan(&actualChunkID) + err = r.pool.QueryRow(ctx, query, args...).Scan(&actualChunkID) if err != nil { slogger.Error(ctx, "Failed to save chunk", slogger.Fields{ "chunk_id": chunk.ID, "file_path": chunk.FilePath, - "repository_id": repositoryID.String(), + "repository_id": chunk.RepositoryID.String(), "error": err.Error(), }) return fmt.Errorf("failed to save chunk: %w", err) @@ -319,7 +429,9 @@ func (r *PostgreSQLChunkRepository) SaveChunk(ctx context.Context, chunk *outbou return nil } -// SaveChunks stores multiple code chunks in a batch operation. +// SaveChunks stores multiple code chunks in a batch operation using multi-row INSERT. +// Large batches are automatically split into smaller batches to stay within PostgreSQL's +// parameter limit (65535). The batch size is controlled by multiRowInsertBatchSize constant. func (r *PostgreSQLChunkRepository) SaveChunks(ctx context.Context, chunks []outbound.CodeChunk) error { if len(chunks) == 0 { return nil @@ -336,100 +448,90 @@ func (r *PostgreSQLChunkRepository) SaveChunks(ctx context.Context, chunks []out } }() - query := insertChunkQuery - - for i, chunk := range chunks { - chunkID, err := uuid.Parse(chunk.ID) - if err != nil { - slogger.Error(ctx, "Invalid chunk ID in batch", slogger.Fields2( - "chunk_id", chunk.ID, - "error", err.Error(), - )) - return fmt.Errorf("invalid chunk ID format: %w", err) + for batchStart := 0; batchStart < len(chunks); batchStart += multiRowInsertBatchSize { + batchEnd := batchStart + multiRowInsertBatchSize + if batchEnd > len(chunks) { + batchEnd = len(chunks) } + batch := chunks[batchStart:batchEnd] - // Use chunk type information or defaults - entityName := chunk.EntityName - if entityName == "" { - entityName = defaultEntityName - } + // Build the multi-row INSERT query + query := buildMultiRowChunkInsert(len(batch)) - parentEntity := chunk.ParentEntity - if parentEntity == "" { - parentEntity = defaultParentEntity - } + // Prepare arguments for the query + args := make([]interface{}, 0, len(batch)*chunkColumnsPerRow) + for _, chunk := range batch { + chunkID, err := uuid.Parse(chunk.ID) + if err != nil { + slogger.Error(ctx, "Invalid chunk ID in batch", slogger.Fields2( + "chunk_id", chunk.ID, + "error", err.Error(), + )) + return fmt.Errorf("invalid chunk ID format: %w", err) + } - chunkType := chunk.Type - if chunkType == "" { - chunkType = defaultChunkType - } + repositoryID := chunk.RepositoryID + if repositoryID == uuid.Nil { + slogger.Error(ctx, "Missing repository_id for chunk batch save", slogger.Fields2( + "chunk_id", chunk.ID, + "file_path", chunk.FilePath, + )) + return errors.New("repository_id is required to save chunk in batch") + } - qualifiedName := chunk.QualifiedName - if qualifiedName == "" { - qualifiedName = defaultQualifiedName + // Use helper to prepare all 17 arguments with consistent defaults + chunkArgs := prepareChunkInsertArgs(ctx, chunk, chunkID) + args = append(args, chunkArgs...) } - signature := chunk.Signature - if signature == "" { - signature = defaultSignature + // Execute the multi-row INSERT and scan all returned IDs + rows, err := tx.Query(ctx, query, args...) + if err != nil { + slogger.Error(ctx, "Failed to execute multi-row chunk insert", slogger.Fields{ + "batch_size": len(batch), + "error": err.Error(), + }) + return fmt.Errorf("failed to execute multi-row chunk insert: %w", err) } - visibility := chunk.Visibility - if visibility == "" { - visibility = defaultVisibility - } + // Scan returned IDs and update chunks + rowIndex := 0 + for rows.Next() { + var actualChunkID uuid.UUID + if err := rows.Scan(&actualChunkID); err != nil { + rows.Close() + slogger.Error(ctx, "Failed to scan returned chunk ID", slogger.Fields2( + "row_index", rowIndex, + "error", err.Error(), + )) + return fmt.Errorf("failed to scan returned chunk ID: %w", err) + } - repositoryID := chunk.RepositoryID - if repositoryID == uuid.Nil { - slogger.Error(ctx, "Missing repository_id for chunk batch save", slogger.Fields2( - "chunk_id", chunk.ID, - "file_path", chunk.FilePath, - )) - return errors.New("repository_id is required to save chunk in batch") + // Update in-memory chunk ID if it changed due to conflict + expectedChunkID, _ := uuid.Parse(batch[rowIndex].ID) + if actualChunkID != expectedChunkID { + slogger.Info(ctx, "Chunk already exists in batch, using existing chunk ID", slogger.Fields{ + "generated_chunk_id": expectedChunkID.String(), + "actual_chunk_id": actualChunkID.String(), + "file_path": batch[rowIndex].FilePath, + }) + chunks[batchStart+rowIndex].ID = actualChunkID.String() + } + rowIndex++ } + rows.Close() - // Sanitize content for PostgreSQL UTF-8 compatibility - sanitizedContent := sanitizeContentWithLogging(ctx, chunk.Content, chunk.ID, chunk.FilePath) - - // Save chunk and get the actual chunk ID (could be existing if conflict occurs) - var actualChunkID uuid.UUID - err = tx.QueryRow(ctx, query, - chunkID, - repositoryID, - chunk.FilePath, - chunkType, - sanitizedContent, - chunk.Language, - chunk.StartLine, - chunk.EndLine, - entityName, - parentEntity, - chunk.Hash, - nil, // metadata - qualifiedName, - signature, - visibility, - chunk.TokenCount, - chunk.TokenCountedAt, - ).Scan(&actualChunkID) - if err != nil { - slogger.Error(ctx, "Failed to save chunk in batch", slogger.Fields{ - "chunk_id": chunk.ID, - "file_path": chunk.FilePath, - "repository_id": repositoryID.String(), - "error": err.Error(), - }) - return fmt.Errorf("failed to save chunk in batch: %w", err) + if err := rows.Err(); err != nil { + slogger.Error(ctx, "Error iterating over returned chunk IDs", slogger.Field("error", err.Error())) + return fmt.Errorf("error iterating over returned chunk IDs: %w", err) } - // Update the in-memory chunk ID with the actual ID - if actualChunkID != chunkID { - slogger.Info(ctx, "Chunk already exists in batch, using existing chunk ID", slogger.Fields{ - "generated_chunk_id": chunkID.String(), - "actual_chunk_id": actualChunkID.String(), - "file_path": chunk.FilePath, - }) - chunks[i].ID = actualChunkID.String() + if rowIndex != len(batch) { + slogger.Error(ctx, "Mismatch between inserted rows and returned IDs", slogger.Fields2( + "expected", len(batch), + "actual", rowIndex, + )) + return fmt.Errorf("expected %d returned IDs, got %d", len(batch), rowIndex) } } @@ -451,6 +553,7 @@ func (r *PostgreSQLChunkRepository) SaveChunks(ctx context.Context, chunks []out // FindOrCreateChunks saves chunks and returns the actual chunk IDs (existing or new). // For chunks that already exist (same repo/path/hash), returns the existing chunk with its ID. // This prevents FK constraint violations when using batch embeddings. +// Uses multi-row INSERT for efficient batch processing. func (r *PostgreSQLChunkRepository) FindOrCreateChunks( ctx context.Context, chunks []outbound.CodeChunk, @@ -470,124 +573,102 @@ func (r *PostgreSQLChunkRepository) FindOrCreateChunks( } }() - // Use a query that returns the actual persisted ID (whether new or existing) - // On conflict, preserve existing token_count (find or create, not update) - query := ` - INSERT INTO codechunking.code_chunks ( - id, repository_id, file_path, chunk_type, content, language, - start_line, end_line, entity_name, parent_entity, content_hash, metadata, - qualified_name, signature, visibility, token_count, token_counted_at - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) - ON CONFLICT (repository_id, file_path, content_hash) - DO UPDATE SET - id = code_chunks.id, - token_count = COALESCE(code_chunks.token_count, EXCLUDED.token_count), - token_counted_at = COALESCE(code_chunks.token_counted_at, EXCLUDED.token_counted_at) - RETURNING id - ` - resultChunks := make([]outbound.CodeChunk, len(chunks)) - for i, chunk := range chunks { - chunkID, err := uuid.Parse(chunk.ID) - if err != nil { - slogger.Error(ctx, "Invalid chunk ID in FindOrCreateChunks", slogger.Fields2( - "chunk_id", chunk.ID, - "error", err.Error(), - )) - return nil, fmt.Errorf("invalid chunk ID format: %w", err) + // Process in batches to stay under PostgreSQL's parameter limit + for batchStart := 0; batchStart < len(chunks); batchStart += multiRowInsertBatchSize { + batchEnd := batchStart + multiRowInsertBatchSize + if batchEnd > len(chunks) { + batchEnd = len(chunks) } + batch := chunks[batchStart:batchEnd] - // Use chunk type information or defaults - entityName := chunk.EntityName - if entityName == "" { - entityName = defaultEntityName - } + // Build the multi-row INSERT query + query := buildMultiRowChunkInsert(len(batch)) - parentEntity := chunk.ParentEntity - if parentEntity == "" { - parentEntity = defaultParentEntity - } + // Prepare arguments for the query + args := make([]interface{}, 0, len(batch)*chunkColumnsPerRow) + for _, chunk := range batch { + chunkID, err := uuid.Parse(chunk.ID) + if err != nil { + slogger.Error(ctx, "Invalid chunk ID in FindOrCreateChunks", slogger.Fields2( + "chunk_id", chunk.ID, + "error", err.Error(), + )) + return nil, fmt.Errorf("invalid chunk ID format: %w", err) + } - chunkType := chunk.Type - if chunkType == "" { - chunkType = defaultChunkType - } + repositoryID := chunk.RepositoryID + if repositoryID == uuid.Nil { + slogger.Error(ctx, "Missing repository_id for chunk in FindOrCreateChunks", slogger.Fields2( + "chunk_id", chunk.ID, + "file_path", chunk.FilePath, + )) + return nil, errors.New("repository_id is required for FindOrCreateChunks") + } - qualifiedName := chunk.QualifiedName - if qualifiedName == "" { - qualifiedName = defaultQualifiedName + // Use helper to prepare all 17 arguments with consistent defaults + chunkArgs := prepareChunkInsertArgs(ctx, chunk, chunkID) + args = append(args, chunkArgs...) } - signature := chunk.Signature - if signature == "" { - signature = defaultSignature + // Execute the multi-row INSERT and scan all returned IDs + rows, err := tx.Query(ctx, query, args...) + if err != nil { + slogger.Error(ctx, "Failed to execute multi-row FindOrCreateChunks insert", slogger.Fields{ + "batch_size": len(batch), + "error": err.Error(), + }) + return nil, fmt.Errorf("failed to execute multi-row FindOrCreateChunks insert: %w", err) } - visibility := chunk.Visibility - if visibility == "" { - visibility = defaultVisibility - } + // Scan returned IDs - PostgreSQL guarantees RETURNING order matches VALUES order + rowIndex := 0 + for rows.Next() { + var returnedID uuid.UUID + if err := rows.Scan(&returnedID); err != nil { + rows.Close() + slogger.Error(ctx, "Failed to scan returned chunk ID in FindOrCreateChunks", slogger.Fields2( + "row_index", rowIndex, + "error", err.Error(), + )) + return nil, fmt.Errorf("failed to scan returned chunk ID: %w", err) + } - repositoryID := chunk.RepositoryID - if repositoryID == uuid.Nil { - slogger.Error(ctx, "Missing repository_id for chunk in FindOrCreateChunks", slogger.Fields2( - "chunk_id", chunk.ID, - "file_path", chunk.FilePath, - )) - return nil, errors.New("repository_id is required for FindOrCreateChunks") - } + // Map returned ID to the corresponding chunk + // resultChunks[batchStart + rowIndex] corresponds to batch[rowIndex] + resultChunks[batchStart+rowIndex] = batch[rowIndex] + resultChunks[batchStart+rowIndex].ID = returnedID.String() + + // Log if we got a different ID (chunk already existed) + expectedChunkID, _ := uuid.Parse(batch[rowIndex].ID) + if returnedID != expectedChunkID { + slogger.Debug(ctx, "Chunk already existed in FindOrCreateChunks, using existing ID", slogger.Fields{ + "new_id": batch[rowIndex].ID, + "existing_id": returnedID.String(), + "file_path": batch[rowIndex].FilePath, + }) + } - // Defensive sanitization - sanitizedContent := valueobject.SanitizeContent(chunk.Content) - if len(sanitizedContent) != len(chunk.Content) { - slogger.Warn(ctx, "Null bytes detected and removed in FindOrCreateChunks", slogger.Fields{ - "chunk_id": chunk.ID, - "file_path": chunk.FilePath, - "null_bytes_removed": len(chunk.Content) - len(sanitizedContent), - }) + rowIndex++ } + rows.Close() - var returnedID uuid.UUID - err = tx.QueryRow(ctx, query, - chunkID, - repositoryID, - chunk.FilePath, - chunkType, - sanitizedContent, - chunk.Language, - chunk.StartLine, - chunk.EndLine, - entityName, - parentEntity, - chunk.Hash, - nil, // metadata - qualifiedName, - signature, - visibility, - chunk.TokenCount, - chunk.TokenCountedAt, - ).Scan(&returnedID) - if err != nil { - slogger.Error(ctx, "Failed to insert/find chunk", slogger.Fields{ - "chunk_id": chunk.ID, - "file_path": chunk.FilePath, - "repository_id": repositoryID.String(), - "error": err.Error(), - }) - return nil, fmt.Errorf("failed to insert/find chunk: %w", err) + if err := rows.Err(); err != nil { + slogger.Error( + ctx, + "Error iterating over returned chunk IDs in FindOrCreateChunks", + slogger.Field("error", err.Error()), + ) + return nil, fmt.Errorf("error iterating over returned chunk IDs: %w", err) } - // Create result chunk with the actual persisted ID - resultChunks[i] = chunk - resultChunks[i].ID = returnedID.String() - - if returnedID.String() != chunk.ID { - slogger.Debug(ctx, "Chunk already existed, using existing ID", slogger.Fields{ - "new_id": chunk.ID, - "existing_id": returnedID.String(), - "file_path": chunk.FilePath, - }) + if rowIndex != len(batch) { + slogger.Error(ctx, "Mismatch between inserted rows and returned IDs in FindOrCreateChunks", slogger.Fields2( + "expected", len(batch), + "actual", rowIndex, + )) + return nil, fmt.Errorf("expected %d returned IDs, got %d", len(batch), rowIndex) } } @@ -1839,66 +1920,17 @@ func (r *PostgreSQLChunkRepository) SaveChunksWithEmbeddings( return fmt.Errorf("invalid chunk ID format: %w", err) } - // Use chunk type information or defaults - entityName := chunk.EntityName - if entityName == "" { - entityName = defaultEntityName - } - - parentEntity := chunk.ParentEntity - if parentEntity == "" { - parentEntity = defaultParentEntity - } - - chunkType := chunk.Type - if chunkType == "" { - chunkType = defaultChunkType - } - - qualifiedName := chunk.QualifiedName - if qualifiedName == "" { - qualifiedName = defaultQualifiedName - } - - signature := chunk.Signature - if signature == "" { - signature = defaultSignature - } - - visibility := chunk.Visibility - if visibility == "" { - visibility = defaultVisibility - } - - // CRITICAL: Ensure consistent repository ID between chunk and embedding - // This prevents partition routing foreign key violations - repositoryID := chunks[i].RepositoryID - - // Validate that we have a repository ID - if repositoryID == uuid.Nil { - slogger.Error(ctx, "Missing repository_id for batch transactional chunk save", slogger.Fields{ + // Prepare chunk fields with defaults and validate repository ID consistency + fields, err := prepareChunkFields(&chunks[i], &embeddings[i]) + if err != nil { + slogger.Error(ctx, "Failed to prepare chunk fields in batch transaction", slogger.Fields{ "chunk_id": chunk.ID, "batch_index": i, + "error": err.Error(), }) - return errors.New("repository_id is required to save chunk in batch transaction") + return fmt.Errorf("failed to prepare chunk fields at index %d: %w", i, err) } - // Validate repository ID consistency before forcing - // If embedding has a different non-nil repository ID, that's an error - if embeddings[i].RepositoryID != uuid.Nil && embeddings[i].RepositoryID != repositoryID { - slogger.Error(ctx, "Repository ID mismatch detected", slogger.Fields3( - "chunk_id", chunk.ID, - "chunk_repository_id", repositoryID.String(), - "embedding_repository_id", embeddings[i].RepositoryID.String(), - )) - return fmt.Errorf("repository ID mismatch for chunk %s: chunk=%s, embedding=%s", - chunk.ID, repositoryID.String(), embeddings[i].RepositoryID.String()) - } - - // Force embedding to use the EXACT same repository ID as the chunk - // This ensures they route to the same partition and prevents foreign key violations - embeddings[i].RepositoryID = repositoryID - // Sanitize content for PostgreSQL UTF-8 compatibility sanitizedContent := sanitizeContentWithLogging(ctx, chunk.Content, chunk.ID, chunk.FilePath) @@ -1906,20 +1938,20 @@ func (r *PostgreSQLChunkRepository) SaveChunksWithEmbeddings( var actualChunkID uuid.UUID err = tx.QueryRow(ctx, chunkQuery, chunkID, - repositoryID, + fields.RepositoryID, chunk.FilePath, - chunkType, + fields.ChunkType, sanitizedContent, chunk.Language, chunk.StartLine, chunk.EndLine, - entityName, - parentEntity, + fields.EntityName, + fields.ParentEntity, chunk.Hash, nil, // metadata - qualifiedName, - signature, - visibility, + fields.QualifiedName, + fields.Signature, + fields.Visibility, chunk.TokenCount, chunk.TokenCountedAt, ).Scan(&actualChunkID) @@ -1927,7 +1959,7 @@ func (r *PostgreSQLChunkRepository) SaveChunksWithEmbeddings( slogger.Error(ctx, "Failed to save chunk in batch transaction", slogger.Fields{ "chunk_id": chunk.ID, "batch_index": i, - "repository_id": repositoryID.String(), + "repository_id": fields.RepositoryID.String(), "error": err.Error(), }) return fmt.Errorf("failed to save chunk in batch transaction: %w", err) diff --git a/internal/adapter/outbound/repository/chunk_repository_findorcreate_multirow_test.go b/internal/adapter/outbound/repository/chunk_repository_findorcreate_multirow_test.go new file mode 100644 index 0000000..279fcfe --- /dev/null +++ b/internal/adapter/outbound/repository/chunk_repository_findorcreate_multirow_test.go @@ -0,0 +1,787 @@ +//go:build integration + +package repository + +import ( + "codechunking/internal/port/outbound" + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestFindOrCreateChunks_MultiRow_AllNewChunks tests that FindOrCreateChunks returns generated UUIDs for all new chunks. +// This test verifies that when all chunks are new, the function assigns and returns valid UUIDs for each chunk. +func TestFindOrCreateChunks_MultiRow_AllNewChunks(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/all-new-chunks-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "all-new-chunks-repo", "Test repository for all new chunks", "indexed") + require.NoError(t, err) + + // Create 5 new chunks (none exist in DB yet) + now := time.Now() + inputChunks := make([]outbound.CodeChunk, 5) + for i := range inputChunks { + inputChunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "new_chunks.go", + Content: "func NewFunc" + string(rune('A'+i)) + "() { }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-new-" + uuid.New().String()[:8], + Type: "function", + EntityName: "NewFunc" + string(rune('A'+i)), + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + } + } + + // Call FindOrCreateChunks - should create all chunks + resultChunks, err := repo.FindOrCreateChunks(ctx, inputChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed for all new chunks") + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 5, "Should return 5 chunks") + + // Verify all chunks were created with their generated IDs + for i, resultChunk := range resultChunks { + assert.NotEmpty(t, resultChunk.ID, "Chunk %d should have an ID", i) + + // Verify the chunk exists in the database with the returned ID + var exists bool + err = pool.QueryRow(ctx, ` + SELECT EXISTS(SELECT 1 FROM codechunking.code_chunks WHERE id = $1 AND deleted_at IS NULL) + `, resultChunk.ID).Scan(&exists) + require.NoError(t, err) + assert.True(t, exists, "Chunk %d with ID %s should exist in database", i, resultChunk.ID) + + // Verify content matches + var retrievedContent string + err = pool.QueryRow(ctx, ` + SELECT content FROM codechunking.code_chunks WHERE id = $1 AND deleted_at IS NULL + `, resultChunk.ID).Scan(&retrievedContent) + require.NoError(t, err) + assert.Equal(t, inputChunks[i].Content, retrievedContent, "Content should match for chunk %d", i) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestFindOrCreateChunks_MultiRow_AllExistingChunks tests that FindOrCreateChunks returns existing IDs for all existing chunks. +// This test verifies that when all chunks already exist (based on repo/path/hash), the function returns their existing IDs. +func TestFindOrCreateChunks_MultiRow_AllExistingChunks(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/all-existing-chunks-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "all-existing-chunks-repo", "Test repository for all existing chunks", "indexed") + require.NoError(t, err) + + // Create and save initial chunks + now := time.Now() + existingChunks := make([]outbound.CodeChunk, 3) + existingIDs := make([]string, 3) + for i := range existingChunks { + chunkID := uuid.New() + existingIDs[i] = chunkID.String() + + existingChunks[i] = outbound.CodeChunk{ + ID: chunkID.String(), + RepositoryID: repositoryID, + FilePath: "existing.go", + Content: "func ExistingFunc" + string(rune('A'+i)) + "() { }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-existing-" + string(rune('1'+i)), // Stable hashes + Type: "function", + EntityName: "ExistingFunc" + string(rune('A'+i)), + CreatedAt: now, + TokenCount: 100 + i*10, // Token counts: 100, 110, 120 + TokenCountedAt: &now, + } + } + + // Save the chunks first + err = repo.SaveChunks(ctx, existingChunks) + require.NoError(t, err, "SaveChunks should succeed") + + // Now call FindOrCreateChunks with chunks that have the SAME repo/path/hash but DIFFERENT IDs + duplicateChunks := make([]outbound.CodeChunk, 3) + for i := range duplicateChunks { + duplicateChunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), // DIFFERENT ID (should be ignored) + RepositoryID: repositoryID, + FilePath: "existing.go", + Content: "func ExistingFunc" + string(rune('A'+i)) + "() { }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-existing-" + string(rune('1'+i)), // SAME hash (triggers conflict) + Type: "function", + EntityName: "ExistingFunc" + string(rune('A'+i)), + CreatedAt: now, + TokenCount: 999, // Different token count (should be ignored) + TokenCountedAt: &now, + } + } + + // Call FindOrCreateChunks - should return existing IDs + resultChunks, err := repo.FindOrCreateChunks(ctx, duplicateChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed for all existing chunks") + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 3, "Should return 3 chunks") + + // Verify that returned IDs match the ORIGINAL existing IDs, not the new ones we provided + for i, resultChunk := range resultChunks { + assert.Equal(t, existingIDs[i], resultChunk.ID, "Chunk %d should return existing ID", i) + assert.NotEqual(t, duplicateChunks[i].ID, resultChunk.ID, "Chunk %d should NOT use new ID", i) + } + + // Verify only 3 chunks exist in total (no duplicates created) + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 3, count, "Should have exactly 3 chunks (no duplicates)") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestFindOrCreateChunks_MultiRow_MixedNewAndExisting tests that FindOrCreateChunks handles a mix of new and existing chunks. +// This test verifies that the function correctly identifies existing chunks and creates new ones in a single batch. +func TestFindOrCreateChunks_MultiRow_MixedNewAndExisting(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/mixed-chunks-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "mixed-chunks-repo", "Test repository for mixed chunks", "indexed") + require.NoError(t, err) + + // Create and save 2 existing chunks + now := time.Now() + existingChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func ExistingA() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-existing-A", + Type: "function", + EntityName: "ExistingA", + CreatedAt: now, + TokenCount: 50, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func ExistingB() { }", + Language: "go", + StartLine: 5, + EndLine: 5, + Hash: "hash-existing-B", + Type: "function", + EntityName: "ExistingB", + CreatedAt: now, + TokenCount: 60, + TokenCountedAt: &now, + }, + } + + err = repo.SaveChunks(ctx, existingChunks) + require.NoError(t, err, "SaveChunks should succeed") + + // Store existing IDs for verification + existingID_A := existingChunks[0].ID + existingID_B := existingChunks[1].ID + + // Create mixed batch: [new, existing-A, new, existing-B, new] + mixedChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func NewC() { }", + Language: "go", + StartLine: 10, + EndLine: 10, + Hash: "hash-new-C", + Type: "function", + EntityName: "NewC", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), // Different ID + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func ExistingA() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-existing-A", // Same hash as existing chunk + Type: "function", + EntityName: "ExistingA", + CreatedAt: now, + TokenCount: 999, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func NewD() { }", + Language: "go", + StartLine: 15, + EndLine: 15, + Hash: "hash-new-D", + Type: "function", + EntityName: "NewD", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), // Different ID + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func ExistingB() { }", + Language: "go", + StartLine: 5, + EndLine: 5, + Hash: "hash-existing-B", // Same hash as existing chunk + Type: "function", + EntityName: "ExistingB", + CreatedAt: now, + TokenCount: 999, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func NewE() { }", + Language: "go", + StartLine: 20, + EndLine: 20, + Hash: "hash-new-E", + Type: "function", + EntityName: "NewE", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + } + + // Call FindOrCreateChunks + resultChunks, err := repo.FindOrCreateChunks(ctx, mixedChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed for mixed chunks") + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 5, "Should return 5 chunks") + + // Verify that existing chunks return existing IDs + assert.Equal(t, existingID_A, resultChunks[1].ID, "Index 1 should return existing ID for chunk A") + assert.Equal(t, existingID_B, resultChunks[3].ID, "Index 3 should return existing ID for chunk B") + + // Verify that new chunks have valid IDs (not the ones we provided) + assert.NotEmpty(t, resultChunks[0].ID, "Index 0 should have an ID") + assert.NotEmpty(t, resultChunks[2].ID, "Index 2 should have an ID") + assert.NotEmpty(t, resultChunks[4].ID, "Index 4 should have an ID") + + // Verify total count (should be 5: 2 existing + 3 new) + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 5, count, "Should have exactly 5 chunks total") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestFindOrCreateChunks_MultiRow_ReturningOrderMatchesInput tests that returned chunk IDs match input order. +// This is CRITICAL for batch embedding workflow - chunks[i] must get result[i]'s ID. +func TestFindOrCreateChunks_MultiRow_ReturningOrderMatchesInput(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/order-preservation-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "order-preservation-repo", "Test repository for order preservation", "indexed") + require.NoError(t, err) + + // Create chunks with unique, identifiable content + now := time.Now() + inputChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderA() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-order-A", + Type: "function", + EntityName: "OrderA", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderB() { }", + Language: "go", + StartLine: 5, + EndLine: 5, + Hash: "hash-order-B", + Type: "function", + EntityName: "OrderB", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderC() { }", + Language: "go", + StartLine: 10, + EndLine: 10, + Hash: "hash-order-C", + Type: "function", + EntityName: "OrderC", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderD() { }", + Language: "go", + StartLine: 15, + EndLine: 15, + Hash: "hash-order-D", + Type: "function", + EntityName: "OrderD", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + } + + // Call FindOrCreateChunks + resultChunks, err := repo.FindOrCreateChunks(ctx, inputChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed") + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 4, "Should return 4 chunks") + + // CRITICAL: Verify that each result chunk corresponds to the input chunk at the same index + // This is verified by checking that the content matches + for i, resultChunk := range resultChunks { + // Query the database to get the content for the returned ID + var retrievedContent string + var retrievedEntityName string + err = pool.QueryRow(ctx, ` + SELECT content, entity_name + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, resultChunk.ID).Scan(&retrievedContent, &retrievedEntityName) + require.NoError(t, err, "Should retrieve chunk %d", i) + + // The content at result[i] must match input[i] + assert.Equal(t, inputChunks[i].Content, retrievedContent, + "Result chunk at index %d must correspond to input chunk at index %d (content mismatch)", i, i) + assert.Equal(t, inputChunks[i].EntityName, retrievedEntityName, + "Result chunk at index %d must correspond to input chunk at index %d (entity name mismatch)", i, i) + + // Also verify the in-memory result chunk has matching content + assert.Equal(t, inputChunks[i].Content, resultChunk.Content, + "In-memory result chunk at index %d should preserve input content", i) + } + + // Now test with mixed new/existing to ensure order is still preserved + // Create a new batch that references some existing chunks in different order + mixedInputChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderC() { }", // Existing chunk (index 2 from before) + Language: "go", + StartLine: 10, + EndLine: 10, + Hash: "hash-order-C", + Type: "function", + EntityName: "OrderC", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderE() { }", // New chunk + Language: "go", + StartLine: 20, + EndLine: 20, + Hash: "hash-order-E", + Type: "function", + EntityName: "OrderE", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "order.go", + Content: "func OrderA() { }", // Existing chunk (index 0 from before) + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-order-A", + Type: "function", + EntityName: "OrderA", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + } + + // Call FindOrCreateChunks again + mixedResultChunks, err := repo.FindOrCreateChunks(ctx, mixedInputChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed for mixed batch") + require.NotNil(t, mixedResultChunks, "Result should not be nil") + require.Len(t, mixedResultChunks, 3, "Should return 3 chunks") + + // Verify order preservation for mixed batch + for i, resultChunk := range mixedResultChunks { + var retrievedContent string + err = pool.QueryRow(ctx, ` + SELECT content + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, resultChunk.ID).Scan(&retrievedContent) + require.NoError(t, err, "Should retrieve chunk %d from mixed batch", i) + + assert.Equal(t, mixedInputChunks[i].Content, retrievedContent, + "Mixed batch: Result chunk at index %d must correspond to input chunk at index %d", i, i) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestFindOrCreateChunks_MultiRow_LargeBatch tests that FindOrCreateChunks handles a large batch correctly. +// This test verifies that the function can process 100+ chunks efficiently (typical batch embedding size). +func TestFindOrCreateChunks_MultiRow_LargeBatch(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/large-batch-findorcreate-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "large-batch-findorcreate-repo", "Test repository for large batch FindOrCreateChunks", "indexed") + require.NoError(t, err) + + // Create 100 chunks + now := time.Now() + largeChunks := make([]outbound.CodeChunk, 100) + for i := range largeChunks { + largeChunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "large_batch.go", + Content: "func LargeFunc" + string( + rune('0'+(i%10)), + ) + "() { /* chunk " + string( + rune('0'+(i/10)), + ) + " */ }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-large-findorcreate-" + uuid.New().String()[:8], + Type: "function", + EntityName: "LargeFunc" + string(rune('0'+(i%10))), + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + } + } + + // Call FindOrCreateChunks - should create all 100 chunks + startTime := time.Now() + resultChunks, err := repo.FindOrCreateChunks(ctx, largeChunks) + duration := time.Since(startTime) + require.NoError(t, err, "FindOrCreateChunks should succeed for large batch") + + t.Logf("FindOrCreateChunks for 100 chunks took: %v", duration) + + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 100, "Should return 100 chunks") + + // Verify all chunks were created + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 100, count, "Should have exactly 100 chunks") + + // Spot check order preservation for first, middle, and last chunks + testIndices := []int{0, 50, 99} + for _, i := range testIndices { + var retrievedContent string + err = pool.QueryRow(ctx, ` + SELECT content + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, resultChunks[i].ID).Scan(&retrievedContent) + require.NoError(t, err, "Should retrieve chunk %d", i) + assert.Equal(t, largeChunks[i].Content, retrievedContent, "Content should match for chunk %d", i) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestFindOrCreateChunks_MultiRow_EmptySlice tests that FindOrCreateChunks handles empty input correctly. +// This test verifies that the function is a no-op for empty/nil input. +func TestFindOrCreateChunks_MultiRow_EmptySlice(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Test with nil slice + result, err := repo.FindOrCreateChunks(ctx, nil) + assert.NoError(t, err, "FindOrCreateChunks should return nil error for nil slice") + assert.Nil(t, result, "FindOrCreateChunks should return nil for nil slice") + + // Test with empty slice + result, err = repo.FindOrCreateChunks(ctx, []outbound.CodeChunk{}) + assert.NoError(t, err, "FindOrCreateChunks should return nil error for empty slice") + assert.Nil(t, result, "FindOrCreateChunks should return nil for empty slice") +} + +// TestFindOrCreateChunks_MultiRow_PreservesExistingTokenCounts tests that FindOrCreateChunks preserves existing token counts. +// This test verifies that when chunks already exist with token counts, those counts are NOT overwritten. +func TestFindOrCreateChunks_MultiRow_PreservesExistingTokenCounts(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/preserve-tokens-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "preserve-tokens-repo", "Test repository for preserving token counts", "indexed") + require.NoError(t, err) + + // Create and save chunks with token counts + now := time.Now() + existingChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "preserve.go", + Content: "func TokenFunc() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-preserve-tokens", + Type: "function", + EntityName: "TokenFunc", + CreatedAt: now, + TokenCount: 150, // Original token count + TokenCountedAt: &now, + }, + } + + err = repo.SaveChunks(ctx, existingChunks) + require.NoError(t, err, "SaveChunks should succeed") + + originalChunkID := existingChunks[0].ID + + // Verify original token count was saved + var originalTokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, originalChunkID).Scan(&originalTokenCount) + require.NoError(t, err) + require.NotNil(t, originalTokenCount) + assert.Equal(t, 150, *originalTokenCount, "Original token count should be 150") + + // Call FindOrCreateChunks with the same chunk but different token count + duplicateChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), // Different ID + RepositoryID: repositoryID, + FilePath: "preserve.go", + Content: "func TokenFunc() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-preserve-tokens", // Same hash triggers conflict + Type: "function", + EntityName: "TokenFunc", + CreatedAt: now, + TokenCount: 999, // Different token count (should be ignored) + TokenCountedAt: &now, + }, + } + + resultChunks, err := repo.FindOrCreateChunks(ctx, duplicateChunks) + require.NoError(t, err, "FindOrCreateChunks should succeed") + require.NotNil(t, resultChunks, "Result should not be nil") + require.Len(t, resultChunks, 1, "Should return 1 chunk") + + // Verify the returned ID is the original one + assert.Equal(t, originalChunkID, resultChunks[0].ID, "Should return existing chunk ID") + + // Verify the token count was preserved (not overwritten) + var preservedTokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, originalChunkID).Scan(&preservedTokenCount) + require.NoError(t, err) + require.NotNil(t, preservedTokenCount) + assert.Equal(t, 150, *preservedTokenCount, "Token count should be preserved as 150, not overwritten to 999") + + // Verify only one chunk exists (no duplicates) + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Should have exactly 1 chunk (no duplicates)") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} diff --git a/internal/adapter/outbound/repository/chunk_repository_multirow_test.go b/internal/adapter/outbound/repository/chunk_repository_multirow_test.go new file mode 100644 index 0000000..80b9585 --- /dev/null +++ b/internal/adapter/outbound/repository/chunk_repository_multirow_test.go @@ -0,0 +1,354 @@ +package repository + +import ( + "strings" + "testing" +) + +// TestBuildMultiRowChunkInsert_SingleRow tests building query for a single chunk. +// This test verifies the basic query structure and parameter placeholders. +func TestBuildMultiRowChunkInsert_SingleRow(t *testing.T) { + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(1) + + // Verify query contains the INSERT clause with all 17 columns + expectedColumns := []string{ + "id", "repository_id", "file_path", "chunk_type", "content", "language", + "start_line", "end_line", "entity_name", "parent_entity", "content_hash", + "metadata", "qualified_name", "signature", "visibility", "token_count", "token_counted_at", + } + + for _, col := range expectedColumns { + if !strings.Contains(query, col) { + t.Errorf("Query missing column: %s", col) + } + } + + // Verify VALUES clause has exactly 17 parameters for single row + if !strings.Contains(query, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)") { + t.Error("Query missing correct VALUES clause with $1-$17 parameters for single row") + } + + // Verify ON CONFLICT clause exists + if !strings.Contains(query, "ON CONFLICT (repository_id, file_path, content_hash)") { + t.Error("Query missing ON CONFLICT clause") + } + + // Verify DO UPDATE SET clause exists with COALESCE logic for token fields + if !strings.Contains(query, "DO UPDATE SET") { + t.Error("Query missing DO UPDATE SET clause") + } + + if !strings.Contains(query, "COALESCE") { + t.Error("Query missing COALESCE logic for token_count preservation") + } + + // Verify RETURNING clause exists + if !strings.Contains(query, "RETURNING id") { + t.Error("Query missing RETURNING id clause") + } +} + +// TestBuildMultiRowChunkInsert_MultipleRows tests building query for multiple chunks. +// This test verifies correct generation of multiple VALUES clauses with proper parameter numbering. +func TestBuildMultiRowChunkInsert_MultipleRows(t *testing.T) { + tests := []struct { + name string + numRows int + firstRow string // Expected first VALUES row parameters + lastRow string // Expected last VALUES row parameters + }{ + { + name: "Two rows", + numRows: 2, + firstRow: "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)", + lastRow: "($18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34)", + }, + { + name: "Three rows", + numRows: 3, + firstRow: "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)", + lastRow: "($35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50, $51)", + }, + { + name: "Five rows", + numRows: 5, + firstRow: "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)", + lastRow: "($69, $70, $71, $72, $73, $74, $75, $76, $77, $78, $79, $80, $81, $82, $83, $84, $85)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(tt.numRows) + + // Verify first row parameters + if !strings.Contains(query, tt.firstRow) { + t.Errorf("Query missing first row with parameters %s", tt.firstRow) + } + + // Verify last row parameters + if !strings.Contains(query, tt.lastRow) { + t.Errorf("Query missing last row with parameters %s", tt.lastRow) + } + + // Verify VALUES keyword appears exactly once + valuesCount := strings.Count(query, "VALUES") + if valuesCount != 1 { + t.Errorf("Expected exactly 1 VALUES keyword, got %d", valuesCount) + } + + // Verify correct number of value rows (count commas between rows + 1) + // Each VALUES row except last should be followed by comma + valuesSectionStart := strings.Index(query, "VALUES") + onConflictStart := strings.Index(query, "ON CONFLICT") + if valuesSectionStart == -1 || onConflictStart == -1 { + t.Fatal("Query structure invalid - missing VALUES or ON CONFLICT") + } + + valuesSection := query[valuesSectionStart:onConflictStart] + // Count opening parentheses in VALUES section (one per row) + rowCount := strings.Count(valuesSection, "($") + if rowCount != tt.numRows { + t.Errorf("Expected %d value rows, found %d", tt.numRows, rowCount) + } + + // Verify ON CONFLICT and RETURNING clauses still present + if !strings.Contains(query, "ON CONFLICT (repository_id, file_path, content_hash)") { + t.Error("Query missing ON CONFLICT clause") + } + + if !strings.Contains(query, "RETURNING id") { + t.Error("Query missing RETURNING id clause") + } + }) + } +} + +// TestBuildMultiRowChunkInsert_ParameterIndexing tests parameter numbering correctness. +// This test specifically validates that parameters are numbered sequentially across all rows. +func TestBuildMultiRowChunkInsert_ParameterIndexing(t *testing.T) { + const numRows = 3 + const columnsPerRow = 17 + + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(numRows) + + // Verify each expected parameter placeholder exists exactly once + totalParams := numRows * columnsPerRow + for i := 1; i <= totalParams; i++ { + // For proper testing, check each parameter number exists in query + // Parameter $1 should appear, $2 should appear, etc. + // But we need to be careful about $1 vs $10, $11, etc. + // So we'll check for the parameter with word boundaries (comma or parenthesis) + paramPatterns := []string{ + "$" + formatParamNum(i) + ",", // Followed by comma + "$" + formatParamNum(i) + ")", // Followed by closing paren + } + + found := false + for _, pattern := range paramPatterns { + if strings.Contains(query, pattern) { + found = true + break + } + } + + if !found { + t.Errorf("Parameter $%d not found in query at expected position", i) + } + } + + // Verify no parameters beyond the expected total exist + invalidParam := "$" + formatParamNum(totalParams+1) + if strings.Contains(query, invalidParam) { + t.Errorf("Query contains unexpected parameter %s (expected max $%d)", invalidParam, totalParams) + } +} + +// formatParamNum formats an integer parameter number as a string. +func formatParamNum(num int) string { + if num < 10 { + return string(rune('0' + num)) + } + // For two-digit numbers + tens := num / 10 + ones := num % 10 + return string(rune('0'+tens)) + string(rune('0'+ones)) +} + +// TestBuildMultiRowChunkInsert_ZeroRows tests handling of zero rows. +// This test verifies error handling or empty string return for invalid input. +func TestBuildMultiRowChunkInsert_ZeroRows(t *testing.T) { + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(0) + + // For zero rows, we expect either: + // 1. Empty string (no-op) + // 2. Or the function could panic/error (if it validates input) + // For this test, we'll expect an empty string as the most sensible behavior + if query != "" { + t.Errorf("Expected empty query for 0 rows, got: %s", query) + } +} + +// TestBuildMultiRowChunkInsert_NegativeRows tests handling of negative rows. +// This test verifies error handling for invalid negative input. +func TestBuildMultiRowChunkInsert_NegativeRows(t *testing.T) { + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(-1) + + // For negative rows, we expect empty string (invalid input should result in no-op) + if query != "" { + t.Errorf("Expected empty query for negative rows, got: %s", query) + } +} + +// TestBuildMultiRowChunkInsert_LargeRowCount tests handling of large batch sizes. +// This test verifies the function can handle large numbers of rows (e.g., 1000). +func TestBuildMultiRowChunkInsert_LargeRowCount(t *testing.T) { + const numRows = 1000 + const columnsPerRow = 17 + + // This will fail because buildMultiRowChunkInsert doesn't exist yet + query := buildMultiRowChunkInsert(numRows) + + // Verify query is not empty + if query == "" { + t.Fatal("Expected non-empty query for 1000 rows") + } + + // Verify first and last parameter numbers are correct + firstParam := "$1" + lastParam := "$17000" // 1000 rows * 17 columns = 17000 parameters + + if !strings.Contains(query, firstParam) { + t.Error("Query missing first parameter $1") + } + + // Check for last parameter (should appear at end of last VALUES row) + if !strings.Contains(query, lastParam) { + t.Errorf("Query missing last parameter %s", lastParam) + } + + // Verify structure is still valid + if !strings.Contains(query, "INSERT INTO codechunking.code_chunks") { + t.Error("Query missing INSERT INTO clause") + } + + if !strings.Contains(query, "VALUES") { + t.Error("Query missing VALUES clause") + } + + if !strings.Contains(query, "ON CONFLICT") { + t.Error("Query missing ON CONFLICT clause") + } + + if !strings.Contains(query, "RETURNING id") { + t.Error("Query missing RETURNING clause") + } +} + +// TestBuildMultiRowChunkInsert_QueryStructure tests overall query structure compliance. +// This test verifies the query follows the exact expected SQL structure. +func TestBuildMultiRowChunkInsert_QueryStructure(t *testing.T) { + query := buildMultiRowChunkInsert(2) + + // Verify query sections appear in correct order + sections := []string{ + "INSERT INTO codechunking.code_chunks", + "VALUES", + "ON CONFLICT (repository_id, file_path, content_hash)", + "DO UPDATE SET", + "RETURNING id", + } + + lastIndex := -1 + for _, section := range sections { + index := strings.Index(query, section) + if index == -1 { + t.Errorf("Query missing required section: %s", section) + continue + } + if index <= lastIndex { + t.Errorf("Query sections out of order: %s should appear after previous section", section) + } + lastIndex = index + } + + // Verify UPDATE SET clause includes proper COALESCE logic + updateSetStart := strings.Index(query, "DO UPDATE SET") + returningStart := strings.Index(query, "RETURNING") + if updateSetStart != -1 && returningStart != -1 { + updateSetClause := query[updateSetStart:returningStart] + + // Should update id to existing value (no change) + if !strings.Contains(updateSetClause, "id = code_chunks.id") { + t.Error("UPDATE SET clause should preserve existing chunk ID") + } + + // Should use COALESCE to prefer existing token_count + if !strings.Contains(updateSetClause, "token_count = COALESCE(code_chunks.token_count, EXCLUDED.token_count)") { + t.Error("UPDATE SET clause should use COALESCE to preserve existing token_count") + } + + // Should use COALESCE to prefer existing token_counted_at + if !strings.Contains( + updateSetClause, + "token_counted_at = COALESCE(code_chunks.token_counted_at, EXCLUDED.token_counted_at)", + ) { + t.Error("UPDATE SET clause should use COALESCE to preserve existing token_counted_at") + } + } +} + +// TestBuildMultiRowChunkInsert_ColumnOrder tests that columns appear in exact order. +// This test verifies the column order matches the expected schema. +func TestBuildMultiRowChunkInsert_ColumnOrder(t *testing.T) { + query := buildMultiRowChunkInsert(1) + + // Expected column order (17 columns) + expectedOrder := []string{ + "id", + "repository_id", + "file_path", + "chunk_type", + "content", + "language", + "start_line", + "end_line", + "entity_name", + "parent_entity", + "content_hash", + "metadata", + "qualified_name", + "signature", + "visibility", + "token_count", + "token_counted_at", + } + + // Extract the column list from INSERT clause + insertStart := strings.Index(query, "INSERT INTO codechunking.code_chunks") + valuesStart := strings.Index(query, "VALUES") + if insertStart == -1 || valuesStart == -1 { + t.Fatal("Query missing INSERT or VALUES clause") + } + + columnSection := query[insertStart:valuesStart] + + // Verify each column appears in order + lastIndex := -1 + for i, col := range expectedOrder { + index := strings.Index(columnSection, col) + if index == -1 { + t.Errorf("Column %d (%s) not found in INSERT clause", i+1, col) + continue + } + if index <= lastIndex { + t.Errorf("Column %d (%s) appears out of order (should appear after column %d)", i+1, col, i) + } + lastIndex = index + } +} diff --git a/internal/adapter/outbound/repository/chunk_repository_savechunks_multirow_test.go b/internal/adapter/outbound/repository/chunk_repository_savechunks_multirow_test.go new file mode 100644 index 0000000..608ec17 --- /dev/null +++ b/internal/adapter/outbound/repository/chunk_repository_savechunks_multirow_test.go @@ -0,0 +1,649 @@ +//go:build integration + +package repository + +import ( + "codechunking/internal/port/outbound" + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSaveChunks_MultiRow_SingleChunk tests that SaveChunks correctly handles a single chunk. +// This test verifies that the multi-row INSERT implementation works correctly for the edge case of a single chunk. +func TestSaveChunks_MultiRow_SingleChunk(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/single-chunk-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "single-chunk-test-repo", "Test repository for single chunk multi-row INSERT", "indexed") + require.NoError(t, err) + + // Create a single chunk + now := time.Now() + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "single_chunk.go", + Content: "func SingleChunk() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-single-chunk-1", + Type: "function", + EntityName: "SingleChunk", + CreatedAt: now, + TokenCount: 10, + TokenCountedAt: &now, + }, + } + + // Save the single chunk using SaveChunks (should use multi-row INSERT for single chunk too) + err = repo.SaveChunks(ctx, chunks) + require.NoError(t, err, "SaveChunks should work correctly for a single chunk") + + // Verify the chunk was saved correctly + chunkID, err := uuid.Parse(chunks[0].ID) + require.NoError(t, err) + + var retrievedContent string + var retrievedLanguage string + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT content, language, token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&retrievedContent, &retrievedLanguage, &tokenCount) + require.NoError(t, err, "Should retrieve the saved chunk") + + assert.Equal(t, chunks[0].Content, retrievedContent, "Content should match") + assert.Equal(t, chunks[0].Language, retrievedLanguage, "Language should match") + require.NotNil(t, tokenCount, "Token count should be saved") + assert.Equal(t, 10, *tokenCount, "Token count should match") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestSaveChunks_MultiRow_MultipleChunks tests that SaveChunks correctly saves multiple chunks in a single operation. +// This test verifies that the multi-row INSERT implementation batches multiple chunks together. +func TestSaveChunks_MultiRow_MultipleChunks(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/multiple-chunks-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "multiple-chunks-test-repo", "Test repository for multiple chunks multi-row INSERT", "indexed") + require.NoError(t, err) + + // Create 5 chunks + now := time.Now() + chunks := make([]outbound.CodeChunk, 5) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "multiple_chunks.go", + Content: "func TestFunc" + string(rune('A'+i)) + "() { }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-multiple-" + uuid.New().String()[:8], + Type: "function", + EntityName: "TestFunc" + string(rune('A'+i)), + CreatedAt: now, + TokenCount: (i + 1) * 5, // 5, 10, 15, 20, 25 + TokenCountedAt: &now, + } + } + + // Save all chunks in a single operation + err = repo.SaveChunks(ctx, chunks) + require.NoError(t, err, "SaveChunks should work correctly for multiple chunks") + + // Verify all chunks were saved correctly + for i, chunk := range chunks { + chunkID, err := uuid.Parse(chunk.ID) + require.NoError(t, err) + + var retrievedContent string + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT content, token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&retrievedContent, &tokenCount) + require.NoError(t, err, "Should retrieve chunk %d", i) + + assert.Equal(t, chunk.Content, retrievedContent, "Content should match for chunk %d", i) + require.NotNil(t, tokenCount, "Token count should be saved for chunk %d", i) + assert.Equal(t, (i+1)*5, *tokenCount, "Token count should match for chunk %d", i) + } + + // Verify correct number of chunks were saved + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 5, count, "Should have exactly 5 chunks") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestSaveChunks_MultiRow_EmptySlice tests that SaveChunks handles an empty slice correctly. +// This test verifies that SaveChunks is a no-op for empty input and returns nil error. +func TestSaveChunks_MultiRow_EmptySlice(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Test with nil slice + err := repo.SaveChunks(ctx, nil) + assert.NoError(t, err, "SaveChunks should return nil for nil slice") + + // Test with empty slice + err = repo.SaveChunks(ctx, []outbound.CodeChunk{}) + assert.NoError(t, err, "SaveChunks should return nil for empty slice") +} + +// TestSaveChunks_MultiRow_LargeBatch tests that SaveChunks handles a large batch of chunks. +// This test verifies that the multi-row INSERT implementation can handle typical token counting batch sizes (50+ chunks). +func TestSaveChunks_MultiRow_LargeBatch(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/large-batch-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "large-batch-test-repo", "Test repository for large batch multi-row INSERT", "indexed") + require.NoError(t, err) + + // Create 50 chunks (typical batch size for token counting) + now := time.Now() + chunks := make([]outbound.CodeChunk, 50) + for i := range chunks { + chunks[i] = outbound.CodeChunk{ + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "large_batch.go", + Content: "func LargeFunc" + string(rune('0'+i%10)) + "() { }", + Language: "go", + StartLine: i*2 + 1, + EndLine: i*2 + 1, + Hash: "hash-large-" + uuid.New().String()[:8], + Type: "function", + EntityName: "LargeFunc" + string(rune('0'+i%10)), + CreatedAt: now, + TokenCount: (i + 1) * 2, + TokenCountedAt: &now, + } + } + + // Save all chunks in a single operation + startTime := time.Now() + err = repo.SaveChunks(ctx, chunks) + duration := time.Since(startTime) + require.NoError(t, err, "SaveChunks should work correctly for large batch") + + t.Logf("SaveChunks for 50 chunks took: %v", duration) + + // Verify all chunks were saved correctly + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 50, count, "Should have exactly 50 chunks") + + // Spot check a few chunks to verify data integrity + testIndices := []int{0, 24, 49} // First, middle, last + for _, i := range testIndices { + chunkID, err := uuid.Parse(chunks[i].ID) + require.NoError(t, err) + + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount) + require.NoError(t, err, "Should retrieve chunk %d", i) + + require.NotNil(t, tokenCount, "Token count should be saved for chunk %d", i) + assert.Equal(t, (i+1)*2, *tokenCount, "Token count should match for chunk %d", i) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestSaveChunks_MultiRow_PreservesChunkData tests that all chunk fields are correctly saved. +// This test verifies that the multi-row INSERT implementation preserves all chunk field values. +func TestSaveChunks_MultiRow_PreservesChunkData(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/preserve-data-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "preserve-data-test-repo", "Test repository for data preservation multi-row INSERT", "indexed") + require.NoError(t, err) + + // Create chunks with comprehensive field values + now := time.Now() + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "preserve_data.go", + Content: "func CompleteExample() {\n\t// Full function\n\treturn\n}", + Language: "go", + StartLine: 10, + EndLine: 13, + Hash: "hash-preserve-complete", + Type: "function", + EntityName: "CompleteExample", + ParentEntity: "PackageMain", + QualifiedName: "main.CompleteExample", + Signature: "func CompleteExample()", + Visibility: "public", + CreatedAt: now, + TokenCount: 25, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "preserve_data.go", + Content: "type DataStruct struct {\n\tField string\n}", + Language: "go", + StartLine: 20, + EndLine: 22, + Hash: "hash-preserve-struct", + Type: "struct", + EntityName: "DataStruct", + ParentEntity: "PackageMain", + QualifiedName: "main.DataStruct", + Signature: "type DataStruct struct", + Visibility: "public", + CreatedAt: now, + TokenCount: 15, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "preserve_data.go", + Content: "func (d *DataStruct) privateMethod() { }", + Language: "go", + StartLine: 30, + EndLine: 30, + Hash: "hash-preserve-private", + Type: "method", + EntityName: "privateMethod", + ParentEntity: "DataStruct", + QualifiedName: "main.DataStruct.privateMethod", + Signature: "func (d *DataStruct) privateMethod()", + Visibility: "private", + CreatedAt: now, + TokenCount: 12, + TokenCountedAt: &now, + }, + } + + // Save all chunks + err = repo.SaveChunks(ctx, chunks) + require.NoError(t, err, "SaveChunks should preserve all chunk data") + + // Verify all fields for each chunk + for i, chunk := range chunks { + chunkID, err := uuid.Parse(chunk.ID) + require.NoError(t, err) + + var ( + retrievedContent string + retrievedLanguage string + retrievedFilePath string + retrievedType string + retrievedEntityName string + retrievedParentEntity string + retrievedQualifiedName string + retrievedSignature string + retrievedVisibility string + retrievedStartLine int + retrievedEndLine int + retrievedHash string + tokenCount *int + tokenCountedAt *time.Time + ) + + err = pool.QueryRow(ctx, ` + SELECT content, language, file_path, chunk_type, entity_name, parent_entity, + qualified_name, signature, visibility, start_line, end_line, content_hash, + token_count, token_counted_at + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan( + &retrievedContent, &retrievedLanguage, &retrievedFilePath, &retrievedType, + &retrievedEntityName, &retrievedParentEntity, &retrievedQualifiedName, + &retrievedSignature, &retrievedVisibility, &retrievedStartLine, &retrievedEndLine, + &retrievedHash, &tokenCount, &tokenCountedAt, + ) + require.NoError(t, err, "Should retrieve chunk %d", i) + + // Verify all fields + assert.Equal(t, chunk.Content, retrievedContent, "Content should match for chunk %d", i) + assert.Equal(t, chunk.Language, retrievedLanguage, "Language should match for chunk %d", i) + assert.Equal(t, chunk.FilePath, retrievedFilePath, "FilePath should match for chunk %d", i) + assert.Equal(t, chunk.Type, retrievedType, "Type should match for chunk %d", i) + assert.Equal(t, chunk.EntityName, retrievedEntityName, "EntityName should match for chunk %d", i) + assert.Equal(t, chunk.ParentEntity, retrievedParentEntity, "ParentEntity should match for chunk %d", i) + assert.Equal(t, chunk.QualifiedName, retrievedQualifiedName, "QualifiedName should match for chunk %d", i) + assert.Equal(t, chunk.Signature, retrievedSignature, "Signature should match for chunk %d", i) + assert.Equal(t, chunk.Visibility, retrievedVisibility, "Visibility should match for chunk %d", i) + assert.Equal(t, chunk.StartLine, retrievedStartLine, "StartLine should match for chunk %d", i) + assert.Equal(t, chunk.EndLine, retrievedEndLine, "EndLine should match for chunk %d", i) + assert.Equal(t, chunk.Hash, retrievedHash, "Hash should match for chunk %d", i) + + // Verify token count fields + require.NotNil(t, tokenCount, "Token count should be saved for chunk %d", i) + assert.Equal(t, chunk.TokenCount, *tokenCount, "Token count should match for chunk %d", i) + require.NotNil(t, tokenCountedAt, "Token counted at should be saved for chunk %d", i) + assert.WithinDuration( + t, + *chunk.TokenCountedAt, + *tokenCountedAt, + time.Second, + "Token counted at should match for chunk %d", + i, + ) + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestSaveChunks_MultiRow_ConflictHandling tests that SaveChunks correctly handles conflicts. +// This test verifies that the multi-row INSERT implementation properly uses ON CONFLICT to handle duplicate chunks. +func TestSaveChunks_MultiRow_ConflictHandling(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/conflict-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "conflict-test-repo", "Test repository for conflict handling multi-row INSERT", "indexed") + require.NoError(t, err) + + // Create initial chunks with token counts + now := time.Now() + initialChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "conflict.go", + Content: "func ConflictTest() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-conflict-1", + Type: "function", + EntityName: "ConflictTest", + CreatedAt: now, + TokenCount: 100, + TokenCountedAt: &now, + }, + } + + // Save initial chunks + err = repo.SaveChunks(ctx, initialChunks) + require.NoError(t, err, "Initial save should succeed") + + // Save the same chunk again with different token count (should preserve existing token count per COALESCE logic) + duplicateChunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), // Different ID + RepositoryID: repositoryID, + FilePath: "conflict.go", + Content: "func ConflictTest() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-conflict-1", // Same hash triggers conflict + Type: "function", + EntityName: "ConflictTest", + CreatedAt: now, + TokenCount: 999, // Different token count (should be ignored) + TokenCountedAt: &now, + }, + } + + // Save duplicate chunks + err = repo.SaveChunks(ctx, duplicateChunks) + require.NoError(t, err, "Duplicate save should succeed (ON CONFLICT handling)") + + // Verify the original token count is preserved + var tokenCount *int + err = pool.QueryRow(ctx, ` + SELECT token_count + FROM codechunking.code_chunks + WHERE repository_id = $1 AND file_path = $2 AND content_hash = $3 AND deleted_at IS NULL + `, repositoryID, "conflict.go", "hash-conflict-1").Scan(&tokenCount) + require.NoError(t, err) + + require.NotNil(t, tokenCount, "Token count should be preserved") + assert.Equal(t, 100, *tokenCount, "Original token count should be preserved, not overwritten") + + // Verify only one chunk exists (not duplicated) + var count int + err = pool.QueryRow(ctx, ` + SELECT COUNT(*) FROM codechunking.code_chunks + WHERE repository_id = $1 AND deleted_at IS NULL + `, repositoryID).Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count, "Should have exactly 1 chunk (not duplicated)") + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} + +// TestSaveChunks_MultiRow_MixedTokenCounts tests SaveChunks with chunks having varied token count states. +// This test verifies that the multi-row INSERT handles chunks with and without token counts correctly. +func TestSaveChunks_MultiRow_MixedTokenCounts(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + pool := getTestPool(t) + defer pool.Close() + + repo := NewPostgreSQLChunkRepository(pool) + ctx := context.Background() + + // Create test repository + repositoryID := uuid.New() + testURL := "https://github.com/test/mixed-tokens-multirow-" + repositoryID.String()[:8] + + _, err := pool.Exec(ctx, ` + INSERT INTO codechunking.repositories (id, url, normalized_url, name, description, status) + VALUES ($1, $2, $3, $4, $5, $6) + `, repositoryID, testURL, testURL, "mixed-tokens-test-repo", "Test repository for mixed token counts", "indexed") + require.NoError(t, err) + + // Create chunks with mixed token count states + now := time.Now() + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func WithTokens() { }", + Language: "go", + StartLine: 1, + EndLine: 1, + Hash: "hash-with-tokens", + Type: "function", + EntityName: "WithTokens", + CreatedAt: now, + TokenCount: 50, + TokenCountedAt: &now, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func WithoutTokens() { }", + Language: "go", + StartLine: 5, + EndLine: 5, + Hash: "hash-without-tokens", + Type: "function", + EntityName: "WithoutTokens", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: nil, + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + FilePath: "mixed.go", + Content: "func ZeroTokens() { }", + Language: "go", + StartLine: 10, + EndLine: 10, + Hash: "hash-zero-tokens", + Type: "function", + EntityName: "ZeroTokens", + CreatedAt: now, + TokenCount: 0, + TokenCountedAt: &now, // Zero tokens but has timestamp + }, + } + + // Save all chunks + err = repo.SaveChunks(ctx, chunks) + require.NoError(t, err, "SaveChunks should handle mixed token count states") + + // Verify each chunk's token count state + for i, chunk := range chunks { + chunkID, err := uuid.Parse(chunk.ID) + require.NoError(t, err) + + var tokenCount *int + var tokenCountedAt *time.Time + err = pool.QueryRow(ctx, ` + SELECT token_count, token_counted_at + FROM codechunking.code_chunks + WHERE id = $1 AND deleted_at IS NULL + `, chunkID).Scan(&tokenCount, &tokenCountedAt) + require.NoError(t, err, "Should retrieve chunk %d", i) + + // Verify based on original chunk state + if i == 0 { + // Chunk with token count + require.NotNil(t, tokenCount, "Chunk 0 should have token count") + assert.Equal(t, 50, *tokenCount) + require.NotNil(t, tokenCountedAt, "Chunk 0 should have timestamp") + } else if i == 1 { + // Chunk without token count + assert.Nil(t, tokenCountedAt, "Chunk 1 should not have timestamp") + } else if i == 2 { + // Chunk with zero tokens but timestamp + require.NotNil(t, tokenCountedAt, "Chunk 2 should have timestamp") + } + } + + // Cleanup + t.Cleanup(func() { + pool.Exec(ctx, "DELETE FROM codechunking.code_chunks WHERE repository_id = $1", repositoryID) + pool.Exec(ctx, "DELETE FROM codechunking.repositories WHERE id = $1", repositoryID) + }) +} From e4d2adb987e8870b85eb69333e678d3cb22bbef0 Mon Sep 17 00:00:00 2001 From: Anthony Bible Date: Fri, 28 Nov 2025 12:45:46 -0700 Subject: [PATCH 5/5] feat: add chunk deduplication before database operations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement deduplicateChunksByKey to remove duplicate chunks - Integrate deduplication into countTokensForChunks - Add deduplication to submitBatchJobAsync before FindOrCreateChunks - Prevent duplicate key errors during batch processing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- internal/application/worker/job_processor.go | 52 +++ .../job_processor_batch_queueing_test.go | 118 +++++ .../worker/job_processor_dedup_test.go | 437 ++++++++++++++++++ .../job_processor_token_counting_test.go | 16 +- 4 files changed, 619 insertions(+), 4 deletions(-) create mode 100644 internal/application/worker/job_processor_dedup_test.go diff --git a/internal/application/worker/job_processor.go b/internal/application/worker/job_processor.go index 04a5d14..60f00f9 100644 --- a/internal/application/worker/job_processor.go +++ b/internal/application/worker/job_processor.go @@ -1439,6 +1439,7 @@ func (p *DefaultJobProcessor) countTokensForChunks( // Select chunks to count based on mode chunksToCount := p.selectChunksForCounting(ctx, chunks, mode) + chunksToCount = deduplicateChunksByKey(ctx, chunksToCount) if len(chunksToCount) == 0 { return } @@ -2397,6 +2398,9 @@ func (p *DefaultJobProcessor) submitBatchJobAsync( chunks[i].RepositoryID = repositoryID } + // Deduplicate chunks to avoid inserting duplicates + chunks = deduplicateChunksByKey(ctx, chunks) + // FindOrCreateChunks returns chunks with their actual persisted IDs // This prevents FK constraint violations when embeddings reference chunk IDs slogger.Debug(ctx, "Finding or creating chunks before queueing batch", slogger.Fields{ @@ -2598,3 +2602,51 @@ func (p *DefaultJobProcessor) resumeFromLastBatch( return lastBatch, nil } + +// chunkKey represents a unique key for deduplication based on repository ID, file path, and content hash. +// This combination matches the PostgreSQL unique constraint on (repository_id, file_path, content_hash) +// to prevent duplicate chunk insertion errors. +type chunkKey struct { + repoID uuid.UUID + filePath string + hash string +} + +// deduplicateChunksByKey removes duplicate chunks by (repository_id, file_path, content_hash) key. +// It keeps the first occurrence of each unique key and logs a warning if duplicates are found. +func deduplicateChunksByKey(ctx context.Context, chunks []outbound.CodeChunk) []outbound.CodeChunk { + if len(chunks) == 0 { + return []outbound.CodeChunk{} + } + + // Pre-allocate map with capacity to avoid rehashing + seen := make(map[chunkKey]struct{}, len(chunks)) + result := make([]outbound.CodeChunk, 0, len(chunks)) + duplicateCount := 0 + + for _, chunk := range chunks { + key := chunkKey{ + repoID: chunk.RepositoryID, + filePath: chunk.FilePath, + hash: chunk.Hash, + } + + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + result = append(result, chunk) + } else { + duplicateCount++ + } + } + + if duplicateCount > 0 { + slogger.Warn(ctx, "Duplicate chunks detected in batch - possible parser issue", + slogger.Fields{ + "original_count": len(chunks), + "unique_count": len(result), + "duplicates": duplicateCount, + }) + } + + return result +} diff --git a/internal/application/worker/job_processor_batch_queueing_test.go b/internal/application/worker/job_processor_batch_queueing_test.go index 2b2fb3e..31869c3 100644 --- a/internal/application/worker/job_processor_batch_queueing_test.go +++ b/internal/application/worker/job_processor_batch_queueing_test.go @@ -756,3 +756,121 @@ func TestJobProcessor_SubmitBatchJobAsync_MultipleChunks_CorrectRequestCount(t * requestIDs[req.RequestID] = true } } + +// TestJobProcessor_SubmitBatchJobAsync_DeduplicatesChunks verifies that +// submitBatchJobAsync deduplicates chunks by (repository_id, file_path, content_hash) +// before calling FindOrCreateChunks. +// +// RED PHASE EXPECTATION: +// - Chunks with same (repository_id, file_path, hash) should be deduplicated +// - FindOrCreateChunks should receive only unique chunks +// - Duplicate chunks should be removed before persistence +// +// CURRENT BEHAVIOR (WHY THIS FAILS): +// - submitBatchJobAsync calls FindOrCreateChunks directly without deduplication +// - This can cause duplicate chunk errors in the database. +func TestJobProcessor_SubmitBatchJobAsync_DeduplicatesChunks(t *testing.T) { + // Arrange + ctx := context.Background() + repositoryID := uuid.New() + indexingJobID := uuid.New() + + // Create chunks with duplicates - same repo_id, file_path, and hash + sharedHash := "abc123hash" + sharedFilePath := "/test/duplicate.go" + + chunks := []outbound.CodeChunk{ + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + Content: "function test() { return 42; }", + FilePath: sharedFilePath, + Language: "go", + Hash: sharedHash, // Same hash + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + Content: "function other() { return 'hello'; }", + FilePath: "/test/unique.go", + Language: "go", + Hash: "xyz789different", + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + Content: "function test() { return 42; }", // Duplicate content + FilePath: sharedFilePath, // Same file path + Language: "go", + Hash: sharedHash, // Same hash - THIS IS A DUPLICATE + }, + { + ID: uuid.New().String(), + RepositoryID: repositoryID, + Content: "function another() { return true; }", + FilePath: "/test/another.go", + Language: "go", + Hash: "def456hash", + }, + } + + // Create mocks + mockChunkStorageRepo := new(MockChunkStorageRepository) + mockBatchProgressRepo := new(MockBatchProgressRepository) + mockBatchEmbeddingService := new(MockBatchEmbeddingService) + + // KEY EXPECTATION: Capture what chunks are passed to FindOrCreateChunks + var capturedChunks []outbound.CodeChunk + mockChunkStorageRepo.On("FindOrCreateChunks", ctx, mock.MatchedBy(func(c []outbound.CodeChunk) bool { + capturedChunks = c + return true + })).Return(func(ctx context.Context, chunks []outbound.CodeChunk) []outbound.CodeChunk { + // Return same chunks with IDs (simulating DB save) + return chunks + }, nil) + + mockBatchProgressRepo.On("Save", ctx, mock.Anything).Return(nil) + + processor := &DefaultJobProcessor{ + chunkStorageRepo: mockChunkStorageRepo, + batchProgressRepo: mockBatchProgressRepo, + batchEmbeddingService: mockBatchEmbeddingService, + batchConfig: config.BatchProcessingConfig{ + Enabled: true, + UseTestEmbeddings: false, + }, + } + + options := outbound.EmbeddingOptions{ + Model: "gemini-embedding-001", + TaskType: outbound.TaskTypeRetrievalDocument, + } + + // Act + err := processor.submitBatchJobAsync(ctx, indexingJobID, repositoryID, 1, 1, chunks, options) + + // Assert + require.NoError(t, err) + + // CRITICAL ASSERTION: FindOrCreateChunks should receive only 3 unique chunks (not 4) + // Chunks[0] and chunks[2] are duplicates (same repo_id, file_path, hash) + assert.Len(t, capturedChunks, 3, "Should receive only deduplicated chunks") + + // Verify the duplicate was removed + // Count how many chunks have the shared hash + hashCount := 0 + for _, chunk := range capturedChunks { + if chunk.Hash == sharedHash && chunk.FilePath == sharedFilePath { + hashCount++ + } + } + assert.Equal(t, 1, hashCount, "Should have only one chunk with the duplicate hash/filepath combination") + + // Verify all remaining chunks have unique keys + seenKeys := make(map[string]bool) + for _, chunk := range capturedChunks { + key := chunk.RepositoryID.String() + "|" + chunk.FilePath + "|" + chunk.Hash + assert.False(t, seenKeys[key], "All chunks should have unique (repo_id, file_path, hash) keys") + seenKeys[key] = true + } +} diff --git a/internal/application/worker/job_processor_dedup_test.go b/internal/application/worker/job_processor_dedup_test.go new file mode 100644 index 0000000..3d127af --- /dev/null +++ b/internal/application/worker/job_processor_dedup_test.go @@ -0,0 +1,437 @@ +package worker + +import ( + "codechunking/internal/port/outbound" + "context" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_deduplicateChunksByKey_WithDuplicates verifies that when given chunks +// with duplicate (repository_id, file_path, content_hash) keys, only the first +// occurrence of each unique key is retained. +func Test_deduplicateChunksByKey_WithDuplicates(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + // Create chunks with some duplicates by (repository_id, file_path, hash) + chunks := []outbound.CodeChunk{ + { + ID: "chunk-1", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-abc123", + Content: "func main() { ... }", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-2", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-abc123", // Duplicate: same repo, file, hash + Content: "func main() { ... }", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-3", + RepositoryID: repoID, + FilePath: "src/util.go", + Hash: "hash-def456", + Content: "func helper() { ... }", + StartLine: 5, + EndLine: 15, + }, + { + ID: "chunk-4", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-xyz789", + Content: "func init() { ... }", + StartLine: 20, + EndLine: 25, + }, + { + ID: "chunk-5", + RepositoryID: repoID, + FilePath: "src/util.go", + Hash: "hash-def456", // Duplicate: same repo, file, hash + Content: "func helper() { ... }", + StartLine: 5, + EndLine: 15, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 3, "should have 3 unique chunks (removed 2 duplicates)") + + // Verify we kept the first occurrence of each unique key + assert.Equal(t, "chunk-1", result[0].ID, "first chunk should be chunk-1") + assert.Equal(t, "chunk-3", result[1].ID, "second chunk should be chunk-3") + assert.Equal(t, "chunk-4", result[2].ID, "third chunk should be chunk-4") + + // Verify the unique keys + assert.Equal(t, "src/main.go", result[0].FilePath) + assert.Equal(t, "hash-abc123", result[0].Hash) + + assert.Equal(t, "src/util.go", result[1].FilePath) + assert.Equal(t, "hash-def456", result[1].Hash) + + assert.Equal(t, "src/main.go", result[2].FilePath) + assert.Equal(t, "hash-xyz789", result[2].Hash) +} + +// Test_deduplicateChunksByKey_NoDuplicates verifies that when given chunks +// with all unique keys, the function returns the same chunks unchanged. +func Test_deduplicateChunksByKey_NoDuplicates(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "chunk-1", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-abc123", + Content: "func main() { ... }", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-2", + RepositoryID: repoID, + FilePath: "src/util.go", + Hash: "hash-def456", + Content: "func helper() { ... }", + StartLine: 5, + EndLine: 15, + }, + { + ID: "chunk-3", + RepositoryID: repoID, + FilePath: "src/handlers.go", + Hash: "hash-xyz789", + Content: "func handler() { ... }", + StartLine: 20, + EndLine: 30, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 3, "should have 3 chunks (no duplicates removed)") + + // Verify all chunks are preserved in order + assert.Equal(t, "chunk-1", result[0].ID) + assert.Equal(t, "chunk-2", result[1].ID) + assert.Equal(t, "chunk-3", result[2].ID) +} + +// Test_deduplicateChunksByKey_EmptySlice verifies that when given an empty +// slice of chunks, the function returns an empty slice without errors. +func Test_deduplicateChunksByKey_EmptySlice(t *testing.T) { + // Arrange + ctx := context.Background() + chunks := []outbound.CodeChunk{} + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Empty(t, result, "should return empty slice") +} + +// Test_deduplicateChunksByKey_PreservesOrder verifies that the function +// preserves the order of first occurrences and keeps the first occurrence +// when duplicates are encountered. +func Test_deduplicateChunksByKey_PreservesOrder(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "first-occurrence", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-duplicate", + Content: "original content", + StartLine: 1, + EndLine: 5, + }, + { + ID: "unique-chunk", + RepositoryID: repoID, + FilePath: "src/util.go", + Hash: "hash-unique", + Content: "unique content", + StartLine: 10, + EndLine: 15, + }, + { + ID: "second-occurrence", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-duplicate", // Duplicate of first + Content: "duplicate content", + StartLine: 1, + EndLine: 5, + }, + { + ID: "third-occurrence", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-duplicate", // Another duplicate + Content: "another duplicate", + StartLine: 1, + EndLine: 5, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 2, "should have 2 unique chunks") + + // Verify the first occurrence is kept, not later ones + assert.Equal(t, "first-occurrence", result[0].ID, "should keep first occurrence of duplicate key") + assert.Equal(t, "original content", result[0].Content, "should preserve content of first occurrence") + + // Verify the unique chunk is preserved + assert.Equal(t, "unique-chunk", result[1].ID) + assert.Equal(t, "unique content", result[1].Content) +} + +// Test_deduplicateChunksByKey_DifferentRepositories verifies that chunks +// with the same file path and hash but different repository IDs are treated +// as unique (not duplicates). +func Test_deduplicateChunksByKey_DifferentRepositories(t *testing.T) { + // Arrange + ctx := context.Background() + repoID1 := uuid.New() + repoID2 := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "chunk-repo1", + RepositoryID: repoID1, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "content in repo 1", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-repo2", + RepositoryID: repoID2, + FilePath: "src/main.go", + Hash: "hash-same", // Same file path and hash, but different repo + Content: "content in repo 2", + StartLine: 1, + EndLine: 10, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 2, "should have 2 chunks (different repositories)") + + // Verify both chunks are preserved + assert.Equal(t, "chunk-repo1", result[0].ID) + assert.Equal(t, repoID1, result[0].RepositoryID) + + assert.Equal(t, "chunk-repo2", result[1].ID) + assert.Equal(t, repoID2, result[1].RepositoryID) +} + +// Test_deduplicateChunksByKey_DifferentFilePaths verifies that chunks +// with the same repository ID and hash but different file paths are treated +// as unique (not duplicates). +func Test_deduplicateChunksByKey_DifferentFilePaths(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "chunk-file1", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "same hash different file", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-file2", + RepositoryID: repoID, + FilePath: "src/util.go", // Different file path + Hash: "hash-same", + Content: "same hash different file", + StartLine: 1, + EndLine: 10, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 2, "should have 2 chunks (different file paths)") + + // Verify both chunks are preserved + assert.Equal(t, "chunk-file1", result[0].ID) + assert.Equal(t, "src/main.go", result[0].FilePath) + + assert.Equal(t, "chunk-file2", result[1].ID) + assert.Equal(t, "src/util.go", result[1].FilePath) +} + +// Test_deduplicateChunksByKey_DifferentHashes verifies that chunks +// with the same repository ID and file path but different hashes are treated +// as unique (not duplicates). +func Test_deduplicateChunksByKey_DifferentHashes(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "chunk-hash1", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-abc123", + Content: "version 1", + StartLine: 1, + EndLine: 10, + }, + { + ID: "chunk-hash2", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-def456", // Different hash + Content: "version 2", + StartLine: 1, + EndLine: 10, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 2, "should have 2 chunks (different hashes)") + + // Verify both chunks are preserved + assert.Equal(t, "chunk-hash1", result[0].ID) + assert.Equal(t, "hash-abc123", result[0].Hash) + + assert.Equal(t, "chunk-hash2", result[1].ID) + assert.Equal(t, "hash-def456", result[1].Hash) +} + +// Test_deduplicateChunksByKey_AllDuplicates verifies that when all chunks +// are duplicates of the first one, only the first chunk is returned. +func Test_deduplicateChunksByKey_AllDuplicates(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "first", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "content", + StartLine: 1, + EndLine: 10, + }, + { + ID: "duplicate-1", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "content", + StartLine: 1, + EndLine: 10, + }, + { + ID: "duplicate-2", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "content", + StartLine: 1, + EndLine: 10, + }, + { + ID: "duplicate-3", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-same", + Content: "content", + StartLine: 1, + EndLine: 10, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 1, "should have only 1 chunk (all others are duplicates)") + assert.Equal(t, "first", result[0].ID, "should keep the first occurrence") +} + +// Test_deduplicateChunksByKey_SingleChunk verifies that when given a single +// chunk, the function returns it unchanged. +func Test_deduplicateChunksByKey_SingleChunk(t *testing.T) { + // Arrange + ctx := context.Background() + repoID := uuid.New() + + chunks := []outbound.CodeChunk{ + { + ID: "only-chunk", + RepositoryID: repoID, + FilePath: "src/main.go", + Hash: "hash-unique", + Content: "func main() { ... }", + StartLine: 1, + EndLine: 10, + }, + } + + // Act + result := deduplicateChunksByKey(ctx, chunks) + + // Assert + require.NotNil(t, result, "result should not be nil") + assert.Len(t, result, 1, "should have 1 chunk") + assert.Equal(t, "only-chunk", result[0].ID) + assert.Equal(t, "src/main.go", result[0].FilePath) + assert.Equal(t, "hash-unique", result[0].Hash) +} diff --git a/internal/application/worker/job_processor_token_counting_test.go b/internal/application/worker/job_processor_token_counting_test.go index 502fe8b..3f16d2f 100644 --- a/internal/application/worker/job_processor_token_counting_test.go +++ b/internal/application/worker/job_processor_token_counting_test.go @@ -5,6 +5,7 @@ import ( "codechunking/internal/port/outbound" "context" "errors" + "fmt" "testing" "time" @@ -33,6 +34,7 @@ func TestJobProcessor_CountTokensForChunks(t *testing.T) { StartLine: 1, EndLine: 1, Type: "function", + Hash: "hash1", }, { ID: uuid.New().String(), @@ -42,6 +44,7 @@ func TestJobProcessor_CountTokensForChunks(t *testing.T) { StartLine: 3, EndLine: 3, Type: "type", + Hash: "hash2", }, } @@ -269,11 +272,12 @@ func TestJobProcessor_TokenCountingModes(t *testing.T) { chunks[i] = outbound.CodeChunk{ ID: uuid.New().String(), RepositoryID: repositoryID, - FilePath: "test.go", + FilePath: fmt.Sprintf("test_%d.go", i), Content: "func test() {}", StartLine: i + 1, EndLine: i + 1, Type: "function", + Hash: fmt.Sprintf("hash_%d", i), } } @@ -450,6 +454,7 @@ func TestJobProcessor_TokenCountingFailsGracefully(t *testing.T) { StartLine: 1, EndLine: 1, Type: "function", + Hash: "hash1", }, } @@ -732,11 +737,12 @@ func TestJobProcessor_ProgressiveTokenCounting_BatchSaving(t *testing.T) { chunks[i] = outbound.CodeChunk{ ID: uuid.New().String(), RepositoryID: repositoryID, - FilePath: "test.go", + FilePath: fmt.Sprintf("test_%d.go", i), Content: "func test() {}", StartLine: i + 1, EndLine: i + 1, Type: "function", + Hash: fmt.Sprintf("hash_%d", i), } } @@ -890,11 +896,12 @@ func TestJobProcessor_ProgressiveTokenCounting_FinalPartialBatch(t *testing.T) { chunks[i] = outbound.CodeChunk{ ID: uuid.New().String(), RepositoryID: repositoryID, - FilePath: "test.go", + FilePath: fmt.Sprintf("test_%d.go", i), Content: "func test() {}", StartLine: i + 1, EndLine: i + 1, Type: "function", + Hash: fmt.Sprintf("hash_%d", i), } } @@ -1043,11 +1050,12 @@ func TestJobProcessor_ProgressiveTokenCounting_GracefulDegradation(t *testing.T) chunks[i] = outbound.CodeChunk{ ID: uuid.New().String(), RepositoryID: repositoryID, - FilePath: "test.go", + FilePath: fmt.Sprintf("test_%d.go", i), Content: "func test() {}", StartLine: i + 1, EndLine: i + 1, Type: "function", + Hash: fmt.Sprintf("hash_%d", i), } }