Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c939d3a
experimenting claude
AkhileshNegi Sep 17, 2025
1482a0c
first stab with claude
AkhileshNegi Sep 18, 2025
c3364d5
first stab with claude
AkhileshNegi Sep 18, 2025
ebcd9a0
adding additional logic to call evaluation directly if status is chan…
AkhileshNegi Sep 22, 2025
00b415f
updating testcases
AkhileshNegi Sep 23, 2025
a2ef005
added more testcases
AkhileshNegi Sep 23, 2025
f8e28e9
added cancelled status in enum
AkhileshNegi Sep 23, 2025
397807d
cleanups
AkhileshNegi Sep 23, 2025
ca862cf
update claude.md
AkhileshNegi Sep 23, 2025
375eb5e
coderabbit suggestion
AkhileshNegi Sep 23, 2025
14b7db5
Merge branch 'main' into feature/claude
AkhileshNegi Sep 23, 2025
38dcf45
reverting unnecessary changes
AkhileshNegi Sep 25, 2025
8a1b496
coderabbit suggestions
AkhileshNegi Sep 25, 2025
9e8d046
remove import
AkhileshNegi Sep 25, 2025
dc7a3ce
Merge branch 'main' into feature/claude
AkhileshNegi Oct 6, 2025
724497b
merging endpoints
AkhileshNegi Oct 6, 2025
3bcd772
Merge branch 'main' into feature/classification-unified-api
AkhileshNegi Oct 6, 2025
b6a7073
following PEP8 standards
AkhileshNegi Oct 6, 2025
1bbd0dd
Merge branch 'feature/classification-unified-api' of github.com:Proje…
AkhileshNegi Oct 6, 2025
c47b254
removed redundant checks
AkhileshNegi Oct 8, 2025
a5323a1
Merge branch 'main' into feature/classification-unified-api
AkhileshNegi Oct 9, 2025
0979fd8
added as todo
AkhileshNegi Oct 10, 2025
9f0f26c
Merge branch 'main' into feature/classification-unified-api
AkhileshNegi Oct 10, 2025
c04275b
Merge branch 'feature/classification-unified-api' of github.com:Proje…
AkhileshNegi Oct 10, 2025
8801b39
updated the testcase
AkhileshNegi Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions backend/app/api/routes/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import Optional
import logging
import time
from uuid import UUID
from uuid import UUID, uuid4
from pathlib import Path

import openai
from sqlmodel import Session
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks, File, Form, UploadFile

from app.models import (
FineTuningJobCreate,
FineTuningJobPublic,
FineTuningUpdate,
FineTuningStatus,
Document,
ModelEvaluationBase,
ModelEvaluationStatus,
)
from app.core.cloud import get_cloud_storage
from app.crud.document import DocumentCrud
Expand All @@ -21,10 +25,13 @@
fetch_by_id,
update_finetune_job,
fetch_by_document_id,
create_model_evaluation,
fetch_active_model_evals,
)
from app.core.db import engine
from app.api.deps import CurrentUserOrgProject, SessionDep
from app.core.finetune.preprocessing import DataPreprocessor
from app.api.routes.model_evaluation import run_model_evaluation


logger = logging.getLogger(__name__)
Expand All @@ -38,16 +45,10 @@
"running": FineTuningStatus.running,
"succeeded": FineTuningStatus.completed,
"failed": FineTuningStatus.failed,
"cancelled": FineTuningStatus.cancelled,
}


def handle_openai_error(e: openai.OpenAIError) -> str:
"""Extract error message from OpenAI error."""
if isinstance(e.body, dict) and "message" in e.body:
return e.body["message"]
return str(e)


