# Generative AI project about Data Augmentation
> *This is the notebook for the ninth Profession AI project about Generative AI module*

## Setup & Configuration

In [None]:
!git clone https://github.com/Silvano315/Gen-AI-for-Data-Augmentation.git

In [2]:
import os
os.chdir('/content/Gen-AI-for-Data-Augmentation')

In [None]:
!pwd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Check GPU availability

!nvidia-smi

In [None]:
# Set up device

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Import Libraries

In [None]:
!pip install clean-fid
!pip install git+https://github.com/openai/CLIP.git
!pip install -q transformers datasets accelerate sentencepiece
!pip install -q git+https://github.com/huggingface/transformers
!pip install torchmetrics

In [None]:
import random
import matplotlib.pyplot as plt
import json
import os
import numpy as np
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from transformers import AutoProcessor, AutoModelForCausalLM
from src.data.dataset import PetDatasetHandler
from src.captioning.caption_generator import CaptionGenerator
from src.captioning.captioning_blip_2 import BLIP2CaptionGenerator
from src.captioning.git_caption_generator import GITCaptionGenerator
from src.data.data_with_captions import PetDatasetWithCaptions
from src.generation.text_generation import TextVariationGenerator
from src.utils.logging import GANLogger
from src.generation.image_generator import GANConfig, ConditionalGAN
from src.training.callbacks import EarlyStopping, ModelCheckpoint, MetricsHistory
from src.evaluation.metrics import FIDScore, CLIPScore, MetricsTracker
from src.training.training import GANTrainer

## Initialize and load dataset without transforms for analysis

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Basic dataset information

In [None]:
info = handler.get_dataset_info()
print("Dataset Information:")
for key, value in info.items():
    print(f"{key}: {value}")

### Plot distributions and samples

In [None]:
handler.plot_class_distribution().show()

In [None]:
handler.visualize_samples(9).show()

### Get detailed image statistics


In [None]:
stats = handler.get_image_stats(sample_size=100)
print("\nImage Statistics:")
for category, values in stats.items():
    print(f"\n{category.upper()}:")
    for key, value in values.items():
        print(f"{key}: {value:.2f}")

### For training, load with transforms

In [None]:
train_transforms = handler.get_training_transforms()
train_dataset, test_dataset = handler.load_dataset(transform=train_transforms)

## Image Captioning

### Compare caption generators: 
1. **Blip**
2. **Blip-2**
3. **GIT**

In [None]:
# Configurations

def get_random_images(image_dir, count=5):
    """Randomly select images from the dataset."""
    image_paths = list(Path(image_dir).glob("*.jpg"))
    return random.sample(image_paths, min(count, len(image_paths)))

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

image_dir = "data/oxford-iiit-pet/images"
sample_images = get_random_images(image_dir, count=10)
print(f"Selected {len(sample_images)} random images")

In [None]:
# Test BLIP original model

print("Test di BLIP...")
blip_model = CaptionGenerator()
blip_captions = {}

for img_path in sample_images:
    caption = blip_model.generate_caption(str(img_path))
    blip_captions[str(img_path)] = caption
    print(f"BLIP - {img_path.name}: {caption}")

In [None]:
# Test BLIP-2 model 
# Be Carefull !! Blip-2 is high consuming and high memory requiring, run this cell if you have high computational resources.

print("\nTest di BLIP-2...")
blip2_model = BLIP2CaptionGenerator(model_name="Salesforce/blip2-opt-2.7b")
blip2_captions = {}

for img_path in sample_images:
    caption = blip2_model.generate_caption(str(img_path))
    blip2_captions[str(img_path)] = caption
    print(f"BLIP-2 - {img_path.name}: {caption}")

In [None]:
# Test GIT model

