# Generative AI project about Data Augmentation
> *This is the notebook for the ninth Profession AI project about Generative AI module*

## Setup & Configuration

In [None]:
!git clone https://github.com/Silvano315/Gen-AI-for-Data-Augmentation.git

In [2]:
import os
os.chdir('/content/Gen-AI-for-Data-Augmentation')

In [None]:
!pwd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import Libraries

In [None]:
!pip install clean-fid
!pip install git+https://github.com/openai/CLIP.git
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q git+https://github.com/huggingface/transformers

In [6]:
from pathlib import Path
import random
import matplotlib.pyplot as plt
import torch
import json
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForCausalLM
from src.data.dataset import PetDatasetHandler
from src.captioning.caption_generator import CaptionGenerator
from src.captioning.captioning_blip_2 import BLIP2CaptionGenerator
from src.captioning.git_caption_generator import GITCaptionGenerator
from src.data.data_with_captions import PetDatasetWithCaptions
from src.generation.text_generation import TextVariationGenerator
from src.utils.logging import GANLogger
from src.generation.image_generator import GANConfig, ConditionalGAN
from src.training.callbacks import EarlyStopping, ModelCheckpoint, MetricsHistory
from src.evaluation.metrics import FIDScore, CLIPScore, MetricsTracker
from src.training.training import GANTrainer

## Initialize and load dataset without transforms for analysis

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Basic dataset information

In [None]:
info = handler.get_dataset_info()
print("Dataset Information:")
for key, value in info.items():
    print(f"{key}: {value}")

### Plot distributions and samples

In [None]:
handler.plot_class_distribution().show()

In [None]:
handler.visualize_samples(9).show()

### Get detailed image statistics


In [None]:
stats = handler.get_image_stats(sample_size=100)
print("\nImage Statistics:")
for category, values in stats.items():
    print(f"\n{category.upper()}:")
    for key, value in values.items():
        print(f"{key}: {value:.2f}")

### For training, load with transforms

In [None]:
train_transforms = handler.get_training_transforms()
train_dataset, test_dataset = handler.load_dataset(transform=train_transforms)

## Image Captioning

### Compare caption generators: 
1. **Blip**
2. **Blip-2**
3. **GIT**

In [None]:
# Configurations

def get_random_images(image_dir, count=5):
    """Randomly select images from the dataset."""
    image_paths = list(Path(image_dir).glob("*.jpg"))
    return random.sample(image_paths, min(count, len(image_paths)))

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

image_dir = "data/oxford-iiit-pet/images"
sample_images = get_random_images(image_dir, count=10)
print(f"Selected {len(sample_images)} random images")

In [None]:
# Test BLIP original model

print("Test di BLIP...")
blip_model = CaptionGenerator()
blip_captions = {}

for img_path in sample_images:
    caption = blip_model.generate_caption(str(img_path))
    blip_captions[str(img_path)] = caption
    print(f"BLIP - {img_path.name}: {caption}")

In [None]:
# Test BLIP-2 model 
# Be Carefull !! Blip-2 is high consuming and high memory requiring, run this cell if you have high computational resources.

print("\nTest di BLIP-2...")
blip2_model = BLIP2CaptionGenerator(model_name="Salesforce/blip2-opt-2.7b")
blip2_captions = {}

for img_path in sample_images:
    caption = blip2_model.generate_caption(str(img_path))
    blip2_captions[str(img_path)] = caption
    print(f"BLIP-2 - {img_path.name}: {caption}")

In [None]:
# Test GIT model

