In [None]:
!pip install diffusers transformers datasets wandb Pillow tqdm xformers accelerate

In [None]:
from huggingface_hub import login
login()

In [None]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import os
import io
import glob
import shutil
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import math
from torch.utils.data import DataLoader, IterableDataset
from torchvision import transforms
from PIL import Image, ImageOps
from tqdm.auto import tqdm
import wandb
import xformers # Added for XFORMERS

# Hugging Face Libraries
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import ProjectConfiguration, set_seed, LoggerType

# --- Configuration (Hardcoded for simplicity) ---
vae_model_id = "stabilityai/sd-vae-ft-ema"
text_encoder_model_id = "openai/clip-vit-large-patch14"
dataset_id = "BLIP3o/BLIP3o-Pretrain-Short-Caption"

# Training Hyperparameters
image_resolution = 512 #Il mio codice nella teoria fa il resize delle immagini in caso si vogliano utilizzare altri DS
latent_resolution = image_resolution // 8 #image_resolution // vae_scale_factor (solitamente 8)
train_batch_size = 16 # Batch per scheda video
gradient_accumulation_steps = 8 # Effective batch size (GPU * acc_steps * batch_size)
learning_rate = 1e-4 #std
adam_beta1 = 0.9 #std
adam_beta2 = 0.999 #std
adam_weight_decay = 1e-2 #std
adam_epsilon = 1e-08 #std
max_grad_norm = 1.0 #Taglio i gradienti
num_train_timesteps_scheduler = 1000 #Step del noise scheduler
mixed_precision = "bf16" #Risparmia memoria e va pure piu veloce a fare i calcoli
seed = 42 #Il numero perfetto come seed del random
cfg_drop_probability = 0.1 # CFG Drop Probability
max_train_steps = 30_000
save_every_steps = max_train_steps // 100
num_warmup_steps_ratio = 0.05 # 5% of max_train_steps for warmup
num_warmup_steps = int(num_warmup_steps_ratio * max_train_steps)
# --- Storage Path ---
persistent_storage_mount_path = "/workspace"
project_data_folder_name = "PozzyDiffusion"
output_dir_base = os.path.join(persistent_storage_mount_path, project_data_folder_name)
logging_dir = os.path.join(output_dir_base, "logs")
checkpoint_dir_base = os.path.join(output_dir_base, "checkpoints")
final_model_dir = os.path.join(output_dir_base, "final_model")
max_keep_checkpoints = 2
# WandB
wandb_project_name = "PozzyDiffusion_A40"
use_wandb = True
# Hardcoded prompts per sampling
sampling_prompts = [
    "a photo of an astronaut riding a horse on mars",
    "a fantasy landscape with a castle and a dragon",
    "a cyberpunk city at night with neon lights",
    "a serene Japanese garden with cherry blossoms",
    "a whimsical illustration of a cat playing a piano"
]
num_inference_steps_sampling = 50 #Step che fa per ottenere l'img finale

# --- Accelerator ---
project_config = ProjectConfiguration(project_dir=output_dir_base, logging_dir=logging_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    mixed_precision=mixed_precision,
    log_with="wandb" if use_wandb else None,
    project_config=project_config,
)
# Imposta in modo globale il random seed cosi che posso riprodurre perfettamente tutto
if seed is not None:
    set_seed(seed)
# Crea gia le cartelle
if accelerator.is_main_process:
    os.makedirs(output_dir_base, exist_ok=True)
    os.makedirs(checkpoint_dir_base, exist_ok=True)
    os.makedirs(final_model_dir, exist_ok=True)

# --- Helper Functions ---
def save_checkpoint_custom(accelerator, unet_original, ema_unet, optimizer, lr_scheduler, global_step, epoch, checkpoint_dir_base, filename_prefix="ckpt", max_keep=2):
    if accelerator.is_main_process:
        os.makedirs(checkpoint_dir_base, exist_ok=True)
        # Checkpoint path includes the step
        checkpoint_name = f"{filename_prefix}_step{global_step}"
        save_path = os.path.join(checkpoint_dir_base, checkpoint_name)
        
        accelerator.save_state(save_path) # Saves prepared models, optimizer, scheduler

        # Save EMA model separately as it's not part of accelerator.prepare() directly
        ema_save_path = os.path.join(save_path, "ema_unet.pth")
        torch.save(ema_unet.state_dict(), ema_save_path)
        
        # Save additional training state (could be registered with accelerator too)
        custom_state = {
            'global_step': global_step,
            'epoch': epoch,
            'wandb_run_id': wandb.run.id if use_wandb and wandb.run else None
        }
        torch.save(custom_state, os.path.join(save_path, "custom_training_state.pt"))

        accelerator.print(f"Saved checkpoint: {save_path}")

        # Manage old checkpoints
        checkpoints = sorted(
            glob.glob(os.path.join(checkpoint_dir_base, f"{filename_prefix}_step*")),
            key=lambda x: int(os.path.basename(x).split('step')[-1]) # Sort by step number
        )
        if len(checkpoints) > max_keep:
            for old_ckpt_path in checkpoints[:-max_keep]:
                accelerator.print(f"Removing old checkpoint: {old_ckpt_path}")
                shutil.rmtree(old_ckpt_path) # Remove entire directory

def load_latest_checkpoint_custom(accelerator, unet_original, ema_unet, optimizer, lr_scheduler, checkpoint_dir_base, filename_prefix="ckpt"):
    checkpoints = sorted(
        glob.glob(os.path.join(checkpoint_dir_base, f"{filename_prefix}_step*")),
        key=lambda x: int(os.path.basename(x).split('step')[-1]), # Sort by step number
        reverse=True
    )
    if checkpoints:
        latest_checkpoint_path = checkpoints[0]
        accelerator.print(f"Found checkpoint: {latest_checkpoint_path}")
        try:
            accelerator.load_state(latest_checkpoint_path)
            
            ema_path = os.path.join(latest_checkpoint_path, "ema_unet.pth")
            if os.path.exists(ema_path):
                ema_unet.load_state_dict(torch.load(ema_path, map_location="cpu"))
                accelerator.print("EMA UNet state loaded.")

            custom_state_path = os.path.join(latest_checkpoint_path, "custom_training_state.pt")
            if os.path.exists(custom_state_path):
                custom_state = torch.load(custom_state_path, map_location="cpu")
                return custom_state.get('global_step', 0), custom_state.get('epoch', 0), custom_state.get('wandb_run_id')
            else: # Fallback for older checkpoints perhaps, or if custom state saving failed
                accelerator.print("Warning: custom_training_state.pt not found. Global step and epoch might not be accurate from checkpoint.")
                # Try to infer global_step from path if needed, though accelerator might handle optimizer step counts.
                # For this setup, accelerator.load_state() should restore optimizer and scheduler internal steps.
                # The main global_step is for tracking progress and naming.
                parsed_step = int(os.path.basename(latest_checkpoint_path).split('step')[-1])
                return parsed_step, 0, None # Assuming epoch 0 if not found

        except Exception as e:
            accelerator.print(f"Error loading checkpoint {latest_checkpoint_path}: {e}. Starting fresh or from an earlier state if accelerator handled it.")
            return 0, 0, None # Could not load
    return 0, 0, None # No checkpoint found

