diff --git a/backend/app/api/routes/fine_tuning.py b/backend/app/api/routes/fine_tuning.py index 9d053393..66baa3ad 100644 --- a/backend/app/api/routes/fine_tuning.py +++ b/backend/app/api/routes/fine_tuning.py @@ -1,17 +1,21 @@ from typing import Optional import logging import time -from uuid import UUID +from uuid import UUID, uuid4 +from pathlib import Path import openai from sqlmodel import Session -from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi import APIRouter, HTTPException, BackgroundTasks, File, Form, UploadFile from app.models import ( FineTuningJobCreate, FineTuningJobPublic, FineTuningUpdate, FineTuningStatus, + Document, + ModelEvaluationBase, + ModelEvaluationStatus, ) from app.core.cloud import get_cloud_storage from app.crud.document import DocumentCrud @@ -21,10 +25,13 @@ fetch_by_id, update_finetune_job, fetch_by_document_id, + create_model_evaluation, + fetch_active_model_evals, ) from app.core.db import engine from app.api.deps import CurrentUserOrgProject, SessionDep from app.core.finetune.preprocessing import DataPreprocessor +from app.api.routes.model_evaluation import run_model_evaluation logger = logging.getLogger(__name__) @@ -38,16 +45,10 @@ "running": FineTuningStatus.running, "succeeded": FineTuningStatus.completed, "failed": FineTuningStatus.failed, + "cancelled": FineTuningStatus.cancelled, } -def handle_openai_error(e: openai.OpenAIError) -> str: - """Extract error message from OpenAI error.""" - if isinstance(e.body, dict) and "message" in e.body: - return e.body["message"] - return str(e) - - def process_fine_tuning_job( job_id: int, ratio: float, @@ -179,22 +180,58 @@ def process_fine_tuning_job( description=load_description("fine_tuning/create.md"), response_model=APIResponse, ) -def fine_tune_from_CSV( +async def fine_tune_from_CSV( session: SessionDep, current_user: CurrentUserOrgProject, - request: FineTuningJobCreate, background_tasks: BackgroundTasks, + file: UploadFile = File(..., description="CSV file to use for fine-tuning"), + base_model: str = Form(...), + split_ratio: str = Form(...), + system_prompt: str = Form(...), ): - client = get_openai_client( # Used here only to validate the user's OpenAI key; + # Parse split ratios + try: + split_ratios = [float(r.strip()) for r in split_ratio.split(",")] + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid split_ratio format: {e}") + + # Validate file is CSV + if not file.filename.lower().endswith(".csv") and file.content_type != "text/csv": + raise HTTPException(status_code=400, detail="File must be a CSV file") + + get_openai_client( # Used here only to validate the user's OpenAI key; # the actual client is re-initialized separately inside the background task session, current_user.organization_id, current_user.project_id, ) + # Upload the file to storage and create document + # ToDo: create a helper function and then use it rather than doing things in router + storage = get_cloud_storage(session=session, project_id=current_user.project_id) + document_id = uuid4() + object_store_url = storage.put(file, Path(str(document_id))) + + # Create document in database + document_crud = DocumentCrud(session, current_user.project_id) + document = Document( + id=document_id, + fname=file.filename, + object_store_url=str(object_store_url), + ) + created_document = document_crud.update(document) + + # Create FineTuningJobCreate request object + request = FineTuningJobCreate( + document_id=created_document.id, + base_model=base_model, + split_ratio=split_ratios, + system_prompt=system_prompt.strip(), + ) + results = [] - for ratio in request.split_ratio: + for ratio in split_ratios: job, created = create_fine_tuning_job( session=session, request=request, @@ -246,7 +283,10 @@ def fine_tune_from_CSV( response_model=APIResponse[FineTuningJobPublic], ) def refresh_fine_tune_status( - fine_tuning_id: int, session: SessionDep, current_user: CurrentUserOrgProject + fine_tuning_id: int, + background_tasks: BackgroundTasks, + session: SessionDep, + current_user: CurrentUserOrgProject, ): project_id = current_user.project_id job = fetch_by_id(session, fine_tuning_id, project_id) @@ -282,6 +322,12 @@ def refresh_fine_tune_status( error_message=openai_error_msg, ) + # Check if status is changing from running to completed + is_newly_completed = ( + job.status == FineTuningStatus.running + and update_payload.status == FineTuningStatus.completed + ) + if ( job.status != update_payload.status or job.fine_tuned_model != update_payload.fine_tuned_model @@ -289,6 +335,43 @@ def refresh_fine_tune_status( ): job = update_finetune_job(session=session, job=job, update=update_payload) + # If the job just completed, automatically trigger evaluation + if is_newly_completed: + logger.info( + f"[refresh_fine_tune_status] Fine-tuning job completed, triggering evaluation | " + f"fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + + # Check if there's already an active evaluation for this job + active_evaluations = fetch_active_model_evals( + session, fine_tuning_id, project_id + ) + + if not active_evaluations: + # Create a new evaluation + model_eval = create_model_evaluation( + session=session, + request=ModelEvaluationBase(fine_tuning_id=fine_tuning_id), + project_id=project_id, + organization_id=current_user.organization_id, + status=ModelEvaluationStatus.pending, + ) + + # Queue the evaluation task + background_tasks.add_task( + run_model_evaluation, model_eval.id, current_user + ) + + logger.info( + f"[refresh_fine_tune_status] Created and queued evaluation | " + f"eval_id={model_eval.id}, fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + else: + logger.info( + f"[refresh_fine_tune_status] Skipping evaluation creation - active evaluation exists | " + f"fine_tuning_id={fine_tuning_id}, project_id={project_id}" + ) + job = job.model_copy( update={ "train_data_file_url": storage.get_signed_url(job.train_data_s3_object) diff --git a/backend/app/core/finetune/evaluation.py b/backend/app/core/finetune/evaluation.py index 527087eb..560a4c75 100644 --- a/backend/app/core/finetune/evaluation.py +++ b/backend/app/core/finetune/evaluation.py @@ -1,19 +1,17 @@ import difflib -import time import logging +import time +import uuid from typing import Set import openai import pandas as pd from openai import OpenAI -import uuid -from sklearn.metrics import ( - matthews_corrcoef, -) +from sklearn.metrics import matthews_corrcoef + from app.core.cloud import AmazonCloudStorage -from app.api.routes.fine_tuning import handle_openai_error from app.core.finetune.preprocessing import DataPreprocessor - +from app.utils import handle_openai_error logger = logging.getLogger(__name__) @@ -51,7 +49,8 @@ def load_labels_and_prompts(self) -> None: - 'label' """ logger.info( - f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: {self.test_data_s3_object}" + f"[ModelEvaluator.load_labels_and_prompts] Loading CSV from: " + f"{self.test_data_s3_object}" ) file_obj = self.storage.stream(self.test_data_s3_object) try: @@ -66,11 +65,13 @@ def load_labels_and_prompts(self) -> None: if not query_col or not label_col: logger.error( - "[ModelEvaluator.load_labels_and_prompts] CSV must contain a 'label' column " - f"and one of: {possible_query_columns}" + "[ModelEvaluator.load_labels_and_prompts] CSV must " + "contain a 'label' column and one of: " + f"{possible_query_columns}" ) raise ValueError( - f"CSV must contain a 'label' column and one of: {possible_query_columns}" + f"CSV must contain a 'label' column and one of: " + f"{possible_query_columns}" ) prompts = df[query_col].astype(str).tolist() @@ -85,12 +86,15 @@ def load_labels_and_prompts(self) -> None: logger.info( "[ModelEvaluator.load_labels_and_prompts] " - f"Loaded {len(self.prompts)} prompts and {len(self.y_true)} labels; " - f"query_col={query_col}, label_col={label_col}, allowed_labels={self.allowed_labels}" + f"Loaded {len(self.prompts)} prompts and " + f"{len(self.y_true)} labels; " + f"query_col={query_col}, label_col={label_col}, " + f"allowed_labels={self.allowed_labels}" ) except Exception as e: logger.error( - f"[ModelEvaluator.load_labels_and_prompts] Failed to load/parse test CSV: {e}", + f"[ModelEvaluator.load_labels_and_prompts] " + f"Failed to load/parse test CSV: {e}", exc_info=True, ) raise @@ -111,13 +115,15 @@ def normalize_prediction(self, text: str) -> str: return closest[0] logger.warning( - f"[normalize_prediction] No close match found for '{t}'. Using default label '{next(iter(self.allowed_labels))}'." + f"[normalize_prediction] No close match found for '{t}'. " + f"Using default label '{next(iter(self.allowed_labels))}'." ) return next(iter(self.allowed_labels)) def generate_predictions(self) -> tuple[list[str], str]: logger.info( - f"[generate_predictions] Generating predictions for {len(self.prompts)} prompts." + f"[generate_predictions] Generating predictions for " + f"{len(self.prompts)} prompts." ) start_preds = time.time() predictions = [] @@ -128,7 +134,9 @@ def generate_predictions(self) -> tuple[list[str], str]: while attempt < self.retries: start_time = time.time() logger.info( - f"[generate_predictions] Processing prompt {idx}/{total_prompts} (Attempt {attempt + 1}/{self.retries})" + f"[generate_predictions] Processing prompt " + f"{idx}/{total_prompts} " + f"(Attempt {attempt + 1}/{self.retries})" ) try: @@ -141,7 +149,8 @@ def generate_predictions(self) -> tuple[list[str], str]: elapsed_time = time.time() - start_time if elapsed_time > self.max_latency: logger.warning( - f"[generate_predictions] Timeout exceeded for prompt {idx}/{total_prompts}. Retrying..." + f"[generate_predictions] Timeout exceeded for " + f"prompt {idx}/{total_prompts}. Retrying..." ) continue @@ -153,23 +162,29 @@ def generate_predictions(self) -> tuple[list[str], str]: except openai.OpenAIError as e: error_msg = handle_openai_error(e) logger.error( - f"[generate_predictions] OpenAI API error at prompt {idx}/{total_prompts}: {error_msg}" + f"[generate_predictions] OpenAI API error at prompt " + f"{idx}/{total_prompts}: {error_msg}" ) attempt += 1 if attempt == self.retries: predictions.append("openai_error") logger.error( - f"[generate_predictions] Maximum retries reached for prompt {idx}/{total_prompts}. Appending 'openai_error'." + f"[generate_predictions] Maximum retries reached " + f"for prompt {idx}/{total_prompts}. " + f"Appending 'openai_error'." ) else: logger.info( - f"[generate_predictions] Retrying prompt {idx}/{total_prompts} after OpenAI error ({attempt}/{self.retries})." + f"[generate_predictions] Retrying prompt " + f"{idx}/{total_prompts} after OpenAI error " + f"({attempt}/{self.retries})." ) total_elapsed = time.time() - start_preds logger.info( - f"[generate_predictions] Finished {total_prompts} prompts in {total_elapsed:.2f}s | " - f"Generated {len(predictions)} predictions." + f"[generate_predictions] Finished {total_prompts} prompts in " + f"{total_elapsed:.2f}s | Generated {len(predictions)} " + f"predictions." ) prediction_data = pd.DataFrame( @@ -188,7 +203,8 @@ def generate_predictions(self) -> tuple[list[str], str]: self.prediction_data_s3_object = prediction_data_s3_object logger.info( - f"[generate_predictions] Predictions CSV uploaded to S3 | url={prediction_data_s3_object}" + f"[generate_predictions] Predictions CSV uploaded to S3 | " + f"url={prediction_data_s3_object}" ) return predictions, prediction_data_s3_object @@ -197,11 +213,13 @@ def evaluate(self) -> dict: """Evaluate using the predictions CSV previously uploaded to S3.""" if not getattr(self, "prediction_data_s3_object", None): raise RuntimeError( - "[evaluate] predictions_s3_object not set. Call generate_predictions() first." + "[evaluate] predictions_s3_object not set. " + "Call generate_predictions() first." ) logger.info( - f"[evaluate] Streaming predictions CSV from: {self.prediction_data_s3_object}" + f"[evaluate] Streaming predictions CSV from: " + f"{self.prediction_data_s3_object}" ) prediction_obj = self.storage.stream(self.prediction_data_s3_object) try: @@ -211,7 +229,8 @@ def evaluate(self) -> dict: if "true_label" not in df.columns or "prediction" not in df.columns: raise ValueError( - "[evaluate] prediction data CSV must contain 'true_label' and 'prediction' columns." + "[evaluate] prediction data CSV must contain 'true_label' " + "and 'prediction' columns." ) y_true = df["true_label"].astype(str).str.strip().str.lower().tolist() @@ -226,7 +245,10 @@ def evaluate(self) -> dict: raise def run(self) -> dict: - """Run the full evaluation process: load data, generate predictions, evaluate results.""" + """Run the full evaluation process. + + Load data, generate predictions, and evaluate results. + """ try: self.load_labels_and_prompts() predictions, prediction_data_s3_object = self.generate_predictions() diff --git a/backend/app/models/fine_tuning.py b/backend/app/models/fine_tuning.py index a3b0e866..4e326ee5 100644 --- a/backend/app/models/fine_tuning.py +++ b/backend/app/models/fine_tuning.py @@ -15,6 +15,7 @@ class FineTuningStatus(str, Enum): running = "running" completed = "completed" failed = "failed" + cancelled = "cancelled" class FineTuningJobBase(SQLModel): diff --git a/backend/app/tests/api/routes/test_fine_tuning.py b/backend/app/tests/api/routes/test_fine_tuning.py index 5582b73f..abe00680 100644 --- a/backend/app/tests/api/routes/test_fine_tuning.py +++ b/backend/app/tests/api/routes/test_fine_tuning.py @@ -1,10 +1,18 @@ +import io import pytest - +from moto import mock_aws from unittest.mock import patch, MagicMock +import boto3 from app.tests.utils.test_data import create_test_fine_tuning_jobs from app.tests.utils.utils import get_document -from app.models import Fine_Tuning +from app.models import ( + Fine_Tuning, + FineTuningStatus, + ModelEvaluation, + ModelEvaluationStatus, +) +from app.core.config import settings def create_file_mock(file_type): @@ -23,72 +31,87 @@ def _side_effect(file=None, purpose=None): @pytest.mark.usefixtures("client", "db", "user_api_key_header") -@patch("app.api.routes.fine_tuning.DataPreprocessor") -@patch("app.api.routes.fine_tuning.get_openai_client") class TestCreateFineTuningJobAPI: + @mock_aws def test_finetune_from_csv_multiple_split_ratio( self, - mock_get_openai_client, - mock_preprocessor_cls, client, db, user_api_key_header, ): - document = get_document(db, "dalgo_sample.json") + # Setup S3 bucket for moto + s3 = boto3.client("s3", region_name=settings.AWS_DEFAULT_REGION) + bucket_name = settings.AWS_S3_BUCKET_PREFIX + if settings.AWS_DEFAULT_REGION == "us-east-1": + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={ + "LocationConstraint": settings.AWS_DEFAULT_REGION + }, + ) + # Create a test CSV file content + csv_content = "prompt,label\ntest1,label1\ntest2,label2\ntest3,label3" + + # Setup test files for preprocessing for path in ["/tmp/train.jsonl", "/tmp/test.jsonl"]: with open(path, "w") as f: - f.write("{}") - - mock_preprocessor = MagicMock() - mock_preprocessor.process.return_value = { - "train_jsonl_temp_filepath": "/tmp/train.jsonl", - "train_csv_s3_object": "s3://bucket/train.csv", - "test_csv_s3_object": "s3://bucket/test.csv", - } - mock_preprocessor.cleanup = MagicMock() - mock_preprocessor_cls.return_value = mock_preprocessor - - mock_openai = MagicMock() - mock_openai.files.create.side_effect = create_file_mock("fine-tune") - mock_openai.fine_tuning.jobs.create.side_effect = [ - MagicMock(id=f"ft_mock_job_{i}", status="running") for i in range(1, 4) - ] - mock_get_openai_client.return_value = mock_openai - - body = { - "document_id": str(document.id), - "base_model": "gpt-4", - "split_ratio": [0.5, 0.7, 0.9], - "system_prompt": "you are a model able to classify", - } - - with patch("app.api.routes.fine_tuning.Session") as SessionMock: - SessionMock.return_value.__enter__.return_value = db - SessionMock.return_value.__exit__.return_value = None - - response = client.post( - "/api/v1/fine_tuning/fine_tune", - json=body, - headers=user_api_key_header, - ) + f.write('{"prompt": "test", "completion": "label"}') + + with patch( + "app.api.routes.fine_tuning.get_cloud_storage" + ) as mock_get_cloud_storage: + with patch( + "app.api.routes.fine_tuning.get_openai_client" + ) as mock_get_openai_client: + with patch( + "app.api.routes.fine_tuning.process_fine_tuning_job" + ) as mock_process_job: + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.put.return_value = ( + f"s3://{settings.AWS_S3_BUCKET_PREFIX}/test.csv" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI client (for validation only) + mock_openai = MagicMock() + mock_get_openai_client.return_value = mock_openai + + # Create file upload data + csv_file = io.BytesIO(csv_content.encode()) + response = client.post( + "/api/v1/fine_tuning/fine_tune", + files={"file": ("test.csv", csv_file, "text/csv")}, + data={ + "base_model": "gpt-4", + "split_ratio": "0.5,0.7,0.9", + "system_prompt": "you are a model able to classify", + }, + headers=user_api_key_header, + ) assert response.status_code == 200 json_data = response.json() assert json_data["success"] is True assert json_data["data"]["message"] == "Fine-tuning job(s) started." assert json_data["metadata"] is None + assert "jobs" in json_data["data"] + assert len(json_data["data"]["jobs"]) == 3 + + # Verify that the background task was called for each split ratio + assert mock_process_job.call_count == 3 jobs = db.query(Fine_Tuning).all() assert len(jobs) == 3 - for i, job in enumerate(jobs, start=1): + for job in jobs: db.refresh(job) - assert job.status == "running" - assert job.provider_job_id == f"ft_mock_job_{i}" - assert job.training_file_id is not None - assert job.train_data_s3_object == "s3://bucket/train.csv" - assert job.test_data_s3_object == "s3://bucket/test.csv" + assert ( + job.status == "pending" + ) # Since background processing is mocked, status remains pending assert job.split_ratio in [0.5, 0.7, 0.9] @@ -100,7 +123,7 @@ def test_retrieve_fine_tuning_job( ): jobs, _ = create_test_fine_tuning_jobs(db, [0.3]) job = jobs[0] - job.provider_job_id = "ft_mock_job_123" + job.provider_job_id = "ftjob-mock_job_123" db.flush() mock_openai_job = MagicMock( @@ -129,7 +152,7 @@ def test_retrieve_fine_tuning_job_failed( ): jobs, _ = create_test_fine_tuning_jobs(db, [0.3]) job = jobs[0] - job.provider_job_id = "ft_mock_job_123" + job.provider_job_id = "ftjob-mock_job_123" db.flush() mock_openai_job = MagicMock( @@ -178,3 +201,267 @@ def test_fetch_jobs_document(self, client, db, user_api_key_header): for job in json_data["data"]: assert job["document_id"] == str(document.id) assert job["status"] == "pending" + + +@pytest.mark.usefixtures("client", "db", "user_api_key_header") +@patch("app.api.routes.fine_tuning.get_openai_client") +@patch("app.api.routes.fine_tuning.get_cloud_storage") +@patch("app.api.routes.fine_tuning.run_model_evaluation") +class TestAutoEvaluationTrigger: + """Test cases for automatic evaluation triggering when fine-tuning completes.""" + + def test_successful_auto_evaluation_trigger( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is automatically triggered when job status changes from running to completed.""" + # Setup: Create a fine-tuning job with running status + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.running + job.provider_job_id = "ftjob-mock_job_123" + # Add required fields for model evaluation + job.test_data_s3_object = f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv" + job.system_prompt = "You are a helpful assistant" + db.add(job) + db.commit() + db.refresh(job) + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response indicating job completion + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + assert json_data["data"]["fine_tuned_model"] == "ft:gpt-4:custom-model:12345" + + # Verify that model evaluation was triggered + mock_run_model_evaluation.assert_called_once() + call_args = mock_run_model_evaluation.call_args[0] + eval_id = call_args[0] + + # Verify evaluation was created in database + model_eval = ( + db.query(ModelEvaluation).filter(ModelEvaluation.id == eval_id).first() + ) + assert model_eval is not None + assert model_eval.fine_tuning_id == job.id + assert model_eval.status == ModelEvaluationStatus.pending + + def test_skip_evaluation_when_already_exists( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is skipped when an active evaluation already exists.""" + # Setup: Create a fine-tuning job with running status + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.running + job.provider_job_id = "ftjob-mock_job_123" + # Add required fields for model evaluation + job.test_data_s3_object = f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv" + job.system_prompt = "You are a helpful assistant" + db.add(job) + db.commit() + + # Create an existing active evaluation + existing_eval = ModelEvaluation( + fine_tuning_id=job.id, + status=ModelEvaluationStatus.pending, + project_id=job.project_id, + organization_id=job.organization_id, + document_id=job.document_id, + fine_tuned_model="ft:gpt-4:test-model:123", + test_data_s3_object=f"{settings.AWS_S3_BUCKET_PREFIX}/test-data.csv", + base_model="gpt-4", + split_ratio=0.7, + system_prompt="You are a helpful assistant", + ) + db.add(existing_eval) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response indicating job completion + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + + # Verify that no new evaluation was triggered + mock_run_model_evaluation.assert_not_called() + + # Verify only one evaluation exists in database + evaluations = ( + db.query(ModelEvaluation) + .filter(ModelEvaluation.fine_tuning_id == job.id) + .all() + ) + assert len(evaluations) == 1 + assert evaluations[0].id == existing_eval.id + + def test_evaluation_not_triggered_for_non_completion_status_changes( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is not triggered for status changes other than to completed.""" + # Test Case 1: pending to running + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.pending + job.provider_job_id = "ftjob-mock_job_123" + db.add(job) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + mock_openai_job = MagicMock( + status="running", + fine_tuned_model=None, + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "running" + mock_run_model_evaluation.assert_not_called() + + # Test Case 2: running to failed + job.status = FineTuningStatus.running + db.add(job) + db.commit() + + mock_openai_job.status = "failed" + mock_openai_job.error = MagicMock(message="Training failed") + + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "failed" + mock_run_model_evaluation.assert_not_called() + + def test_evaluation_not_triggered_for_already_completed_jobs( + self, + mock_run_model_evaluation, + mock_get_cloud_storage, + mock_get_openai_client, + client, + db, + user_api_key_header, + ): + """Test that evaluation is not triggered when refreshing an already completed job.""" + # Setup: Create a fine-tuning job that's already completed + jobs, _ = create_test_fine_tuning_jobs(db, [0.7]) + job = jobs[0] + job.status = FineTuningStatus.completed + job.provider_job_id = "ftjob-mock_job_123" + job.fine_tuned_model = "ft:gpt-4:custom-model:12345" + db.add(job) + db.commit() + + # Mock cloud storage + mock_storage = MagicMock() + mock_storage.get_signed_url.return_value = ( + "https://test.s3.amazonaws.com/signed-url" + ) + mock_get_cloud_storage.return_value = mock_storage + + # Mock OpenAI response (job remains succeeded) + mock_openai_job = MagicMock( + status="succeeded", + fine_tuned_model="ft:gpt-4:custom-model:12345", + error=None, + ) + mock_openai = MagicMock() + mock_openai.fine_tuning.jobs.retrieve.return_value = mock_openai_job + mock_get_openai_client.return_value = mock_openai + + # Action: Refresh the fine-tuning job status + response = client.get( + f"/api/v1/fine_tuning/{job.id}/refresh", headers=user_api_key_header + ) + + # Verify response + assert response.status_code == 200 + json_data = response.json() + assert json_data["data"]["status"] == "completed" + + # Verify that no evaluation was triggered (since it wasn't newly completed) + mock_run_model_evaluation.assert_not_called() + + # Verify no evaluations exist in database for this job + evaluations = ( + db.query(ModelEvaluation) + .filter(ModelEvaluation.fine_tuning_id == job.id) + .all() + ) + assert len(evaluations) == 0