def process_fine_tuning_job(
job_id: int,
ratio: float,
Expand Down Expand Up @@ -179,22 +180,58 @@ def process_fine_tuning_job(
description=load_description("fine_tuning/create.md"),
response_model=APIResponse,
)
def fine_tune_from_CSV(
async def fine_tune_from_CSV(
session: SessionDep,
current_user: CurrentUserOrgProject,
request: FineTuningJobCreate,
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="CSV file to use for fine-tuning"),
base_model: str = Form(...),
split_ratio: str = Form(...),
system_prompt: str = Form(...),
):
client = get_openai_client( # Used here only to validate the user's OpenAI key;
# Parse split ratios
try:
split_ratios = [float(r.strip()) for r in split_ratio.split(",")]
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid split_ratio format: {e}")

# Validate file is CSV
if not file.filename.lower().endswith(".csv") and file.content_type != "text/csv":
raise HTTPException(status_code=400, detail="File must be a CSV file")

get_openai_client( # Used here only to validate the user's OpenAI key;
# the actual client is re-initialized separately inside the background task
session,
current_user.organization_id,
current_user.project_id,
)

# Upload the file to storage and create document
# ToDo: create a helper function and then use it rather than doing things in router
storage = get_cloud_storage(session=session, project_id=current_user.project_id)
document_id = uuid4()
object_store_url = storage.put(file, Path(str(document_id)))

# Create document in database
document_crud = DocumentCrud(session, current_user.project_id)
document = Document(
id=document_id,
fname=file.filename,
object_store_url=str(object_store_url),
)
created_document = document_crud.update(document)

# Create FineTuningJobCreate request object
request = FineTuningJobCreate(
document_id=created_document.id,
base_model=base_model,
split_ratio=split_ratios,
system_prompt=system_prompt.strip(),
)

results = []

for ratio in request.split_ratio:
for ratio in split_ratios:
job, created = create_fine_tuning_job(
session=session,
request=request,
Expand Down Expand Up @@ -246,7 +283,10 @@ def fine_tune_from_CSV(
response_model=APIResponse[FineTuningJobPublic],
)
def refresh_fine_tune_status(
fine_tuning_id: int, session: SessionDep, current_user: CurrentUserOrgProject
fine_tuning_id: int,
background_tasks: BackgroundTasks,
session: SessionDep,
current_user: CurrentUserOrgProject,
):
project_id = current_user.project_id
job = fetch_by_id(session, fine_tuning_id, project_id)
Expand Down Expand Up @@ -282,13 +322,56 @@ def refresh_fine_tune_status(
error_message=openai_error_msg,
)

# Check if status is changing from running to completed
is_newly_completed = (
job.status == FineTuningStatus.running
and update_payload.status == FineTuningStatus.completed
)

Comment on lines +325 to +330
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Race: auto-evaluation can be created twice under concurrent refresh calls

Two concurrent requests can both observe “no active evals” and create duplicates. Add a DB-level guard (unique constraint/partial index) or a transactional re-check/catch on insert.

  • Option A (preferred): add a unique index on model_evaluation.fine_tuning_id where is_deleted=false and status != 'failed', then wrap create in try/except IntegrityError; on conflict, skip.
  • Option B: within a session transaction, lock the fine-tuning row (SELECT ... FOR UPDATE), re-check active evals, then create.

I can provide a concrete migration + guarded create if desired.

Also applies to: 343-369

🤖 Prompt for AI Agents
In backend/app/api/routes/fine_tuning.py around lines 330-335 (also applies to
343-369), two concurrent requests can both see “no active evals” and create
duplicate model_evaluation rows; implement a DB-level guard and guarded create.
Preferred fix: add a migration that creates a unique partial index on
model_evaluation(fine_tuning_id) WHERE is_deleted = false AND status !=
'failed', then wrap the create in a try/except catching IntegrityError and
skip/return if conflict occurs. Alternative: perform the check-and-create inside
a DB transaction that locks the fine_tuning row (SELECT ... FOR UPDATE),
re-check for active evals, then insert. Ensure sessions/transactions are used
consistently and surface a clear response when insert is skipped due to
integrity conflict.

if (
job.status != update_payload.status
or job.fine_tuned_model != update_payload.fine_tuned_model
or job.error_message != update_payload.error_message
):
job = update_finetune_job(session=session, job=job, update=update_payload)

# If the job just completed, automatically trigger evaluation
if is_newly_completed:
logger.info(
f"[refresh_fine_tune_status] Fine-tuning job completed, triggering evaluation | "
f"fine_tuning_id={fine_tuning_id}, project_id={project_id}"
)

# Check if there's already an active evaluation for this job
active_evaluations = fetch_active_model_evals(
session, fine_tuning_id, project_id
)

if not active_evaluations:
# Create a new evaluation
model_eval = create_model_evaluation(
session=session,
request=ModelEvaluationBase(fine_tuning_id=fine_tuning_id),
project_id=project_id,
organization_id=current_user.organization_id,
status=ModelEvaluationStatus.pending,
)

# Queue the evaluation task
background_tasks.add_task(
run_model_evaluation, model_eval.id, current_user
)

logger.info(
f"[refresh_fine_tune_status] Created and queued evaluation | "
f"eval_id={model_eval.id}, fine_tuning_id={fine_tuning_id}, project_id={project_id}"
)
else:
logger.info(
f"[refresh_fine_tune_status] Skipping evaluation creation - active evaluation exists | "
f"fine_tuning_id={fine_tuning_id}, project_id={project_id}"
)

job = job.model_copy(
update={
"train_data_file_url": storage.get_signed_url(job.train_data_s3_object)
Expand Down
78 changes: 50 additions & 28 deletions backend/app/core/finetune/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import difflib
import time
import logging
import time
import uuid
from typing import Set

import openai
import pandas as pd
from openai import OpenAI
import uuid
from sklearn.metrics import (
matthews_corrcoef,
)
from sklearn.metrics import matthews_corrcoef

from app.core.cloud import AmazonCloudStorage
from app.api.routes.fine_tuning import handle_openai_error
from app.core.finetune.preprocessing import DataPreprocessor

from app.utils import handle_openai_error

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +49,8 @@ def load_labels_and_prompts(self) -> None:
- 'label'
"""
logger.info(
f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: {self.test_data_s3_object}"
f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: "
f"{self.test_data_s3_object}"
)
file_obj = self.storage.stream(self.test_data_s3_object)
try:
Expand All @@ -66,11 +65,13 @@ def load_labels_and_prompts(self) -> None:

if not query_col or not label_col:
logger.error(
"[ModelEvaluator.load_labels_and_prompts] CSV must contain a 'label' column "
f"and one of: {possible_query_columns}"
"[ModelEvaluator.load_labels_and_prompts] CSV must "
"contain a 'label' column and one of: "
f"{possible_query_columns}"
)
raise ValueError(
f"CSV must contain a 'label' column and one of: {possible_query_columns}"
f"CSV must contain a 'label' column and one of: "
f"{possible_query_columns}"
)

prompts = df[query_col].astype(str).tolist()
Expand All @@ -85,12 +86,15 @@ def load_labels_and_prompts(self) -> None:

logger.info(
"[ModelEvaluator.load_labels_and_prompts] "
f"Loaded {len(self.prompts)} prompts and {len(self.y_true)} labels; "
f"query_col={query_col}, label_col={label_col}, allowed_labels={self.allowed_labels}"
f"Loaded {len(self.prompts)} prompts and "
f"{len(self.y_true)} labels; "
f"query_col={query_col}, label_col={label_col}, "
f"allowed_labels={self.allowed_labels}"
)
except Exception as e:
logger.error(
f"[ModelEvaluator.load_labels_and_prompts] Failed to load/parse test CSV: {e}",
f"[ModelEvaluator.load_labels_and_prompts] "
f"Failed to load/parse test CSV: {e}",
exc_info=True,
)
raise
Expand All @@ -111,13 +115,15 @@ def normalize_prediction(self, text: str) -> str:
return closest[0]

logger.warning(
f"[normalize_prediction] No close match found for '{t}'. Using default label '{next(iter(self.allowed_labels))}'."
f"[normalize_prediction] No close match found for '{t}'. "
f"Using default label '{next(iter(self.allowed_labels))}'."
)
return next(iter(self.allowed_labels))

def generate_predictions(self) -> tuple[list[str], str]:
logger.info(
f"[generate_predictions] Generating predictions for {len(self.prompts)} prompts."
f"[generate_predictions] Generating predictions for "
f"{len(self.prompts)} prompts."
)
start_preds = time.time()
predictions = []
Expand All @@ -128,7 +134,9 @@ def generate_predictions(self) -> tuple[list[str], str]:
while attempt < self.retries:
start_time = time.time()
logger.info(
f"[generate_predictions] Processing prompt {idx}/{total_prompts} (Attempt {attempt + 1}/{self.retries})"
f"[generate_predictions] Processing prompt "
f"{idx}/{total_prompts} "
f"(Attempt {attempt + 1}/{self.retries})"
)

try:
Expand All @@ -141,7 +149,8 @@ def generate_predictions(self) -> tuple[list[str], str]:
elapsed_time = time.time() - start_time
if elapsed_time > self.max_latency:
logger.warning(
f"[generate_predictions] Timeout exceeded for prompt {idx}/{total_prompts}. Retrying..."
f"[generate_predictions] Timeout exceeded for "
f"prompt {idx}/{total_prompts}. Retrying..."
)
continue

Expand All @@ -153,23 +162,29 @@ def generate_predictions(self) -> tuple[list[str], str]:
except openai.OpenAIError as e:
error_msg = handle_openai_error(e)
logger.error(
f"[generate_predictions] OpenAI API error at prompt {idx}/{total_prompts}: {error_msg}"
f"[generate_predictions] OpenAI API error at prompt "
f"{idx}/{total_prompts}: {error_msg}"
)
attempt += 1
if attempt == self.retries:
predictions.append("openai_error")
logger.error(
f"[generate_predictions] Maximum retries reached for prompt {idx}/{total_prompts}. Appending 'openai_error'."
f"[generate_predictions] Maximum retries reached "
f"for prompt {idx}/{total_prompts}. "
f"Appending 'openai_error'."
)
else:
logger.info(
f"[generate_predictions] Retrying prompt {idx}/{total_prompts} after OpenAI error ({attempt}/{self.retries})."
f"[generate_predictions] Retrying prompt "
f"{idx}/{total_prompts} after OpenAI error "
f"({attempt}/{self.retries})."
)

total_elapsed = time.time() - start_preds
logger.info(
f"[generate_predictions] Finished {total_prompts} prompts in {total_elapsed:.2f}s | "
f"Generated {len(predictions)} predictions."
f"[generate_predictions] Finished {total_prompts} prompts in "
f"{total_elapsed:.2f}s | Generated {len(predictions)} "
f"predictions."
)

prediction_data = pd.DataFrame(
Expand All @@ -188,7 +203,8 @@ def generate_predictions(self) -> tuple[list[str], str]:
self.prediction_data_s3_object = prediction_data_s3_object

logger.info(
f"[generate_predictions] Predictions CSV uploaded to S3 | url={prediction_data_s3_object}"
f"[generate_predictions] Predictions CSV uploaded to S3 | "
f"url={prediction_data_s3_object}"
)

return predictions, prediction_data_s3_object
Expand All @@ -197,11 +213,13 @@ def evaluate(self) -> dict:
"""Evaluate using the predictions CSV previously uploaded to S3."""
if not getattr(self, "prediction_data_s3_object", None):
raise RuntimeError(
"[evaluate] predictions_s3_object not set. Call generate_predictions() first."
"[evaluate] predictions_s3_object not set. "
"Call generate_predictions() first."
)

logger.info(
f"[evaluate] Streaming predictions CSV from: {self.prediction_data_s3_object}"
f"[evaluate] Streaming predictions CSV from: "
f"{self.prediction_data_s3_object}"
)
prediction_obj = self.storage.stream(self.prediction_data_s3_object)
try:
Expand All @@ -211,7 +229,8 @@ def evaluate(self) -> dict:

if "true_label" not in df.columns or "prediction" not in df.columns:
raise ValueError(
"[evaluate] prediction data CSV must contain 'true_label' and 'prediction' columns."
"[evaluate] prediction data CSV must contain 'true_label' "
"and 'prediction' columns."
)

y_true = df["true_label"].astype(str).str.strip().str.lower().tolist()
Expand All @@ -226,7 +245,10 @@ def evaluate(self) -> dict:
raise

def run(self) -> dict:
"""Run the full evaluation process: load data, generate predictions, evaluate results."""
"""Run the full evaluation process.

Load data, generate predictions, and evaluate results.
"""
try:
self.load_labels_and_prompts()
predictions, prediction_data_s3_object = self.generate_predictions()
Expand Down
1 change: 1 addition & 0 deletions backend/app/models/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class FineTuningStatus(str, Enum):
running = "running"
completed = "completed"
failed = "failed"
cancelled = "cancelled"


class FineTuningJobBase(SQLModel):
Expand Down
Loading