# Image Enhancement & Background Removal
This notebook provides tools to fix blurry/pixelated images and remove backgrounds cleanly.

**Features:**
- **Super Resolution (Upscaling):** Fixes blurry pixels and preserves transparency.
- **Background Removal:** Extracts subjects without crashing in complex environments.
- **Line-by-Line Documentation:** Beginner-friendly explanations for every step.

In [None]:
# Step 1: Install specialized AI tools for image processing
# Note: We use ipywidgets==7.7.1 to avoid a common rendering error in VS Code
%pip install -q transformers torch Pillow ipywidgets==7.7.1 numpy torchvision

In [None]:
# Import the core engine 'torch' for AI math
import torch
# 'Image' tool for opening and seeing pictures
from PIL import Image
import io
import numpy as np
# 'AutoModelForImageSegmentation' is a specialized tool for background removal
from transformers import AutoModelForImageSegmentation
# 'pipeline' for specific AI tasks
from transformers import pipeline
# UI tools to create buttons and uploaders
import ipywidgets as widgets
from IPython.display import display
from torchvision.transforms.functional import normalize

# Detect if there is a fast Graphics Card (GPU) available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Hardware detected: {device}")

In [None]:
# --- Defining the Models ---
BG_MODEL_ID = "briaai/RMBG-1.4"
UPSCALER_MODEL_ID = "caidas/swin2SR-classical-sr-x2-64"

# We load models only when needed to save your computer's memory

In [None]:
# --- Logic for Background Removal ---
def remove_background(img):
    print("Loading Background Removal Model...")
    model = AutoModelForImageSegmentation.from_pretrained(BG_MODEL_ID, trust_remote_code=True)
    model.to(device)
    model.eval()

    # Standardize image for the AI (1024x1024)
    original_size = img.size
    input_images = img.convert("RGB").resize((1024, 1024), Image.BILINEAR)
    input_images = np.array(input_images) / 255.0
    input_images = np.transpose(input_images, (2, 0, 1))
    input_images = torch.tensor(input_images, dtype=torch.float32).unsqueeze(0).to(device)
    input_images = normalize(input_images, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])

    # Run AI inference
    with torch.no_grad():
        result = model(input_images)
        # FIX: The model might return a list of tensors. We need the actual data tensor.
        if isinstance(result, list):
            preds = result[-1].sigmoid().cpu()
        else:
            preds = result.sigmoid().cpu()
    
    pred = preds[0].squeeze()
    mask_image = Image.fromarray((pred.numpy() * 255).astype('uint8'), mode='L').resize(original_size, Image.BILINEAR)
    
    img_rgba = img.convert("RGBA")
    img_rgba.putalpha(mask_image)
    return img_rgba

# --- Logic for Upscaling with Transparency Preservation ---
def upscale_image(img, keep_resolution=False):
    print("Loading HD Enhancer (preserving transparency)...")
    orig_size = img.size
    
    # 1. Handle Transparency: Separate the "Alpha" channel
    if img.mode == 'RGBA':
        rgb_part = img.convert("RGB")
        alpha_part = img.getchannel('A')
    else:
        rgb_part = img.convert("RGB")
        alpha_part = None

    # 2. Upscale the Colors with AI
    upscaler = pipeline("image-to-image", model=UPSCALER_MODEL_ID, device=0 if device == "cuda" else -1)
    result = upscaler(rgb_part)
    upscaled_rgb = result[0] if isinstance(result, list) else result
    
    # 3. Upscale the Transparency Mask (Lanczos resizing blends it perfectly)
    target_size = upscaled_rgb.size
    if alpha_part:
        upscaled_alpha = alpha_part.resize(target_size, Image.LANCZOS)
        final_image = upscaled_rgb.convert("RGBA")
        final_image.putalpha(upscaled_alpha)
    else:
        final_image = upscaled_rgb

    # 4. Resolve Size: Shrink back if requested while keeping pixel quality
    if keep_resolution:
        print(f"Finalizing at original size: {orig_size}")
        return final_image.resize(orig_size, Image.LANCZOS)
    
    return final_image

In [None]:
# --- Interaction Dashboard ---
uploader = widgets.FileUpload(accept='image/*', multiple=False, description="Upload")
task_selector = widgets.Dropdown(
    options=[('Upscale (Fix Pixels)', 'upscale'), ('Remove Background', 'remove_bg')],
    value='upscale', description='Task:',
)
resolution_check = widgets.Checkbox(value=False, description='Keep Original Resolution', indent=False)
button = widgets.Button(description="Fix My Image", button_style='success')
output = widgets.Output()

def on_click(b):
    with output:
        output.clear_output()
        if not uploader.value: return print("No photo uploaded!")
        
        val = list(uploader.value.values())[0] if isinstance(uploader.value, dict) else uploader.value[0]
        input_image = Image.open(io.BytesIO(val['content']))
        
        print(f"Action: {task_selector.label}...")
        try:
            if task_selector.value == 'upscale':
                final = upscale_image(input_image, keep_resolution=resolution_check.value)
            else:
                final = remove_background(input_image)
            
            print("Done! Opening result...")
            display(final)
        except Exception as e:
            print(f"Error: {e}")

button.on_click(on_click)
display(widgets.VBox([uploader, task_selector, resolution_check, button, output]))