# Generative AI project about Data Augmentation
> *This is the notebook for the ninth Profession AI project about Generative AI module*

## Setup & Configuration

In [None]:
!git clone https://github.com/Silvano315/Gen-AI-for-Data-Augmentation.git

In [2]:
import os
os.chdir('/content/Gen-AI-for-Data-Augmentation')

In [None]:
!pwd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import Libraries

In [None]:
!pip install clean-fid
!pip install git+https://github.com/openai/CLIP.git
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q git+https://github.com/huggingface/transformers

In [6]:
from pathlib import Path
import random
import matplotlib.pyplot as plt
import torch
import json
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForCausalLM
from src.data.dataset import PetDatasetHandler
from src.captioning.caption_generator import CaptionGenerator
from src.captioning.captioning_blip_2 import BLIP2CaptionGenerator
from src.captioning.git_caption_generator import GITCaptionGenerator
from src.data.data_with_captions import PetDatasetWithCaptions
from src.generation.text_generation import TextVariationGenerator
from src.utils.logging import GANLogger
from src.generation.image_generator import GANConfig, ConditionalGAN
from src.training.callbacks import EarlyStopping, ModelCheckpoint, MetricsHistory
from src.evaluation.metrics import FIDScore, CLIPScore, MetricsTracker
from src.training.training import GANTrainer

## Initialize and load dataset without transforms for analysis

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Basic dataset information

In [None]:
info = handler.get_dataset_info()
print("Dataset Information:")
for key, value in info.items():
    print(f"{key}: {value}")

### Plot distributions and samples

In [None]:
handler.plot_class_distribution().show()

In [None]:
handler.visualize_samples(9).show()

### Get detailed image statistics


In [None]:
stats = handler.get_image_stats(sample_size=100)
print("\nImage Statistics:")
for category, values in stats.items():
    print(f"\n{category.upper()}:")
    for key, value in values.items():
        print(f"{key}: {value:.2f}")

### For training, load with transforms

In [None]:
train_transforms = handler.get_training_transforms()
train_dataset, test_dataset = handler.load_dataset(transform=train_transforms)

## Image Captioning

### Compare caption generators: 
1. **Blip**
2. **Blip-2**
3. **GIT**

In [None]:
# Configurations

def get_random_images(image_dir, count=5):
    """Randomly select images from the dataset."""
    image_paths = list(Path(image_dir).glob("*.jpg"))
    return random.sample(image_paths, min(count, len(image_paths)))

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

image_dir = "data/oxford-iiit-pet/images"
sample_images = get_random_images(image_dir, count=10)
print(f"Selected {len(sample_images)} random images")

In [None]:
# Test BLIP original model

print("Test di BLIP...")
blip_model = CaptionGenerator()
blip_captions = {}

for img_path in sample_images:
    caption = blip_model.generate_caption(str(img_path))
    blip_captions[str(img_path)] = caption
    print(f"BLIP - {img_path.name}: {caption}")

In [None]:
# Test BLIP-2 model 
# Be Carefull !! Blip-2 is high consuming and high memory requiring, run this cell if you have high computational resources.

print("\nTest di BLIP-2...")
blip2_model = BLIP2CaptionGenerator(model_name="Salesforce/blip2-opt-2.7b")
blip2_captions = {}

for img_path in sample_images:
    caption = blip2_model.generate_caption(str(img_path))
    blip2_captions[str(img_path)] = caption
    print(f"BLIP-2 - {img_path.name}: {caption}")

In [None]:
# Test GIT model

print("\nTest di GIT...")
git_model_name = "microsoft/git-base-coco"
processor_git = AutoProcessor.from_pretrained(git_model_name)
model_git = AutoModelForCausalLM.from_pretrained(git_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_git = model_git.to(device)

git_captions = {}

for img_path in sample_images:
    image = Image.open(img_path).convert("RGB")
    inputs_git = processor_git(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model_git.generate(
            pixel_values=inputs_git.pixel_values,
            max_length=50,
            num_beams=5
        )
    
    caption = processor_git.batch_decode(generated_ids, skip_special_tokens=True)[0]
    git_captions[str(img_path)] = caption
    print(f"GIT - {img_path.name}: {caption}")

In [None]:
# Visualize images with captions from three models (side by side)

from textwrap import wrap

def visualize_comparison(image_paths, blip_captions, blip2_captions, git_captions, wrap_width=30):
    """Visualizza il confronto tra le caption generate dai diversi modelli."""
    n_images = len(image_paths)
    
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 5 * n_images))
    
    if n_images > 0:
        axes[0, 0].set_title("BLIP", fontsize=14)
        axes[0, 1].set_title("BLIP-2", fontsize=14)
        axes[0, 2].set_title("GIT", fontsize=14)
    
    for idx, img_path in enumerate(image_paths):
        img_path_str = str(img_path)
        img = Image.open(img_path).convert("RGB")
        
        # BLIP
        axes[idx, 0].imshow(img)
        axes[idx, 0].axis('off')
        wrapped_caption = "\n".join(wrap(blip_captions[img_path_str], wrap_width))
        axes[idx][0].set_xlabel(wrapped_caption, fontsize = 12)

        # BLIP-2
        axes[idx][1].imshow(img)
        axes[idx][1].axis('off')
        wrapped_caption = "\n".join(wrap(blip2_captions[img_path_str], wrap_width))
        axes[idx][1].set_xlabel(wrapped_caption, fontsize=12)
        
        # GIT
        axes[idx][2].imshow(img)
        axes[idx][2].axis('off')
        wrapped_caption = "\n".join(wrap(git_captions[img_path_str], wrap_width))
        axes[idx][2].set_xlabel(wrapped_caption, fontsize=12)

    plt.tight_layout()
    plt.show()


visualize_comparison(sample_images, blip_captions, blip2_captions, git_captions)

### Initialize caption generator

In [None]:
caption_gen = GITCaptionGenerator()

### Load Dataset (If you haven’t done it before)

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Test single image caption generation

In [None]:
sample = random.randint(0, len(train_dataset)-1)
sample_image_path = Path(train_dataset._images[sample])
label = train_dataset.classes[train_dataset[sample][1]]
caption = caption_gen.generate_caption(sample_image_path, label, max_length = 50)
print(f"Sample caption: {caption}")

In [None]:
plt.figure(figsize=(10, 10))
plt.title(f"{label}")
fig = plt.imshow(train_dataset[sample][0])

### Process a batch of images


In [None]:
batch_size = 4
image_paths = [Path(img) for img in train_dataset._images[:10]]
labels = [train_dataset.classes[train_dataset[i][1]] for i in range(10)]
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

### Process train and test datasets

