In [1]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from omegaconf import OmegaConf
import mlflow
from mlflow.tracking import MlflowClient
from cyprus_fish.data import CyprusFishDataset
from cyprus_fish.utils import (
    compute_metrics,
    CalculateTrainMetricsCallback,
    extract_clean_history,
    get_best_run_by_parameter
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "convnext-tiny-224"
experiment_name = "Cyprus-Fish-Recognition"
metric_name = "metrics.test_accuracy"

In [3]:
client = MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
print(experiment)

<Experiment: artifact_location='mlflow-artifacts:/117ed76dfb5f4a3e98207858ebeee4b7', creation_time=1768081917560, experiment_id='0', last_update_time=1768081917560, lifecycle_stage='active', name='Cyprus-Fish-Recognition', tags={'mlflow.experimentKind': 'custom_model_development'}>


In [6]:
filter_parts = ["status = 'FINISHED'"]

filter_parts.append(f"params.model_name = '{model_name}'")

filter_string = " AND ".join(filter_parts)

# Search for runs
runs = mlflow.search_runs(
    experiment_ids=[experiment.experiment_id],
    filter_string=filter_string
)

In [7]:
runs

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.train_loss,metrics.test_samples_per_second,metrics.test_runtime,metrics.train_steps_per_second,...,params.report_to,params.train,params.logging_strategy,params.gradient_checkpointing,params.model,params.no_cuda,tags.mlflow.user,tags.mlflow.source.name,tags.mlflow.runName,tags.mlflow.source.type
0,b0d6a7b2bc2848adac7918ffd4cfad0a,0,FINISHED,mlflow-artifacts:/117ed76dfb5f4a3e98207858ebee...,2026-01-12 04:12:21.777000+00:00,2026-01-12 04:12:51.959000+00:00,0.671482,23.407,2.6488,0.157,...,['mlflow'],"{'k_folds': 5, 'batch_size': 32, 'grad_acc': 1...",epoch,False,"{'name': 'convnext-tiny-224', 'hf_repo_id': 'f...",False,jayray5,/home/rbencharef/miniconda3/envs/cyprus-fish-e...,convnext-tiny-224_Global_train_Ep1_BS32_LR0.00...,LOCAL


In [None]:
client = MlflowClient()
experiment = client.get_experiment_by_name(experiment_name)
print(experiment)

# Filter only among finished run
filter_parts = ["status = 'FINISHED'"]

filter_parts.append(f"params.model.name = '{model_name}'")

filter_string = " AND ".join(filter_parts)

# Search for runs
runs = mlflow.search_runs(
    experiment_ids=[experiment.experiment_id],
    filter_string=filter_string
)

In [4]:
get_best_run_by_parameter(experiment_name,metric_name,model_name)

[INFO] start pulling ml flow runs
<Experiment: artifact_location='mlflow-artifacts:/117ed76dfb5f4a3e98207858ebeee4b7', creation_time=1768081917560, experiment_id='0', last_update_time=1768081917560, lifecycle_stage='active', name='Cyprus-Fish-Recognition', tags={'mlflow.experimentKind': 'custom_model_development'}>
Run not found for the model 'convnext-tiny-224'.


(None, None, None)

In [19]:
cfg = OmegaConf.create(
        {
            
            "data": {
                "repo_id": "JayRay5/cyprus-fish-dataset",
                "repo_revision": "be9dbe8f4048fe5c399a71713fcb8d0cfdd37ae5",
                "num_classes": 3,
                "class_names": ["fish_A", "fish_B", "fish_C"],
            },
            "model": {
                "hf_repo_id": "facebook/convnext-tiny-224",
                "revision": "6166b7613034066690a621d8bf25ffdf181a34f0",
                "training_output_dir": "convnext-cyprus-fish-cls",
                "target_hf_repo_id": "fake2/convnext",
            },
            "train": {
                "k_folds": 2,
                "batch_size": 2,
                "grad_acc": 1,
                "epochs": 1,
                "lr": 3e-4,
                "warmup_steps": 0,
                "weight_decay": 0.01,
                "scheduler": "constant",
                "num_workers": 2,
                "seed": 42,
                "device": "cpu",
                "fp16": False,
                "freeze_backbone": True,
                "push_to_hub": False,
            },
        }
)

In [20]:
dataset = CyprusFishDataset(
        repo_id=cfg.data.repo_id,
        repo_revision=cfg.data.repo_revision,
        model_name=cfg.model.hf_repo_id,
        model_revision=cfg.model.revision,
        split="train",
        num_classes=cfg.data.num_classes,
    )

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [12]:
model = AutoModelForImageClassification.from_pretrained(
        cfg.model.hf_repo_id,
        num_labels=cfg.data.num_classes,
        ignore_mismatched_sizes=True,
    )

Some weights of ConvNextForImageClassification were not initialized from the model checkpoint at facebook/convnext-tiny-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
training_args = TrainingArguments(
            output_dir=".",
            logging_strategy="epoch",
            eval_strategy="no",
            remove_unused_columns=False,
            learning_rate=cfg.train.lr,
            lr_scheduler_type=cfg.train.scheduler,
            per_device_train_batch_size=cfg.train.batch_size,
            gradient_accumulation_steps=cfg.train.grad_acc,
            per_device_eval_batch_size=cfg.train.batch_size,
            num_train_epochs=cfg.train.epochs,
            warmup_steps=cfg.train.warmup_steps,
            fp16=cfg.train.fp16,
            push_to_hub=False,
            report_to=["mlflow"]
        )

In [24]:
trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            compute_metrics=compute_metrics,
        )


In [29]:
trainer.args.eval_strategy.value

'no'