# Data Augemntation - Difussion Model
A diffusion model gradually adds noise to an image and then learns to reverse this process to generate new, realistic images from pure noise.

In [None]:
import os
import shutil
import torch
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline

### Set the prompts to be used for each category

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True

device = "cuda" if torch.cuda.is_available() else "cpu"

# Define category-specific configurations
category_settings = {
  "cardboard": {
    "prompt": "A perfectly clean, brand-new cardboard box, neatly presented in {items}. It has sharp edges with no creases, dirt, or damage. The object is isolated on a seamless, bright white studio backdrop with soft, even lighting.",
    "items": ["rectangular box", "square box", "flat folded box", "open-top box", "closed shipping box", "tall vertical box"],
    "strength": 0.8,
    "guidance_scale": 9,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  },
  "glass": {
    "prompt": "A perfectly clean, brand-new {items} with crisp reflections and pristine glass. The object is isolated on a seamless, bright white studio backdrop with soft, even lighting.",
    "items": ["glass bottle", "glass jar", "drinking glass", "wine glass", "glass vase"],
    "strength": 0.6,
    "guidance_scale": 12,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  },
  "metal": {
    "prompt": "A detailed close-up shot of a {items}, without any dents, rust, or damage. The object showcases sharp details, isolated against a seamless, bright white studio backdrop with soft, even lighting.",
    "items": ["aluminum can", "tin can", "food can", "soda can", "drink can", "beverage can"],
    "strength": 0.8,
    "guidance_scale": 14,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  },
  "paper": {
    "prompt": "A single, clean printer {items} lying flat on a seamless, bright white studio backdrop. The edges are perfectly straight, and the surface is pristine, free of any folds, dirt, or creases.",
    "items": ["newspaper", "magazine", "printed document", "written paper sheet", "folded paper", "crumpled paper", "envelope"],
    "strength": 0.75,
    "guidance_scale": 14,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  },
  "plastic": {
    "prompt": "A single plastic {items}, perfectly shaped with a smooth, glossy surface, free of any scratches or dents. The object is placed on an isolated, seamless, bright white studio backdrop with soft, even lighting.",
    "items": ["water bottle", "beverage bottle", "soda bottle", "juice bottle", "milk bottle"],
    "strength": 0.75,
    "guidance_scale": 14,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  },
  "trash": {
    "prompt": "A single, dirty {items} on a plain surface. The object has stains, or slight tears. The objects are casually scattered, resembling everyday trash left behind after a meal or snack. The lighting is even, highlighting textures and reflections on the different materials.",
    "items": ["crumpled paper", "torn plastic wrappers", "empty snack pouches", "used napkins"],
    "strength": 0.8,
    "guidance_scale": 14,
    "image_size": [512, 512],
    "model": "runwayml/stable-diffusion-v1-5"
  }
}

### Process the dataset and start the image generation process using diffusion models
### Idea : The augmented images at each catergory must all be equal to majority_class*2

In [None]:
# Define dataset paths
input_dir = "../data/dataset_split/train"
output_dir = "../data/dataset_diffusion_balanced/train"


if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)

# Count images per category
class_counts = {cat: len(os.listdir(os.path.join(input_dir, cat))) for cat in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, cat))}

majority_class = max(class_counts, key=class_counts.get)
majority_size = class_counts[majority_class]
target_size = majority_size * 2  # Double the majority class

def preprocess_image(image_path, size):
    image = Image.open(image_path).convert("RGB")
    image = image.resize(size)
    return image

def is_black_image(image, mean_threshold=10):
    gray = image.convert("L")
    return np.array(gray).mean() < mean_threshold

for category, count in tqdm(class_counts.items(), desc="Balancing dataset"):
    settings = category_settings.get(category, category_settings["trash"])
    
    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        settings["model"],
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    ).to(device)
    
    class_path = os.path.join(input_dir, category)
    augmented_class_path = os.path.join(output_dir, category)
    os.makedirs(augmented_class_path, exist_ok=True)
    
    images = os.listdir(class_path)
    for img_name in images:
        shutil.copy(os.path.join(class_path, img_name), os.path.join(augmented_class_path, img_name))
    
    num_needed = target_size - count
    print(num_needed)
    max_attempts = num_needed * 4
    attempts = 0
    
    while num_needed > 0 and attempts < max_attempts:
        attempts += 1
        img_name = random.choice(images)
        input_image = preprocess_image(os.path.join(class_path, img_name), settings["image_size"])

        prompt_item = random.choice(settings.get("items", ["object"]))
        prompt = settings["prompt"].format(items=prompt_item)
        
        with torch.no_grad():
            result = pipe(
                prompt=prompt,
                image=input_image,
                strength=settings["strength"],
                guidance_scale=settings["guidance_scale"],
                num_images_per_prompt=1,
            )
        
        synthetic_image = result.images[0]
        if is_black_image(synthetic_image):
            continue
        
        output_filename = f"{os.path.splitext(img_name)[0]}_aug_{num_needed}.png"
        synthetic_image.save(os.path.join(augmented_class_path, output_filename))
        num_needed -= 1

print(f"Balanced dataset (Category-Specific Diffusion) saved at '{output_dir}'!")