In [None]:
batch_size = 4
image_paths = [Path(img) for img in test_dataset._images]
labels = [test_dataset.classes[test_dataset[i][1]] for i in range(len(image_paths))]

In [None]:
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

caption_gen.save_captions(save_dir / 'captions_testdataset.json')

### Save & Load captions


In [None]:
save_dir = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions')
save_dir.mkdir(parents=True, exist_ok=True)
caption_gen.save_captions(save_dir / 'captions_testdataset.json')

In [None]:
caption_gen.load_captions(save_dir / 'captions.json')

### Visualize results

In [None]:
caption_gen.visualize_captions(num_samples=4)

### Print some statistics

In [None]:
print(f"\nTotal captions generated: {len(caption_gen.captions_cache)}")
print("\nSample of generated captions:")
for path, caption in list(caption_gen.captions_cache.items())[:3]:
    print(f"\nImage: {Path(path).name}")
    print(f"Caption: {caption}")

## Text Generations with Flan-T5 to increase numbers of captions

#### Initialize Text Generator

In [None]:
generator = TextVariationGenerator(model_name="google/flan-t5-large")

#### How it works
> Comparison of different types of prompts

In [None]:
results = generator.test_prompt_types("A white dog sitting on a brown chair", temperature=0.95, num_variations=5)

#### Test Quality for Few-Shot Prompting

In [None]:
original_caption = "c"

variations = generator.test_few_shot_quality(
    original_caption,
    num_variations=3,
    temperature=1
)

#### Load Captions saved and Datasets

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json')

with open(caption_file, 'r') as f:
    captions = json.load(f)

In [None]:
# Load Data

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

#### Test the variations on random captions

In [None]:
sample_captions = dict(random.sample(list(captions.items()), 3))

for img_path, caption in sample_captions.items():
    print(f"\Image: {Path(img_path).name}")
    print(f"Original Caption: {caption}")
    
    variations = generator.generate_variations(
        caption,
        num_variations=3,
        temperature=1,
        prompt_type="few-shot"
    )
    
    print("New Versions:")
    for i, var in enumerate(variations):
        print(f"{i+1}. {var}")

In [None]:
# Display images with caption variations

def visualize_caption_variations(image_path, caption, variations):
    """View an image with the original caption and variations."""
    plt.figure(figsize=(10, 12))
    
    img = Image.open(image_path).convert("RGB")
    plt.subplot(1, 1, 1)
    plt.imshow(img)
    plt.axis('off')
    
    title = f"Original: {caption}\n\nVariations:\n"
    for i, var in enumerate(variations):
        title += f"{i+1}. {var}\n"
    
    plt.title(title, fontsize=10, loc='left')
    plt.tight_layout()
    plt.show()

sample_img_path = list(sample_captions.keys())[0]
visualize_caption_variations(
    sample_img_path,
    sample_captions[sample_img_path],
    variations
)

#### Process entire caption file

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

print("Generation of variations for all captions...")
variations_dict = generator.process_caption_file(
    caption_file=caption_file,
    output_file=output_file,
    variations_per_caption=3,
    class_balancing=True,
    target_per_class=150,
    min_variations=1,
    max_variations=5,
    prompt_type="few-shot",
    temperature=1
)

print(f"Generate variations for {len(variations_dict)} caption")
print(f"File saved in: {output_file}")

#### Evaluate captions generated

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

with open(output_file, 'r') as f:
    variations_data = json.load(f)

variations_counts = {img_path: len(vars_list) for img_path, vars_list in variations_data.items()}

# Statistics
avg_variations = sum(variations_counts.values()) / len(variations_counts)
max_variations = max(variations_counts.values())
min_variations = min(variations_counts.values())
total_variations = sum(variations_counts.values())

print(f"Statistics of variations:")
print(f"Total original captions: {len(variations_counts)}")
print(f"Total generated variations: {total_variations}")
print(f"Mean variation per caption: {avg_variations:.2f}")
print(f"Max variations per caption: {max_variations}")
print(f"Min variations per caption: {min_variations}")

In [None]:
# Extract the breed from the captions and calculate statistics by breed

def extract_breed(caption):
    import re
    match = re.search(r"This is an? ([^\.]+)\.", caption)
    return match.group(1) if match else None

breed_counts_original = {}
for caption in captions.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_original[breed] = breed_counts_original.get(breed, 0) + 1

breed_counts_variations = breed_counts_original.copy()
for var_list in variations_data.values():
    for var in var_list:
        breed = extract_breed(var)
        if breed:
            breed_counts_variations[breed] = breed_counts_variations.get(breed, 0) + 1

breeds = sorted(breed_counts_original.keys())
original_counts = [breed_counts_original[breed] for breed in breeds]
total_counts = [breed_counts_variations[breed] for breed in breeds]

plt.figure(figsize=(15, 10))
x = range(len(breeds))
plt.bar(x, original_counts, width=0.4, align='edge', label='Original')
plt.bar([i+0.4 for i in x], total_counts, width=0.4, align='edge', label='After Variations')
plt.xticks([i+0.2 for i in x], breeds, rotation=90)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Dataset Augmentation by Breed')
plt.legend()
plt.tight_layout()
plt.show()

#### Balanced Subset selection

In [None]:
balanced_output = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/balanced_flan_t5_variations.json')
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json')

balanced_subset = generator.select_balanced_subset(
    caption_file=caption_file,
    variations_file=output_file,
    output_file=balanced_output,
    target_per_class=150
)

print(f"Created balanced subset with {len(balanced_subset)} captions")
print(f"File saved in: {balanced_output}")

In [None]:
breed_counts_balanced = {}
for caption in balanced_subset.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_balanced[breed] = breed_counts_balanced.get(breed, 0) + 1

breeds = sorted(breed_counts_balanced.keys())
balanced_counts = [breed_counts_balanced.get(breed, 0) for breed in breeds]

plt.figure(figsize=(15, 10))
plt.bar(breeds, balanced_counts)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Balanced Subset by Breed')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

## Image Generation Diffusion Models

### Zero-Shot Testing

In [None]:
# Import Libraries

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, AutoPipelineForText2Image, FluxPipeline
from PIL import Image
import matplotlib.pyplot as plt
import json
import random
from pathlib import Path
import numpy as np

