In [7]:
import json
import sqlite3
import cloudpickle
import pickle
import collections
import ray
import threading
import hmac

def serialize_sqlite_connection(conn):
    return ray.data.datasource

def deserialize_sqlite_connection(path):
    return sqlite3.connect(path)

def serialize_thread_lock(lock):
    print("serialized_thread_lock")
    return None

def deserialize_thread_lock(_):
    return None

class SerializableHMAC:
    def __init__(self, hmac_obj):
        self.msg = hmac_obj.digest()
        self.digestmod = hmac_obj.digest_size

    @classmethod
    def from_hmac(cls, hmac_obj):
        return cls(hmac_obj)

    def to_hmac(self):
        return hmac.new(b'', self.msg, digestmod=self.digestmod)

def serialize_hmac(hmac_obj):
    return SerializableHMAC.from_hmac(hmac_obj)

def deserialize_hmac(serialized_hmac):
    return serialized_hmac.to_hmac()

class HandleWrapper:
    def __init__(self, obj):
        self.class_name = type(obj).__name__
        self.attributes = {}
        for key, value in obj.__dict__.items():
            if not key.startswith('_'):
                try:
                    cloudpickle.dumps(value)
                    self.attributes[key] = value
                except:
                    self.attributes[key] = f"Unpicklable_{type(value).__name__}"

    @classmethod
    def from_object(cls, obj):
        return cls(obj)

    def to_object(self):
        # This is a placeholder. You might need to implement proper reconstruction logic.
        return type(self.class_name, (), self.attributes)()

def serialize_handle_object(obj):
    return HandleWrapper.from_object(obj)

def deserialize_handle_object(wrapped):
    return wrapped.to_object()


# # Register the custom serializers with Ray
# ray.util.register_serializer(
#     object,  # This will catch all objects
#     serializer=lambda obj: serialize_handle_object(obj) if hasattr(obj, 'handle') or not cloudpickle.is_picklable(obj) else obj,
#     deserializer=lambda obj: deserialize_handle_object(obj) if isinstance(obj, HandleWrapper) else obj
# )

ray.util.register_serializer(sqlite3.Connection, serializer=serialize_sqlite_connection, deserializer=deserialize_sqlite_connection)
ray.util.register_serializer(type(threading.Lock), serializer=serialize_thread_lock, deserializer=deserialize_thread_lock)
ray.util.register_serializer(hmac.HMAC, serializer=serialize_hmac, deserializer=deserialize_hmac)

# Initialize Ray
if not ray.is_initialized():
    ray.init()

In [8]:
import ray
import mlflow
import pandas as pd
from pydantic import BaseModel
from typing import Dict, Any, List
import uuid
import json

# Initialize Ray
if not ray.is_initialized():
    ray.init()

# Define models (not used directly with Ray)
class DataObject(BaseModel):
    content: Dict[str, Any]
    metadata: Dict[str, Any] = {}
    vector: List[float] = None

class Artifact(BaseModel):
    artifact_id: str
    payload: bytes
    metadata: Dict[str, Any] = {}

class Agent(BaseModel):
    agent_id: str

    def process(self, *args, **kwargs):
        raise NotImplementedError("Subclasses must implement this method")

class DocumentAgent(Agent):
    def process(self, data_object: Dict[str, Any]):
        data_object['metadata']['length'] = len(str(data_object['content']))
        data_object['metadata']['type'] = type(data_object['content']).__name__
        return data_object

class ArtifactAgent(Agent):
    def process(self, artifact: Dict[str, Any]):
        print(f"Saving artifact {artifact['artifact_id']} to storage")
        return artifact

class GraphAgent(Agent):
    def process(self, relations: List[Dict[str, str]]):
        return pd.DataFrame(relations).to_dict()

class HTNGenerator(Agent):
    def process(self, data_objects: List[Dict], artifacts: List[Dict], graph: Dict):
        return {
            "root": "process_data",
            "subtasks": [
                {"task": "process_documents", "objects": data_objects},
                {"task": "process_artifacts", "objects": artifacts},
                {"task": "analyze_graph", "relations": graph}
            ]
        }

# Ray remote functions
@ray.remote
def process_document(data_object: Dict[str, Any]):
    agent = DocumentAgent(agent_id="doc_agent").model_dump()
    return agent.process(data_object)

