In [2]:
import os
import torch
import numpy as np
from PIL import Image, ImageOps
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from diffusers import (
    StableDiffusionXLInstructPix2PixPipeline,
    EDMEulerScheduler,
    AutoencoderKL,
)
from huggingface_hub import hf_hub_download
import time
import gc

# Add support for quantization
try:
    import bitsandbytes as bnb
    from transformers.utils.quantization_config import BitsAndBytesConfig
    HAS_BNB = True
except ImportError:
    HAS_BNB = False
    print("bitsandbytes not found, CUDA-based quantization will not be available")

# ========== BASIC MPS SETUP ==========
# Set environment variables
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Check device availability
print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")

# Set device - simpler approach
if torch.backends.mps.is_available():
    device = "mps"
    print("Using MPS")
elif torch.cuda.is_available():
    device = "cuda"
    print("Using CUDA")
else:
    device = "cpu"
    print("Using CPU")

# ========== Memory Management ==========
def torch_gc():
    """Basic memory cleanup compatible with all PyTorch versions"""
    try:
        # Try to empty MPS cache if available
        if device == "mps" and hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    except:
        pass
        
    # Always do GC collection
    gc.collect()
    
    # Try CUDA cleanup just in case
    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

# ========== Custom Quantization for MPS ==========
class MPSQuantizedLinear(torch.nn.Module):
    """MPS-friendly quantized linear layer using 8-bit representation"""
    def __init__(self, linear_layer):
        super().__init__()
        self.input_size = linear_layer.in_features
        self.output_size = linear_layer.out_features
        
        # Quantize weights to int8, scale to preserve range
        weight = linear_layer.weight.data.float()
        self.w_scale = weight.abs().max() / 127.0
        self.w_quant = (weight / self.w_scale).round().char()
        
        # Keep bias in float32
        self.bias = None
        if linear_layer.bias is not None:
            self.bias = linear_layer.bias.data.clone()
            
    def forward(self, x):
        # Convert quantized weights back to float for computation
        w_dequant = self.w_quant.float() * self.w_scale
        output = torch.nn.functional.linear(x, w_dequant)
        if self.bias is not None:
            output += self.bias
        return output

def quantize_module_for_mps(module, dtype=torch.float16):
    """Apply MPS-friendly quantization or precision reduction to modules recursively"""
    for name, child in list(module.named_children()):
        if isinstance(child, torch.nn.Linear):
            setattr(module, name, MPSQuantizedLinear(child))
        elif len(list(child.children())) > 0:
            # Recursively quantize child modules
            quantize_module_for_mps(child, dtype)
    
    # Convert remaining parameters to half precision
    for param_name, param in module.named_parameters():
        if param.dtype == torch.float32:
            param.data = param.data.to(dtype)
    
    return module

# ========== Padding Helper ==========
def pad_to_square_and_record(image, fill_color=(255, 255, 255)):
    width, height = image.size
    max_side = max(width, height)
    max_side = (max_side + 7) // 8 * 8  # Make divisible by 8

    delta_w = max_side - width
    delta_h = max_side - height
    padding = (
        delta_w // 2, delta_h // 2,
        delta_w - delta_w // 2, delta_h - delta_h // 2
    )
    padded = ImageOps.expand(image, padding, fill=fill_color)
    return padded, padding

def crop_back_to_original(image: Image.Image, padding) -> Image.Image:
    left, top, right, bottom = padding
    width, height = image.size
    return image.crop((left, top, width - right, height - bottom))

def remove_white_background(image: Image.Image, white_thresh=240) -> Image.Image:
    image = image.convert("RGBA")
    data = np.array(image)

    r, g, b, a = data.T
    white_mask = (r >= white_thresh) & (g >= white_thresh) & (b >= white_thresh)
    data[..., 3][white_mask.T] = 0

    fade_mask = (r >= 200) & (g >= 200) & (b >= 200) & (~white_mask)
    data[..., 3][fade_mask.T] = 128
    return Image.fromarray(data)