In [None]:
# Function to load and test different models
def test_diffusion_model(model_id, caption, num_images=1, seed=None):
    """
    Test a diffusion model with a specific caption.
    
    Args:
        model_id (str): Model ID to be tested
        caption (str): Caption to use
        num_images (int): Numbers of image to generate
        seed (int, optional): Seed for the generation
        
    Returns:
        list: List of generated images
    """
    print(f"Loading the model: {model_id}")

    if "xl" in model_id.lower():
        pipe = StableDiffusionXLPipeline.from_pretrained(
            model_id, 
            torch_dtype=torch.float16, 
            use_safetensors=True, 
            variant="fp16"
        )
    elif "kandinsky" in model_id.lower():
        pipe = AutoPipelineForText2Image.from_pretrained(
                    "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
                  ).to("cuda")
    elif "flux" in model_id.lower():
        pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 
        )
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
    
    pipe = pipe.to("cuda")

    # Save memory
    pipe.enable_attention_slicing()
    if hasattr(pipe, 'enable_vae_slicing'):
        pipe.enable_vae_slicing()

    breed_info = ""
    if " - This is " in caption[0]:
        parts = caption[0].split(" - This is a ")
        cleaned_caption = parts[0]
        breed_info = parts[1].strip(".")
        prompt = f"A high-quality photo of a {breed_info}, {cleaned_caption}"
    else:
        prompt = f"A high-quality photo of {caption}"

    print(f"Prompt: {prompt}")

    images = []
    for i in range(num_images):
        generator = None
        if seed is not None:
            generator = torch.Generator(device = "cuda").manual_seed(seed + i)

        if "kandinsky" in model_id.lower():
            image = pipe(prompt, generator = generator).images[0]
        elif "flux" in model_id.lower():
            image = pipe(prompt,
                          height=512,
                          width=512,
                          guidance_scale=3.5,
                          num_inference_steps=50,
                          max_sequence_length=512,
                          generator=generator
                        ).images[0]
        else:
            image = pipe(
                prompt, 
                guidance_scale=7.5,
                num_inference_steps=30,
                generator=generator
            ).images[0]
        
        images.append(image)

    fig, axes = plt.subplots(1, num_images, figsize=(5*num_images, 5))
    if num_images == 1:
        axes = [axes]
    
    for i, img in enumerate(images):
        axes[i].imshow(np.array(img))
        axes[i].set_title(f"Image {i+1}")
        axes[i].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    return images

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/balanced_flan_t5_variations.json')
with open(caption_file, 'r') as f:
    balanced_captions = json.load(f)

sample_captions = random.sample(list(balanced_captions.values()), 3)
print("Caption selected for this test:")
for i, caption in enumerate(sample_captions):
    print(f"{i+1}. {caption}")

In [None]:
models_to_test = [
    "stabilityai/stable-diffusion-2-1-base",
    #"stabilityai/stable-diffusion-xl-base-1.0",
    "runwayml/stable-diffusion-v1-5",
    "kandinsky-community/kandinsky-2-2-decoder",
    "black-forest-labs/FLUX.1-dev"
]

for caption in sample_captions:
    test_diffusion_model(models_to_test[0], caption)

### LoRA Fine Tuning
> Give a look [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README.md)

