In [10]:
import torch
from PIL import Image
from diffusers import DiffusionPipeline

In [11]:
import json
import os
from pathlib import Path
from PIL import Image, ImageDraw

def create_mask_from_json(json_file, output_dir=None, save=True):
    """
    Creates a full-resolution binary mask from a LabelMe JSON file.
    Background is Black (0), Polygons are White (255).
    """
    # 1. Load the JSON data
    with open(json_file, 'r') as f:
        data = json.load(f)

    # 2. Get Image Dimensions
    # LabelMe JSONs usually store height/width. 
    # If missing, we default to opening the image file to check.
    height = data.get('imageHeight')
    width = data.get('imageWidth')

    if not height or not width:
        # Fallback: Try to find the image to get size
        image_path = Path(json_file).parent / data.get('imagePath', '')
        if image_path.exists():
            with Image.open(image_path) as img:
                width, height = img.size
        else:
            raise ValueError(f"Could not determine dimensions for {json_file}")

    # 3. Create a blank Black image (Mode 'L' = 8-bit grayscale)
    mask = Image.new('L', (width, height), 0)
    draw = ImageDraw.Draw(mask)

    # 4. Draw all polygons in White
    for shape in data.get('shapes', []):
        if shape.get('shape_type') == 'polygon':
            # Convert points to a list of tuples [(x,y), (x,y)...]
            points = [tuple(p) for p in shape['points']]
            
            # Fill the polygon with 255 (White)
            draw.polygon(points, outline=255, fill=255)

    # 5. Save the mask
    if save:
        if output_dir is None:
            # Default: save in a 'masks' folder next to the json
            output_dir = Path(json_file).parent / 'masks'
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Save as PNG (Important! JPG compresses and ruins masks)
        out_name = f"{Path(json_file).stem}_mask.png"
        mask.save(Path(output_dir) / out_name)

    return mask

In [12]:
# Set the path to your images directory
images_dir = Path('data')


# Example: Process a single file
example_json = images_dir / '20200713_1207268869_plant1041_rgb_trigger010.json'

if example_json.exists():
    # Create mask for this file
    mask = create_mask_from_json(
        example_json,
        output_dir=images_dir / 'masks',
        save=False
    )
    
    print("Mask created!")
    
    # Display the mask
    if mask is not None:
        mask.show()
else:
    print(f"File not found: {example_json}")


Mask created!


In [13]:
# 1. Load the specific "Paint By Example" model
# This model is trained specifically to copy the STYLE of the reference into the MASK
pipe = DiffusionPipeline.from_pretrained(
    "Fantasy-Studio/Paint-by-Example",
    torch_dtype=torch.float16,
)
# pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]An error occurred while trying to fetch C:\Users\ShanO\.cache\huggingface\hub\models--Fantasy-Studio--Paint-by-Example\snapshots\351e6427d8c28a3b24f7c751d43eb4b6735127f7\vae: Error no file named diffusion_pytorch_model.safetensors found in directory C:\Users\ShanO\.cache\huggingface\hub\models--Fantasy-Studio--Paint-by-Example\snapshots\351e6427d8c28a3b24f7c751d43eb4b6735127f7\vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  40%|████      | 2/5 [00:00<00:00,  3.07it/s]You are using a model of type clip_vision_model to instantiate a model of type clip. This is not supported for all configurations of models and can yield errors.
Loading pipeline components...:  60%|██████    | 3/5 [00:02<00:01,  1.10it/s]An error occurred while trying to fetch C:\Users\ShanO\.cache\huggingface\hub\models--Fantasy-Studio--Paint-by-Example\snapshots\351e6427d8c2

In [None]:
# 2. Load and preprocess your images
# Paint-by-Example works best with 512x512 images
TARGET_SIZE = 512

def preprocess_image(img, target_size=TARGET_SIZE):
    """Convert to RGB and resize to target size (maintains aspect ratio, center crops)"""
    if img.mode != 'RGB':
        img = img.convert('RGB')
    
    # Calculate resize dimensions maintaining aspect ratio
    width, height = img.size
    aspect_ratio = width / height
    
    if width > height:
        new_width = target_size
        new_height = int(target_size / aspect_ratio)
    else:
        new_height = target_size
        new_width = int(target_size * aspect_ratio)
    
    # Resize
    img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # Center crop to exact target size
    left = (new_width - target_size) // 2
    top = (new_height - target_size) // 2
    right = left + target_size
    bottom = top + target_size
    img = img.crop((left, top, right, bottom))
    
    return img

def preprocess_mask(mask, target_size=TARGET_SIZE):
    """Resize mask to target size (grayscale, maintains aspect ratio, center crops)"""
    if mask.mode != 'L':
        mask = mask.convert('L')
    
    # Calculate resize dimensions maintaining aspect ratio
    width, height = mask.size
    aspect_ratio = width / height
    
    if width > height:
        new_width = target_size
        new_height = int(target_size / aspect_ratio)
    else:
        new_height = target_size
        new_width = int(target_size * aspect_ratio)
    
    # Resize
    mask = mask.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # Center crop to exact target size
    left = (new_width - target_size) // 2
    top = (new_height - target_size) // 2
    right = left + target_size
    bottom = top + target_size
    mask = mask.crop((left, top, right, bottom))
    
    return mask

# Load images
image = Image.open("data/20200713_1207268869_plant1041_rgb_trigger010.png")
mask_image = mask
example_image = Image.open("reference_diseases/Figure-1_poly1.png")

# Preprocess all images to same size
image = preprocess_image(image, TARGET_SIZE)
mask_image = preprocess_mask(mask_image, TARGET_SIZE)
example_image = preprocess_image(example_image, TARGET_SIZE)

# Verify sizes match
print(f"Image size: {image.size}, mode: {image.mode}")
print(f"Mask size: {mask_image.size}, mode: {mask_image.mode}")
print(f"Example size: {example_image.size}, mode: {example_image.mode}")

## Tips for Better Results:

1. **Reference Image Quality**: Make sure your reference disease image is clear and shows the disease pattern well
2. **Mask Precision**: The mask should accurately outline where you want the disease to appear
3. **Guidance Scale**: 
   - Lower (3.0-4.0): More creative, less faithful to reference
   - Higher (6.0-7.0): More faithful to reference, less blending
4. **Inference Steps**: More steps (75-100) = better quality but slower
5. **Try Different Seeds**: Change the seed value to get different variations


In [None]:
# 3. Run generation with optimized parameters
# Paint-by-Example works better with more inference steps and proper guidance
result = pipe(
    image=image,
    mask_image=mask_image,
    example_image=example_image,
    num_inference_steps=50,  # More steps = better quality (default is 50, but can increase to 100)
    guidance_scale=5.0,  # Higher guidance = stronger adherence to example (try 3.0-7.0)
    generator=torch.Generator().manual_seed(42),  # For reproducibility
).images[0]

# Create output directory if it doesn't exist
os.makedirs("synthetic_data", exist_ok=True)
result.save("synthetic_data/output.png")
print("Result saved to synthetic_data/output.png")

100%|██████████| 50/50 [03:17<00:00,  3.96s/it]


In [None]:
# Optional: Experiment with different parameters
# Uncomment and modify to try different settings

# result_v2 = pipe(
#     image=image,
#     mask_image=mask_image,
#     example_image=example_image,
#     num_inference_steps=75,  # More steps for better quality
#     guidance_scale=6.0,  # Higher for stronger reference adherence
#     generator=torch.Generator().manual_seed(123),  # Different seed for variation
# ).images[0]
# 
# result_v2.save("synthetic_data/output_v2.png")