@ray.remote
def process_artifact(artifact: Dict[str, Any]):
    agent = ArtifactAgent(agent_id="artifact_agent").model_dump()
    return agent.process(artifact)

@ray.remote
def process_graph(relations: List[Dict[str, str]]):
    agent = GraphAgent(agent_id="graph_agent")
    return agent.process(relations)

@ray.remote
def generate_htn(data_objects: List[Dict], artifacts: List[Dict], graph: Dict):
    agent = HTNGenerator(agent_id="htn_generator")
    return agent.process(data_objects, artifacts, graph)

@ray.remote
def htn_workflow(data_objects: List[Dict], artifacts: List[Dict], relations: List[Dict[str, str]]):
    processed_docs = ray.get([process_document.remote(obj) for obj in data_objects])
    processed_artifacts = ray.get([process_artifact.remote(art) for art in artifacts])
    graph = ray.get(process_graph.remote(relations))
    htn = ray.get(generate_htn.remote(processed_docs, processed_artifacts, graph))
    return htn, processed_docs, processed_artifacts, graph

# Example usage
if __name__ == "__main__":
    # Create sample data (as dictionaries)
    data_objects = [
        {"content": {"text": "Sample document 1"}},
        {"content": {"text": "Sample document 2"}}
    ]
    artifacts = [
        {"artifact_id": str(uuid.uuid4()), "payload": b"Sample binary data 1"},
        {"artifact_id": str(uuid.uuid4()), "payload": b"Sample binary data 2"}
    ]
    relations = [
        {"subject": "doc1", "predicate": "relates_to", "object": "art1"},
        {"subject": "doc2", "predicate": "relates_to", "object": "art2"}
    ]

    # Run the workflow
    # mlflow.set_experiment("HTN_Agent_System")
    # with mlflow.start_run(run_name="HTN_Workflow"):
    htn, processed_docs, processed_artifacts, graph = ray.get(htn_workflow.remote(data_objects, artifacts, relations))
        
        # Log metrics and parameters
        # mlflow.log_metric("num_documents", len(processed_docs))
        # mlflow.log_metric("num_artifacts", len(processed_artifacts))
        # mlflow.log_metric("num_relations", len(graph))
        
        # for i, doc in enumerate(processed_docs):
        #     mlflow.log_metric(f"doc_{i}_length", doc['metadata']['length'])
        
        # for i, art in enumerate(processed_artifacts):
        #     mlflow.log_param(f"artifact_{i}_id", art['artifact_id'])
        
        # mlflow.log_dict(htn, "htn.json")

    print("Generated HTN:", htn)

    # Query MLflow for results
    print("\nMLflow Runs:")
    for run in mlflow.search_runs():
        print(f"Run ID: {run.run_id}")
        print(f"Metrics: {run.data.metrics}")
        print(f"Params: {run.data.params}")
        print("---")

serialized_thread_lock
serialized_thread_lock
serialized_thread_lock
serialized_thread_lock


TypeError: Could not serialize the function 1909578674.htn_workflow:
=====================================================================
Checking Serializability of <function htn_workflow at 0x7f5e35edc8b0>
=====================================================================
[31m!!! FAIL[39m serialization: cannot pickle '_thread.lock' object
Detected 3 global variables. Checking serializability...
    Serializing 'ray' <module 'ray' from '/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/ray/__init__.py'>...
    Serializing 'process_graph' <ray.remote_function.RemoteFunction object at 0x7f5e36828700>...
    [31m!!! FAIL[39m serialization: cannot pickle '_thread.lock' object
        Serializing '_function' <function process_graph at 0x7f5e35edd2d0>...
        [31m!!! FAIL[39m serialization: cannot pickle '_thread.lock' object
        Detected 1 global variables. Checking serializability...
            Serializing 'GraphAgent' <class '__main__.GraphAgent'>...
            [31m!!! FAIL[39m serialization: cannot pickle '_thread.lock' object
        Serializing '__generator_backpressure_num_objects' None...
=====================================================================
Variable: 

	[1mFailTuple(GraphAgent [obj=<class '__main__.GraphAgent'>, parent=<function process_graph at 0x7f5e35edd2d0>])[0m

was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. 
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. 
=====================================================================
Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information.
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
=====================================================================