> This [link](https://huggingface.co/docs/diffusers/v0.13.0/en/training/lora) to understand LoRA and an example of it

In [None]:
# Diffusers installation

!git clone https://github.com/huggingface/diffusers
!cd diffusers
!pip install /content/Gen-AI-for-Data-Augmentation/diffusers

In [None]:
# Import Libraries 

import os
import json
import random
import logging
from pathlib import Path
import shutil
import subprocess

In [None]:
# Setup logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
def prepare_dataset(
    captions_file,
    images_dir,
    output_dir,
    max_samples_per_breed=None,
    min_samples_per_breed=3,
    target_total_samples=100
):
    """
    Prepare dataset for LoRA Stable Diffusion fine-tuning.
    
    Args:
        captions_file: Path to the JSON file with captions
        images_dir: Directory containing the images
        output_dir: Output directory for the prepared dataset
        max_samples_per_breed: Maximum number of samples per breed
        min_samples_per_breed: Minimum number of samples per breed
        target_total_samples: Total target number of samples
    """

    images_output_dir = Path(output_dir)
    os.makedirs(images_output_dir, exist_ok=True)
    
    with open(captions_file, "r") as f:
        captions = json.load(f)
    
    breed_samples = {}
    for img_path, caption in captions.items():
        if " - This is a " in caption:
            breed = caption.split(" - This is a ")[1].strip(".")
            if breed not in breed_samples:
                breed_samples[breed] = []
            
            img_name = Path(img_path).name
            full_img_path = Path(images_dir) / img_name

            if full_img_path.exists():
                breed_samples[breed].append((str(full_img_path), caption))

    
    # Select a balanced subset
    selected_samples = []
    metadata = []
    
    for breed, samples in breed_samples.items():
        num_samples = min(
            len(samples),
            max_samples_per_breed if max_samples_per_breed else len(samples)
        )
        num_samples = max(num_samples, min_samples_per_breed)
        
        # Select random sample
        breed_selection = random.sample(samples, min(num_samples, len(samples)))
        selected_samples.extend(breed_selection)
    
    # Limit with target_total_samples
    if target_total_samples and len(selected_samples) > target_total_samples:
        random.shuffle(selected_samples)
        selected_samples = selected_samples[:target_total_samples]
    
    print(f"Selected {len(selected_samples)} samples from {len(breed_samples)} breed")
    
    # Create a metadata.jsonl as required by LoRA
    for i, (img_path, caption) in enumerate(selected_samples):
        dest_filename = f"image_{i:06d}.jpg"
        dest_path = images_output_dir / dest_filename
        shutil.copy(img_path, dest_path)
        
        metadata.append({
            "file_name": f"images/{dest_filename}",
            "text": caption
        })
    
    with open(Path(output_dir) / "metadata.jsonl", "w") as f:
        for item in metadata:
            f.write(json.dumps(item) + "\n")
    
    print(f"Dataset prepared in {output_dir}")
    return output_dir

In [None]:
def run_lora_training(
    dataset_dir,
    output_dir,
    base_model="runwayml/stable-diffusion-v1-5",
    resolution=512,
    train_batch_size=1,
    max_train_steps=1000,
    validation_prompts=None
):
    """
    Fine-tuning LoRA using the Hugging Face script.
    """
    # Correct diffusers version if you didn't run it before
    subprocess.run(["pip", "install", "-q", "git+https://github.com/huggingface/diffusers.git"])
    
    # We need to locally download the most up-to-date version of train_text_to_image_lora.py
    subprocess.run(["wget", "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/text_to_image/train_text_to_image_lora.py", "-O", "train_text_to_image_lora.py"])
    
    # TO-DO: improve validation_prompts if validation_prompts is None 
    if validation_prompts is None:
        validation_prompts = [
            "A high-quality photo of a dog",
            "A high-quality photo of a cat",
            "A close-up portrait of a pet"
        ]
    
    # Command to use train_text_to_image_lora.py
    cmd = [
        "accelerate", "launch",
        "train_text_to_image_lora.py",
        f"--pretrained_model_name_or_path={base_model}",
        f"--train_data_dir={dataset_dir}",
        f"--output_dir={output_dir}",
        f"--resolution={resolution}",
        "--center_crop",
        "--random_flip",
        f"--train_batch_size={train_batch_size}",
        "--gradient_accumulation_steps=4",
        "--gradient_checkpointing",
        "--mixed_precision=fp16",
        f"--max_train_steps={max_train_steps}",
        "--learning_rate=1e-04",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--validation_epochs=100",
        f"--validation_prompt=\"{'; '.join(validation_prompts)}\"",
        "--seed=42",
        "--checkpointing_steps=500"
    ]
    
    print(f"Start of LoRA training...")
    print(f"Command: {' '.join(cmd)}")
    process = subprocess.run(cmd, capture_output=True, text=True)
    print(f"Standard output: {process.stdout}")
    print(f"Standard error: {process.stderr}")
    print(f"Training completed. Results in {output_dir}")

In [None]:
# Configuratios + Prepare dataset + Create validation prompts

captions_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json"
images_dir = "/content/Gen-AI-for-Data-Augmentation/data/oxford-iiit-pet/images"
dataset_dir = "/content/Gen-AI-for-Data-Augmentation/lora_dataset"
output_dir = "/content/Gen-AI-for-Data-Augmentation/lora_model"

prepare_dataset(
    captions_file,
    images_dir,
    dataset_dir,
    max_samples_per_breed=10,
    min_samples_per_breed=3,
    target_total_samples=37*10
)

def select_validation_prompts_from_variations(variations_file, num_prompts=5, seed=42):
    """Select some validation prompts from the generated variations."""
    with open(variations_file, 'r') as f:
        variations = json.load(f)
    
    all_captions = []
    for variations_list in variations.values():
        all_captions.extend(variations_list)
    
    # Random captions selection as validation prompts
    random.seed(seed)
    selected_prompts = random.sample(all_captions, min(num_prompts, len(all_captions)))
    
    return selected_prompts

In [None]:
# Start LoRA training

variations_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json"
validation_prompts = select_validation_prompts_from_variations(variations_file)

run_lora_training(
    dataset_dir,
    output_dir,
    base_model = "runwayml/stable-diffusion-v1-5",
    resolution = 128,
    train_batch_size = 1,
    max_train_steps=800,
    validation_prompts=validation_prompts
)

### Inference 

In [None]:
# Import Libraries

import torch
import json
import random
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from typing import List, Union, Optional

In [None]:
def load_lora_model(
    base_model_id: str = "runwayml/stable-diffusion-v1-5",
    lora_weights_path: Optional[str] = None,
    device: str = "cuda",
    torch_dtype = torch.float16
):
    """
    Load Stable Diffusion model with LoRA weights.
    
    Args:
        base_model_id (str): base Stable Diffusion model ID 
        lora_weights_path (str): Path to trained LoRA weights
        device (str): Device to use (cuda o cpu)
        torch_dtype: Precision to use (float16 o float32)
    
    Returns:
        pipeline: Pipeline SD configured with LoRA
    """
    print(f"Loading base model {base_model_id}...")
    
    pipeline = StableDiffusionPipeline.from_pretrained(
        base_model_id,
        torch_dtype=torch_dtype
    )
    
    # Optimize the pipeline
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    
    if lora_weights_path:
        print(f"Loading LoRA weights from {lora_weights_path}...")
        pipeline.unet.load_attn_procs(lora_weights_path)
        print("LoRA weights loaded with success!")
    
    pipeline.to(device)
    
    # Optimize memory
    pipeline.enable_attention_slicing()
    
    return pipeline

In [None]:
def generate_images(
    pipeline,
    prompts: Union[str, List[str]],
    output_dir: Optional[str] = None,
    num_images_per_prompt: int = 1,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 30,
    seed: Optional[int] = None,
    width: int = 512,
    height: int = 512,
    save_images: bool = True,
    display_images: bool = True,
    batch_size: int = 1
):
    """
    Generate images from the model with LoRA using specified prompts.
    
    Args:
        pipeline: Stable Diffusion Pipeline
        prompts: List of prompts or single prompt
        output_dir: Directory where to save images
        num_images_per_prompt: Number of images to be generated for each prompt
        guidance_scale: Guidance factor (higher = more faithful to the text)
        num_inference_steps: Number of inference steps
        seed: Seed for the generation
        width: Image width
        height: Image height
        save_images: Whether to save images on disk
        display_images: Whether to display images
        batch_size: Batch size for generation
    
    Returns:
        Dict[str, List[Image.Image]]: Dictionary that maps prompts to generated images
    """
    if isinstance(prompts, str):
        prompts = [prompts]
    
    if save_images and output_dir:
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
    
    results = {}
    all_images = []
    all_prompts = []

    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        
        generator = None
        if seed is not None:
            generator = torch.Generator(device=pipeline.device).manual_seed(seed)
            seed += 1
        
        print(f"Batch Generation {i//batch_size + 1}/{(len(prompts)-1)//batch_size + 1}...")
        batch_results = pipeline(
            batch_prompts,
            num_images_per_prompt=num_images_per_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            width=width,
            height=height
        )

        for j, prompt in enumerate(batch_prompts):
            images = batch_results.images[j*num_images_per_prompt:(j+1)*num_images_per_prompt]
            results[prompt] = images
            all_images.extend(images)
            all_prompts.extend([prompt] * num_images_per_prompt)
            
            if save_images and output_dir:
                for k, img in enumerate(images):
                    prompt_hash = abs(hash(prompt)) % 10000
                    img_path = output_path / f"gen_{prompt_hash}_{i}_{j}_{k}.png"
                    img.save(img_path)
                    
                    with open(output_path / f"gen_{prompt_hash}_{i}_{j}_{k}.txt", "w") as f:
                        f.write(prompt)

    if display_images:
        n_cols = min(5, len(all_images))
        n_rows = (len(all_images) + n_cols - 1) // n_cols
        
        plt.figure(figsize=(n_cols * 5, n_rows * 6))
        
        for i, (image, prompt) in enumerate(zip(all_images, all_prompts)):
            plt.subplot(n_rows, n_cols, i + 1)
            plt.imshow(np.array(image))
            plt.title(prompt[:50] + "..." if len(prompt) > 50 else prompt, fontsize=8)
            plt.axis("off")
            
        plt.tight_layout()
        plt.show()
    
    return results

In [None]:
def generate_from_variations(
    pipeline,
    variations_file: str,
    output_dir: str,
    num_samples: int = 5,
    num_images_per_prompt: int = 1,
    seed: Optional[int] = None,
    random_variations: bool = True
):
    """
    Generate images from the caption variations in the specified file.
    
    Args:
        pipeline: Stable Diffusion Pipeline
        variations_file: JSON file of variations
        output_dir: Directory where to save images
        num_samples: Number of variations to be used
        num_images_per_prompt: Number of images per prompt
        seed: Seed for the generation
        random_variations: Whether to select random variations or first ones
    
    Returns:
        Dict[str, List[Image.Image]]: Dictionary that maps prompts to generated images
    """
    with open(variations_file, "r") as f:
        variations_data = json.load(f)
    
    all_variations = []
    for img_path, variations in variations_data.items():
        all_variations.extend(variations)
    
    if random_variations:
        selected_variations = random.sample(all_variations, min(num_samples, len(all_variations)))
    else:
        selected_variations = all_variations[:num_samples]
    
    return generate_images(
        pipeline,
        selected_variations,
        output_dir,
        num_images_per_prompt=num_images_per_prompt,
        seed=seed
    )

In [None]:
# Configurations + Load Model + Test with presonalized prompt

base_model_id = "runwayml/stable-diffusion-v1-5"
lora_weights_path = "/content/drive/MyDrive/outputs_master_ProfAI/lora_model/pytorch_lora_weights.safetensors"
variations_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json"
output_dir = "/content/drive/MyDrive/outputs_master_ProfAI/generated_images"

pipeline = load_lora_model(
    base_model_id=base_model_id,
    lora_weights_path=lora_weights_path
)

test_prompts = [
    "A high-quality photograph of a British Shorthair cat sitting on a windowsill",
    "A detailed image of a Beagle dog playing in a park",
    "A professional photo of a Maine Coon cat with fluffy fur"
]

generate_images(
    pipeline,
    test_prompts,
    output_dir=output_dir + "/test_prompts",
    num_images_per_prompt=1,
    seed=42
)

In [None]:
# Generate images from variations previously generated 

generate_from_variations(
    pipeline,
    variations_file,
    output_dir=output_dir + "/variations",
    num_samples=10,
    num_images_per_prompt=1,
    seed=42
)

### Unique Class

In [None]:
import os
import torch
import json
import random
import shutil
import subprocess
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Optional, Tuple, Any
from diffusers import (
    StableDiffusionPipeline, 
    StableDiffusionXLPipeline,
    DPMSolverMultistepScheduler,
    DDPMScheduler,
    AutoPipelineForText2Image, 
    FluxPipeline
)

class DiffusionModelManager:
    """
    Unified manager class for diffusion models that handles:
    - Zero-shot testing with different models
    - Dataset preparation for LoRA fine-tuning
    - LoRA fine-tuning process
    - Inference with fine-tuned models
    - Image generation from text variations
    
    This class centralizes all functionality related to diffusion models
    in a single interface for easier workflow management.
    """
    
    def __init__(
        self,
        base_models_dir: Optional[str] = None,
        output_dir: str = "diffusion_output",
        device: str = "cuda",
        default_model: str = "runwayml/stable-diffusion-v1-5"
    ):
        """
        Initialize the diffusion model manager.
        
        Args:
            base_models_dir: Directory to store cached models
            output_dir: Directory for outputs (images, logs, checkpoints)
            device: Device to use for inference/training (cuda or cpu)
            default_model: Default model ID to use
        """
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.models_dir = Path(base_models_dir) if base_models_dir else None
        self.default_model = default_model

        self.current_pipeline = None
        self.current_model_id = None
        
        (self.output_dir / "zero_shot").mkdir(exist_ok=True)
        (self.output_dir / "fine_tuned").mkdir(exist_ok=True)
        (self.output_dir / "datasets").mkdir(exist_ok=True)
        (self.output_dir / "lora_models").mkdir(exist_ok=True)
        
        print(f"DiffusionModelManager initialized with device: {self.device}")
        print(f"Output directory: {self.output_dir}")

    def test_diffusion_model(
        self,
        model_id: str,
        caption: str,
        num_images: int = 1,
        seed: Optional[int] = None,
        output_subdir: Optional[str] = None,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 30,
        width: int = 512,
        height: int = 512,
        save_images: bool = True
    ) -> List[Image.Image]:
        """
        Test a diffusion model with a specific caption in zero-shot setting.
        
        Args:
            model_id: Hugging Face model ID
            caption: Text caption for image generation
            num_images: Number of images to generate
            seed: Random seed for reproducibility
            output_subdir: Subdirectory to save images
            guidance_scale: Classifier-free guidance scale
            num_inference_steps: Number of denoising steps
            width: Image width
            height: Image height
            save_images: Whether to save generated images
            
        Returns:
            List of generated PIL images
        """
        print(f"Testing model: {model_id}")

        if "xl" in model_id.lower():
            pipe = StableDiffusionXLPipeline.from_pretrained(
                model_id, 
                torch_dtype=torch.float16, 
                use_safetensors=True, 
                variant="fp16"
            )
        elif "kandinsky" in model_id.lower():
            pipe = AutoPipelineForText2Image.from_pretrained(
                        "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
                    ).to("cuda")
        elif "flux" in model_id.lower():
            pipe = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 
            )
        else:
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16
            )
        
        pipe = pipe.to("cuda")

        # Save memory
        pipe.enable_attention_slicing()
        if hasattr(pipe, 'enable_vae_slicing'):
            pipe.enable_vae_slicing()

        breed_info = ""
        if " - This is " in caption[0]:
            parts = caption[0].split(" - This is a ")
            cleaned_caption = parts[0]
            breed_info = parts[1].strip(".")
            prompt = f"A high-quality photo of a {breed_info}, {cleaned_caption}"
        else:
            prompt = f"A high-quality photo of {caption}"

        print(f"Prompt: {prompt}")

        images = []
        for i in range(num_images):
            generator = None
            if seed is not None:
                generator = torch.Generator(device = "cuda").manual_seed(seed + i)

            if "kandinsky" in model_id.lower():
                image = pipe(prompt, generator = generator).images[0]
            elif "flux" in model_id.lower():
                image = pipe(prompt,
                             guidance_scale=guidance_scale,
                            num_inference_steps=num_inference_steps,
                            generator=generator,
                            width=width,
                            height=height,
                            max_sequence_length=512
                            ).images[0]
            else:
                image = pipe(
                    prompt, 
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator
                ).images[0]
            
            images.append(image)

        if save_images:
            save_dir = self.output_dir / "zero_shot"
            if output_subdir:
                save_dir = save_dir / output_subdir
                
            save_dir.mkdir(parents=True, exist_ok=True)
            
            model_name = model_id.split("/")[-1]
            for i, img in enumerate(images):
                img_path = save_dir / f"{model_name}_{i}_seed{seed}.png"
                img.save(img_path)
                
                with open(save_dir / f"{model_name}_{i}_seed{seed}.txt", "w") as f:
                    f.write(prompt)

        fig, axes = plt.subplots(1, num_images, figsize=(5*num_images, 5))
        if num_images == 1:
            axes = [axes]
        
        for i, img in enumerate(images):
            axes[i].imshow(np.array(img))
            axes[i].set_title(f"Image {i+1}")
            axes[i].axis("off")
        
        plt.tight_layout()
        plt.show()
        
        return images

    def prepare_dataset(
        self,
        captions_file: Union[str, Path],
        images_dir: Union[str, Path],
        output_name: Optional[str] = None,
        max_samples_per_breed: Optional[int] = None,
        min_samples_per_breed: int = 3,
        target_total_samples: Optional[int] = None,
        class_field: str = "breed",
        resolution: int = 512
    ) -> Path:
        """
        Prepare a dataset for LoRA fine-tuning.
        
        Args:
            captions_file: Path to JSON file with captions
            images_dir: Directory containing images
            output_name: Name of output dataset directory
            max_samples_per_breed: Maximum samples per class
            min_samples_per_breed: Minimum samples per class
            target_total_samples: Target total number of samples
            class_field: Field name for class information extraction
            resolution: Target resolution for images
            
        Returns:
            Path to prepared dataset
        """
        if output_name:
            output_dir = self.output_dir / "datasets" / output_name
        else:
            timestamp = Path(captions_file).stem
            output_dir = self.output_dir / "datasets" / f"dataset_{timestamp}"
            
        output_dir.mkdir(parents=True, exist_ok=True)
        
        with open(captions_file, "r") as f:
            captions = json.load(f)
                
        breed_samples = {}
        for img_path, caption in captions.items():
            if " - This is a " in caption:
                breed = caption.split(" - This is a ")[1].strip(".")

                if not breed:
                    continue

                if breed not in breed_samples:
                    breed_samples[breed] = []
                
                img_name = Path(img_path).name
                full_img_path = Path(images_dir) / img_name

                if full_img_path.exists():
                    breed_samples[breed].append((str(full_img_path), caption))
                else:
                    print(f"Warning: Image not found: {img_path}")
                    continue

        
        # Select a balanced subset
        selected_samples = []
        metadata = []
        
        for breed, samples in breed_samples.items():
            num_samples = min(
                len(samples),
                max_samples_per_breed if max_samples_per_breed else len(samples)
            )
            num_samples = max(num_samples, min_samples_per_breed)
            
            # Select random sample
            breed_selection = random.sample(samples, min(num_samples, len(samples)))
            selected_samples.extend(breed_selection)
        
        # Limit with target_total_samples
        if target_total_samples and len(selected_samples) > target_total_samples:
            random.shuffle(selected_samples)
            selected_samples = selected_samples[:target_total_samples]
        
        print(f"Selected {len(selected_samples)} samples from {len(breed_samples)} breed")
        
        # Create a metadata.jsonl as required by LoRA
        for i, (img_path, caption) in enumerate(selected_samples):
            dest_filename = f"image_{i:06d}.jpg"
            dest_path = output_dir / dest_filename
            #shutil.copy(img_path, dest_path)

            img = Image.open(img_path).convert("RGB")
            
            # Resize if needed
            if resolution:
                # Center crop to square while maintaining aspect ratio
                width, height = img.size
                min_dim = min(width, height)
                left = (width - min_dim) // 2
                top = (height - min_dim) // 2
                right = left + min_dim
                bottom = top + min_dim
                
                img = img.crop((left, top, right, bottom))
                img = img.resize((resolution, resolution), Image.LANCZOS)
                
            # Save image
            img.save(dest_path)
            
            metadata.append({
                "file_name": dest_filename,
                "text": caption
            })
        
        metadata_path = output_dir / "metadata.jsonl"
        with open(metadata_path, "w") as f:
            for item in metadata:
                f.write(json.dumps(item) + "\n")
        
        print(f"Dataset prepared in {output_dir}")
        print(f"Metadata saved to: {metadata_path}")

        return output_dir

    def select_validation_prompts_from_variations(
        self, 
        variations_file: Union[str, Path], 
        num_prompts: int = 5, 
        seed: int = 42
    ) -> List[str]:
        """
        Select validation prompts from a variations file.
        
        Args:
            variations_file: Path to variations JSON file
            num_prompts: Number of prompts to select
            seed: Random seed for selection
            
        Returns:
            List of selected prompts
        """
        with open(variations_file, 'r') as f:
            variations = json.load(f)
        
        all_captions = []
        for variations_list in variations.values():
            all_captions.extend(variations_list)
        
        random.seed(seed)
        selected_prompts = random.sample(all_captions, min(num_prompts, len(all_captions)))
        
        return selected_prompts

    def run_lora_training(
        self,
        dataset_dir: Union[str, Path],
        output_name: Optional[str] = None,
        base_model: Optional[str] = None,
        resolution: int = 512,
        train_batch_size: int = 1,
        max_train_steps: int = 1000,
        learning_rate: float = 1e-4,
        validation_prompts: Optional[List[str]] = None,
        rank: int = 4
    ) -> Path:
        """
        Run LoRA fine-tuning on a prepared dataset.
        
        Args:
            dataset_dir: Directory with prepared dataset
            output_name: Name for the output directory
            base_model: Base model ID
            resolution: Image resolution
            train_batch_size: Batch size for training
            max_train_steps: Maximum training steps
            learning_rate: Learning rate
            validation_prompts: Prompts for validation
            rank: LoRA rank parameter
            lora_alpha: LoRA alpha parameter
            
        Returns:
            Path to fine-tuned model
        """
        dataset_dir = Path(dataset_dir)
        
        if output_name:
            output_dir = self.output_dir / "lora_models" / output_name
        else:
            timestamp = dataset_dir.name
            output_dir = self.output_dir / "lora_models" / f"lora_{timestamp}"
            
        output_dir.mkdir(parents=True, exist_ok=True)
        
        if base_model is None:
            base_model = self.default_model
            
        print(f"Starting LoRA training with base model: {base_model}")
        print(f"Dataset: {dataset_dir}")
        print(f"Output directory: {output_dir}")

        # Correct diffusers version if you didn't run it before
        try:
            subprocess.run(["pip", "install", "-q", "git+https://github.com/huggingface/diffusers.git"])
            subprocess.run(["pip", "install", "-q", "accelerate", "transformers", "bitsandbytes", "datasets"])
        except Exception as e:
            print(f"Warning: Failed to install dependencies - {e}")
            
        # We need to locally download the most up-to-date version of train_text_to_image_lora.py
        script_path = Path.cwd() / "train_text_to_image_lora.py"
        try:
            subprocess.run([
                "wget", 
                "https://raw.githubusercontent.com/huggingface/diffusers/main/examples/text_to_image/train_text_to_image_lora.py", 
                "-O", 
                str(script_path)
                # "train_text_to_image_lora.py"
            ])
        except Exception as e:
            print(f"Warning: Failed to download script - {e}")
            if not script_path.exists():
                raise RuntimeError("Could not download training script and no local script found.")
                
        if validation_prompts is None:
            validation_prompts = [
                "A high-quality photo of a dog",
                "A high-quality photo of a cat",
                "A close-up portrait of a pet"
            ]

        # Command to use train_text_to_image_lora.py
        cmd = [
            "accelerate", "launch",
            "train_text_to_image_lora.py",
            f"--pretrained_model_name_or_path={base_model}",
            f"--train_data_dir={dataset_dir}",
            f"--output_dir={output_dir}",
            f"--resolution={resolution}",
            "--center_crop",
            "--random_flip",
            f"--train_batch_size={train_batch_size}",
            "--gradient_accumulation_steps=4",
            "--gradient_checkpointing",
            "--mixed_precision=fp16",
            f"--max_train_steps={max_train_steps}",
            f"--learning_rate={learning_rate}",
            "--lr_scheduler=constant",
            "--lr_warmup_steps=0",
            "--validation_epochs=100",
            f"--validation_prompt=\"{'; '.join(validation_prompts)}\"",
            "--seed=42",
            "--checkpointing_steps=500",
            f"--rank={rank}"
        ]

        print(f"Running command: {' '.join(cmd)}")
        try:
            process = subprocess.run(cmd, capture_output=True, text=True)
            # Log output
            with open(output_dir / "training_log.txt", "w") as f:
                f.write(f"STDOUT:\n{process.stdout}\n\nSTDERR:\n{process.stderr}")
                
            if process.returncode != 0:
                print(f"Warning: Training process exited with code {process.returncode}")
                print(f"Error details: {process.stderr}")
            else:
                print(f"Training completed successfully!")
        except Exception as e:
            print(f"Error during training: {e}")
            raise
            
        return output_dir

    def load_lora_model(
        self,
        base_model_id: Optional[str] = None,
        lora_weights_path: Optional[str] = None,
        torch_dtype = torch.float16
    ):
        """
        Load a model with LoRA weights.
        
        Args:
            base_model_id: ID of the base model
            lora_weights_path: Path to LoRA weights
            torch_dtype: Data type for model loading
            
        Returns:
            self for method chaining
        """
        if base_model_id is None:
            base_model_id = self.default_model
            
        print(f"Loading base model: {base_model_id}")
        
        pipeline = StableDiffusionPipeline.from_pretrained(
            base_model_id,
            torch_dtype=torch_dtype
        )
        
        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
        
        if lora_weights_path:
            print(f"Loading LoRA weights from: {lora_weights_path}")
            pipeline.unet.load_attn_procs(lora_weights_path)
            #pipeline.load_lora_adapter(lora_weights_path)
            print("LoRA weights loaded successfully!")
        
        pipeline.to(self.device)
        pipeline.enable_attention_slicing()
        
        self.current_pipeline = pipeline
        self.current_model_id = base_model_id
        
        return self

    def generate_images(
        self,
        prompts: Union[str, List[str]],
        output_subdir: Optional[str] = None,
        num_images_per_prompt: int = 1,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 30,
        seed: Optional[int] = None,
        width: int = 512,
        height: int = 512,
        save_images: bool = True,
        display_images: bool = True,
        batch_size: int = 1,
        negative_prompt: Optional[str] = None
    ) -> Dict[str, List[Image.Image]]:
        """
        Generate images using the current model pipeline.
        
        Args:
            prompts: Text prompts for image generation
            output_subdir: Subdirectory to save images
            num_images_per_prompt: Number of images per prompt
            guidance_scale: Guidance scale for classifier-free guidance
            num_inference_steps: Number of inference steps
            seed: Random seed for reproducibility
            width: Image width
            height: Image height
            save_images: Whether to save generated images
            display_images: Whether to display generated images
            batch_size: Batch size for generation
            negative_prompt: Negative prompt for generation
            
        Returns:
            Dictionary mapping prompts to generated images
        """
        if self.current_pipeline is None:
            self.load_lora_model()
            
        if isinstance(prompts, str):
            prompts = [prompts]
            
        if save_images:
            if output_subdir:
                output_path = self.output_dir / "fine_tuned" / output_subdir
            else:
                output_path = self.output_dir / "fine_tuned" / "generated"
                
            output_path.mkdir(parents=True, exist_ok=True)
            
        results = {}
        all_images = []
        all_prompts = []
        
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i + batch_size]
            
            generator = None
            if seed is not None:
                generator = torch.Generator(device=self.device).manual_seed(seed)
                seed += 1
                
            print(f"Generating batch {i//batch_size + 1}/{(len(prompts)-1)//batch_size + 1}...")
            
            batch_results = self.current_pipeline(
                batch_prompts,
                num_images_per_prompt=num_images_per_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator,
                #negative_prompt=negative_prompt,
                width=width,
                height=height
            )
            
            batch_images = batch_results.images
            
            for j, prompt in enumerate(batch_prompts):
                start_idx = j * num_images_per_prompt
                end_idx = start_idx + num_images_per_prompt
                prompt_images = batch_images[start_idx:end_idx]
                
                results[prompt] = prompt_images
                all_images.extend(prompt_images)
                all_prompts.extend([prompt] * num_images_per_prompt)
                
                if save_images:
                    for k, img in enumerate(prompt_images):
                        prompt_hash = abs(hash(prompt)) % 10000
                        img_path = output_path / f"gen_{prompt_hash}_{i}_{j}_{k}.png"
                        img.save(img_path)
                        
                        with open(output_path / f"gen_{prompt_hash}_{i}_{j}_{k}.txt", "w") as f:
                            f.write(prompt)
        
        if display_images and all_images:
            n_cols = min(5, len(all_images))
            n_rows = (len(all_images) + n_cols - 1) // n_cols
            
            plt.figure(figsize=(n_cols * 5, n_rows * 6))
            
            for i, (image, prompt) in enumerate(zip(all_images, all_prompts)):
                plt.subplot(n_rows, n_cols, i + 1)
                plt.imshow(np.array(image))
                plt.title(prompt[:50] + "..." if len(prompt) > 50 else prompt, fontsize=8)
                plt.axis("off")
                
            plt.tight_layout()
            plt.show()
            
        return results
    

    def generate_from_variations(
        self,
        variations_file: Union[str, Path],
        output_subdir: Optional[str] = None,
        num_samples: int = 5,
        num_images_per_prompt: int = 1,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 30,
        seed: Optional[int] = None,
        random_variations: bool = True,
        negative_prompt: Optional[str] = None
    ) -> Dict[str, List[Image.Image]]:
        """
        Generate images from text variations in a file.
        
        Args:
            variations_file: JSON file with text variations
            output_subdir: Subdirectory for saving images
            num_samples: Number of variation prompts to use
            num_images_per_prompt: Number of images per prompt
            guidance_scale: Guidance scale for generation
            num_inference_steps: Number of inference steps
            seed: Random seed for reproducibility
            random_variations: Whether to sample variations randomly
            negative_prompt: Negative prompt for generation
            
        Returns:
            Dictionary mapping prompts to generated images
        """
        with open(variations_file, "r") as f:
            variations_data = json.load(f)
            
        all_variations = []
        for img_path, variations in variations_data.items():
            all_variations.extend(variations)
            
        if random_variations:
            selected_variations = random.sample(all_variations, min(num_samples, len(all_variations)))
        else:
            selected_variations = all_variations[:num_samples]
            
        return self.generate_images(
            prompts=selected_variations,
            output_subdir=output_subdir or "variations",
            num_images_per_prompt=num_images_per_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            seed=seed,
            #negative_prompt=negative_prompt
        )