print("\nTest di GIT...")
git_model_name = "microsoft/git-base-coco"
processor_git = AutoProcessor.from_pretrained(git_model_name)
model_git = AutoModelForCausalLM.from_pretrained(git_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_git = model_git.to(device)

git_captions = {}

for img_path in sample_images:
    image = Image.open(img_path).convert("RGB")
    inputs_git = processor_git(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        generated_ids = model_git.generate(
            pixel_values=inputs_git.pixel_values,
            max_length=50,
            num_beams=5
        )
    
    caption = processor_git.batch_decode(generated_ids, skip_special_tokens=True)[0]
    git_captions[str(img_path)] = caption
    print(f"GIT - {img_path.name}: {caption}")

In [None]:
# Visualize images with captions from three models (side by side)

from textwrap import wrap

def visualize_comparison(image_paths, blip_captions, blip2_captions, git_captions, wrap_width=30):
    """Visualizza il confronto tra le caption generate dai diversi modelli."""
    n_images = len(image_paths)
    
    fig, axes = plt.subplots(n_images, 3, figsize=(15, 5 * n_images))
    
    if n_images > 0:
        axes[0, 0].set_title("BLIP", fontsize=14)
        axes[0, 1].set_title("BLIP-2", fontsize=14)
        axes[0, 2].set_title("GIT", fontsize=14)
    
    for idx, img_path in enumerate(image_paths):
        img_path_str = str(img_path)
        img = Image.open(img_path).convert("RGB")
        
        # BLIP
        axes[idx, 0].imshow(img)
        axes[idx, 0].axis('off')
        wrapped_caption = "\n".join(wrap(blip_captions[img_path_str], wrap_width))
        axes[idx][0].set_xlabel(wrapped_caption, fontsize = 12)

        # BLIP-2
        axes[idx][1].imshow(img)
        axes[idx][1].axis('off')
        wrapped_caption = "\n".join(wrap(blip2_captions[img_path_str], wrap_width))
        axes[idx][1].set_xlabel(wrapped_caption, fontsize=12)
        
        # GIT
        axes[idx][2].imshow(img)
        axes[idx][2].axis('off')
        wrapped_caption = "\n".join(wrap(git_captions[img_path_str], wrap_width))
        axes[idx][2].set_xlabel(wrapped_caption, fontsize=12)

    plt.tight_layout()
    plt.show()


visualize_comparison(sample_images, blip_captions, blip2_captions, git_captions)

### Initialize caption generator

In [None]:
caption_gen = GITCaptionGenerator()

### Load Dataset (If you haven’t done it before)

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

### Test single image caption generation

In [None]:
sample = random.randint(0, len(train_dataset)-1)
sample_image_path = Path(train_dataset._images[sample])
label = train_dataset.classes[train_dataset[sample][1]]
caption = caption_gen.generate_caption(sample_image_path, label, max_length = 50)
print(f"Sample caption: {caption}")

In [None]:
plt.figure(figsize=(10, 10))
plt.title(f"{label}")
fig = plt.imshow(train_dataset[sample][0])

### Process a batch of images


In [None]:
batch_size = 4
image_paths = [Path(img) for img in train_dataset._images[:10]]
labels = [train_dataset.classes[train_dataset[i][1]] for i in range(10)]
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)

### Process train and test datasets

In [None]:
batch_size = 4
image_paths = [Path(img) for img in train_dataset._images]
labels = [train_dataset.classes[train_dataset[i][1]] for i in range(len(image_paths))]

In [None]:
captions = caption_gen.process_batch(image_paths, labels, batch_size=batch_size)
save_dir = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions')
caption_gen.save_captions(save_dir / 'captions_traindataset.json')

### Save & Load captions


In [None]:
save_dir = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions')
save_dir.mkdir(parents=True, exist_ok=True)
caption_gen.save_captions(save_dir / 'captions_traindataset.json')

In [None]:
caption_gen.load_captions(save_dir / 'captions.json')

### Visualize results

In [None]:
caption_gen.visualize_captions(num_samples=4)

### Print some statistics

In [None]:
print(f"\nTotal captions generated: {len(caption_gen.captions_cache)}")
print("\nSample of generated captions:")
for path, caption in list(caption_gen.captions_cache.items())[:3]:
    print(f"\nImage: {Path(path).name}")
    print(f"Caption: {caption}")

## Text Generations with Flan-T5 to increase numbers of captions

#### Initialize Text Generator

In [None]:
generator = TextVariationGenerator(model_name="google/flan-t5-large")

#### How it works
> Comparison of different types of prompts

In [None]:
results = generator.test_prompt_types("A white dog sitting on a brown chair", temperature=0.95, num_variations=5)

#### Test Quality for Few-Shot Prompting

In [None]:
original_caption = "c"

variations = generator.test_few_shot_quality(
    original_caption,
    num_variations=3,
    temperature=1
)

#### Load Captions saved and Datasets

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json')

with open(caption_file, 'r') as f:
    captions = json.load(f)

In [None]:
# Load Data

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

#### Test the variations on random captions

In [None]:
sample_captions = dict(random.sample(list(captions.items()), 3))

for img_path, caption in sample_captions.items():
    print(f"\Image: {Path(img_path).name}")
    print(f"Original Caption: {caption}")
    
    variations = generator.generate_variations(
        caption,
        num_variations=3,
        temperature=1,
        prompt_type="few-shot"
    )
    
    print("New Versions:")
    for i, var in enumerate(variations):
        print(f"{i+1}. {var}")

In [None]:
# Display images with caption variations

def visualize_caption_variations(image_path, caption, variations):
    """View an image with the original caption and variations."""
    plt.figure(figsize=(10, 12))
    
    img = Image.open(image_path).convert("RGB")
    plt.subplot(1, 1, 1)
    plt.imshow(img)
    plt.axis('off')
    
    title = f"Original: {caption}\n\nVariations:\n"
    for i, var in enumerate(variations):
        title += f"{i+1}. {var}\n"
    
    plt.title(title, fontsize=10, loc='left')
    plt.tight_layout()
    plt.show()

sample_img_path = list(sample_captions.keys())[0]
visualize_caption_variations(
    sample_img_path,
    sample_captions[sample_img_path],
    variations
)

#### Process entire caption file

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

print("Generation of variations for all captions...")
variations_dict = generator.process_caption_file(
    caption_file=caption_file,
    output_file=output_file,
    variations_per_caption=3,
    class_balancing=True,
    target_per_class=150,
    min_variations=1,
    max_variations=5,
    prompt_type="few-shot",
    temperature=1
)

print(f"Generate variations for {len(variations_dict)} caption")
print(f"File saved in: {output_file}")

#### Evaluate captions generated

In [None]:
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')

with open(output_file, 'r') as f:
    variations_data = json.load(f)

variations_counts = {img_path: len(vars_list) for img_path, vars_list in variations_data.items()}

# Statistics
avg_variations = sum(variations_counts.values()) / len(variations_counts)
max_variations = max(variations_counts.values())
min_variations = min(variations_counts.values())
total_variations = sum(variations_counts.values())

print(f"Statistics of variations:")
print(f"Total original captions: {len(variations_counts)}")
print(f"Total generated variations: {total_variations}")
print(f"Mean variation per caption: {avg_variations:.2f}")
print(f"Max variations per caption: {max_variations}")
print(f"Min variations per caption: {min_variations}")

In [None]:
# Extract the breed from the captions and calculate statistics by breed

def extract_breed(caption):
    import re
    match = re.search(r"This is an? ([^\.]+)\.", caption)
    return match.group(1) if match else None

breed_counts_original = {}
for caption in captions.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_original[breed] = breed_counts_original.get(breed, 0) + 1

breed_counts_variations = breed_counts_original.copy()
for var_list in variations_data.values():
    for var in var_list:
        breed = extract_breed(var)
        if breed:
            breed_counts_variations[breed] = breed_counts_variations.get(breed, 0) + 1

breeds = sorted(breed_counts_original.keys())
original_counts = [breed_counts_original[breed] for breed in breeds]
total_counts = [breed_counts_variations[breed] for breed in breeds]

plt.figure(figsize=(15, 10))
x = range(len(breeds))
plt.bar(x, original_counts, width=0.4, align='edge', label='Original')
plt.bar([i+0.4 for i in x], total_counts, width=0.4, align='edge', label='After Variations')
plt.xticks([i+0.2 for i in x], breeds, rotation=90)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Dataset Augmentation by Breed')
plt.legend()
plt.tight_layout()
plt.show()

#### Balanced Subset selection

In [None]:
balanced_output = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/balanced_flan_t5_variations.json')
output_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json')

balanced_subset = generator.select_balanced_subset(
    caption_file=caption_file,
    variations_file=output_file,
    output_file=balanced_output,
    target_per_class=150
)

print(f"Created balanced subset with {len(balanced_subset)} captions")
print(f"File saved in: {balanced_output}")

In [None]:
breed_counts_balanced = {}
for caption in balanced_subset.values():
    breed = extract_breed(caption)
    if breed:
        breed_counts_balanced[breed] = breed_counts_balanced.get(breed, 0) + 1

breeds = sorted(breed_counts_balanced.keys())
balanced_counts = [breed_counts_balanced.get(breed, 0) for breed in breeds]

plt.figure(figsize=(15, 10))
plt.bar(breeds, balanced_counts)
plt.xlabel('Breed')
plt.ylabel('Count')
plt.title('Balanced Subset by Breed')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

## Image Generation Diffusion Models

#### Import Libraries

In [None]:
from src.generation.image_diffusion_generator import DiffusionModelManager

#### Initialize Diffusion class

In [None]:
diffusion_manager = DiffusionModelManager(
    output_dir="/content/drive/MyDrive/outputs_master_ProfAI",
    default_model="runwayml/stable-diffusion-v1-5"
)

#### Zero-Shot Prompting

In [None]:
models_to_test = [
    "stabilityai/stable-diffusion-2-1-base",
    #"stabilityai/stable-diffusion-xl-base-1.0",
    "runwayml/stable-diffusion-v1-5",
    "kandinsky-community/kandinsky-2-2-decoder",
    "black-forest-labs/FLUX.1-dev"
]

for model in models_to_test:
    diffusion_manager.test_diffusion_model(
        model_id=model,
        prompt="A high-quality photo of a British Shorthair cat",
        num_images=2,
        output_subdir="model_comparison"
    )

In [None]:
caption_file = Path('/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json')
with open(caption_file, 'r') as f:
    captions_data = json.load(f)

sample_captions = random.sample(list(captions_data.values()), 3)
print("Caption selected for this test:")
for i, caption in enumerate(sample_captions):
    print(f"{i+1}. {caption}")

In [None]:
for caption in sample_captions:
    diffusion_manager.test_diffusion_model(models_to_test[0], caption)

#### LoRA fine-tuning

In [None]:
# Load Images and Captions from original dataset in the correct format

captions_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/captions_git_train_dataset.json"
images_dir = "/content/Gen-AI-for-Data-Augmentation/data/oxford-iiit-pet/images"

dataset_dir = diffusion_manager.prepare_dataset(
    captions_file=captions_file,
    images_dir=images_dir,
    max_samples_per_breed=20,
    min_samples_per_breed=10,
    target_total_samples=37*20,
    resolution=512
)

variations_file = "/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json"
validation_prompts = diffusion_manager.select_validation_prompts_from_variations(
    variations_file=variations_file,
    num_prompts=10
)

In [None]:
# Run LoRA

lora_model_dir = diffusion_manager.run_lora_training(
    dataset_dir=dataset_dir,
    output_name="pet_breeds_lora",
    base_model="runwayml/stable-diffusion-v1-5",
    resolution=512,
    max_train_steps=1500,
    learning_rate=5e-5,
    validation_prompts=validation_prompts,
    rank=4
)

#### Inference with LoRA fine-tuned model

In [None]:
# Load the fine-tuned model

lora_model_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI/lora_models/pet_breeds_lora/")

diffusion_manager.load_lora_model(
    base_model_id="runwayml/stable-diffusion-v1-5",
    lora_weights_path=str(lora_model_dir / "pytorch_lora_weights.safetensors")
)

# Generate images with custom prompts
test_prompts = [
    "A high-quality photograph of a Maine Coon cat with long fur",
    "A detailed image of a Beagle dog running in a park",
    "A professional photo of a Persian cat with blue eyes"
]

generated_images = diffusion_manager.generate_images(
    prompts=test_prompts,
    output_subdir="test_generation",
    num_images_per_prompt=2,
    guidance_scale=7.5,
    num_inference_steps=40
)

In [None]:
# Generate images from variations

variation_images = diffusion_manager.generate_from_variations(
    variations_file=variations_file,
    output_subdir="variation_generation",
    num_samples=20,
    num_images_per_prompt=1,
    guidance_scale=7.5,
    num_inference_steps=40
)

#### Model Evaluation

As said bu Diffusers team [here](https://huggingface.co/docs/diffusers/v0.26.1/conceptual/evaluation), evaluation should be done qualitatively. They also suggest to try other metrics (CLIP Score or FID) to have a quantitative point of view.

In [None]:
lora_model_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI/lora_model")
variations_file = Path("/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json")

diffusion_manager.load_lora_model(
    base_model_id="runwayml/stable-diffusion-v1-5",
    lora_weights_path=str(lora_model_dir / "pytorch_lora_weights.safetensors")
)

variation_images = diffusion_manager.generate_from_variations(
    variations_file=variations_file,
    output_subdir="variation_generation",
    num_samples=20,
    num_images_per_prompt=1,
    guidance_scale=7.5,
    num_inference_steps=40
)

all_images = []
all_prompts = []
for prompt, images in variation_images.items():
    all_images.extend(images)
    all_prompts.extend([prompt] * len(images))

clip_scores = diffusion_manager.evaluate_clip_score(
    images=all_images,
    prompts=all_prompts
)

print(f"Average CLIP score: {clip_scores['mean_clip_score']}")
print(f"Individual scores: {clip_scores['individual_scores']}")

#### Generate Balanced Dataset

##### Load Dataset (if you haven't done it before)

In [None]:
data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

##### Generate and Save Zip file

In [None]:
# Load the fine-tuned model

lora_model_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI/lora_models/pet_breeds_lora/")

diffusion_manager.load_lora_model(
    base_model_id="runwayml/stable-diffusion-v1-5",
    lora_weights_path=str(lora_model_dir / "pytorch_lora_weights.safetensors")
)

In [None]:
# Generate balanced dataset
# Reminder: it is usefull to save images on Goggle Drive in order to restart generation in case of 

balanced_dataset_dir = diffusion_manager.generate_balanced_dataset(
    variations_file="/content/drive/MyDrive/outputs_master_ProfAI/captions/flan_t5_variations.json",
    original_dataset_dir="/content/Gen-AI-for-Data-Augmentation/data/oxford-iiit-pet/images",
    target_dir="/content/drive/MyDrive/outputs_master_ProfAI/generated_data/",
    target_samples_per_class=150,
    guidance_scale=7.5,
    num_inference_steps=40,
    zip_result=True,
    train_dataset = train_dataset,
    time_limit_hours = 2.0,
    resume_from_breed = "Pug"
)

In [None]:
# Find black images generated by the fine-tuned model

images_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI/generated_data/")
threshold = 10

black_images = []
image_files = list(image_dir.glob("*.jpg")) + list(image_dir.glob("*.png"))
print(f"Analysis of {len(image_files)} images...")
    
for img_path in image_files:
    try:
        img = Image.open(img_path)
        img_array = np.array(img)
        mean_brightness = np.mean(img_array)
        if mean_brightness < threshold:
          black_images.append(str(img_path))
          print(f"Black image found: {img_path.name} (brightness: {mean_brightness:.2f})")
    except Exception as e:
        print(f"Elaboration Error for {img_path.name}: {e}")
    
print(f"Found {len(black_images)} black images over {len(image_files)} total images")

if black_images:
    output_file = images_dir / "black_images.txt"
    with open(output_file, "w") as f:
        for img_path in black_images:
            f.write(f"{img_path}\n")
    print(f"List of black images saved in {output_file}")

In [None]:
# Remove black images and count by class

from collections import Counter

black_images = []
if output_file and os.path.exists(output_file):
    with open(output_file, "r") as f:
        black_images = [line.strip() for line in f if line.strip()]
    print(f"Loaded {len(black_images)} balck images to remove")
    
    removed_count = 0
    for img_path in black_images:
        img_path = Path(img_path)
        if img_path.exists():
            img_path.unlink()
            txt_path = img_path.with_suffix(".txt")
            if txt_path.exists():
                txt_path.unlink()    
            removed_count += 1
    
    print(f"Removed {removed_count} black images and corresponding txt files")
    
class_counts = Counter()
for img_path in images_dir.glob("*.jpg"):
    filename = img_path.stem
    
    if "_gen_" in filename:
        class_name = filename.split("_gen_")[0].replace("_", " ")
        class_counts[class_name] += 1

sorted_counts = dict(sorted(class_counts.items()))
print("\Images count per class:")
print("="*50)
print(f"{'Class':<30} | {'Images':<10}")
print("-"*50)
for class_name, count in sorted_counts.items():
    print(f"{class_name:<30} | {count:<10}")
print("="*50)
print(f"Total: {sum(sorted_counts.values())} generated images")

# Save a report
report_path = images_dir / "class_statistics.txt"
with open(report_path, "w") as f:
    f.write("Images count per class::\n")
    f.write("="*50 + "\n")
    f.write(f"{'Class':<30} | {'Images':<10}\n")
    f.write("-"*50 + "\n")
    for class_name, count in sorted_counts.items():
        f.write(f"{class_name:<30} | {count:<10}\n")
    f.write("="*50 + "\n")
    f.write(f"Total: {sum(sorted_counts.values())} generated images\n")

print(f"\nReport saved in {report_path}")

## Classification

In [None]:
import logging
import yaml
import torch
import collections
import numpy as np
import pandas as pd
import os
import shutil
from pathlib import Path
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, Subset, ConcatDataset, DataLoader

from src.utils.logger_setup_classifier import get_logger
from src.models.model_factory import create_model, validate_model_config, get_available_models
from src.training.trainer import ModelTrainer
from src.training.experiment import Experiment
from src.training.callbacks_classifier import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from src.visualization.plot_results import scatter_plot_metrics, plot_confusion_matrix, plot_misclassified_images

In [None]:
# Setup logging

logger = get_logger(ch_log_level=logging.INFO, fh_log_level=logging.DEBUG)

In [None]:
# Load config

with open('config/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [None]:
# Training parameters

BATCH_SIZE = config['training']['batch_size']
NUM_EPOCHS = config['training']['num_epochs']
LEARNING_RATE = config['training']['learning_rate']
NUM_CLASSES = config['dataset']['num_classes']

In [None]:
# Download generated images from Drive folder and set up directories

data_dir = Path('./data')
generated_data_dir = data_dir / "generated_data"
drive_generated_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI/generated_data")

if not generated_data_dir.exists():
    os.makedirs(generated_data_dir, exist_ok=True)

if drive_generated_dir.exists():
    jpg_files = list(drive_generated_dir.glob("*.jpg"))
    print(f"Download {len(jpg_files)} images from Drive...")
    
    for img_path in jpg_files:
        dest_path = generated_data_dir / img_path.name
        if not dest_path.exists():
            shutil.copy(img_path, dest_path)
    
    print(f"Loaded images in {generated_data_dir}")

In [None]:
# Load Dataset

handler = PetDatasetHandler(data_dir)
original_train_dataset, test_dataset = handler.load_dataset(transform=None)

CLASS_NAMES = original_train_dataset.classes

In [None]:
# Define transformations

train_transform = transforms.Compose([
    transforms.Resize(config['preprocessing']['image']['size']),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=config['preprocessing']['image']['mean'],
        std=config['preprocessing']['image']['std']
    )
])

val_test_transform = transforms.Compose([
    transforms.Resize(config['preprocessing']['image']['size']),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=config['preprocessing']['image']['mean'],
        std=config['preprocessing']['image']['std']
    )
])

In [None]:
# You need to choose which test you want to run:
    # 1. Only original dataset with no augmentation techniques
    # 2. Only original dataset with augmentation techniques
    # 3. Original dataset concatenated with generated dataset 

only_original = True
augmentation = False
generated = False


if augmentation:
    train_transform = transforms.Compose([
    transforms.Resize(config['preprocessing']['image']['size']),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=config['preprocessing']['image']['mean'],
        std=config['preprocessing']['image']['std']
    )
])
    