def generate_samples(unet_model_for_sampling, vae, text_encoder, tokenizer_obj, noise_scheduler_obj, prompts, output_dir, global_step, device, accelerator_ref):
    unet_model_for_sampling.eval() # Ensure the passed model is set to eval
    vae.eval()
    text_encoder.eval()

    with torch.no_grad():
        for i, prompt in enumerate(prompts):
            text_inputs = tokenizer_obj(prompt, padding="max_length", max_length=tokenizer_obj.model_max_length, truncation=True, return_tensors="pt")
            # Move input_ids to the correct device
            text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
            
            # Use accelerator.device for latents
            latents_shape = (1, unet_model_for_sampling.config.in_channels, latent_resolution, latent_resolution)
            
            # --- FIXED LINE ---
            # The generator must be created on the same device as the target tensor.
            # torch.manual_seed() creates a CPU generator by default.
            # The fix is to instantiate a generator on the correct device.
            generator = torch.Generator(device=device).manual_seed(seed + i) if seed is not None else None
            latents = torch.randn(latents_shape, device=device, generator=generator) # Add seed for reproducibility of samples
            
            noise_scheduler_obj.set_timesteps(num_inference_steps_sampling) # Use configured number of steps
            
            for t in tqdm(noise_scheduler_obj.timesteps, desc=f"Sampling for prompt {i+1}", disable=not accelerator_ref.is_main_process):
                # scale the model input
                latent_model_input = noise_scheduler_obj.scale_model_input(latents, t)
                noise_pred = unet_model_for_sampling(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                latents = noise_scheduler_obj.step(noise_pred, t, latents).prev_sample
            
            latents = 1 / vae.config.scaling_factor * latents
            image = vae.decode(latents).sample
            image = (image / 2 + 0.5).clamp(0, 1) # Denormalize
            image = image.cpu().permute(0, 2, 3, 1).numpy()[0] # (H, W, C)
            image = Image.fromarray((image * 255).round().astype("uint8"))
            
            if accelerator_ref.is_main_process:
                sample_output_dir = os.path.join(output_dir, "samples", f"step_{global_step}")
                os.makedirs(sample_output_dir, exist_ok=True)
                img_path = os.path.join(sample_output_dir, f"prompt_{i+1}_seed{seed+i if seed is not None else 'rand'}.png")
                image.save(img_path)
                accelerator_ref.print(f"Generated sample for prompt '{prompt}' saved to {img_path}")
    
    unet_model_for_sampling.train() # Set back to train if it was the main model (though here it's a copy)

# --- Inizializza i modelli ---
accelerator.print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(vae_model_id)
accelerator.print("Loading Text Encoder & Tokenizer...")
tokenizer = CLIPTokenizer.from_pretrained(text_encoder_model_id)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_model_id)
accelerator.print("Initializing UNet...")
# Struttura originale su cui si basera l'EMA
original_unet = UNet2DConditionModel(
    sample_size=latent_resolution, in_channels=vae.config.latent_channels, out_channels=vae.config.latent_channels,
    layers_per_block=2, block_out_channels=(256, 512, 768, 1024), # Example, adjust as needed
    down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
    cross_attention_dim=text_encoder.config.hidden_size,
)
#Conto i parametri
unet_param_count = sum(p.numel() for p in original_unet.parameters() if p.requires_grad)
accelerator.print(f"UNet initialized with {unet_param_count / 1e6:.2f}M parameters.")
#Inizializza EMA model
ema_unet = EMAModel(original_unet.parameters(), model_cls=UNet2DConditionModel, model_config=original_unet.config)
ema_unet.to(accelerator.device) #E lo butto sulla GPU
#Butto anche gli altri 2 coglioni sulla gpu, ma non allenandoli metto require_grad a false per risparmiare memoria
vae.to(accelerator.device).eval().requires_grad_(False)
text_encoder.to(accelerator.device).eval().requires_grad_(False)

#Embedding vuoto per CFG invece di calcolarlo ogni volta come uno stupido
with torch.no_grad():
    uncond_tokens = tokenizer([""], padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
    uncond_embeddings = text_encoder(uncond_tokens.to(accelerator.device))[0]

#Provo ad utilizzare XFormers per risparmiare memoria
if xformers.version:
    try:
        original_unet.enable_xformers_memory_efficient_attention()
        accelerator.print("XFormers memory efficient attention enabled for UNet.")
    except Exception as e:
        accelerator.print(f"Could not enable xformers memory efficient attention: {e}")
else:
    accelerator.print("XFormers not available or not installed. Skipping memory efficient attention.") #Cacca
#Qua prima compilavo, ora non compilo piu il modello
compiled_unet_object = None #Se non compilo lascio a none
accelerator.print("Skipping UNet compilation to avoid potential CUDAGraph issues.") # Updated message
unet_to_prepare = original_unet # Directly use the original, uncompiled UNet

# --- Noise Scheduler ---
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_train_timesteps_scheduler,
    beta_schedule="scaled_linear",
    prediction_type="v_prediction" # Switched to v-prediction
)

# --- Optimizer ---
optimizer = AdamW(
    unet_to_prepare.parameters(), # Optimize parameters of the (potentially compiled) UNet
    lr=learning_rate, betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay, eps=adam_epsilon,
    # Fused AdamW is often handled automatically by PyTorch/Accelerate or can be enabled if available
    # fused=True if accelerator.device.type == 'cuda' and mixed_precision in ["fp16", "bf16"] else False # Be cautious with fused and compilation
)

