In [None]:
# from datasets import load_dataset
# from torchvision import transforms
# from torch.utils.data import Dataset, DataLoader
# from PIL import Image
# import torch
# import random

# MY_TOKEN = "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"
# # HuggingFace dataset
# dataset = load_dataset("nlphuji/flickr30k", split="train[:1%], use_auth_token=MY_TOKEN")  # use a subset to test first

# # Preprocessing
# image_size = 512
# transform = transforms.Compose([
#     transforms.Resize((image_size, image_size)),
#     transforms.ToTensor(),
#     transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] for SD
# ])

# # Use CLIP tokenizer for prompt text
# from transformers import CLIPTokenizer
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# class FlickrDataset(Dataset):
#     def __init__(self, dataset, transform, tokenizer):
#         self.dataset = dataset
#         self.transform = transform
#         self.tokenizer = tokenizer

#     def __len__(self):
#         return len(self.dataset)

#     def __getitem__(self, idx):
#         entry = self.dataset[idx]
#         image = self.transform(entry['image'].convert("RGB"))
#         # caption = entry['caption'][0]
#         caption = random.choice(entry['caption'])
#         tokens = self.tokenizer(caption, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
#         return {
#             "image": image,
#             "input_ids": tokens["input_ids"].squeeze(0),
#             "attention_mask": tokens["attention_mask"].squeeze(0),
#             "text": caption
#         }

# # Create DataLoader
# flickr_dataset = FlickrDataset(dataset, transform, tokenizer)
# dataloader = DataLoader(flickr_dataset, batch_size=2, shuffle=True, num_workers=2)


In [None]:
import torch
from torch.utils.data import DataLoader, Dataset

SAVE_PATH = "flickr30k_preprocessed_1pct_comp.pt"
loaded_data = torch.load(SAVE_PATH)

class CachedFlickrDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Usage
dataset = CachedFlickrDataset(loaded_data)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

for batch in dataloader:
    print(batch['text'])
    break


In [None]:
import torch
torch.cuda.empty_cache()
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Import necessary diffusion components from diffusers and transformers
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel

# Import our custom attention utilities.
from new_attention_utils import (
    register_attention_hooks,
    compute_hallucination_penalty,
    get_last_cross_attention_resized,
)


# -----------------------------
# Configuration & Model Loading
# -----------------------------
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_path = "CompVis/stable-diffusion-v1-4"
batch_size = 32  # For demonstration; adjust as needed

# Load components (with torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")

# For text conditioning, use CLIP's tokenizer and text encoder.
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)

# -----------------------------
# Inject LoRA into UNet
# -----------------------------
from peft import get_peft_model, LoraConfig
lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v"],
    lora_dropout=0.1,
    bias="none"
)
unet = get_peft_model(unet, lora_config)
unet.train()

# Register attention hooks so that our custom utility collects attention maps.
register_attention_hooks(unet)

# -----------------------------
# Create Dataset and DataLoader
# -----------------------------


# -----------------------------
# Helper for Prompt Encoding
# -----------------------------
# Define a projector that maps the attribute vector to the text encoder's hidden dimension.
hidden_dim = 768


# -----------------------------
# Training Loop
# -----------------------------
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
num_train_steps = 10  # For testing; adjust as needed

for epoch in range(num_train_steps):
    print(f"Epoch {epoch+1}")
    for batch in tqdm(dataloader):
        images = batch["image"].to(device, dtype=torch.float16)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Encode images into latents via VAE (no grad through VAE)
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * vae.config.scaling_factor
        
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=device).long()
        
        # Add noise (forward diffusion)
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        noisy_latents.requires_grad_(True)
        
        with torch.no_grad():
            encoder_hidden_states = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        
        # Forward pass through UNet (this call triggers attention hooks)
        model_output = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
        
        # Standard denoising (MSE) loss between predicted noise and true noise.
        loss_mse = F.mse_loss(model_output.float(), noise.float(), reduction='mean')
        
        # Extract and resize attention maps (to a chosen target shape; e.g., 32x32).
        try:
            attn_maps = get_last_cross_attention_resized(target_shape=(32, 32))
        except ValueError as e:
            print(e)
            attn_maps = torch.zeros(1, 1, noisy_latents.shape[-2], noisy_latents.shape[-1]).to(device)
        
        # Compute hallucination loss:
        # According to our formulation, compute A_sum from attention maps, then M_halluc = 1 - clip(A_sum,0,1)
        # and then penalize the predicted features model_output with M_halluc.
        loss_hallu = compute_hallucination_penalty(attn_maps, model_output)
        
        # Total loss = denoising loss + lambda * hallucination loss
        lambda_h = 0.3
        loss = loss_mse + lambda_h * loss_hallu
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tqdm.write(f"Loss: {loss.item():.6f} | MSE: {loss_mse.item():.6f} | Hallu: {loss_hallu.item():.6f}")
    
    # Optionally, reset attention storage here if necessary.
    
# Save the fine-tuned LoRA weights.
save_path = "./lora_finetuned_f3_all"
unet.save_pretrained(save_path)
print(f"LoRA weights saved to {save_path}")


  from .autonotebook import tqdm as notebook_tqdm


Epoch 1


NameError: name 'dataloader' is not defined