In [1]:
import sys
sys.path.append('../')  # Ajustez le chemin selon la structure de votre dossier


In [2]:
from dataclasses import dataclass
from omegaconf import DictConfig

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import MLFlowLogger

from src.models.classification import DocumentClassifier
from src.dataloader.classification import ClassificationDataModule
from src.callbacks.mlflow_callback import MLFlowModelRegistryCallback

In [25]:
# ~~~ Configuration ~~~
@dataclass
class Config:
    #data_dir: str = "/media/olivier/Media/DATASETS/cropped"
    data_dir: str = "/Users/jeff/Dev/MSE_Projects/PI/monorepo/doc-analyzer/model-experiment/data"
    batch_size: int = 16
    num_workers: int = 4
    shuffle: bool = True
    experiment_name: str = "doc-classifier-v1.0"
    logger_uri: str =  "https://user:28rCps1l6U@msemlflow.kube.isc.heia-fr.ch"
    #logger_uri: str =  "http://user:28rCps1l6U@localhost:5000"
    logger_uri: str =  "file:./mlruns" 
    max_epochs: int = 3
    upsample: bool = True
    limit_train_batches: int = 5

config = Config()

In [26]:
# ~~~ Data Preparation ~~~
data_module = ClassificationDataModule(
    config.data_dir, config.batch_size, config.num_workers, config.shuffle, config.upsample
)
data_module.setup()

In [5]:
#data_module.plot_label_distribution()

In [27]:
# ~~~ Model Initialization ~~~
model = DocumentClassifier()



In [28]:
mlf_logger = MLFlowLogger(
    experiment_name=config.experiment_name,
    tracking_uri=config.logger_uri,
    log_model=True,
)

In [29]:
# ~~~ Training ~~~
trainer = Trainer(
    limit_train_batches=config.limit_train_batches,
    max_epochs=config.max_epochs,
    logger=mlf_logger,
    fast_dev_run=False,
    #callbacks=[MLFlowModelRegistryCallback(config.experiment_name)],
)
trainer.fit(
    model,
    train_dataloaders=data_module,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name           | Type                  | Params
---------------------------------------------------------
0 | model          | ResNet                | 11.2 M
1 | train_accuracy | BinaryAccuracy        | 0     
2 | val_accuracy   | BinaryAccuracy        | 0     
3 | test_accuracy  | BinaryAccuracy        | 0     
4 | conf_matrix    | BinaryConfusionMatrix | 0     
---------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.708    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/jeff/Dev/MSE_Projects/PI/monorepo/doc-analyzer/model-experiment/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Pl

Training: |          | 0/? [00:00<?, ?it/s]

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


Validation: |          | 0/? [00:00<?, ?it/s]

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to di

Validation: |          | 0/? [00:00<?, ?it/s]

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to di

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [30]:
trainer.test(model, datamodule=data_module)

/Users/jeff/Dev/MSE_Projects/PI/monorepo/doc-analyzer/model-experiment/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. P

Testing: |          | 0/? [00:00<?, ?it/s]

TypeError: MlflowClient.log_artifact() missing 1 required positional argument: 'local_path'