print("\nTest di GIT...")
git_model_name = "microsoft/git-base-coco"
processor_git = AutoProcessor.from_pretrained(git_model_name)
model_git = AutoModelForCausalLM.from_pretrained(git_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_git = model_git.to(device)

git_captions = {}

for img_path in sample_images:
    image = Image.open(img_path).convert("RGB")
    inputs_git = processor_git(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model_git.generate(
            pixel_values=inputs_git.pixel_values,
            max_length=50,
            num_beams=5
        )
    
    caption = processor_git.batch_decode(generated_ids, skip_special_tokens=True)[0]
    git_captions[str(img_path)] = caption
    print(f"GIT - {img_path.name}: {caption}")

In [None]:
# Visualize images with captions from three models (side by side)

from textwrap import wrap

def visualize_comparison(image_paths, blip_captions, blip2_captions, git_captions, wrap_width=30):
    """Visualizza il confronto tra le caption generate dai diversi modelli."""
    n_images = len(image_paths)
    
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 5 * n_images))
    
    if n_images > 0:
        axes[0, 0].set_title("BLIP", fontsize=14)
        axes[0, 1].set_title("BLIP-2", fontsize=14)
        axes[0, 2].set_title("GIT", fontsize=14)
    
    for idx, img_path in enumerate(image_paths):
        img_path_str = str(img_path)
        img = Image.open(img_path).convert("RGB")
        
        # BLIP
        axes[idx, 0].imshow(img)
        axes[idx, 0].axis('off')
        wrapped_caption = "\n".join(wrap(blip_captions[img_path_str], wrap_width))
        axes[idx][0].set_xlabel(wrapped_caption, fontsize = 12)

        # BLIP-2
        axes[idx][1].imshow(img)
        axes[idx][1].axis('off')
        wrapped_caption = "\n".join(wrap(blip2_captions[img_path_str], wrap_width))
        axes[idx][1].set_xlabel(wrapped_caption, fontsize=12)
        
        # GIT
        axes[idx][2].imshow(img)
        axes[idx][2].axis('off')
        wrapped_caption = "\n".join(wrap(git_captions[img_path_str], wrap_width))
        axes[idx][2].set_xlabel(wrapped_caption, fontsize=12)

    plt.tight_layout()
    plt.show()


visualize_comparison(sample_images, blip_captions, blip2_captions, git_captions)

### Initialize caption generator

In [None]:
caption_gen = GITCaptionGenerator()

### Load Dataset (If you haven’t done it before)

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Test single image caption generation

In [None]:
sample = random.randint(0, len(train_dataset)-1)
sample_image_path = Path(train_dataset._images[sample])
label = train_dataset.classes[train_dataset[sample][1]]
caption = caption_gen.generate_caption(sample_image_path, label, max_length = 50)
print(f"Sample caption: {caption}")

In [None]:
plt.figure(figsize=(10, 10))
plt.title(f"{label}")
fig = plt.imshow(train_dataset[sample][0])

### Process a batch of images


In [None]:
batch_size = 4
image_paths = [Path(img) for img in train_dataset._images[:10]]
labels = [train_dataset.classes[train_dataset[i][1]] for i in range(10)]
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

### Process train and test datasets

In [None]:
batch_size = 4
image_paths = [Path(img) for img in test_dataset._images]
labels = [test_dataset.classes[test_dataset[i][1]] for i in range(len(image_paths))]

In [None]:
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

caption_gen.save_captions(save_dir / 'captions_testdataset.json')

### Save & Load captions


In [None]:
save_dir = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions')
save_dir.mkdir(parents=True, exist_ok=True)
caption_gen.save_captions(save_dir / 'captions_testdataset.json')

In [None]:
caption_gen.load_captions(save_dir / 'captions.json')

### Visualize results

In [None]:
caption_gen.visualize_captions(num_samples=4)

### Print some statistics

In [None]:
print(f"\nTotal captions generated: {len(caption_gen.captions_cache)}")
print("\nSample of generated captions:")
for path, caption in list(caption_gen.captions_cache.items())[:3]:
    print(f"\nImage: {Path(path).name}")
    print(f"Caption: {caption}")

## Text Generations with Flan-T5 to increase numbers of captions

#### Initialize Text Generator

In [None]:
generator = TextVariationGenerator(model_name="google/flan-t5-large")

#### How it works
> Comparison of different types of prompts

In [None]:
results = generator.test_prompt_types("A white dog sitting on a brown chair", temperature=0.95, num_variations=5)

#### Test Quality for Few-Shot Prompting

In [None]:
original_caption = "A gray cat sitting on a window sill - This is a British Shorthair."

variations = generator.test_few_shot_quality(
    original_caption,
    num_variations=3,
    temperature=1
)

#### Load Captions saved and Datasets

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json')

with open(caption_file, 'r') as f:
    captions = json.load(f)

In [None]:
# Load Data

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

#### Test the variations on random captions

In [None]:
sample_captions = dict(random.sample(list(captions.items()), 3))

for img_path, caption in sample_captions.items():
    print(f"\Image: {Path(img_path).name}")
    print(f"Original Caption: {caption}")
    
    variations = generator.generate_variations(
        caption,
        num_variations=3,
        temperature=1,
        prompt_type="few-shot"
    )
    
    print("New Versions:")
    for i, var in enumerate(variations):
        print(f"{i+1}. {var}")

In [None]:
# Display images with caption variations