# --- Learning Rate Scheduler ---
def lr_lambda_cosine(current_step: int):
    if num_warmup_steps > 0 and current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    denominator = float(max(1, max_train_steps - num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / denominator
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

lr_scheduler = LambdaLR(optimizer, lr_lambda_cosine)


# --- Sistemo img ---
image_transforms = transforms.Compose([
    transforms.Resize((image_resolution, image_resolution), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]), # Normalize to [-1, 1]
])
#Preprocessing
def preprocess_function(examples, tokenizer_obj, image_transforms_fn, accelerator_ref):
    prompts, pil_images = examples["txt"], examples["jpg"]
    processed_images, valid_prompts = [], []
    for idx, (img_pil, prompt_text) in enumerate(zip(pil_images, prompts)):
        try:
            if img_pil is None:
                if accelerator_ref.is_main_process: accelerator_ref.print(f"Warning: Found None image for prompt: {prompt_text}. Skipping.")
                continue
            img = img_pil.convert("RGB")
            img = ImageOps.exif_transpose(img) # Handle EXIF orientation
            processed_images.append(image_transforms_fn(img))
            valid_prompts.append(prompt_text)
        except Exception as e:
            if accelerator_ref.is_main_process: # Log only on main process to avoid spam
                accelerator_ref.print(f"Warning: Skipping an image/prompt due to error: {e}. Prompt: '{prompt_text}'. Image index in batch: {idx}")
            continue
    
    if not processed_images: # All images in this chunk were bad
        return None
        
    text_inputs = tokenizer_obj(valid_prompts, padding="max_length", max_length=tokenizer_obj.model_max_length, truncation=True, return_tensors="pt")
    return {"pixel_values": torch.stack(processed_images), "input_ids": text_inputs.input_ids}

#Streaming DS
class StreamingImageTextDataset(IterableDataset):
    def __init__(self, dataset_id, split, transform_fn, tokenizer_obj, image_transforms_fn, processing_chunk_size, accelerator_ref):
        self.dataset = load_dataset(dataset_id, split=split, streaming=True)
        self.transform_fn = transform_fn
        self.tokenizer_obj = tokenizer_obj
        self.image_transforms_fn = image_transforms_fn
        self.processing_chunk_size = processing_chunk_size
        self.accelerator_ref = accelerator_ref # For logging inside preprocess

    def __iter__(self):
        buffer = []
        for example in self.dataset:
            # Ensure 'jpg' and 'txt' keys exist and are not None
            if example.get("jpg") is not None and example.get("txt") is not None:
                buffer.append({"jpg": example["jpg"], "txt": example["txt"]})
            else:
                if self.accelerator_ref.is_main_process:
                    self.accelerator_ref.print(f"Warning: Skipping example due to missing 'jpg' or 'txt' field: {example.get('txt', 'N/A')}")
                continue

            if len(buffer) == self.processing_chunk_size:
                processed_batch = self.transform_fn(
                    {"jpg": [i["jpg"] for i in buffer], "txt": [i["txt"] for i in buffer]},
                    self.tokenizer_obj, self.image_transforms_fn, self.accelerator_ref
                )
                if processed_batch:
                    for i in range(processed_batch["pixel_values"].size(0)):
                        yield {"pixel_values": processed_batch["pixel_values"][i], "input_ids": processed_batch["input_ids"][i]}
                buffer = []
        
        # Process any remaining items in the buffer
        if buffer:
            processed_batch = self.transform_fn(
                {"jpg": [i["jpg"] for i in buffer], "txt": [i["txt"] for i in buffer]},
                self.tokenizer_obj, self.image_transforms_fn, self.accelerator_ref
            )
            if processed_batch:
                for i in range(processed_batch["pixel_values"].size(0)):
                    yield {"pixel_values": processed_batch["pixel_values"][i], "input_ids": processed_batch["input_ids"][i]}

accelerator.print("Setting up dataset stream...")
#Num di sample passati insieme alla funzione di preprocessing
#Evito di chiamare la stessa funzione piu volte cosi
transform_processing_chunk_size = 1
#Numero di thread che svolgono questo lavoro di merda
num_dataloader_workers = 8

#Bon, metto tutto insieme
train_dataset = StreamingImageTextDataset(
    dataset_id=dataset_id, split="train", 
    transform_fn=preprocess_function, 
    tokenizer_obj=tokenizer, 
    image_transforms_fn=image_transforms,
    processing_chunk_size=transform_processing_chunk_size,
    accelerator_ref=accelerator
)
#Sistemo tutto
train_dataloader = DataLoader(
    train_dataset, batch_size=train_batch_size, 
    num_workers=num_dataloader_workers, 
    pin_memory=True if num_dataloader_workers > 0 else False, # Pin memory if using workers
    persistent_workers=True if num_dataloader_workers > 0 else False
)
#Daje, ci si prepara
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet_to_prepare, optimizer, train_dataloader, lr_scheduler
)

# --- Resume from Checkpoint perche non so se la macchina a meta schioppa ---
global_step, start_epoch, resume_wandb_id = 0, 0, None
resume_wandb_id = None
try:
    global_step, start_epoch, resume_wandb_id = load_latest_checkpoint_custom(
        accelerator, accelerator.unwrap_model(unet), ema_unet, optimizer, lr_scheduler, checkpoint_dir_base
    )
    if global_step > 0:
        accelerator.print(f"Resumed training from step {global_step}, epoch {start_epoch}.")
    else:
        accelerator.print("Starting training from scratch or no compatible checkpoint found.")
except Exception as e:
    accelerator.print(f"Could not load checkpoint. Starting from scratch. Error: {e}")
    global_step, start_epoch, resume_wandb_id = 0, 0, None
ema_unet.to(accelerator.device)

