In [52]:
import yaml
import sys
import os
import seaborn as sns
from pathlib import Path

In [69]:
MODELS = {
    "sam_vit_h_4b8939.pth": "vit_h",
    "sam_vit_l_0b3195.pth": "vit_l",
    "sam_vit_b_01ec64.pth": "vit_b"
}

def parse_config(file_path):
    try:
        # Resolve the absolute path of the config file
        config_path = Path(file_path).resolve()
        if not config_path.is_file():
            print(f"Error: Config file not found at '{config_path}'.")
            sys.exit()
        
        # Load the config file
        with config_path.open("r") as file:
            data = yaml.safe_load(file)

        # Check if config file is not empty
        if data is None:
            print(f"Error: Config file is empty.")
            sys.exit()
            
        # Validate required keys
        required_keys = ["checkpoint", "output-directory", "wing-cells"]
        missing_keys = [key for key in required_keys if key not in data]
        if missing_keys:
            print(f"Error: Missing required keys in config: {', '.join(missing_keys)}")
            sys.exit()

        # Resolve the absolute path for the output directory
        output_path = Path(data["output-directory"]).resolve()
        if not os.path.isdir(output_path):
            os.makedirs(output_path)
        else:
            print(f"Warning: Output directory already exists. Files might get overwritten.")
            
        # Resolve the absolute path for the checkpoint and identify the checkpoint
        checkpoint_path = Path(data["checkpoint"]).resolve()
        if not checkpoint_path.is_file():
            print(f"Error: Checkpoint file not found at '{checkpoint_path}'.")
            sys.exit()
        checkpoint_name = os.path.basename(checkpoint_path)
        if checkpoint_name in MODELS.keys():
            model_type = MODELS[checkpoint_name]
        else:
            print(f"Error: The checkpoint is not supported. Supported checkpoints: {', '.join(MODELS.keys())}")

        # Load wing cell data
        wing_cells = data["wing-cells"]
        sns_colors = sns.color_palette("hls", len(wing_cells))
        # Transform dictionary
        wing_segments = {}
        for i, (cell_id, display_name) in enumerate(wing_cells.items()):
            wing_segments[cell_id] = {
                "display_name": display_name,
                "color": sns_colors[i],
                "mask": None,
                "wing_area": None,
                "wing_height": None,
                "cell_area": None,
                "cell_perimeter": None
            }
    
        return checkpoint_path, model_type, wing_segments

    except yaml.YAMLError as e:
        print(f"Error parsing config file: {e}")
        sys.exit()
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        sys.exit()


if __name__ == "__main__":
    file_path = "config.yaml"  
    checkpoint_path, model_type, wing_segments = parse_config(file_path)
    print(checkpoint_path)
    print(model_type)
    print("\n", wing_segments)

/home/wsl/bin/segment-anything/checkpoints/sam_vit_h_4b8939.pth
vit_h

 {'FWL': {'display_name': 'forewing lobe', 'color': (0.86, 0.3712, 0.33999999999999997), 'mask': None, 'wing_area': None, 'wing_height': None, 'cell_area': None, 'cell_perimeter': None}, 'MC': {'display_name': 'marginal cell', 'color': (0.86, 0.6832, 0.33999999999999997), 'mask': None, 'wing_area': None, 'wing_height': None, 'cell_area': None, 'cell_perimeter': None}, '1sMC': {'display_name': '1st submarginal cell', 'color': (0.7247999999999999, 0.86, 0.33999999999999997), 'mask': None, 'wing_area': None, 'wing_height': None, 'cell_area': None, 'cell_perimeter': None}, '2sMC': {'display_name': '2nd submarginal cell', 'color': (0.41279999999999994, 0.86, 0.33999999999999997), 'mask': None, 'wing_area': None, 'wing_height': None, 'cell_area': None, 'cell_perimeter': None}, '3sMC': {'display_name': '3rd submarginal cell', 'color': (0.33999999999999997, 0.86, 0.5792000000000002), 'mask': None, 'wing_area': None, 'wing_h

In [39]:
sns_colors = sns.color_palette("hls", 10)
sns_colors