In [None]:
!pip install diffusers\
             transformers\
             accelerate\
             datasets \
             safetensors \
             bitsandbytes \
             wandb

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_c

In [None]:
import os
import csv
import copy
import torch
import random
from dataclasses import dataclass

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from diffusers import StableDiffusionPipeline
from diffusers import DDPMScheduler
from diffusers import UNet2DConditionModel
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer

import torch.nn.functional as F
from accelerate import Accelerator
from torch.utils.data import DataLoader

from tqdm import tqdm

print("PyTorch version:", torch.__version__)

PyTorch version: 2.6.0+cu124


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

os.chdir('/content/drive/MyDrive/Projet/ADD0')
print("Répertoire courant :", os.getcwd())

Répertoire courant : /content/drive/MyDrive/Projet/ADD0


In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split

def create_splits(full_csv, train_csv="data/pokemon_train.csv", val_csv="data/pokemon_val.csv", test_csv="data/pokemon_test.csv"):
    if not (os.path.exists(val_csv) and os.path.exists(test_csv)):
        print("Création des splits train / val / test...")
        df_full = pd.read_csv(full_csv)
        train_df, temp_df = train_test_split(df_full, test_size=0.2, random_state=42)
        val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

        # Remarque : le fichier original "pokemon_train.csv" sera réécrit avec 80% des données.
        train_df.to_csv(train_csv, index=False)
        val_df.to_csv(val_csv, index=False)
        test_df.to_csv(test_csv, index=False)
        print("Splits créés : {} samples train, {} samples val, {} samples test".format(
            len(train_df), len(val_df), len(test_df)))
    else:
        print("Les fichiers split train/val/test existent déjà.")

In [None]:
full_csv_path = "data/pokemon_train.csv"
create_splits(full_csv_path)

Création des splits train / val / test...
Splits créés : 647 samples train, 81 samples val, 81 samples test


