# Generative AI project about Data Augmentation
> *This is the notebook for the ninth Profession AI project about Generative AI module*

## Setup & Configuration

In [None]:
!git clone https://github.com/Silvano315/Gen-AI-for-Data-Augmentation.git

In [2]:
import os
os.chdir('/content/Gen-AI-for-Data-Augmentation')

In [None]:
!pwd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Import Libraries

In [None]:
!pip install clean-fid
!pip install git+https://github.com/openai/CLIP.git
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q git+https://github.com/huggingface/transformers

In [6]:
from pathlib import Path
import random
import matplotlib.pyplot as plt
import torch
import json
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForCausalLM
from src.data.dataset import PetDatasetHandler
from src.captioning.caption_generator import CaptionGenerator
from src.captioning.captioning_blip_2 import BLIP2CaptionGenerator
from src.captioning.git_caption_generator import GITCaptionGenerator
from src.data.data_with_captions import PetDatasetWithCaptions
from src.generation.text_generation import TextVariationGenerator
from src.utils.logging import GANLogger
from src.generation.image_generator import GANConfig, ConditionalGAN
from src.training.callbacks import EarlyStopping, ModelCheckpoint, MetricsHistory
from src.evaluation.metrics import FIDScore, CLIPScore, MetricsTracker
from src.training.training import GANTrainer

## Initialize and load dataset without transforms for analysis

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Basic dataset information

In [None]:
info = handler.get_dataset_info()
print("Dataset Information:")
for key, value in info.items():
    print(f"{key}: {value}")

### Plot distributions and samples

In [None]:
handler.plot_class_distribution().show()

In [None]:
handler.visualize_samples(9).show()

### Get detailed image statistics


In [None]:
stats = handler.get_image_stats(sample_size=100)
print("\nImage Statistics:")
for category, values in stats.items():
    print(f"\n{category.upper()}:")
    for key, value in values.items():
        print(f"{key}: {value:.2f}")

### For training, load with transforms

In [None]:
train_transforms = handler.get_training_transforms()
train_dataset, test_dataset = handler.load_dataset(transform=train_transforms)

## Image Captioning

### Compare caption generators: 
1. **Blip**
2. **Blip-2**
3. **GIT**

In [None]:
# Configurations

def get_random_images(image_dir, count=5):
    """Randomly select images from the dataset."""
    image_paths = list(Path(image_dir).glob("*.jpg"))
    return random.sample(image_paths, min(count, len(image_paths)))

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

image_dir = "data/oxford-iiit-pet/images"
sample_images = get_random_images(image_dir, count=10)
print(f"Selected {len(sample_images)} random images")

In [None]:
# Test BLIP original model

print("Test di BLIP...")
blip_model = CaptionGenerator()
blip_captions = {}

for img_path in sample_images:
    caption = blip_model.generate_caption(str(img_path))
    blip_captions[str(img_path)] = caption
    print(f"BLIP - {img_path.name}: {caption}")

In [None]:
# Test BLIP-2 model 
# Be Carefull !! Blip-2 is high consuming and high memory requiring, run this cell if you have high computational resources.

print("\nTest di BLIP-2...")
blip2_model = BLIP2CaptionGenerator(model_name="Salesforce/blip2-opt-2.7b")
blip2_captions = {}

for img_path in sample_images:
    caption = blip2_model.generate_caption(str(img_path))
    blip2_captions[str(img_path)] = caption
    print(f"BLIP-2 - {img_path.name}: {caption}")

In [None]:
# Test GIT model

print("\nTest di GIT...")
git_model_name = "microsoft/git-base-coco"
processor_git = AutoProcessor.from_pretrained(git_model_name)
model_git = AutoModelForCausalLM.from_pretrained(git_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_git = model_git.to(device)

git_captions = {}

for img_path in sample_images:
    image = Image.open(img_path).convert("RGB")
    inputs_git = processor_git(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model_git.generate(
            pixel_values=inputs_git.pixel_values,
            max_length=50,
            num_beams=5
        )
    
    caption = processor_git.batch_decode(generated_ids, skip_special_tokens=True)[0]
    git_captions[str(img_path)] = caption
    print(f"GIT - {img_path.name}: {caption}")

In [None]:
# Visualize images with captions from three models (side by side)

from textwrap import wrap

def visualize_comparison(image_paths, blip_captions, blip2_captions, git_captions, wrap_width=30):
    """Visualizza il confronto tra le caption generate dai diversi modelli."""
    n_images = len(image_paths)
    
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 5 * n_images))
    
    if n_images > 0:
        axes[0, 0].set_title("BLIP", fontsize=14)
        axes[0, 1].set_title("BLIP-2", fontsize=14)
        axes[0, 2].set_title("GIT", fontsize=14)
    
    for idx, img_path in enumerate(image_paths):
        img_path_str = str(img_path)
        img = Image.open(img_path).convert("RGB")
        
        # BLIP
        axes[idx, 0].imshow(img)
        axes[idx, 0].axis('off')
        wrapped_caption = "\n".join(wrap(blip_captions[img_path_str], wrap_width))
        axes[idx][0].set_xlabel(wrapped_caption, fontsize = 12)

        # BLIP-2
        axes[idx][1].imshow(img)
        axes[idx][1].axis('off')
        wrapped_caption = "\n".join(wrap(blip2_captions[img_path_str], wrap_width))
        axes[idx][1].set_xlabel(wrapped_caption, fontsize=12)
        
        # GIT
        axes[idx][2].imshow(img)
        axes[idx][2].axis('off')
        wrapped_caption = "\n".join(wrap(git_captions[img_path_str], wrap_width))
        axes[idx][2].set_xlabel(wrapped_caption, fontsize=12)

    plt.tight_layout()
    plt.show()


