In [None]:
import os
import pickle
import joblib
from abc import ABC, abstractmethod

class BaseFormatHandler(ABC):
    @abstractmethod
    def save(self, obj, path, **kwargs):
        pass

    @abstractmethod
    def load(self, path, **kwargs):
        pass

class PickleHandler(BaseFormatHandler):
    def save(self, obj, path, **kwargs):
        with open(path, 'wb') as f:
            pickle.dump(obj, f, **kwargs)

    def load(self, path, **kwargs):
        with open(path, 'rb') as f:
            return pickle.load(f, **kwargs)

class JoblibHandler(BaseFormatHandler):
    def save(self, obj, path, **kwargs):
        joblib.dump(obj, path, **kwargs)

    def load(self, path, **kwargs):
        return joblib.load(path, **kwargs)

class OnnxHandler(BaseFormatHandler):
    def save(self, obj, path, **kwargs):
        try:
            import onnx

        except ImportError:
            raise ImportError("onnx package is required for ONNX format.")

        onnx.save(obj, path, **kwargs)

    def load(self, path, **kwargs):
        try:
            import onnx

        except ImportError:
            raise ImportError("onnx package is required for ONNX format.")

        return onnx.load(path, **kwargs)

class PmmlHandler(BaseFormatHandler):
    def save(self, obj, path, **kwargs):
        raise NotImplementedError("PMML saving not implemented yet.")

    def load(self, path, **kwargs):
        raise NotImplementedError("PMML loading not implemented yet.")

In [None]:
import os
import glob
from datetime import datetime

class PersistenceManager:
    def __init__(
            self, save_dir='./models', default_format='joblib',
            checkpoint_enabled=False, checkpoint_step=1, checkpoint_keep_last=5
        ):
        self.save_dir = save_dir
        self.default_format = default_format
        self.checkpoint_enabled = checkpoint_enabled
        self.checkpoint_step = checkpoint_step
        self.checkpoint_keep_last = checkpoint_keep_last

        os.makedirs(save_dir, exist_ok=True)

        self._formats = {
            'pickle': PickleHandler(),
            'joblib': JoblibHandler(),
            'onnx': OnnxHandler(),
            'pmml': PmmlHandler(),
        }

        self._checkpoint_history = []

    def _get_handler(self, format):
        if format not in self._formats:
            raise ValueError(f"Unsupported format: {format}. Supported: {list(self._formats.keys())}")

        return self._formats[format]

    def save_model(self, model, path, format=None, **kwargs):
        format = format or self.default_format
        handler = self._get_handler(format)

        full_path = os.path.join(self.save_dir, path) if not os.path.isabs(path) else path
        os.makedirs(os.path.dirname(full_path), exist_ok=True)

        handler.save(model, full_path, **kwargs)
        print(f"Model saved to {full_path} (format: {format})")
    
    def load_model(self, path, format=None, **kwargs):
        if format is None:
            ext = os.path.splitext(path)[1].lstrip('.')
            if ext in self._formats:
                format = ext

            else:
                format = self.default_format

        handler = self._get_handler(format)
        full_path = os.path.join(self.save_dir, path) if not os.path.isabs(path) else path
        if not os.path.exists(full_path):
            raise FileNotFoundError(f"File not found: {full_path}")

        return handler.load(full_path, **kwargs)

    def save_preprocessor(self, preprocessor, path, format=None, **kwargs):
        self.save_model(preprocessor, path, format, **kwargs)

    def load_preprocessor(self, path, format=None, **kwargs):
        return self.load_model(path, format, **kwargs)

    def save_checkpoint(self, step_name, objects_dict, format=None, **kwargs):
        if not self.checkpoint_enabled:
            return

        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"checkpoint_{step_name}_{timestamp}.pkl"
        full_path = os.path.join(self.save_dir, 'checkpoints', filename)
        os.makedirs(os.path.dirname(full_path), exist_ok=True)

        with open(full_path, 'wb') as f:
            pickle.dump(objects_dict, f)

        self._checkpoint_history.append({
            'step_name': step_name,
            'path': full_path,
            'timestamp': timestamp,
            'objects': list(objects_dict.keys())
        })

        self._clean_old_checkpoints()

    def load_checkpoint(self, step_name=None):
        if not self._checkpoint_history:
            raise RuntimeError("No checkpoint found.")

        if step_name is None:
            checkpoint = self._checkpoint_history[-1]

        else:
            matching = [c for c in self._checkpoint_history if c['step_name'] == step_name]
            if not matching:
                raise ValueError(f"No checkpoint found for step '{step_name}'.")

            checkpoint = matching[-1]

        with open(checkpoint['path'], 'rb') as f:
            return pickle.load(f)

    def _clean_old_checkpoints(self):
        if len(self._checkpoint_history) <= self.checkpoint_keep_last:
            return

        sorted_ck = sorted(self._checkpoint_history, key=lambda x: x['timestamp'])
        to_remove = sorted_ck[:-self.checkpoint_keep_last]

        for ck in to_remove:
            try:
                os.remove(ck['path'])
                self._checkpoint_history.remove(ck)

            except OSError:
                pass

    def set_checkpoint_config(self, enabled=None, step=None, keep_last=None):
        if enabled is not None:
            self.checkpoint_enabled = enabled

        if step is not None:
            self.checkpoint_step = step

        if keep_last is not None:
            self.checkpoint_keep_last = keep_last