# --- WandB Initialization ---
if use_wandb and accelerator.is_main_process:
    wandb_config = {
        "learning_rate": learning_rate, "train_batch_size": train_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "effective_batch_size": train_batch_size * accelerator.num_processes * gradient_accumulation_steps,
        "image_resolution": image_resolution, "unet_params_M": unet_param_count / 1e6,
        "vae_model_id": vae_model_id, "text_encoder_model_id": text_encoder_model_id,
        "dataset_id": dataset_id, "max_train_steps": max_train_steps,
        "num_warmup_steps": num_warmup_steps, "mixed_precision": mixed_precision,
        "torch_compile_mode": "reduce-overhead" if compiled_unet_object is not None else "None",
        "v_prediction": True, "cfg_drop_probability": cfg_drop_probability, "seed": seed,
    }
    #Resume se ce un checkpoint
    run_id_to_resume = resume_wandb_id if global_step > 0 and resume_wandb_id else None
    
    accelerator.init_trackers(
        project_name=wandb_project_name,
        config=wandb_config,
        init_kwargs={"wandb": {"id": run_id_to_resume, "resume": "allow"}}
    )
    if run_id_to_resume: accelerator.print(f"Resuming WandB run with ID: {run_id_to_resume}")
    elif global_step > 0: accelerator.print("Warning: Resuming training but no wandb_run_id found in checkpoint for WandB.")


# --- Training Loop ---
accelerator.print(f"Starting training. Target steps: {max_train_steps}. Current step: {global_step}. Warmup steps: {num_warmup_steps}.")
progress_bar = tqdm(initial=global_step, total=max_train_steps, desc="Training Steps", disable=not accelerator.is_main_process)
current_epoch_for_tracking = start_epoch # For saving in checkpoint

unet_for_sampling_config = accelerator.unwrap_model(unet).config

# Main training loop
while global_step < max_train_steps:
    unet.train()
    for batch_idx, batch in enumerate(train_dataloader):
        if global_step >= max_train_steps:
            break

        with accelerator.accumulate(unet): #Gradient acc
            pixel_values = batch["pixel_values"]
            input_ids = batch["input_ids"]
            bsz = pixel_values.shape[0]

            with torch.no_grad():
                latents = vae.encode(pixel_values).latent_dist.sample() * vae.config.scaling_factor
                encoder_hidden_states = text_encoder(input_ids)[0]

            # Classifier-Free Guidance
            mask = torch.rand(bsz, device=accelerator.device) < cfg_drop_probability
            if mask.any():
                 encoder_hidden_states[mask] = uncond_embeddings.expand(mask.sum(), -1, -1)

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            #V-prediction invece di image o noise prediction (via di mezzo, dovrebbe funzionare anche se male)
            target_velocity = noise_scheduler.get_velocity(latents, noise, timesteps)

            #Mixed precision
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target_velocity.float(), reduction="mean")

            accelerator.backward(loss)
            if accelerator.sync_gradients: # Only clip when gradients are synced (after accumulation)
                accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        #Usa param della Unet originale
        if accelerator.sync_gradients: #Update EMA unet
            ema_unet.step(accelerator.unwrap_model(unet).parameters())

        if accelerator.sync_gradients: #Progress barrrrrrr
            progress_bar.update(1)
            global_step += 1
            
            if global_step % 100 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                log_data = {"train_loss": loss.item(), "learning_rate": current_lr, "global_step": global_step}
                accelerator.log(log_data, step=global_step)
                progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{current_lr:.2e}"})

                #Salva
            if global_step > 0 and global_step % save_every_steps == 0 and save_every_steps > 0:
                if accelerator.is_main_process:
                    save_checkpoint_custom(
                        accelerator,
                        accelerator.unwrap_model(unet), #original model
                        ema_unet, #EMA model instance
                        optimizer, #prepared optimizer
                        lr_scheduler, #prepared scheduler
                        global_step, #step a cui siamo arrivati
                        current_epoch_for_tracking, #book-keeping, non tanto necessario
                        checkpoint_dir_base,
                        max_keep=max_keep_checkpoints
                    )
                    
                    #Genera qualche img per i mie occhietti curiosi
                    unet_ema_sample_model = UNet2DConditionModel.from_config(unet_for_sampling_config).to(accelerator.device)
                    ema_unet.copy_to(unet_ema_sample_model.parameters()) # Copy EMA params to this new model
                    
                    generate_samples(
                        unet_ema_sample_model, vae, text_encoder, tokenizer,
                        noise_scheduler, sampling_prompts, output_dir_base,
                        global_step, accelerator.device, accelerator
                    )
                    del unet_ema_sample_model #Libera mem
                    if torch.cuda.is_available(): torch.cuda.empty_cache()
            
            if global_step >= max_train_steps:
                break
    
    #Esce quando gli step sono stati fatti
    if global_step >= max_train_steps:
        break


progress_bar.close()
accelerator.print("Training finished.")

# --- Salva modello finale ---
if accelerator.is_main_process:
    accelerator.print(f"Saving final model and components to {final_model_dir}")
    os.makedirs(final_model_dir, exist_ok=True)
    unwrapped_unet = accelerator.unwrap_model(unet)
    unwrapped_unet.save_pretrained(os.path.join(final_model_dir, "unet"))
    final_ema_unet_model = UNet2DConditionModel.from_config(unwrapped_unet.config).to(accelerator.device)
    ema_unet.copy_to(final_ema_unet_model.parameters())
    final_ema_unet_model.save_pretrained(os.path.join(final_model_dir, "ema_unet"))
    del final_ema_unet_model
    vae.save_pretrained(os.path.join(final_model_dir, "vae"))
    text_encoder.save_pretrained(os.path.join(final_model_dir, "text_encoder"))
    tokenizer.save_pretrained(os.path.join(final_model_dir, "tokenizer"))
    noise_scheduler.save_config(os.path.join(final_model_dir, "scheduler")) # Saves config.json
    accelerator.save_state(os.path.join(final_model_dir, "accelerator_state"))
    torch.save(ema_unet.state_dict(), os.path.join(final_model_dir, "ema_unet_final_state.pth"))
    final_custom_state = {
        'global_step': global_step, 'epoch': current_epoch_for_tracking,
        'wandb_run_id': wandb.run.id if use_wandb and wandb.run else None
    }
    torch.save(final_custom_state, os.path.join(final_model_dir, "final_custom_training_state.pt"))
    if use_wandb and wandb.run:
        try:
            final_model_artifact = wandb.Artifact(
                name=f"{wandb_project_name.lower().replace(' ', '_')}-final_model",
                type="model",
                description=f"Final trained diffusion model components at step {global_step}.",
                metadata=wandb_config
            )
            final_model_artifact.add_dir(final_model_dir)
            wandb.log_artifact(final_model_artifact)
            accelerator.print("Final model saved as WandB artifact.")
        except Exception as e:
            accelerator.print(f"Failed to save model as WandB artifact: {e}")


