In [1]:
!pip install torch torchvision transformers matplotlib
!pip install tensorboard wandb




In [2]:
import torch
from torch import nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Import Hugging Face components
from transformers import CLIPTokenizer, CLIPTextModel  # Example text encoder and tokenizer

# TensorBoard (optional)
from torch.utils.tensorboard import SummaryWriter

# WandB (optional)
import wandb

# Import any additional libraries specific to your diffusion model


In [3]:
from diffusers import StableDiffusionXLImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch


In [4]:
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
policy_unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")


vae/config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/10.3G [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
policy_unet = policy_unet.to(device)
text_encoder = text_encoder.to(device)


In [12]:
class DiffusionEvaluation:
    def __init__(self, tokenizer, text_encoder, vae, policy_unet, noise_scheduler, device):
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.vae = vae
        self.policy_unet = policy_unet
        self.noise_scheduler = noise_scheduler
        self.device = device

        # Lists to store patterns
        self.noise_lst = []
        self.latents_lst = []
        self.attention_pattern_lst = []
        self.attention_pattern_act_lst = []
        self.attention_pattern_residual_lst = []
        self.attention_pattern_copy_lst = []

        # Add a transform for image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
        ])
    def capture_noise_schedule(self, latents, noise, timesteps):
        """ Capture noise patterns added during the diffusion process """
        self.noise_lst.append(noise.detach().cpu().numpy())

    def capture_latents(self, latents):
        """ Capture latent representations after VAE encoding """
        self.latents_lst.append(latents.detach().cpu().numpy())

    def capture_attention(self, layer_name, attention_weights):
        """ Capture attention weights or patterns from the UNet """
        if "act" in layer_name:
            self.attention_pattern_act_lst.append(attention_weights.detach().cpu().numpy())
        elif "residual" in layer_name:
            self.attention_pattern_residual_lst.append(attention_weights.detach().cpu().numpy())
        elif "copy" in layer_name:
            self.attention_pattern_copy_lst.append(attention_weights.detach().cpu().numpy())
        else:
            self.attention_pattern_lst.append(attention_weights.detach().cpu().numpy())

    def forward(self, batch):
        print("\n=== Starting Diffusion Step ===")

        # Process images
        print("Processing images...")
        winners = torch.stack(
            [self.transform(transforms.ToPILImage()(item["winner"])) for item in batch]
        ).to(self.device)
        losers = torch.stack(
            [self.transform(transforms.ToPILImage()(item["loser"])) for item in batch]
        ).to(self.device)
        pixel_values = torch.cat([winners, losers], dim=0)  # [2*batch_size, 3(channel), 512, 512]
        print(f"Combined pixel values shape: {pixel_values.shape}")

        # Process prompts
        print("Processing prompts...")
        tokenized = self.tokenizer(
            [item["prompt"] for item in batch],
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        # Text embeddings
        encoder_hidden_states = self.text_encoder(tokenized.input_ids.to(self.device))[0]
        print(f"Text embeddings shape: {encoder_hidden_states.shape}")

        # Convert images to latent space (VAE encoding)
        print("Converting images to embedding / latent space...")
        with torch.no_grad():
            latents = self.vae.encode(pixel_values).latent_dist.sample()
            latents = latents * self.vae.config.scaling_factor
            print(f"VAE latents shape: {latents.shape}")

        # Capture the latent representations
        self.capture_latents(latents)

        # Sample noise and timesteps
        print("Sampling noise and timesteps...")
        noise = torch.randn_like(latents)
        batch_size = latents.shape[0] // 2  # Because we concatenated winners and losers
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (batch_size,), device=self.device).repeat(2)

        # Capture noise patterns
        self.capture_noise_schedule(latents, noise, timesteps)

        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        print(f"Noisy latents shape: {noisy_latents.shape}")

        # Get predictions from policy UNet
        print("Getting predictions from policy UNet...")
        repeated_embeddings = encoder_hidden_states.repeat(2, 1, 1)
        policy_pred = self.policy_unet(noisy_latents, timesteps, repeated_embeddings).sample
        print(f"Policy model prediction shape: {policy_pred.shape}")

        # Capture attention patterns from policy UNet (example hook)
        self.capture_attention("attention_layer_act", policy_pred)

        return policy_pred

    def forward(self, batch):
        """Run a forward pass using the policy UNet on the CPU."""
        print("\n=== Starting Diffusion Step ===")
        print("Processing images...")
        winners = torch.stack(
            [self.transform(transforms.ToPILImage()(item["winner"])) for item in batch]
        ).to("cpu")  # Use CPU here
        losers = torch.stack(
            [self.transform(transforms.ToPILImage()(item["loser"])) for item in batch]
        ).to("cpu")  # Use CPU here
        pixel_values = torch.cat([winners, losers], dim=0)  # [2*batch_size, 3, 512, 512]
        print(f"Combined pixel values shape: {pixel_values.shape}")

        print("Processing prompts...")
        prompts = [item["prompt"] for item in batch]
        text_inputs = self.tokenizer(
            prompts, padding="max_length", max_length=77, return_tensors="pt"
        )
        text_embeddings = self.text_encoder(text_inputs.input_ids.to("cpu"))[0]  # Use CPU here
        print(f"Text embeddings shape: {text_embeddings.shape}")

        print("Converting images to embedding / latent space...")
        latents = self.vae.encode(pixel_values).latent_dist.sample().detach()
        latents = latents * 0.18215  # Scale factor for SD latents
        print(f"VAE latents shape: {latents.shape}")

        print("Sampling noise and timesteps...")
        noise = torch.randn_like(latents).to("cpu")  # Use CPU here
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device="cpu"
        ).long()  # Use CPU here
        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
        print(f"Noisy latents shape: {noisy_latents.shape}")

        print("Getting predictions from policy UNet...")
        repeated_embeddings = text_embeddings.repeat(2, 1, 1)
        text_embeds = repeated_embeddings  # Use text embeddings for added_cond_kwargs
        added_cond_kwargs = {"text_embeds": text_embeds}

        # Call UNet with the necessary additional kwargs
        policy_pred = self.policy_unet(
            noisy_latents, timesteps, repeated_embeddings, added_cond_kwargs=added_cond_kwargs
        ).sample
        print(f"Policy model prediction shape: {policy_pred.shape}")

        return policy_pred


    def evaluate(self, batch):
        """ Evaluate the model for a given batch """
        print("\n=== Starting Evaluation ===")
        policy_pred = self.forward(batch)

        # Return all captured patterns for visualization
        return {
            "noise_lst": self.noise_lst,
            "latents_lst": self.latents_lst,
            "attention_pattern_lst": self.attention_pattern_lst,
            "attention_pattern_act_lst": self.attention_pattern_act_lst,
            "attention_pattern_residual_lst": self.attention_pattern_residual_lst,
            "attention_pattern_copy_lst": self.attention_pattern_copy_lst,
        }