class TransformDataset:
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

if only_original and not augmentation:
    print("Using only original dataset without augmentation")
    
    train_size = int(0.8 * len(original_train_dataset))
    val_size = len(original_train_dataset) - train_size
    
    temp_train, temp_val = random_split(
        original_train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_ds = TransformDataset(Subset(
                temp_train,
                list(range(len(temp_train)))),
                transform=train_transform
            )
    val_ds = TransformDataset(Subset(
                temp_val,
                list(range(len(temp_val)))),
                transform=val_test_transform
            )    
elif only_original and augmentation:
    print("Using only original dataset with augmentation")
        
    train_size = int(0.8 * len(original_train_dataset))
    val_size = len(original_train_dataset) - train_size
    
    temp_train, temp_val = random_split(
        original_train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_ds = TransformDataset(Subset(
                temp_train,
                list(range(len(temp_train)))),
                transform=train_transform
            )
    val_ds = TransformDataset(Subset(
                temp_val,
                list(range(len(temp_val)))),
                transform=val_test_transform
            )   
else:
    print("Using original dataset + generated images")
    
    if not (generated_data_dir / "organized").exists():
        organized_dir = generated_data_dir / "organized"
        os.makedirs(organized_dir, exist_ok=True)
        
        class_mapping = {}
        for img_path in generated_data_dir.glob("*.jpg"):
            if "_gen_" in img_path.stem:
                class_name = img_path.stem.split("_gen_")[0]
                class_mapping[img_path.name] = class_name
        
        for img_name, class_name in class_mapping.items():
            class_dir = organized_dir / class_name
            os.makedirs(class_dir, exist_ok=True)
            
            src_path = generated_data_dir / img_name
            if src_path.exists():
                shutil.copy(src_path, class_dir / img_name)
        
        print(f"Generated images organized in: {organized_dir}")
        generated_dir = organized_dir
    else:
        generated_dir = generated_data_dir / "organized"
    
    generated_dataset = ImageFolder(root=str(generated_dir), transform=None)
    full_dataset = ConcatDataset([original_train_dataset, generated_dataset])
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    temp_train, temp_val = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    train_ds = TransformDataset(Subset(
                temp_train,
                list(range(len(temp_train)))),
                transform=train_transform
            )
    val_ds = TransformDataset(Subset(
                temp_val,
                list(range(len(temp_val)))),
                transform=val_test_transform
            )   

# Transformation for test dataset
test_ds = TransformDataset(Subset(
                test_dataset,
                list(range(len(test_dataset)))),
                transform=val_test_transform
            )

In [None]:
# Create datasets

train_loader = DataLoader(
    train_ds,
    batch_size = BATCH_SIZE,
    shuffle=True,
    #num_workers=4,
    #pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size = BATCH_SIZE,
    shuffle=False,
    #num_workers=4,
    #pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size = BATCH_SIZE,
    shuffle=False,
    #num_workers=4,
    #pin_memory=True
)

print(f"Train samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
print(f"Test samples: {len(test_ds)}")

In [None]:
# Count labels percentage for a chosen dataloader

dataloader = train_loader

class_counts = collections.defaultdict(int)
for batch in dataloader:
  _, labels = batch     
  for label in labels:   
    class_counts[label.item()] += 1  

class_counts_dict = dict(sorted(class_counts.items()))
class_names = CLASS_NAMES
  
print(f"\Class Distribution in Original Dataset - Train:")
print(f"Total images: {len(dataloader.dataset)}")
print("-" * 50)
for class_idx, count in class_counts.items():
    class_name = class_names[class_idx] if class_names else f"Class {class_idx}"
    print(f"{class_name:<30} | {count:>5} | {count/len(dataloader.dataset)*100:>6.2f}%")

print("-" * 50)

In [None]:
# Model configuration --> Baseline model

model_config = {
    'type': 'baseline',
    'num_classes': 37,
    'input_channels': 3
}

In [None]:
# Model configuration --> Transfer learning model (YOU NEED TO DECIDE IF TO USE CUSTOM CLASSIFIER OR NOT)

model_config = {
    'type': 'transfer',
    'model_name': 'resnet50', 
    'num_classes': NUM_CLASSES,
    'pretrained': True,
    'use_custom_classifier': True
}

In [None]:
# Give a look at every avialable model

get_available_models()

In [None]:
# Validate configuration

validate_model_config(model_config)

In [None]:
# Create model

model = create_model(model_config)

In [None]:
# Optimizer e Loss

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=1e-4
)

criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

In [None]:
# Setup experiment

experiment = Experiment(
    name="resnet50_only_original_data",
    root="/content/drive/MyDrive/outputs_master_ProfAI/experiments_genAI",
    logger=logger
)
experiment.init()

In [None]:
# Setup callbacks

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=config['training']['early_stopping']['patience'],
        min_delta=config['training']['early_stopping']['min_delta'],
        verbose=True
    ),
    ModelCheckpoint(
        filepath='best_baseline_model.pth',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        optimizer=optimizer,
        mode='min',
        patience=5,
        factor=0.1,
        verbose=True
    )
]

