In [6]:
# from dotenv import load_dotenv
# load_dotenv()
import os
import boto3

from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import Compute, SourceCode, StoppingCondition



# iam = boto3.client('iam')
# role = iam.get_role(RoleName='sagemaker')['Role']['Arn']

In [7]:
pytorch_image = '763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.8.0-gpu-py312-cu129-ubuntu22.04-sagemaker'
# define the script to be run
source_code = SourceCode(
    source_dir="scripts",
    entry_script="eval.sh",
)

from huggingface_hub import HfFolder

model_id = 'Qwen/Qwen3-0.6B'
dataset_name = 'Josephgflowers/Finance-Instruct-500k'

environment = {
    'HF_TOKEN': HfFolder.get_token(),
    "HF_DATASET": dataset_name,
    "MODEL_ID": model_id,
    "MLFLOW_TRACKING_URI": os.getenv(
        'MLFLOW_TRACKING_URI',
        'arn:aws:sagemaker:us-east-1:198346569064:mlflow-tracking-server/vlm-finetuning-server',
    ),
    "MLFLOW_EXPERIMENT_NAME": "qwen3-06b-lora-ft-finance",
    "NUM_SAMPLES": '10',
}

# experiment_name = 'qwen3-06b-lora-ft-finance'
# run_name = 'qwen3-06b-finance'

hyperparameters = {
    "model_id": model_id,
    "adapter_path": 's3://sagemaker-us-east-1-198346569064/qwen3-06b-fine-tuned/',
    'dataset_name': dataset_name,
    'experiment_name': 'qwen3-06b-lora-ft-finance',
    'run_name': 'fine-tuning-run-1'
}

assert (
    environment["MLFLOW_TRACKING_URI"] != "XXX"
), "Please set your MLFLOW_TRACKING_URI in the environment variable"

assert (
    environment["HF_TOKEN"] is not None
), "Please set your HF_TOKEN in the environment variable"

In [8]:
stopping_condition = StoppingCondition(
    max_runtime_in_seconds=60 * 60 * 10,  # seconds * minutes * hours
)

compute = Compute(
    instance_count=1,
    instance_type="ml.g5.2xlarge",
    # volume_size_in_gb=96,
    keep_alive_period_in_seconds=3600,
)

In [9]:
base_job_name = "mlflow-eval-llmaaj"

# define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=pytorch_image,
    source_code=source_code,
    stopping_condition=stopping_condition,
    base_job_name=base_job_name,
    compute=compute,
    environment=environment,
    hyperparameters=hyperparameters,
)

In [10]:
model_trainer.train(wait=False)

In [11]:
# import mlflow

# def get_run_id_from_name(experiment_name: str, run_name: str) -> str:
#     # Look up the experiment ID
#     experiment = mlflow.get_experiment_by_name(experiment_name)
#     if experiment is None:
#         raise ValueError(f"Experiment '{experiment_name}' not found")

#     # Search runs in that experiment by run_name (stored as tag)
#     runs = mlflow.search_runs(
#         experiment_ids=[experiment.experiment_id],
#         filter_string=f"tags.mlflow.runName = '{run_name}'"
#     )

#     if runs.empty:
#         raise ValueError(f"No run found with name '{run_name}' in experiment '{experiment_name}'")

#     # Take the first match (assuming run_name is unique per experiment)
#     return runs.iloc[0].run_id
# mlflow.set_tracking_uri('arn:aws:sagemaker:us-east-1:198346569064:mlflow-tracking-server/vlm-finetuning-server')
# experiment_name = 'qwen3-06b-lora-ft-finance'
# run_name = 'qwen3-06b-finance'
# run_id = get_run_id_from_name(experiment_name, run_name)
# run_id


In [12]:
import mlflow

mlflow.set_tracking_uri('arn:aws:sagemaker:us-east-1:198346569064:mlflow-tracking-server/vlm-finetuning-server')

experiment_name = 'qwen3-06b-lora-ft-finance'
run_name = 'qwen3-06b'

experiment = mlflow.get_experiment_by_name(experiment_name)


# Search runs in that experiment by run_name (stored as tag)
runs = mlflow.search_runs(
    experiment_ids=[experiment.experiment_id],
    filter_string=f"tags.mlflow.runName = '{run_name}'"
)


    # Take the first match (assuming run_name is unique per experiment)
id = runs.iloc[0].run_id
id

'02699a4cbfde4ba9b95771a5c1b2ff9e'

In [13]:
id

'02699a4cbfde4ba9b95771a5c1b2ff9e'

In [14]:
mlflow.set_experiment(experiment_name=experiment_name)

<Experiment: artifact_location='s3://sagemaker-us-east-1-198346569064/mlflow-assets/8', creation_time=1759972361966, experiment_id='8', last_update_time=1759972361966, lifecycle_stage='active', name='qwen3-06b-lora-ft-finance', tags={}>

In [15]:

# with mlflow.start_run(run_id=id):
#     with mlflow.start_run(run_name='my-test-run', nested=True):
#         mlflow.log_param('fruit','mango')