In [None]:
from pydantic import BaseModel, field_validator, Extra
from typing import Any, Dict, List, Optional, Union
from numpydantic import NDArray
import dask.array as da
from pycromanager import Dataset
import numpy as np
import luigi
from abc import ABC, abstractmethod
import os
import json

def convert_to_pydantic_safe(value):
    if isinstance(value, (np.integer, np.int32, np.int64)):
        return int(value)
    elif isinstance(value, (np.floating, np.float32, np.float64)):
        return float(value)
    elif isinstance(value, (list, tuple)):
        return [convert_to_pydantic_safe(v) for v in value]
    elif isinstance(value, dict):
        return {k: convert_to_pydantic_safe(v) for k, v in value.items()}
    else:
        return value

class Parameters(BaseModel, extra=Extra.allow):
    voxel_size_yx: float = 130
    voxel_size_z: float = 500
    spot_z: float = 100
    spot_yx: float = 360
    index_dict: Optional[Dict[str, Any]] = None
    nucChannel: Optional[int] = None
    cytoChannel: Optional[int] = None
    FISHChannel: Optional[int] = None
    independent_params: List[Dict[str, Any]] = [{}]
    timestep_s: Optional[float] = None
    local_dataset_location: Optional[Union[str, List]] = None
    nas_location: Optional[Union[str, List]] = None
    num_images: Optional[int] = None
    images: Optional[NDArray] = da.ones((0, 0, 0))
    masks: Optional[NDArray] = da.ones((0, 0, 0))
    image_path: Optional[str] = None
    mask_path: Optional[str] = None
    clear_after_error: bool = True
    name: Optional[str] = None
    NUMBER_OF_CORES: int = 4
    num_images_to_run: int = 100000000
    connection_config_location: str = 'c:\\Users\\formanj\\GitHub\\AngelFISH\\config_nas.yml'
    display_plots: bool = True
    load_in_mask: bool = False
    mask_structure: Optional[Any] = None
    order: str = 'pt'
    data_dir: str = None
    share_name: str = 'share'

    class Config:
        extra = 'allow'
        use_enum_values = True


class Data(BaseModel, ABC, extra=Extra.allow):
    nas_location: str
    history: Optional[List[str]] = []
    image_path: Optional[str] = None
    mask_path: Optional[str] = None
    images: Optional[NDArray] = da.empty((0, 0, 0))
    masks: Optional[NDArray] = da.empty((0, 0, 0))
    independent_params: List[Dict[str, Any]] = [{}]

    class Config:
        extra = 'allow'
        use_enum_values = True

    def append_dict(self, data: dict):
        for k, v in data.items():
            setattr(self, k, convert_to_pydantic_safe(v))

        return self


class DataSingle(Data, extra=Extra.allow):
    p: int
    t: int
    image: Optional[NDArray] = da.empty((0, 0, 0))
    mask: Optional[NDArray] = da.empty((0, 0, 0))
    independent_params: List[Dict[str, Any]] = [{}]


class DataBulk(Data, extra=Extra.allow):
    images: Optional[NDArray] = da.empty((0, 0, 0))
    masks: Optional[NDArray] = da.empty((0, 0, 0))
    
    def split(p, t) -> DataSingle:
        pass

    def merge(data: list):
        pass


In [None]:
class ImageProcessor(luigi.Task, ABC):
    param_path = luigi.Parameter()
    step_name: str = "base"
    modify_images = False
    modify_masks = False
    previous_task = luigi.Parameter()
    nas_location = luigi.Parameter()

    def requires(self):
        """Allow for the chaining of tasks. Returns dependencies if needed."""
        return self.previous_task() if self.previous_task is not None else None# Return other tasks that need to run first

    def output(self):
        """Define the output of this task."""
        return luigi.LocalTarget(f"{self.step_name}_output.json")

    @abstractmethod
    def eval(self, **kwargs):
        """Define the actual processing logic."""
        pass

    def run(self):
        """Run the processing task."""
        print(f'Processing {os.path.basename(self.nas_location)} with {self.step_name}')

        if self.output().exists():
            print('Results already exist')
    
        with open(self.param_path, 'r') as json_file: # load in params
            params = Parameters.model_validate_json(json.load(json_file))

        with open(self.input().path, 'r') as json_file: # load in data
            data = Data.model_validate_json(json.load(json_file))

        if data is not None:
            kwargs = {**params.model_dump(), **data.model_dump()}
        else:
            kwargs = {**params.model_dump()}

        results = self.eval(**kwargs)

        data.append_dict(results)
    
        with self.output().open('w') as f:
            json.dump(data.model_dump_json(round_trip=True), f)