accelerator.end_training()
accelerator.print("All components saved. Training script complete.")

In [None]:
#Mega test se la pipeline e il problema...
#Non e il problema...
"""
import os
import io
import glob
import shutil
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import math
from torch.utils.data import DataLoader, IterableDataset
from torchvision import transforms
from PIL import Image, ImageOps
from tqdm.auto import tqdm
import wandb
import xformers # Added for XFORMERS

# Hugging Face Libraries
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import ProjectConfiguration, set_seed, LoggerType

# --- Configuration (Hardcoded for simplicity) ---
use_dummy_data = True  # <<< SET THIS TO True TO USE DUMMY DATA FOR BOTTLENECK TESTING
# If True, the script will bypass dataset loading and use randomly generated tensors.

vae_model_id = "stabilityai/sd-vae-ft-ema"
text_encoder_model_id = "openai/clip-vit-large-patch14"
dataset_id = "BLIP3o/BLIP3o-Pretrain-Short-Caption" # Used only if use_dummy_data is False

# Training Hyperparameters
image_resolution = 512
latent_resolution = image_resolution // 8
train_batch_size = 16
gradient_accumulation_steps = 8
learning_rate = 1e-4
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-2
adam_epsilon = 1e-08
max_grad_norm = 1.0
num_train_timesteps_scheduler = 1000
mixed_precision = "bf16"
seed = 42
cfg_drop_probability = 0.1
max_train_steps = 30_000
save_every_steps = max_train_steps // 100
num_warmup_steps_ratio = 0.05
num_warmup_steps = int(num_warmup_steps_ratio * max_train_steps)
# --- Storage Path ---
persistent_storage_mount_path = "/workspace"
project_data_folder_name = "PozzyDiffusion"
output_dir_base = os.path.join(persistent_storage_mount_path, project_data_folder_name)
logging_dir = os.path.join(output_dir_base, "logs")
checkpoint_dir_base = os.path.join(output_dir_base, "checkpoints")
final_model_dir = os.path.join(output_dir_base, "final_model")
max_keep_checkpoints = 2
# WandB
wandb_project_name = "PozzyDiffusion_A40"
use_wandb = False # Set to False if you want to run dummy data test without WandB logging
# Hardcoded prompts per sampling
sampling_prompts = [
    "a photo of an astronaut riding a horse on mars",
    "a fantasy landscape with a castle and a dragon",
    "a cyberpunk city at night with neon lights",
    "a serene Japanese garden with cherry blossoms",
    "a whimsical illustration of a cat playing a piano"
]
num_inference_steps_sampling = 50

# --- Accelerator ---
project_config = ProjectConfiguration(project_dir=output_dir_base, logging_dir=logging_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    mixed_precision=mixed_precision,
    log_with="wandb" if use_wandb and not use_dummy_data else None, # Optionally disable wandb for dummy runs
    project_config=project_config,
)
if seed is not None:
    set_seed(seed)
if accelerator.is_main_process:
    os.makedirs(output_dir_base, exist_ok=True)
    os.makedirs(checkpoint_dir_base, exist_ok=True)
    os.makedirs(final_model_dir, exist_ok=True)

# --- Helper Functions ---
def save_checkpoint_custom(accelerator, unet_original, ema_unet, optimizer, lr_scheduler, global_step, epoch, checkpoint_dir_base, filename_prefix="ckpt", max_keep=2):
    if accelerator.is_main_process:
        os.makedirs(checkpoint_dir_base, exist_ok=True)
        checkpoint_name = f"{filename_prefix}_step{global_step}"
        save_path = os.path.join(checkpoint_dir_base, checkpoint_name)
        accelerator.save_state(save_path)
        ema_save_path = os.path.join(save_path, "ema_unet.pth")
        torch.save(ema_unet.state_dict(), ema_save_path)
        custom_state = {
            'global_step': global_step, 'epoch': epoch,
            'wandb_run_id': wandb.run.id if use_wandb and wandb.run and not use_dummy_data else None
        }
        torch.save(custom_state, os.path.join(save_path, "custom_training_state.pt"))
        accelerator.print(f"Saved checkpoint: {save_path}")
        checkpoints = sorted(
            glob.glob(os.path.join(checkpoint_dir_base, f"{filename_prefix}_step*")),
            key=lambda x: int(os.path.basename(x).split('step')[-1])
        )
        if len(checkpoints) > max_keep:
            for old_ckpt_path in checkpoints[:-max_keep]:
                accelerator.print(f"Removing old checkpoint: {old_ckpt_path}")
                shutil.rmtree(old_ckpt_path)

def load_latest_checkpoint_custom(accelerator, unet_original, ema_unet, optimizer, lr_scheduler, checkpoint_dir_base, filename_prefix="ckpt"):
    checkpoints = sorted(
        glob.glob(os.path.join(checkpoint_dir_base, f"{filename_prefix}_step*")),
        key=lambda x: int(os.path.basename(x).split('step')[-1]),
        reverse=True
    )
    if checkpoints:
        latest_checkpoint_path = checkpoints[0]
        accelerator.print(f"Found checkpoint: {latest_checkpoint_path}")
        try:
            accelerator.load_state(latest_checkpoint_path)
            ema_path = os.path.join(latest_checkpoint_path, "ema_unet.pth")
            if os.path.exists(ema_path):
                ema_unet.load_state_dict(torch.load(ema_path, map_location="cpu"))
                accelerator.print("EMA UNet state loaded.")
            custom_state_path = os.path.join(latest_checkpoint_path, "custom_training_state.pt")
            if os.path.exists(custom_state_path):
                custom_state = torch.load(custom_state_path, map_location="cpu")
                return custom_state.get('global_step', 0), custom_state.get('epoch', 0), custom_state.get('wandb_run_id')
            else:
                parsed_step = int(os.path.basename(latest_checkpoint_path).split('step')[-1])
                return parsed_step, 0, None
        except Exception as e:
            accelerator.print(f"Error loading checkpoint {latest_checkpoint_path}: {e}. Starting fresh.")
            return 0, 0, None
    return 0, 0, None

def generate_samples(unet_model_for_sampling, vae, text_encoder, tokenizer_obj, noise_scheduler_obj, prompts, output_dir, global_step, device, accelerator_ref):
    # This function might be less meaningful with dummy data, but can still run.
    # Consider skipping if use_dummy_data is True and samples are not needed for the test.
    if use_dummy_data and accelerator_ref.is_main_process:
        accelerator_ref.print("Skipping sample generation during dummy data run.")
        return

    unet_model_for_sampling.eval()
    vae.eval()
    text_encoder.eval()
    with torch.no_grad():
        for i, prompt in enumerate(prompts):
            text_inputs = tokenizer_obj(prompt, padding="max_length", max_length=tokenizer_obj.model_max_length, truncation=True, return_tensors="pt")
            text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
            latents_shape = (1, unet_model_for_sampling.config.in_channels, latent_resolution, latent_resolution)
            latents = torch.randn(latents_shape, device=device, generator=torch.manual_seed(seed + i) if seed is not None else None)
            noise_scheduler_obj.set_timesteps(num_inference_steps_sampling)
            for t in tqdm(noise_scheduler_obj.timesteps, desc=f"Sampling for prompt {i+1}", disable=not accelerator_ref.is_main_process):
                latent_model_input = noise_scheduler_obj.scale_model_input(latents, t)
                noise_pred = unet_model_for_sampling(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                latents = noise_scheduler_obj.step(noise_pred, t, latents).prev_sample
            latents = 1 / vae.config.scaling_factor * latents
            image = vae.decode(latents).sample
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
            image = Image.fromarray((image * 255).round().astype("uint8"))
            if accelerator_ref.is_main_process:
                sample_output_dir = os.path.join(output_dir, "samples", f"step_{global_step}")
                os.makedirs(sample_output_dir, exist_ok=True)
                img_path = os.path.join(sample_output_dir, f"prompt_{i+1}_seed{seed+i if seed is not None else 'rand'}.png")
                image.save(img_path)
                accelerator_ref.print(f"Generated sample for prompt '{prompt}' saved to {img_path}")
    unet_model_for_sampling.train()


# --- Dummy Data Iterable Dataset ---
class DummyIterableDataset(IterableDataset):
    def __init__(self, image_res, tokenizer_max_len, pixel_dtype=torch.float32, input_id_dtype=torch.long):
        super().__init__()
        self.image_resolution = image_res
        self.tokenizer_max_length = tokenizer_max_len
        self.pixel_dtype = pixel_dtype
        self.input_id_dtype = input_id_dtype

    def __iter__(self):
        while True:
            dummy_pixel_values = torch.randn(
                3, self.image_resolution, self.image_resolution, dtype=self.pixel_dtype
            )
            dummy_input_ids = torch.ones(
                self.tokenizer_max_length, dtype=self.input_id_dtype
            )
            # You could use torch.randint for more varied input_ids if needed:
            # dummy_input_ids = torch.randint(0, 30000, (self.tokenizer_max_length,), dtype=self.input_id_dtype) # Assumes vocab_size > 30000
            yield {"pixel_values": dummy_pixel_values, "input_ids": dummy_input_ids}

# --- Inizializza i modelli ---
accelerator.print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(vae_model_id)
accelerator.print("Loading Text Encoder & Tokenizer...")
tokenizer = CLIPTokenizer.from_pretrained(text_encoder_model_id)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_model_id)
accelerator.print("Initializing UNet...")
original_unet = UNet2DConditionModel(
    sample_size=latent_resolution, in_channels=vae.config.latent_channels, out_channels=vae.config.latent_channels,
    layers_per_block=2, block_out_channels=(256, 512, 768, 1024),
    down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
    cross_attention_dim=text_encoder.config.hidden_size,
)
unet_param_count = sum(p.numel() for p in original_unet.parameters() if p.requires_grad)
accelerator.print(f"UNet initialized with {unet_param_count / 1e6:.2f}M parameters.")
ema_unet = EMAModel(original_unet.parameters(), model_cls=UNet2DConditionModel, model_config=original_unet.config)
ema_unet.to(accelerator.device)
vae.to(accelerator.device).eval().requires_grad_(False)
text_encoder.to(accelerator.device).eval().requires_grad_(False)

with torch.no_grad():
    uncond_tokens = tokenizer([""], padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
    uncond_embeddings = text_encoder(uncond_tokens.to(accelerator.device))[0]

if xformers.version:
    try:
        original_unet.enable_xformers_memory_efficient_attention()
        accelerator.print("XFormers memory efficient attention enabled for UNet.")
    except Exception as e:
        accelerator.print(f"Could not enable xformers memory efficient attention: {e}")
else:
    accelerator.print("XFormers not available or not installed. Skipping memory efficient attention.")
compiled_unet_object = None
accelerator.print("Skipping UNet compilation to avoid potential CUDAGraph issues.")
unet_to_prepare = original_unet

# --- Noise Scheduler ---
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_train_timesteps_scheduler,
    beta_schedule="scaled_linear",
    prediction_type="v_prediction"
)

# --- Optimizer ---
optimizer = AdamW(
    unet_to_prepare.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay, eps=adam_epsilon,
)

# --- Learning Rate Scheduler ---
def lr_lambda_cosine(current_step: int):
    if num_warmup_steps > 0 and current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    denominator = float(max(1, max_train_steps - num_warmup_steps))
    progress = float(current_step - num_warmup_steps) / denominator
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr_scheduler = LambdaLR(optimizer, lr_lambda_cosine)


# --- Sistemo img (Real Data Path) ---
image_transforms = transforms.Compose([
    transforms.Resize((image_resolution, image_resolution), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(), # Converts to [0,1] range, float32
    transforms.Normalize([0.5], [0.5]), # Normalize to [-1, 1]
])
def preprocess_function(examples, tokenizer_obj, image_transforms_fn, accelerator_ref):
    prompts, pil_images = examples["txt"], examples["jpg"]
    processed_images, valid_prompts = [], []
    for idx, (img_pil, prompt_text) in enumerate(zip(pil_images, prompts)):
        try:
            if img_pil is None:
                if accelerator_ref.is_main_process: accelerator_ref.print(f"Warning: Found None image for prompt: {prompt_text}. Skipping.")
                continue
            img = img_pil.convert("RGB")
            img = ImageOps.exif_transpose(img)
            processed_images.append(image_transforms_fn(img))
            valid_prompts.append(prompt_text)
        except Exception as e:
            if accelerator_ref.is_main_process:
                accelerator_ref.print(f"Warning: Skipping an image/prompt due to error: {e}. Prompt: '{prompt_text}'. Image index in batch: {idx}")
            continue
    if not processed_images:
        return None
    text_inputs = tokenizer_obj(valid_prompts, padding="max_length", max_length=tokenizer_obj.model_max_length, truncation=True, return_tensors="pt")
    return {"pixel_values": torch.stack(processed_images), "input_ids": text_inputs.input_ids}

class StreamingImageTextDataset(IterableDataset):
    def __init__(self, dataset_id_str, split, transform_fn, tokenizer_obj, image_transforms_fn, processing_chunk_size, accelerator_ref):
        self.dataset = load_dataset(dataset_id_str, split=split, streaming=True)
        self.transform_fn = transform_fn
        self.tokenizer_obj = tokenizer_obj
        self.image_transforms_fn = image_transforms_fn
        self.processing_chunk_size = processing_chunk_size
        self.accelerator_ref = accelerator_ref

    def __iter__(self):
        buffer = []
        for example in self.dataset:
            if example.get("jpg") is not None and example.get("txt") is not None:
                buffer.append({"jpg": example["jpg"], "txt": example["txt"]})
            else:
                if self.accelerator_ref.is_main_process:
                    self.accelerator_ref.print(f"Warning: Skipping example due to missing 'jpg' or 'txt' field: {example.get('txt', 'N/A')}")
                continue
            if len(buffer) == self.processing_chunk_size:
                processed_batch = self.transform_fn(
                    {"jpg": [i["jpg"] for i in buffer], "txt": [i["txt"] for i in buffer]},
                    self.tokenizer_obj, self.image_transforms_fn, self.accelerator_ref
                )
                if processed_batch:
                    for i in range(processed_batch["pixel_values"].size(0)):
                        yield {"pixel_values": processed_batch["pixel_values"][i], "input_ids": processed_batch["input_ids"][i]}
                buffer = []
        if buffer:
            processed_batch = self.transform_fn(
                {"jpg": [i["jpg"] for i in buffer], "txt": [i["txt"] for i in buffer]},
                self.tokenizer_obj, self.image_transforms_fn, self.accelerator_ref
            )
            if processed_batch:
                for i in range(processed_batch["pixel_values"].size(0)):
                    yield {"pixel_values": processed_batch["pixel_values"][i], "input_ids": processed_batch["input_ids"][i]}

# --- DATALOADER SETUP (REAL OR DUMMY) ---
transform_processing_chunk_size = 1 # Used only for real data
num_dataloader_workers = 8

if use_dummy_data:
    accelerator.print("INFO: Using DUMMY DATA for training to test GPU throughput.")
    train_dataset = DummyIterableDataset(
        image_res=image_resolution,
        tokenizer_max_len=tokenizer.model_max_length, # Ensure tokenizer is loaded before this
        pixel_dtype=torch.float32 # Matches output of image_transforms
    )
else:
    accelerator.print("Setting up real dataset stream...")
    train_dataset = StreamingImageTextDataset(
        dataset_id_str=dataset_id, split="train",
        transform_fn=preprocess_function,
        tokenizer_obj=tokenizer,
        image_transforms_fn=image_transforms,
        processing_chunk_size=transform_processing_chunk_size,
        accelerator_ref=accelerator
    )

train_dataloader = DataLoader(
    train_dataset, batch_size=train_batch_size,
    num_workers=num_dataloader_workers,
    pin_memory=True if num_dataloader_workers > 0 else False,
    persistent_workers=True if num_dataloader_workers > 0 else False
)

unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet_to_prepare, optimizer, train_dataloader, lr_scheduler
)

# --- Resume from Checkpoint ---
global_step, start_epoch, resume_wandb_id = 0, 0, None
if not use_dummy_data: # Checkpoints are usually for real data runs
    try:
        global_step, start_epoch, resume_wandb_id = load_latest_checkpoint_custom(
            accelerator, accelerator.unwrap_model(unet), ema_unet, optimizer, lr_scheduler, checkpoint_dir_base
        )
        if global_step > 0:
            accelerator.print(f"Resumed training from step {global_step}, epoch {start_epoch}.")
        else:
            accelerator.print("Starting training from scratch or no compatible checkpoint found.")
    except Exception as e:
        accelerator.print(f"Could not load checkpoint. Starting from scratch. Error: {e}")
        global_step, start_epoch, resume_wandb_id = 0, 0, None
else:
    accelerator.print("INFO: Dummy data run. Checkpoint loading skipped. Starting from step 0.")


# --- WandB Initialization ---
if use_wandb and accelerator.is_main_process and not use_dummy_data: # Optionally skip wandb for dummy data
    wandb_config = {
        "learning_rate": learning_rate, "train_batch_size": train_batch_size,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "effective_batch_size": train_batch_size * accelerator.num_processes * gradient_accumulation_steps,
        "image_resolution": image_resolution, "unet_params_M": unet_param_count / 1e6,
        "vae_model_id": vae_model_id, "text_encoder_model_id": text_encoder_model_id,
        "dataset_id": dataset_id, "max_train_steps": max_train_steps,
        "num_warmup_steps": num_warmup_steps, "mixed_precision": mixed_precision,
        "torch_compile_mode": "reduce-overhead" if compiled_unet_object is not None else "None",
        "v_prediction": True, "cfg_drop_probability": cfg_drop_probability, "seed": seed,
        "using_dummy_data": use_dummy_data
    }
    run_id_to_resume = resume_wandb_id if global_step > 0 and resume_wandb_id else None
    accelerator.init_trackers(
        project_name=wandb_project_name,
        config=wandb_config,
        init_kwargs={"wandb": {"id": run_id_to_resume, "resume": "allow", "mode": "online" if not use_dummy_data else "disabled"}}
    )
    if run_id_to_resume: accelerator.print(f"Resuming WandB run with ID: {run_id_to_resume}")
    elif global_step > 0: accelerator.print("Warning: Resuming training but no wandb_run_id found in checkpoint for WandB.")
elif use_dummy_data and accelerator.is_main_process:
    accelerator.print("INFO: Dummy data run. WandB logging is disabled or limited.")


# --- Training Loop ---
accelerator.print(f"Starting training. Target steps: {max_train_steps}. Current step: {global_step}. Warmup steps: {num_warmup_steps}.")
progress_bar = tqdm(initial=global_step, total=max_train_steps, desc="Training Steps", disable=not accelerator.is_main_process)
current_epoch_for_tracking = start_epoch
unet_for_sampling_config = accelerator.unwrap_model(unet).config

while global_step < max_train_steps:
    unet.train()
    for batch_idx, batch in enumerate(train_dataloader):
        if global_step >= max_train_steps:
            break

        with accelerator.accumulate(unet):
            # Data is expected to be on CPU from DataLoader, needs to be moved to GPU
            # For VAE and TextEncoder, which are on accelerator.device
            pixel_values = batch["pixel_values"].to(accelerator.device)
            input_ids = batch["input_ids"].to(accelerator.device)
            bsz = pixel_values.shape[0]

            with torch.no_grad():
                latents = vae.encode(pixel_values).latent_dist.sample() * vae.config.scaling_factor
                encoder_hidden_states = text_encoder(input_ids)[0]

            mask = torch.rand(bsz, device=accelerator.device) < cfg_drop_probability
            if mask.any():
                 encoder_hidden_states[mask] = uncond_embeddings.expand(mask.sum(), -1, -1)

            noise = torch.randn_like(latents) # latents are on device
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            target_velocity = noise_scheduler.get_velocity(latents, noise, timesteps)

            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target_velocity.float(), reduction="mean")

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        if accelerator.sync_gradients:
            ema_unet.step(accelerator.unwrap_model(unet).parameters())
            progress_bar.update(1)
            global_step += 1
            
            # Logging (can be simplified or skipped for dummy data runs if desired)
            if global_step % 100 == 0: # Log every 100 steps
                current_lr = optimizer.param_groups[0]['lr']
                log_data = {"train_loss": loss.item(), "learning_rate": current_lr, "global_step": global_step}
                if use_wandb and not use_dummy_data : # Only log to wandb if it's enabled and not dummy data
                    accelerator.log(log_data, step=global_step)
                elif accelerator.is_main_process: # Print to console for dummy data or if wandb is off
                    print(f"Step: {global_step}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}")
                progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{current_lr:.2e}"})

            # Save Checkpoint (usually skipped for dummy data runs)
            if not use_dummy_data and global_step > 0 and global_step % save_every_steps == 0 and save_every_steps > 0:
                if accelerator.is_main_process:
                    save_checkpoint_custom(
                        accelerator, accelerator.unwrap_model(unet), ema_unet,
                        optimizer, lr_scheduler, global_step, current_epoch_for_tracking,
                        checkpoint_dir_base, max_keep=max_keep_checkpoints
                    )
                    unet_ema_sample_model = UNet2DConditionModel.from_config(unet_for_sampling_config).to(accelerator.device)
                    ema_unet.copy_to(unet_ema_sample_model.parameters())
                    generate_samples(
                        unet_ema_sample_model, vae, text_encoder, tokenizer,
                        noise_scheduler, sampling_prompts, output_dir_base,
                        global_step, accelerator.device, accelerator
                    )
                    del unet_ema_sample_model
                    if torch.cuda.is_available(): torch.cuda.empty_cache()
            
            if global_step >= max_train_steps:
                break
    
    current_epoch_for_tracking +=1 # Increment epoch conceptually after iterating through dataloader
    if global_step >= max_train_steps:
        break

progress_bar.close()
accelerator.print("Training finished.")

# --- Salva modello finale (usually skipped for dummy data runs) ---
if accelerator.is_main_process and not use_dummy_data:
    accelerator.print(f"Saving final model and components to {final_model_dir}")
    os.makedirs(final_model_dir, exist_ok=True)
    unwrapped_unet = accelerator.unwrap_model(unet)
    unwrapped_unet.save_pretrained(os.path.join(final_model_dir, "unet"))
    final_ema_unet_model = UNet2DConditionModel.from_config(unwrapped_unet.config).to(accelerator.device)
    ema_unet.copy_to(final_ema_unet_model.parameters())
    final_ema_unet_model.save_pretrained(os.path.join(final_model_dir, "ema_unet"))
    del final_ema_unet_model
    vae.save_pretrained(os.path.join(final_model_dir, "vae"))
    text_encoder.save_pretrained(os.path.join(final_model_dir, "text_encoder"))
    tokenizer.save_pretrained(os.path.join(final_model_dir, "tokenizer"))
    noise_scheduler.save_config(os.path.join(final_model_dir, "scheduler"))
    accelerator.save_state(os.path.join(final_model_dir, "accelerator_state"))
    torch.save(ema_unet.state_dict(), os.path.join(final_model_dir, "ema_unet_final_state.pth"))
    final_custom_state = {
        'global_step': global_step, 'epoch': current_epoch_for_tracking,
        'wandb_run_id': wandb.run.id if use_wandb and wandb.run else None
    }
    torch.save(final_custom_state, os.path.join(final_model_dir, "final_custom_training_state.pt"))
    if use_wandb and wandb.run:
        try:
            final_model_artifact = wandb.Artifact(
                name=f"{wandb_project_name.lower().replace(' ', '_')}-final_model", type="model",
                description=f"Final trained diffusion model components at step {global_step}.",
                metadata=wandb_config # wandb_config might not be fully defined if dummy_data was true and wandb init was skipped
            )
            final_model_artifact.add_dir(final_model_dir)
            wandb.log_artifact(final_model_artifact)
            accelerator.print("Final model saved as WandB artifact.")
        except Exception as e:
            accelerator.print(f"Failed to save model as WandB artifact: {e}")
elif accelerator.is_main_process and use_dummy_data:
    accelerator.print("INFO: Dummy data run. Final model saving skipped.")

accelerator.end_training()
accelerator.print("All components processed. Training script complete.")
"""