def visualize_caption_variations(image_path, caption, variations):
    """View an image with the original caption and variations."""
    plt.figure(figsize=(10, 12))
    
    img = Image.open(image_path).convert("RGB")
    plt.subplot(1, 1, 1)
    plt.imshow(img)
    plt.axis('off')
    
    title = f"Original: {caption}\n\nVariations:\n"
    for i, var in enumerate(variations):
        title += f"{i+1}. {var}\n"
    
    plt.title(title, fontsize=10, loc='left')
    plt.tight_layout()
    plt.show()

sample_img_path = list(sample_captions.keys())[0]
visualize_caption_variations(
    sample_img_path,
    sample_captions[sample_img_path],
    variations
)

#### Process entire caption file

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

print("Generation of variations for all captions...")
variations_dict = generator.process_caption_file(
    caption_file=caption_file,
    output_file=output_file,
    variations_per_caption=3,
    class_balancing=True,
    target_per_class=150,
    min_variations=1,
    max_variations=5,
    prompt_type="few-shot",
    temperature=1
)

print(f"Generate variations for {len(variations_dict)} caption")
print(f"File saved in: {output_file}")

#### Evaluate captions generated

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

with open(output_file, 'r') as f:
    variations_data = json.load(f)

variations_counts = {img_path: len(vars_list) for img_path, vars_list in variations_data.items()}

# Statistics
avg_variations = sum(variations_counts.values()) / len(variations_counts)
max_variations = max(variations_counts.values())
min_variations = min(variations_counts.values())
total_variations = sum(variations_counts.values())

print(f"Statistics of variations:")
print(f"Total original captions: {len(variations_counts)}")
print(f"Total generated variations: {total_variations}")
print(f"Mean variation per caption: {avg_variations:.2f}")
print(f"Max variations per caption: {max_variations}")
print(f"Min variations per caption: {min_variations}")

In [None]:
# Extract the breed from the captions and calculate statistics by breed

def extract_breed(caption):
    import re
    match = re.search(r"This is an? ([^\.]+)\.", caption)
    return match.group(1) if match else None

breed_counts_original = {}
for caption in captions.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_original[breed] = breed_counts_original.get(breed, 0) + 1

breed_counts_variations = breed_counts_original.copy()
for var_list in variations_data.values():
    for var in var_list:
        breed = extract_breed(var)
        if breed:
            breed_counts_variations[breed] = breed_counts_variations.get(breed, 0) + 1

breeds = sorted(breed_counts_original.keys())
original_counts = [breed_counts_original[breed] for breed in breeds]
total_counts = [breed_counts_variations[breed] for breed in breeds]

plt.figure(figsize=(15, 10))
x = range(len(breeds))
plt.bar(x, original_counts, width=0.4, align='edge', label='Original')
plt.bar([i+0.4 for i in x], total_counts, width=0.4, align='edge', label='After Variations')
plt.xticks([i+0.2 for i in x], breeds, rotation=90)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Dataset Augmentation by Breed')
plt.legend()
plt.tight_layout()
plt.show()

#### Balanced Subset selection

In [None]:
balanced_output = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/balanced_flan_t5_variations.json')
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

balanced_subset = generator.select_balanced_subset(
    caption_file=caption_file,
    variations_file=output_file,
    output_file=balanced_output,
    target_per_class=150
)

print(f"Created balanced subset with {len(balanced_subset)} captions")
print(f"File saved in: {balanced_output}")

In [None]:
breed_counts_balanced = {}
for caption in balanced_subset.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_balanced[breed] = breed_counts_balanced.get(breed, 0) + 1

breeds = sorted(breed_counts_balanced.keys())
balanced_counts = [breed_counts_balanced.get(breed, 0) for breed in breeds]

plt.figure(figsize=(15, 10))
plt.bar(breeds, balanced_counts)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Balanced Subset by Breed')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

## Image Generation Diffusion Models

### Zero-Shot Testing

In [None]:
# Import Libraries

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, DiffusionPipeline
from PIL import Image
import matplotlib.pyplot as plt
import json
import random
from pathlib import Path
import numpy as np