# ========== Load Models ==========
print("Loading models...")

# Configure quantization based on device
quantization_config = None
quantization_enabled = False
mps_quantization = False

# For CUDA devices, use BnB 4-bit quantization
if device == "cuda" and HAS_BNB:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",  # Q4_K is equivalent to NF4 (normalized float 4)
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    quantization_enabled = True
    print("CUDA quantization config enabled: q4km (4-bit quantization)")
# For MPS devices, use custom 8-bit quantization
elif device == "mps":
    mps_quantization = True
    dtype = torch.float16  # Use half precision on MPS
    quantization_enabled = True
    print("MPS quantization enabled: custom 8-bit + FP16")
else:
    # CPU or other devices
    dtype = torch.float32
    print("Quantization not available, loading full precision models")

# Load BLIP for captioning
print("Loading BLIP model...")
blip_processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-base", 
    use_fast=False
)
torch_gc()

# Load BLIP model with appropriate quantization
if device == "cuda" and HAS_BNB and quantization_config:
    blip_model = BlipForConditionalGeneration.from_pretrained(
        "Salesforce/blip-image-captioning-base",
        quantization_config=quantization_config
    )
else:
    blip_model = BlipForConditionalGeneration.from_pretrained(
        "Salesforce/blip-image-captioning-base"
    )
    
    # Apply MPS quantization if needed
    if mps_quantization:
        print("Applying MPS quantization to BLIP model...")
        blip_model = quantize_module_for_mps(blip_model, dtype=dtype)

blip_model.to(device)
torch_gc()

# Load VAE
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=dtype
)

# Apply MPS quantization to VAE if needed
if mps_quantization:
    print("Applying MPS quantization to VAE...")
    vae = quantize_module_for_mps(vae, dtype=dtype)

torch_gc()

# Download COSXL model
print("Downloading COSXL model...")
cosxl_edit_path = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
torch_gc()

# Create pipeline with appropriate quantization
print(f"Creating pipeline on {device}...")
if device == "cuda" and HAS_BNB and quantization_config:
    pipe = StableDiffusionXLInstructPix2PixPipeline.from_single_file(
        cosxl_edit_path,
        num_in_channels=8,
        is_cosxl_edit=True,
        vae=vae,
        torch_dtype=dtype,
        quantization_config=quantization_config
    )
else:
    pipe = StableDiffusionXLInstructPix2PixPipeline.from_single_file(
        cosxl_edit_path,
        num_in_channels=8,
        is_cosxl_edit=True,
        vae=vae,
        torch_dtype=dtype,
    )
    
    # Apply MPS quantization if needed
    if mps_quantization:
        print("Applying MPS quantization to pipeline...")
        # Quantize UNet and text encoder
        pipe.unet = quantize_module_for_mps(pipe.unet, dtype=dtype)
        pipe.text_encoder = quantize_module_for_mps(pipe.text_encoder, dtype=dtype)
        pipe.text_encoder_2 = quantize_module_for_mps(pipe.text_encoder_2, dtype=dtype)

# Move to device
pipe = pipe.to(device)

# Set scheduler
pipe.scheduler = EDMEulerScheduler(
    sigma_min=0.002,
    sigma_max=120.0,
    sigma_data=1.0,
    prediction_type="v_prediction",
    sigma_schedule="exponential"
)

# Enable memory efficient attention if available
if hasattr(pipe, "enable_attention_slicing"):
    pipe.enable_attention_slicing(1)

# Enable memory efficient attention for CUDA
if hasattr(pipe, "enable_xformers_memory_efficient_attention") and device == "cuda":
    pipe.enable_xformers_memory_efficient_attention()

torch_gc()

# Skip warm-up as it might be causing issues
print("All models loaded")