class SingleImageProcessingTask(ImageProcessor):
    p = luigi.IntParameter()
    t = luigi.IntParameter()

    def requires(self):
        """Allow for the chaining of tasks. Returns dependencies if needed."""
        if type(self.previous_task) is dict:
            return self.previous_task[self.p][self.t]
        elif self.previous_task is not None:
            return None
        else:
            return self.previous_task # Return other tasks that need to run first

    def output(self):
        dataName = os.path.basename(self.nas_location)
        cache_key = f'{self.step_name}-p_{self.p}-t_{self.t}-{dataName}.json'
        return luigi.LocalTarget(cache_key)

    def run(self):
        print(f'Processing {os.path.basename(self.nas_location)}, {self.p}, {self.t} with {self.step_name}')

        if self.output().exists():
            raise 'Result already exist'

        if self.output().exists():
            print('Results already exist')
    
        with open(self.param_path, 'r') as json_file: # load in params
            params = Parameters.model_validate_json(json.load(json_file))

        with open(self.input().path, 'r') as json_file: # load in data
            data = DataSingle.model_validate_json(json.load(json_file))
    
        if data is not None:
            kwargs = {**params.model_dump(), **data.model_dump()}
        else:
            kwargs = {**params.model_dump()}

        results = self.eval(**kwargs)

        data.append_dict(results)

        with self.output().open('w') as f:
            if not self.modify_images and not self.modify_masks:
                json.dump(data.model_dump_json(round_trip=True, exclude=['image', 'images', 'masks', 'mask']), f)
            elif not self.modify_images and self.modify_masks :
                json.dump(data.model_dump_json(round_trip=True, exclude=['masks', 'mask']), f)
            elif self.modify_images and not self.modify_masks :
                json.dump(data.model_dump_json(round_trip=True, exclude=['image', 'images']), f)
            else:
                json.dump(data.model_dump_json(round_trip=True), f)

    @abstractmethod
    def eval(self):
        raise NotImplementedError


class BulkImageProcessingTask(ImageProcessor):
    def output(self):
        dataName = os.path.basename(self.nas_location)
        cache_key = f'{self.step_name}-{dataName}.json'
        return luigi.LocalTarget(cache_key)

    def run(self):
        print(f'Processing {os.path.basename(self.nas_location)} with {self.step_name}')

        if self.output().exists():
            print('Results already exist')
    
        with open(self.param_path, 'r') as json_file: # load in params
            params = Parameters.model_validate_json(json.load(json_file))

        if self.previous_task is not None:
            with open(self.input().path, 'r') as json_file: # load in data
                data = DataBulk.model_validate_json(json.load(json_file))
        else: 
            data = None
    
        if data is not None:
            kwargs = {**params.model_dump(), **data.model_dump()}
        else:
            kwargs = {**params.model_dump()}

        results = self.eval(**kwargs)

        if data is not None:
            data.append_dict(results)
        else:
            data = DataBulk(nas_location=self.nas_location).append_dict(results)
    
        with self.output().open('w') as f:
            if not self.modify_images and not self.modify_masks:
                json.dump(data.model_dump_json(round_trip=True, exclude=['image', 'images', 'masks', 'mask']), f)
            elif not self.modify_images and self.modify_masks :
                json.dump(data.model_dump_json(round_trip=True, exclude=['masks', 'mask']), f)
            elif self.modify_images and not self.modify_masks :
                json.dump(data.model_dump_json(round_trip=True, exclude=['image', 'images']), f)
            else:
                json.dump(data.model_dump_json(round_trip=True), f)

    @abstractmethod
    def eval(self):
        pass

