# EcomFruitAI - Fruit Image Generation with Diffusion Models

This notebook demonstrates the creation of synthetic fruit images using diffusion models trained on the Fruits-360 dataset.

In [None]:
# Install required packages
%pip install kagglehub diffusers transformers accelerate torch torchvision

# Import our custom modules
import sys
sys.path.append('..')

from ecomfruitai.dataset import download_fruit_dataset, create_datasets_and_loaders
from ecomfruitai.modeling.train import setup_models, train_model
from ecomfruitai.modeling.predict import generate_image, generate_multiple_images
from ecomfruitai.plots import show_generated_image, show_multiple_generated_images, visualize_dataset_samples
from ecomfruitai.config import DEVICE, GENERATION_CONFIG

print(f"Using device: {DEVICE}")

In [None]:
# Download and setup dataset
dataset_path = download_fruit_dataset()
train_loader, test_loader, descriptive_classes = create_datasets_and_loaders(dataset_path)

print(f"Train dataset: {len(train_loader.dataset)} images")
print(f"Test dataset: {len(test_loader.dataset)} images")
print(f"Number of descriptive classes: {len(descriptive_classes)}")

In [None]:
# Visualize dataset samples
visualize_dataset_samples(dataset_path, descriptive_classes, num_samples=10)

In [None]:
# Setup models
models = setup_models()
tokenizer, text_encoder, vae, unet, scheduler = models

In [None]:
# Train the model
def generate_fn(prompt, num_inference_steps=20):
    return generate_image(prompt, models, num_inference_steps)

trained_unet = train_model(train_loader, models, generate_fn)

In [None]:
# Generate test images
test_prompts = [
    "red apple, whole fruit, realistic photo",
    "green apple, whole fruit, realistic photo", 
    "yellow banana, whole fruit, realistic photo",
    "orange carrot, whole vegetable, realistic photo"
]

generated_images = generate_multiple_images(test_prompts, models, num_inference_steps=20)
show_multiple_generated_images(generated_images, test_prompts)

In [None]:
# Generate single high-quality image
prompt = "red apple, whole fruit, realistic photo"
high_quality_image = generate_image(prompt, models, num_inference_steps=50)
show_generated_image(high_quality_image, title=f"High Quality: {prompt}")