visualize_comparison(sample_images, blip_captions, blip2_captions, git_captions)

### Initialize caption generator

In [None]:
caption_gen = GITCaptionGenerator()

### Load Dataset (If you haven’t done it before)

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Test single image caption generation

In [None]:
sample = random.randint(0, len(train_dataset)-1)
sample_image_path = Path(train_dataset._images[sample])
label = train_dataset.classes[train_dataset[sample][1]]
caption = caption_gen.generate_caption(sample_image_path, label, max_length = 50)
print(f"Sample caption: {caption}")

In [None]:
plt.figure(figsize=(10, 10))
plt.title(f"{label}")
fig = plt.imshow(train_dataset[sample][0])

### Process a batch of images


In [None]:
batch_size = 4
image_paths = [Path(img) for img in train_dataset._images[:10]]
labels = [train_dataset.classes[train_dataset[i][1]] for i in range(10)]
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

### Process train and test datasets

In [None]:
batch_size = 4
image_paths = [Path(img) for img in test_dataset._images]
labels = [test_dataset.classes[test_dataset[i][1]] for i in range(len(image_paths))]

In [None]:
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

caption_gen.save_captions(save_dir / 'captions_testdataset.json')

### Save & Load captions


In [None]:
save_dir = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions')
save_dir.mkdir(parents=True, exist_ok=True)
caption_gen.save_captions(save_dir / 'captions_testdataset.json')

In [None]:
caption_gen.load_captions(save_dir / 'captions.json')

### Visualize results

In [None]:
caption_gen.visualize_captions(num_samples=4)

### Print some statistics

In [None]:
print(f"\nTotal captions generated: {len(caption_gen.captions_cache)}")
print("\nSample of generated captions:")
for path, caption in list(caption_gen.captions_cache.items())[:3]:
    print(f"\nImage: {Path(path).name}")
    print(f"Caption: {caption}")

## Text Generations with Flan-T5 to increase numbers of captions

#### Initialize Text Generator

In [None]:
generator = TextVariationGenerator()

#### How it works
> Comparison of different types of prompts

In [None]:
# Standard generation

original_caption = "A gray cat sitting on a window sill - This is a British Shorthair."

standard_variations = generator.generate_variations(
    original_caption,
    num_variations=3,
    prompt_type="standard",
    temperature=0.8
)
print("Standard Prompting:", standard_variations)

In [None]:
# Generation with a specific prompt

specific_variations = generator.generate_variations(
    original_caption,
    num_variations=3,
    prompt_type="specific",
    temperature=0.8
)
print("Specific Prompting:", specific_variations)

In [None]:
# Generation with few-shot prompting

fewshot_variations = generator.generate_variations(
    original_caption,
    num_variations=3,
    prompt_type="few-shot",
    temperature=0.8
)
print("Few-Shot Prompting:", fewshot_variations)

## Imgae Generation with Conditional GAN

In [None]:
# Setup device

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Device used: {device}")

In [8]:
# Configurations

batch_size = 32
image_size = 128
num_workers = 4

In [None]:
# Setup components

output_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI")
os.makedirs(output_dir, exist_ok=True)
checkpoint_dir = output_dir / "checkpoints"
log_dir = output_dir / "logs"

metrics = MetricsTracker([
    FIDScore(device=device),
    CLIPScore(device=device)
])

logger = GANLogger("conditional_gan", log_dir=log_dir)

callbacks = [
    EarlyStopping(monitor='fid', patience=5),
    ModelCheckpoint(filepath=checkpoint_dir / "best_model.pt", monitor='fid'),
    MetricsHistory(log_dir=log_dir / "metrics")
]

In [19]:
# Load Dataset

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

In [20]:
# Load Captions
with open('output/captions/captions_traindataset.json', 'r') as f:
    caption_dict = json.load(f)

In [21]:
train_images_paths = [str(img) for img in train_dataset._images]

test_images_paths = [str(img) for img in test_dataset._images]

In [None]:
# Initialize train and val loader

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_ds = PetDatasetWithCaptions(
    image_paths=train_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_size = int(0.9 * len(full_train_ds))
val_size = len(full_train_ds) - train_size
train_ds, val_ds = random_split(full_train_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(42))

test_ds = PetDatasetWithCaptions(
    image_paths=test_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

In [None]:
# Initialize GAN model

config = GANConfig(
    latent_dim = 100,
    caption_dim = 768,
    image_size = image_size,
    num_channels = 3,
    generator_features = 64
)

gan = ConditionalGAN(config)

In [24]:
# Initialize trainer

trainer = GANTrainer(
    gan=gan,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    metrics_tracker=metrics,
    logger=logger,
    callbacks=callbacks
)

In [None]:
# Train

trainer.train(
    num_epochs=100,
    eval_freq=1,
    sample_freq=500,
    sample_dir=Path("samples")
)