# ========== Prompt Generation ==========
def generate_prompt(image: Image.Image):
    try:
        image = image.convert("RGB")
        
        # Resize to save memory
        max_size = 512
        width, height = image.size
        if max(width, height) > max_size:
            ratio = max_size / max(width, height)
            new_width = int(width * ratio)
            new_height = int(height * ratio)
            image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        
        # Process image
        inputs = blip_processor(image, return_tensors="pt").to(device)
        
        # Generate caption
        with torch.no_grad():
            output = blip_model.generate(**inputs, max_length=30)
        
        caption = blip_processor.decode(output[0], skip_special_tokens=True)
        torch_gc()
        return caption
    except Exception as e:
        print(f"Error generating caption: {e}")
        return "An image"

def suggest_prompt_only(image):
    if image is None:
        return ""
        
    print("Generating caption...")
    return generate_prompt(image)

# ========== Core Generation ==========
def edit_image(input_img, instruction_override, cached_prompt, cfg_scale=7.0, steps=25):
    if input_img is None:
        return None
    
    start_time = time.time()
    input_img = input_img.convert("RGB")
    original_size = input_img.size
    
    # Preprocess image
    padded_img, padding = pad_to_square_and_record(input_img)
    
    # Use standard size for better results
    resize_size = 512
    padded_img = padded_img.resize((resize_size, resize_size), Image.Resampling.LANCZOS)

    # Construct prompt
    final_prompt = instruction_override.strip() if instruction_override.strip() else f"{cached_prompt}, but modified"
    print(f"Using prompt: '{final_prompt}'")

    # Clean memory
    torch_gc()
    
    try:
        # Ensure model is on correct device
        pipe.to(device)
        
        # Run inference with reduced steps for MPS if needed
        with torch.no_grad():
            result = pipe(
                prompt=final_prompt,
                image=padded_img,
                height=resize_size,
                width=resize_size,
                guidance_scale=cfg_scale,
                num_inference_steps=steps
            ).images[0]
    except Exception as e:
        print(f"Error during image generation: {e}")
        return input_img
    finally:
        torch_gc()

    # Postprocess result
    result = remove_white_background(result)
    result = crop_back_to_original(result, padding)
    result = result.resize(original_size, Image.Resampling.LANCZOS)
    
    total_time = time.time() - start_time
    print(f"Total processing time: {total_time:.2f} seconds")
    
    return result