In [10]:
from PIL import Image
import numpy as np

# Prepare a batch for evaluation
transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])

# Convert the tensor to a PIL image
random_tensor = torch.randn(3, 512, 512)  # Replace with real image data
random_numpy = random_tensor.permute(1, 2, 0).numpy()  # Convert to HWC format
random_pil = Image.fromarray((random_numpy * 255).astype(np.uint8))  # Convert to PIL Image

winner_image_tensor = transform(random_pil)  # Apply transform pipeline
loser_image_tensor = transform(random_pil)   # Apply transform pipeline


In [13]:
# Initialize the DiffusionEvaluation class with SDXL components
evaluation = DiffusionEvaluation(
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    vae=vae,
    policy_unet=policy_unet,
    noise_scheduler=noise_scheduler,
    device=device
)
batch = [
    {
        "prompt": "A high-quality realistic depiction of a fantasy cityscape.",
        "winner": winner_image_tensor,
        "loser": loser_image_tensor,
    },
    # Add more items to the batch if needed
]

# Run evaluation
evaluation_results = evaluation.evaluate(batch)

# Visualize the results
for i, noise in enumerate(evaluation_results['noise_lst']):
    plt.figure(figsize=(5, 5))
    plt.imshow(noise[0, 0, :, :].detach().cpu().numpy(), cmap='gray')
    plt.title(f"Noise Pattern at Timestep {i}")
    plt.colorbar()
    plt.show()



=== Starting Evaluation ===

=== Starting Diffusion Step ===
Processing images...
Combined pixel values shape: torch.Size([2, 3, 512, 512])
Processing prompts...


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)