In [None]:
class Splitter(luigi.Task):
    nas_location = luigi.Parameter()
    previous_task = luigi.Parameter()
    positions = luigi.IntParameter()
    time_points = luigi.IntParameter()

    def requires(self):
        return self.previous_task()
    
    def output(self):
        dataName = os.path.basename(self.nas_location)
        return {p: {t: luigi.LocalTarget(f'Split-p_{p}-t_{t}-{dataName}.json') for t in self.time_points} for p in self.positions}

    def run(self): # convert DataBulk -> {DataSingles}
        with open(self.input().path, 'r') as json_file: # load in params
            bulkData = DataBulk.model_validate_json(json.load(json_file))
        
        for p in self.positions:
            for t in self.time_points:
                singleData = bulkData.split(p=p, t=t)
                with self.output()[p][t].open('w') as f:
                    json.dump(singleData.model_dump_json(round_trip=True), f)

class Merger(luigi.Task):
    nas_location = luigi.Parameter()
    previous_task = luigi.Parameter()
    positions = luigi.IntParameter()
    time_points = luigi.IntParameter()
    
    def requires(self):
        return {p: {t: self.previous_task(p=p, t=t) for t in self.time_points} for p in self.positions}

    def output(self):
        """Define the output for the merged result."""
        data_name = os.path.basename(self.nas_location)
        return luigi.LocalTarget(f'Merged-{data_name}.json')

    def run(self):
        """Merge all the split data."""
        merged_data = []
        for p in self.positions:
            for t in self.time_points:
                with self.input()[p][t].open('r') as f:
                    single_data = DataSingle.model_validate_json(json.load(f))
                    merged_data.append(single_data)
        
        bulk_data = DataBulk.merge(merged_data)
        
        with self.output().open('w') as f:
            json.dump(bulk_data.model_dump_json(round_trip=True), f)



In [None]:
class ProcessSingles(luigi.WrapperTask):
    steps: List[Any]
    positions: List[int]
    time_points: List[int]
    
    def requires(self):
        """For each position and time point, run the processing steps."""
        tasks = []
        for step in self.steps:
            for p in self.positions:
                for t in self.time_points:
                    # Each time-point/position combination will require specific tasks
                    task = step(p=p, t=t, previous_task=Splitter, nas_location=self.nas_location)
                    tasks.append(task)
        return tasks

class ProcessBulk(luigi.WrapperTask):
    steps: List[Any]
    nas_location: str

    def requires(self):
        tasks = []
        for step in self.steps:
            task = step(param_path=self.param_path, nas_location=self.nas_location, previous_task=Merger)
            tasks.append(task)
        return tasks

class Workflow(luigi.WrapperTask):
    params: Parameters
    steps: List[Any]
    positions: List[int]
    time_points: List[int]
    nas_location: str
    param_path: str
    
    def requires(self):
        return [
            ProcessSingles(steps=self.steps, positions=self.positions, time_points=self.time_points, nas_location=self.nas_location),
            ProcessBulk(steps=self.steps, nas_location=self.nas_location, param_path=self.param_path)
        ]

    def output(self):
        """The final output of the pipeline."""
        return luigi.LocalTarget("final_output.json")

    def run(self):
        """This is where the final result is collected and written to disk."""
        with self.output().open('w') as f:
            f.write("Final result of configurable pipeline")

In [None]:
# import logging

# logger = logging.getLogger('luigi-interface')
# logger.setLevel(logging.ERROR)
# logger.propagate = False


# Simple Example
# Create some fake data using dask arrays
import dask.array as da

# Create a fake 3D image dataset with dimensions (time, height, width)
fake_images = da.random.random((10, 100, 100), chunks=(1, 100, 100))

# Create a fake 3D mask dataset with the same dimensions
fake_masks = da.random.randint(0, 2, size=(10, 100, 100), chunks=(1, 100, 100))

# Save the fake data to disk
# Save the fake data to disk as JSON
fake_images_list = fake_images.compute().tolist()
fake_masks_list = fake_masks.compute().tolist()

with open('fake_images.json', 'w') as f:
    json.dump(fake_images_list, f)

with open('fake_masks.json', 'w') as f:
    json.dump(fake_masks_list, f)

data = DataBulk(images=fake_images, masks=fake_masks, nas_location='somewhere', history=['step1', 'step2'], p=0, t=0)