# ========== Gradio UI ==========
with gr.Blocks() as demo:
    quant_description = ""
    if device == "cuda" and HAS_BNB and quantization_enabled:
        quant_description = "with Q4KM (4-bit) quantization"
    elif device == "mps" and quantization_enabled:
        quant_description = "with custom 8-bit + FP16 quantization"
    else:
        quant_description = "in full precision"
        
    gr.Markdown(f"## ✨ COSXL Edit with Quantization Support")
    gr.Markdown(f"Running on: **{device.upper()}** {quant_description}")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload Reference Image")
        output_image = gr.Image(type="pil", label="Edited Output")

    with gr.Row():
        suggested_prompt = gr.Textbox(label="🧠 Auto-Suggested Prompt", interactive=False)
        user_instruction = gr.Textbox(label="✍️ Your Edit Instruction", placeholder="e.g. make it futuristic")

    with gr.Row():
        slider_cfg = gr.Slider(1.0, 15.0, value=7.0, step=0.5, label="Guidance Scale (higher = more adherence to prompt)")
        slider_steps = gr.Slider(10, 50, value=25, step=1, label="Inference Steps (higher = more quality)")

    with gr.Row():
        run_btn = gr.Button("Generate", variant="primary")
        clear_btn = gr.Button("Clear")
        download_btn = gr.File(label="📥 Download as PNG", interactive=False)
        
    # Quantization options
    with gr.Accordion("Quantization Details", open=False):
        if device == "cuda" and HAS_BNB and quantization_enabled:
            quantization_info = gr.Markdown(f"""
            ### 4-bit Quantization (Q4KM/NF4)
            
            - Type: NF4 (Normalized Float 4-bit, equivalent to Q4_K/Q4KM)
            - Method: 4-bit quantization with double quantization
            - Compute Dtype: FP16
            - Memory savings: ~75% compared to FP16
            """)
        elif device == "mps" and quantization_enabled:
            quantization_info = gr.Markdown(f"""
            ### Custom MPS Quantization
            
            - Weights: 8-bit integer quantization
            - Activations: FP16 (half precision)
            - Method: Custom linear layer implementation for MPS
            - Memory savings: ~60% compared to FP32
            - Performance: Slightly slower but more memory efficient
            """)
        else:
            quantization_info = gr.Markdown(f"""
            ### Full Precision Mode
            
            Quantization is not enabled. Models are running in full precision ({dtype}).
            """)
        
    # System info
    with gr.Accordion("System Info", open=False):
        gr.Markdown(f"""
        - Device: {device.upper()}
        - PyTorch: {torch.__version__}
        - MPS Available: {torch.backends.mps.is_available()}
        - MPS Built: {torch.backends.mps.is_built()}
        - BnB Available: {HAS_BNB}
        - Quantization Enabled: {quantization_enabled}
        - Quantization Mode: {"CUDA Q4KM" if device == "cuda" and HAS_BNB and quantization_enabled else "MPS 8-bit" if device == "mps" and quantization_enabled else "None"}
        """)

    # Tips for MPS users
    if device == "mps":
        with gr.Accordion("Tips for Apple Silicon Users", open=True):
            gr.Markdown(f"""
            ### Optimizing Performance on Apple Silicon
            
            - Lower the inference steps (15-20) for faster generation
            - If you encounter out-of-memory errors, restart the application and try again
            - Close memory-intensive applications when running
            - If you need more VRAM, restart your computer to clear the memory
            - Smaller images (512x512) work better than larger ones
            """)

    # Event handlers
    input_image.change(fn=suggest_prompt_only, inputs=input_image, outputs=suggested_prompt)
    
    def clear_outputs():
        return None, None, ""
    
    clear_btn.click(fn=clear_outputs, inputs=[], outputs=[input_image, output_image, suggested_prompt])

    def process_and_save(input_img, instruction_override, cached_prompt, cfg_scale, steps):
        if input_img is None:
            return None, None
        
        result = edit_image(input_img, instruction_override, cached_prompt, cfg_scale, steps)
        if result is None:
            return None, None
            
        path = "edited_result.png"
        result.save(path, format="PNG")
        return result, path

    run_btn.click(
        fn=process_and_save,
        inputs=[input_image, user_instruction, suggested_prompt, slider_cfg, slider_steps],
        outputs=[output_image, download_btn]
    )

if __name__ == "__main__":
    print(f"Starting Gradio interface with device: {device}")
    demo.launch(
        debug=True, 
        server_name="0.0.0.0", 
        share=False,
        show_error=True
    )

bitsandbytes not found, CUDA-based quantization will not be available
PyTorch version: 2.8.0.dev20250412
MPS available: True
MPS built: True
Using MPS
Loading models...
MPS quantization enabled: custom 8-bit + FP16
Loading BLIP model...
Applying MPS quantization to BLIP model...
Loading VAE...
Applying MPS quantization to VAE...
Downloading COSXL model...
Creating pipeline on mps...


Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Applying MPS quantization to pipeline...
All models loaded
Starting Gradio interface with device: mps
* Running on local URL:  http://0.0.0.0:7860

To create a public link, set `share=True` in `launch()`.


Generating caption...
Error generating caption: Tensor for argument weight is on cpu but expected on mps
Generating caption...
Error generating caption: Tensor for argument weight is on cpu but expected on mps
Using prompt: 'red belt'
Error during image generation: Tensor for argument weight is on cpu but expected on mps
Using prompt: 'red belt'
Error during image generation: Tensor for argument weight is on cpu but expected on mps
Using prompt: 'red belt'
Error during image generation: Tensor for argument weight is on cpu but expected on mps
Using prompt: 'red belt'
Error during image generation: Tensor for argument weight is on cpu but expected on mps
Keyboard interruption in main thread... closing server.