In [None]:
# Initialize Diffusion class

diffusion_manager = DiffusionModelManager(
    output_dir="/content/drive/MyDrive/outputs_master_ProfAI",
    default_model="runwayml/stable-diffusion-v1-5"
)

In [None]:
# Zero-Shot Prompting

models_to_test = [
    "stabilityai/stable-diffusion-2-1-base",
    #"stabilityai/stable-diffusion-xl-base-1.0",
    "runwayml/stable-diffusion-v1-5",
    "kandinsky-community/kandinsky-2-2-decoder",
    "black-forest-labs/FLUX.1-dev"
]

for model in models_to_test:
    diffusion_manager.test_diffusion_model(
        model_id=model,
        prompt="A high-quality photo of a British Shorthair cat",
        num_images=2,
        output_subdir="model_comparison"
    )

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')
with open(caption_file, 'r') as f:
    balanced_captions = json.load(f)

sample_captions = random.sample(list(balanced_captions.values()), 3)
print("Caption selected for this test:")
for i, caption in enumerate(sample_captions):
    print(f"{i+1}. {caption}")

In [None]:
for caption in sample_captions:
    diffusion_manager.test_diffusion_model(models_to_test[0], caption)

In [None]:
captions_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json"
images_dir = "/content/Gen-AI-for-Data-Augmentation/data/oxford-iiit-pet/images"