# Define parameters
params = Parameters(
    voxel_size_yx=130,
    voxel_size_z=500,
    spot_z=100,
    spot_yx=360,
    num_images=10,
    images=fake_images,
    masks=fake_masks,
    image_path='fake_images',
    mask_path='fake_masks',
    data_dir='data'
)

with open('params.json', 'w') as f:
    json.dump(params.model_dump_json(round_trip=True), f)

with open('data.json', 'w') as f:
    json.dump(data.model_dump_json(round_trip=True), f)

# Instantiate and run the task
class ExampleSingleImageProcessingTask(BulkImageProcessingTask):
    param_path = 'params.json'
    nas_location = 'dataset1'
    p = 0
    t = 0
    previous_task = None

    def eval(self, **kwargs):
        # Example processing: just return the sum of the image and mask
        image_sum = kwargs['images'].sum().compute()
        mask_sum = kwargs['masks'].sum().compute()
        return {'image_sum': image_sum, 'mask_sum': mask_sum}

# Instantiate and run the task
task = ExampleSingleImageProcessingTask()
luigi.build([task], local_scheduler=True)

In [None]:
# More complicated Example
# Create some fake data using dask arrays

# Create a fake 3D image dataset with dimensions (time, height, width)
images = da.random.random((5, 7, 3, 10, 100, 100)) # p, t, c, z, y, x

# Save the fake data to disk as JSON
fake_images_list = fake_images.compute().tolist()
fake_masks_list = fake_masks.compute().tolist()

with open('fake_images.json', 'w') as f:
    json.dump(fake_images_list, f)

with open('fake_masks.json', 'w') as f:
    json.dump(fake_masks_list, f)

data = DataBulk(images=fake_images, masks=fake_masks, nas_location='somewhere', history=['step1', 'step2'], p=0, t=0)

# Define parameters
params = Parameters(
    voxel_size_yx=130,
    voxel_size_z=500,
    spot_z=100,
    spot_yx=360,
    num_images=10,
    images=fake_images,
    masks=fake_masks,
    image_path='fake_images',
    mask_path='fake_masks',
    data_dir='data'
)

with open('params.json', 'w') as f:
    json.dump(params.model_dump_json(round_trip=True), f)


class LoadInData(BulkImageProcessingTask):
    def eval(self, image_path, **kwargs):
        return {'load_in_data': image_path}

class GenerateMasks(SingleImageProcessingTask):
    def eval(self, images, **kwargs):
        # Create a fake 3D mask dataset with the same dimensions
        nuc_mask = da.random.randint(0, 2, size=(10, 100, 100)) # p, t, z, y, x
        return {'nuc_mask': nuc_mask}
    
# Instantiate and run the task
class MeasureSomething(SingleImageProcessingTask):
    def eval(self, images, nuc_masks, **kwargs):
        # Example processing: just return the sum of the image and mask
        image_sum = images.sum().compute()
        mask_sum = nuc_masks.sum().compute()
        return {'image_sum': image_sum, 'mask_sum': mask_sum}

# Instantiate and run the task
class SaveSomething(BulkImageProcessingTask):
    def eval(self, **kwargs):
        print('something happened')
        return None

# Instantiate and run the task
task = ExampleSingleImageProcessingTask()
luigi.build([LoadInData(nas_location='dataset2', param_path='params.json', previous_task=None), 
             Splitter(nas_location='dataset2', previous_task=LoadInData(nas_location='dataset2', param_path='params.json', previous_task=None), positions=list(np.arange(5)), time_points=list(np.arange(7))), 
             *[MeasureSomething(nas_location='dataset2', param_path='params.json', previous_task=Splitter(nas_location='dataset2', previous_task=LoadInData(nas_location='dataset2', param_path='params.json', previous_task=None), positions=list(np.arange(5)), time_points=list(np.arange(7))), p=p, t=t) for p in range(5) for t in range(7)], 
             Merger(nas_location='dataset2', previous_task=MeasureSomething, positions=list(np.arange(5)), time_points=list(np.arange(7))),
             SaveSomething(nas_location='dataset2', param_path='params.json', previous_task=Merger(nas_location='dataset2', previous_task=MeasureSomething, positions=list(np.arange(5)), time_points=list(np.arange(7))))], 
             local_scheduler=True)