In [None]:
# Initialize trainer

trainer = ModelTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    experiment=experiment,
    device=device,
    logger=logger
)

In [None]:
# Train the model

trained_model = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    callbacks=callbacks
)

torch.save(trained_model.state_dict(), experiment.root / 'final_model.pth')

In [None]:
# Evaluation on test set

test_logs = trainer.validate(test_loader)
experiment.save_history('test', **test_logs)
logger.info(f"Test Results: {test_logs}")

In [None]:
# Get predictions on test set

test_targets, test_predictions = trainer.predict(test_loader)

plot_confusion_matrix(test_targets, test_predictions, classes = CLASS_NAMES, path_to_save=str(experiment.root))
logger.info("Confusion matrix saved as 'confusion_matrix.png'")

In [None]:
# Save test results
test_results = {
    'targets': test_targets.tolist(),
    'predictions': test_predictions.tolist()
}

with open(f"{experiment.results_dir}/test_results.json", 'w') as f:
    json.dump(test_results, f)

In [None]:
# Generate and Save plots training history

experiment.plot_history()

In [None]:
# Evaluation train and validation results

scatter_plot_metrics(f'{experiment.root}/history/train.csv', 
                     f'{experiment.root}/history/val.csv')

In [None]:
# Replace missing values with 0 in column lr from val.csv and test.csv (TO BE REFACTORED)