dataset_dir = diffusion_manager.prepare_dataset(
    captions_file=captions_file,
    images_dir=images_dir,
    max_samples_per_breed=20,
    min_samples_per_breed=10,
    target_total_samples=37*20,
    resolution=512
)

variations_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json"
validation_prompts = diffusion_manager.select_validation_prompts_from_variations(
    variations_file=variations_file,
    num_prompts=10
)

In [None]:
# Run LoRA fine-tuning

lora_model_dir = diffusion_manager.run_lora_training(
    dataset_dir=dataset_dir,
    output_name="pet_breeds_lora",
    base_model="runwayml/stable-diffusion-v1-5",
    resolution=512,
    max_train_steps=1500,
    learning_rate=5e-5,
    validation_prompts=validation_prompts,
    rank=4
)

In [None]:
# Load the fine-tuned model

diffusion_manager.load_lora_model(
    base_model_id="runwayml/stable-diffusion-v1-5",
    lora_weights_path=str(lora_model_dir / "pytorch_lora_weights.safetensors")
)

# Generate images with custom prompts
test_prompts = [
    "A high-quality photograph of a Maine Coon cat with long fur",
    "A detailed image of a Beagle dog running in a park",
    "A professional photo of a Persian cat with blue eyes"
]

generated_images = diffusion_manager.generate_images(
    prompts=test_prompts,
    output_subdir="test_generation",
    num_images_per_prompt=2,
    guidance_scale=7.5,
    num_inference_steps=40,
    negative_prompt="blurry, deformed, distorted, low quality, poor details, watermark"
)