In [None]:
@dataclass
class PokemonDataset(Dataset):
    """Custom Dataset for loading images and text prompts from your CSV."""

    csv_file: str
    transform: transforms.Compose = None

    def __post_init__(self):
        self.samples = []
        with open(self.csv_file, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                image_path = row["image_path"]
                text_prompt = row["caption"]
                self.samples.append((image_path, text_prompt))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, text_prompt = self.samples[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return {
            "pixel_values": image,
            "prompt": text_prompt
        }

image_transforms = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
train_dataset = PokemonDataset(csv_file="data/pokemon_train.csv", transform=image_transforms)
val_dataset   = PokemonDataset(csv_file="data/pokemon_val.csv",   transform=image_transforms)
test_dataset  = PokemonDataset(csv_file="data/pokemon_test.csv",  transform=image_transforms)

print("Nombre d'échantillons :")
print(" - Train :", len(train_dataset))
print(" - Val   :", len(val_dataset))
print(" - Test  :", len(test_dataset))

Nombre d'échantillons :
 - Train : 647
 - Val   : 81
 - Test  : 81


In [None]:
example = train_dataset[0]
print("Forme de l'image échantillon :", example["pixel_values"].shape)
print("Exemple de caption :", example["prompt"])

Forme de l'image échantillon : torch.Size([3, 512, 512])
Exemple de caption : solgaleo, a Psychic/Steel type Sunne Pokémon. Solgaleo is a majestic lion-like Pokémon with a golden mane and a mane that shines like the sun.


In [None]:
print("Number of training samples:", len(train_dataset))

# Inspection d'un échantillon
example = train_dataset[0]
print("Sample image tensor shape:", example["pixel_values"].shape)
print("Sample caption:", example["prompt"])

Number of training samples: 809
Sample image tensor shape: torch.Size([3, 512, 512])
Sample caption: bulbasaur, a Grass/Poison type Seed Pokémon. Bulbasaur is a small, green Pokémon with a bulb on its back that blooms into a beautiful flower as it grows.


In [None]:
model_name = "CompVis/stable-diffusion-v1-4"

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

In [None]:
batch_size = 4
learning_rate = 1e-5
num_epochs = 5
max_train_steps = 100
gradient_accumulation_steps = 1

# accelerator multi-GPU
accelerator = Accelerator(mixed_precision="fp16")
device = accelerator.device

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_dataloader = accelerator.prepare(train_dataloader)
val_dataloader = accelerator.prepare(val_dataloader)
test_dataloader = accelerator.prepare(test_dataloader)

In [None]:
# Freeze VAE
vae.requires_grad_(False)
vae.to(device)
vae.eval()

# Optimizer
optimizer = torch.optim.AdamW(
    [
        {"params": unet.parameters(), "lr": learning_rate},
        {"params": text_encoder.parameters(), "lr": learning_rate},
    ],
    betas=(0.9, 0.999),
    weight_decay=1e-2
)

unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
    unet, text_encoder, optimizer, train_dataloader
)

# Fine-tuning BASIC avec boucle de validation


In [None]:
print("Starting BASIC finetuning...")
global_step = 0
for epoch in range(num_epochs):
    unet.train()
    text_encoder.train()

    progress_bar = tqdm(train_dataloader, desc=f"[BASIC] Epoch {epoch+1}", leave=True)
    epoch_loss_total = 0
    num_steps = 0

    # Boucle d'entraînement
    for step, batch in enumerate(progress_bar):
        prompts = batch["prompt"]
        inputs = tokenizer(prompts, padding="max_length", max_length=77,
                             truncation=True, return_tensors="pt")
        input_ids = inputs.input_ids.to(device)

        encoder_hidden_states = text_encoder(input_ids)[0]
        pixel_values = batch["pixel_values"].to(device)
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
                                  (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = F.mse_loss(noise_pred, noise)

        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss_total += loss.item()
        num_steps += 1
        progress_bar.set_postfix(loss=loss.item())

        # global_step += 1  # Optionnel pour arrêt anticipé
        # if global_step >= max_train_steps:
        #     break

    average_loss = epoch_loss_total / num_steps
    print(f"\n[BASIC] ✅ Epoch {epoch+1} finished — avg train loss: {average_loss:.4f}")

    # Boucle de validation après chaque epoch (pas de backward)
    unet.eval()
    text_encoder.eval()
    val_loss_total = 0
    val_steps = 0
    with torch.no_grad():
        for batch in val_dataloader:
            prompts = batch["prompt"]
            inputs = tokenizer(prompts, padding="max_length", max_length=77,
                               truncation=True, return_tensors="pt")
            input_ids = inputs.input_ids.to(device)
            encoder_hidden_states = text_encoder(input_ids)[0]
            pixel_values = batch["pixel_values"].to(device)
            latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
                                      (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            val_loss = F.mse_loss(noise_pred, noise)
            val_loss_total += val_loss.item()
            val_steps += 1
    avg_val_loss = val_loss_total / val_steps if val_steps > 0 else 0.
    print(f"[BASIC] Epoch {epoch+1} validation loss: {avg_val_loss:.4f}")

Starting BASIC finetuning...


[BASIC] Epoch 1: 100%|██████████| 162/162 [02:30<00:00,  1.08it/s, loss=0.0255]



[BASIC] ✅ Epoch 1 finished — avg train loss: 0.0370
[BASIC] Epoch 1 validation loss: 0.0395


[BASIC] Epoch 2: 100%|██████████| 162/162 [00:58<00:00,  2.77it/s, loss=0.0437]



[BASIC] ✅ Epoch 2 finished — avg train loss: 0.0345
[BASIC] Epoch 2 validation loss: 0.0343


[BASIC] Epoch 3: 100%|██████████| 162/162 [00:58<00:00,  2.76it/s, loss=0.075]



[BASIC] ✅ Epoch 3 finished — avg train loss: 0.0337
[BASIC] Epoch 3 validation loss: 0.0342


[BASIC] Epoch 4: 100%|██████████| 162/162 [00:58<00:00,  2.77it/s, loss=0.0439]



[BASIC] ✅ Epoch 4 finished — avg train loss: 0.0354
[BASIC] Epoch 4 validation loss: 0.0333


[BASIC] Epoch 5: 100%|██████████| 162/162 [00:58<00:00,  2.78it/s, loss=0.0216]



[BASIC] ✅ Epoch 5 finished — avg train loss: 0.0356
[BASIC] Epoch 5 validation loss: 0.0381


In [None]:
# Final evaluation on Test Set
print("Final evaluation on Test set:")
unet.eval()
text_encoder.eval()
test_loss_total = 0
test_steps = 0

# Pas de rétro-propagation durant l'évaluation
with torch.no_grad():
    for batch in test_dataloader:
        prompts = batch["prompt"]
        inputs = tokenizer(
            prompts,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt"
        )
        input_ids = inputs.input_ids.to(device)
        encoder_hidden_states = text_encoder(input_ids)[0]

        pixel_values = batch["pixel_values"].to(device)
        latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215

        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
                                  (latents.shape[0],), device=device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = F.mse_loss(noise_pred, noise)

        test_loss_total += loss.item()
        test_steps += 1

avg_test_loss = test_loss_total / test_steps if test_steps > 0 else 0.
print("Final Test Loss: {:.4f}".format(avg_test_loss))

Final evaluation on Test set:
Final Test Loss: 0.0405


# Inférence

In [None]:
prompt_file = "prompts.txt"
output_base = "output_teacher"
num_inference_steps = 50
guidance_scale = 7.5
image_size = 512
use_fp16 = True

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if (use_fp16 and device.type == "cuda") else torch.float32

with open(prompt_file, "r", encoding="utf-8") as f:
    prompts = [line.strip() for line in f.readlines() if line.strip()]

print(f"Loaded {len(prompts)} prompts.")

Loaded 20 prompts.


In [None]:
vanilla_pipe = StableDiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=dtype
).to(device)

vanilla_output_dir = os.path.join(output_base, "vanilla")

model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

scheduler_config-checkpoint.json:   0%|          | 0.00/209 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
def generate_images(pipe, prompts, out_dir, tag="default"):
    os.makedirs(out_dir, exist_ok=True)
    print(f"Generating images with tag '{tag}'...")
    for i, prompt in enumerate(tqdm(prompts)):
        with torch.autocast(device_type=device.type):
            image = pipe(
                prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                height=image_size,
                width=image_size
            ).images[0]

        safe_name = "".join(c if c.isalnum() else "_" for c in prompt)[:50]
        image.save(os.path.join(out_dir, f"{i:03d}_{tag}_{safe_name}.png"))

In [None]:
generate_images(vanilla_pipe, prompts, vanilla_output_dir, tag="vanilla")

Generating images with tag 'vanilla'...


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:03<01:07,  3.56s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 10%|█         | 2/20 [00:06<01:00,  3.36s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 15%|█▌        | 3/20 [00:10<00:56,  3.30s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 20%|██        | 4/20 [00:13<00:52,  3.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 25%|██▌       | 5/20 [00:16<00:48,  3.25s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.
 30%|███       | 6/20 [00:19<00:44,  3.20s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 35%|███▌      | 7/20 [00:22<00:41,  3.20s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 40%|████      | 8/20 [00:26<00:38,  3.22s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 45%|████▌     | 9/20 [00:29<00:35,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 50%|█████     | 10/20 [00:32<00:32,  3.21s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 55%|█████▌    | 11/20 [00:35<00:29,  3.22s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 60%|██████    | 12/20 [00:38<00:25,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 65%|██████▌   | 13/20 [00:42<00:22,  3.24s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 70%|███████   | 14/20 [00:45<00:19,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 75%|███████▌  | 15/20 [00:48<00:16,  3.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 80%|████████  | 16/20 [00:51<00:12,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 85%|████████▌ | 17/20 [00:55<00:09,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 90%|█████████ | 18/20 [00:58<00:06,  3.23s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 95%|█████████▌| 19/20 [01:01<00:03,  3.24s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 20/20 [01:04<00:00,  3.24s/it]


In [None]:
finetuned_pipe = StableDiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=dtype
).to(device)

finetuned_pipe.unet = unet.to(device)
finetuned_pipe.text_encoder = text_encoder.to(device)

finetuned_output_dir = os.path.join(output_base, "finetuned1")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
generate_images(finetuned_pipe, prompts, finetuned_output_dir, tag="finetuned")

Generating images with tag 'finetuned'...


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:03<00:59,  3.11s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 10%|█         | 2/20 [00:06<00:54,  3.04s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 15%|█▌        | 3/20 [00:09<00:51,  3.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 20%|██        | 4/20 [00:12<00:48,  3.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 25%|██▌       | 5/20 [00:15<00:45,  3.04s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 30%|███       | 6/20 [00:18<00:42,  3.03s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 35%|███▌      | 7/20 [00:21<00:39,  3.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 40%|████      | 8/20 [00:24<00:36,  3.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 45%|████▌     | 9/20 [00:27<00:33,  3.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 50%|█████     | 10/20 [00:30<00:29,  2.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 55%|█████▌    | 11/20 [00:33<00:26,  2.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 60%|██████    | 12/20 [00:36<00:23,  2.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 65%|██████▌   | 13/20 [00:39<00:20,  3.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 70%|███████   | 14/20 [00:42<00:17,  3.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 75%|███████▌  | 15/20 [00:45<00:14,  2.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 80%|████████  | 16/20 [00:48<00:11,  3.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 85%|████████▌ | 17/20 [00:51<00:09,  3.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 90%|█████████ | 18/20 [00:54<00:05,  3.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

 95%|█████████▌| 19/20 [00:57<00:02,  2.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 20/20 [01:00<00:00,  3.00s/it]


In [None]:
save_directory = "output_teacher/model"
os.makedirs(save_directory, exist_ok=True)
print("Sauvegarde du modèle dans :", save_directory)

finetuned_pipe.save_pretrained(save_directory)

print("Modèle sauvegardé avec succès.")