val = pd.read_csv(f"{experiment.root}/history/val.csv")
val['lr'] = val['lr'].fillna(0).to_numpy()
val.to_csv(f"{experiment.root}/history/val.csv", index=False)

test = pd.read_csv(f"{experiment.root}/history/test.csv")
test['lr'] = test['lr'].fillna(0).to_numpy()
test.to_csv(f"{experiment.root}/history/test.csv", index=False)

In [None]:
# Calculate average metrics for last n epochs

experiment = Experiment("resnet34_only_original_data", "/content/drive/MyDrive/outputs_master_ProfAI/experiments_genAI")
experiment.load_history_from_file("val")
experiment.load_history_from_file("train")
experiment.load_history_from_file("test")

avg_metrics = experiment.calculate_average_metrics('val', last_n_epochs=5)
print("Average validation metrics:", avg_metrics)

In [None]:
# Export results in JSON

experiment.export_results_to_json("/content/drive/MyDrive/outputs_master_ProfAI/experiments_genAI/resnet34_only_original_data/results/results.json")

In [None]:
# Find best epoch according to validation accuracy 

metric = 'accuracy'

best_epoch = experiment.get_best_epoch(metric, mode='max')
print(f"Best validation accuracy was achieved at epoch {best_epoch} with
       {100*experiment.history['val'][metric][best_epoch-1]:.1f}%")

In [None]:
# Plot learning rate

experiment.plot_learning_rate(experiment.history['train']['lr'])

In [None]:
# Plot misclassified images with ground truth and prediction

plot_misclassified_images(
    model=trained_model,
    dataloader=test_loader,
    device=device,
    num_images=16,
    class_names=CLASS_NAMES,
    mean=config['preprocessing']['image']['mean'],
    std=config['preprocessing']['image']['std']
)

## Imgae Generation with Conditional GAN

In [None]:
# Setup device

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Device used: {device}")

In [8]:
# Configurations

batch_size = 32
image_size = 128
num_workers = 4

In [None]:
# Setup components

output_dir = Path("/content/drive/MyDrive/outputs_master_ProfAI")
os.makedirs(output_dir, exist_ok=True)
checkpoint_dir = output_dir / "checkpoints"
log_dir = output_dir / "logs"

metrics = MetricsTracker([
    FIDScore(device=device),
    CLIPScore(device=device)
])

logger = GANLogger("conditional_gan", log_dir=log_dir)

callbacks = [
    EarlyStopping(monitor='fid', patience=5),
    ModelCheckpoint(filepath=checkpoint_dir / "best_model.pt", monitor='fid'),
    MetricsHistory(log_dir=log_dir / "metrics")
]

In [19]:
# Load Dataset

data_dir = Path('./data')
handler = PetDatasetHandler(data_dir)
train_dataset, test_dataset = handler.load_dataset()

In [20]:
# Load Captions
with open('output/captions/captions_traindataset.json', 'r') as f:
    caption_dict = json.load(f)

In [21]:
train_images_paths = [str(img) for img in train_dataset._images]

test_images_paths = [str(img) for img in test_dataset._images]

In [None]:
# Initialize train and val loader

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_train_ds = PetDatasetWithCaptions(
    image_paths=train_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_size = int(0.9 * len(full_train_ds))
val_size = len(full_train_ds) - train_size
train_ds, val_ds = random_split(full_train_ds, [train_size, val_size],
                                generator=torch.Generator().manual_seed(42))

test_ds = PetDatasetWithCaptions(
    image_paths=test_images_paths,
    caption_dict=caption_dict,
    transform=transform
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

In [None]:
# Initialize GAN model

config = GANConfig(
    latent_dim = 100,
    caption_dim = 768,
    image_size = image_size,
    num_channels = 3,
    generator_features = 64
)

gan = ConditionalGAN(config)

In [24]:
# Initialize trainer

trainer = GANTrainer(
    gan=gan,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    metrics_tracker=metrics,
    logger=logger,
    callbacks=callbacks
)

In [None]:
# Train

trainer.train(
    num_epochs=100,
    eval_freq=1,
    sample_freq=500,
    sample_dir=Path("samples")
)