In [None]:
# Generate images from variations

variation_images = diffusion_manager.generate_from_variations(
    variations_file=variations_file,
    output_subdir="variation_generation",
    num_samples=20,
    num_images_per_prompt=1,
    guidance_scale=7.5,
    num_inference_steps=40,
    negative_prompt="blurry, deformed, distorted, low quality, poor details, watermark"
)

## Imgae Generation with Conditional GAN

In [None]:
# Setup device

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Device used: {device}")

In [8]:
# Configurations

batch_size = 32
image_size = 128
num_workers = 4

In [None]:
# Setup components

output_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI")
os.makedirs(output_dir, exist_ok=True)
checkpoint_dir = output_dir / "checkpoints"
log_dir = output_dir / "logs"

metrics = MetricsTracker([
    FIDScore(device=device),
    CLIPScore(device=device)
])

logger = GANLogger("conditional_gan", log_dir=log_dir)

callbacks = [
    EarlyStopping(monitor='fid', patience=5),
    ModelCheckpoint(filepath=checkpoint_dir / "best_model.pt", monitor='fid'),
    MetricsHistory(log_dir=log_dir / "metrics")
]

In [19]:
# Load Dataset

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

In [20]:
# Load Captions
with open('output/captions/captions_traindataset.json', 'r') as f:
    caption_dict = json.load(f)

In [21]:
train_images_paths = [str(img) for img in train_dataset._images]

test_images_paths = [str(img) for img in test_dataset._images]

In [None]:
# Initialize train and val loader

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_ds = PetDatasetWithCaptions(
    image_paths=train_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_size = int(0.9 * len(full_train_ds))
val_size = len(full_train_ds) - train_size
train_ds, val_ds = random_split(full_train_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(42))

test_ds = PetDatasetWithCaptions(
    image_paths=test_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

In [None]:
# Initialize GAN model

config = GANConfig(
    latent_dim = 100,
    caption_dim = 768,
    image_size = image_size,
    num_channels = 3,
    generator_features = 64
)

gan = ConditionalGAN(config)

In [24]:
# Initialize trainer

trainer = GANTrainer(
    gan=gan,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    metrics_tracker=metrics,
    logger=logger,
    callbacks=callbacks
)

In [None]:
# Train

trainer.train(
    num_epochs=100,
    eval_freq=1,
    sample_freq=500,
    sample_dir=Path("samples")
)