In [None]:
import os
from datasets import load_dataset
from segment.create_dataset import CreateSegmentationDataset
from segment.utils import load_resize_image
import random
from datasets import Dataset
from huggingface_hub import create_repo

In [None]:
image_dir = "datasets/fashion_people_detection/images/val"

ds = load_dataset("imagefolder", data_dir=image_dir, split="train")
text_prompt = ["face", "glasses", "clothes"]
ds = ds.shuffle().take(2)

In [None]:
from typing import List, Dict, Any, Union, Set
from datasets import Dataset
from segment.components.segment.SegmentSam import SegmentSam
from segment.components.detect.DetectDino import DetectDino
from segment.components.base import Component
from segment.format_results import format_all_results
from PIL import Image
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from PIL import Image


class DataManager:
    def load_dataset(self, dataset: Union[str, Dataset], split="train") -> Dataset:
        # Placeholder for loading dataset
        if isinstance(dataset, str):
            print(f"Loading dataset: {dataset}")
            return load_dataset(dataset, split=split)
        else:
            return dataset

    def push_to_hub(self, repo_id, token, commit_message="md", private=True):
        create_repo(
            repo_id=repo_id,
            repo_type="dataset",
            exist_ok=True,
            private=private,
            token=token,
        )

        self.ds.push_to_hub(repo_id, commit_message=commit_message, token=token)

        print(f"Pushed Dataset to Hub: {repo_id}")


class TrainingManager:
    def train_model(self, dataset: Dataset, model_config: Dict[str, Any]):
        # Placeholder for model training
        print(f"Training model with config: {model_config}")
        print(f"Using dataset: {dataset}")


class ComponentManager:
    def __init__(self):
        self.components: Dict[str, Component] = {}
        self.pipeline: List[str] = []
        self.loaded_components: Set[str] = set()

    def register_component(self, component: Component):
        self.components[component.name] = component

    def get_component(self, name: str) -> Component:
        return self.components.get(name)

    def set_pipeline(self, pipeline: List[str]):
        self.pipeline = pipeline

    def validate_pipeline(self) -> bool:
        for i in range(len(self.pipeline) - 1):
            current_component = self.get_component(self.pipeline[i])
            next_component = self.get_component(self.pipeline[i + 1])

            if not all(
                key in current_component.output_keys
                for key in next_component.input_requirements
            ):
                print(
                    f"Invalid pipeline: {current_component.name} does not produce all required inputs for {next_component.name}"
                )
                return False
        return True

    def load_models(self):
        """
        Loads the models for all components in the pipeline.
        """
        for component_name in self.pipeline:
            if component_name not in self.loaded_components:
                component = self.get_component(component_name)
                component.load_model()
                self.loaded_components.add(component_name)

    def unload_models(self):
        """
        Unloads the models for all loaded components.
        """
        for component_name in self.loaded_components:
            component = self.get_component(component_name)
            component.unload_model()
        self.loaded_components.clear()


    def process(self, initial_data: Dict[str, Any]) -> Dict[str, Any]:
        if not self.validate_pipeline():
            raise ValueError("Invalid pipeline configuration")

        self.load_models()  # Load models before processing

        try:
            data = initial_data
            for component_name in self.pipeline:
                component = self.get_component(component_name)
                data = component.process(data)
            return data
        finally:
            self.unload_models()  # Ensure models are unloaded even if an exception occurs


class ImagenHeap:
    def __init__(self):
        self.data_manager = DataManager()
        self.component_manager = ComponentManager()
        self.training_manager = TrainingManager()

    def load_dataset(self, ds: Union[str, Dataset]) -> Dataset:
        return self.data_manager.load_dataset(ds)

    def process_dataset(
        self, dataset: Dataset, text_prompt: Union[str, List[str]], **kwargs
    ) -> Dataset:
        # Prepare initial data for the pipeline
        initial_data = {"images": dataset['image'], "text_prompt": text_prompt, **kwargs}

        # Process the data through the pipeline
        processed_data = self.component_manager.process(initial_data)

        # Extract the final dataset from the processed data
        return processed_data
        

    def train_model(self, dataset: Dataset, model_config: Dict[str, Any]):
        self.training_manager.train_model(dataset, model_config)


imagen_heap = ImagenHeap()

# Register components
imagen_heap.component_manager.register_component(DetectDino())
imagen_heap.component_manager.register_component(SegmentSam())

imagen_heap.component_manager.set_pipeline(["detect","segment"])

# Load dataset
dataset = imagen_heap.load_dataset(ds)

# Process dataset
processed_dataset = imagen_heap.process_dataset(
    dataset,
    text_prompt=text_prompt,
    max_image_side=1024,
    box_threshold=0.3,
    text_threshold=0.25,
    iou_threshold=0.8,
    return_tensors=True,
)

In [None]:
results = format_all_results(processed_dataset, polygons=False)
results = [r.pop('polygons') for r in results]
results

In [None]:
results[0][0]