In [None]:
# Function to load and test different models
def test_diffusion_model(model_id, caption, num_images=1, seed=None):
    """
    Test a diffusion model with a specific caption.
    
    Args:
        model_id (str): Model ID to be tested
        caption (str): Caption to use
        num_images (int): Numbers of image to generate
        seed (int, optional): Seed for the generation
        
    Returns:
        list: List of generated images
    """
    print(f"Loading the model: {model_id}")

    if "xl" in model_id.lower():
        pipe = StableDiffusionXLPipeline.from_pretrained(
            model_id, 
            torch_dtype=torch.float16, 
            use_safetensors=True, 
            variant="fp16"
        )
    elif "kandinsky" in model_id.lower():
        pipe = DiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
    
    pipe = pipe.to("cuda")

    # Save memory
    pipe.enable_attention_slicing()
    if hasattr(pipe, 'enable_vae_slicing'):
        pipe.enable_vae_slicing()

    breed_info = ""
    if " - This is a " in caption:
        parts = caption.split(" - This is a ")
        cleaned_caption = parts[0]
        breed_info = parts[1].strip(".")
        prompt = f"A high-quality photo of a {breed_info}, {cleaned_caption}"
    else:
        prompt = f"A high-quality photo of {caption}"

    print(f"Prompt: {prompt}")

    # Generazione delle immagini
    images = []
    for i in range(num_images):
        # Set seed for reproducibility if provided
        generator = None
        if seed is not None:
            generator = torch.Generator(device = "cuda").manual_seed(seed + i)

        if "kandinsky" in model_id.lower():
            image = pipe(prompt, guidance_scale=7.5).images[0]
        else:
            image = pipe(
                prompt, 
                guidance_scale=7.5,
                num_inference_steps=30,
                generator=generator
            ).images[0]
        
        images.append(image)

    fig, axes = plt.subplots(1, num_images, figsize=(5*num_images, 5))
    if num_images == 1:
        axes = [axes]
    
    for i, img in enumerate(images):
        axes[i].imshow(np.array(img))
        axes[i].set_title(f"Image {i+1}")
        axes[i].axis("off")
    
    plt.tight_layout()
    plt.show()
    
    return images

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/balanced_flan_t5_variations.json')
with open(caption_file, 'r') as f:
    balanced_captions = json.load(f)

sample_captions = random.sample(list(balanced_captions.values()), 3)
print("Caption selected for this test:")
for i, caption in enumerate(sample_captions):
    print(f"{i+1}. {caption}")

In [None]:
models_to_test = [
    "stabilityai/stable-diffusion-2-1-base",
    #"stabilityai/stable-diffusion-xl-base-1.0",
    "kandinsky-community/kandinsky-2-2-decoder"
]

for caption in sample_captions:
    test_diffusion_model(models_to_test[0], caption)

## Imgae Generation with Conditional GAN

In [None]:
# Setup device

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Device used: {device}")

In [8]:
# Configurations

batch_size = 32
image_size = 128
num_workers = 4

In [None]:
# Setup components

output_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI")
os.makedirs(output_dir, exist_ok=True)
checkpoint_dir = output_dir / "checkpoints"
log_dir = output_dir / "logs"

metrics = MetricsTracker([
    FIDScore(device=device),
    CLIPScore(device=device)
])

logger = GANLogger("conditional_gan", log_dir=log_dir)

callbacks = [
    EarlyStopping(monitor='fid', patience=5),
    ModelCheckpoint(filepath=checkpoint_dir / "best_model.pt", monitor='fid'),
    MetricsHistory(log_dir=log_dir / "metrics")
]

In [19]:
# Load Dataset

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

In [20]:
# Load Captions
with open('output/captions/captions_traindataset.json', 'r') as f:
    caption_dict = json.load(f)

In [21]:
train_images_paths = [str(img) for img in train_dataset._images]

test_images_paths = [str(img) for img in test_dataset._images]

In [None]:
# Initialize train and val loader

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_ds = PetDatasetWithCaptions(
    image_paths=train_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_size = int(0.9 * len(full_train_ds))
val_size = len(full_train_ds) - train_size
train_ds, val_ds = random_split(full_train_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(42))

test_ds = PetDatasetWithCaptions(
    image_paths=test_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

In [None]:
# Initialize GAN model

config = GANConfig(
    latent_dim = 100,
    caption_dim = 768,
    image_size = image_size,
    num_channels = 3,
    generator_features = 64
)

gan = ConditionalGAN(config)

In [24]:
# Initialize trainer

trainer = GANTrainer(
    gan=gan,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    metrics_tracker=metrics,
    logger=logger,
    callbacks=callbacks
)

In [None]:
# Train

trainer.train(
    num_epochs=100,
    eval_freq=1,
    sample_freq=500,
    sample_dir=Path("samples")
)