In [None]:
!pip install diffusers torch requests pillow tqdm
!pip install pycocotools
!pip install torch-fidelity
!pip install openai-clip
!wget http://images.cocodataset.org/zips/val2017.zip
!unzip -q val2017.zip
!mkdir -p annotations
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip -q annotations_trainval2017.zip -d annotations

In [None]:
import os
import random
import requests
import torch
from PIL import Image
from tqdm import tqdm
from torch_fidelity import calculate_metrics
from diffusers import PixArtAlphaPipeline
import clip
import json

# COCO paths
COCO_ANN_PATH = "annotations/annotations/captions_val2017.json"
COCO_IMG_DIR = "./val2017"
N = 5  # Number of random samples

def ablate_dit_part(dit, blocks_to_patch, part='ff', mode='zero'):
    """
    Patch a given part ('attn1', 'attn2', 'ff') of DiT blocks in PixArt-α.

    Args:
        dit: The transformer model
        blocks_to_patch: List of block indices to patch
        part: 'attn1' (self-attn), 'attn2' (cross-attn), 'ff' (MLP)
        mode: 'zero' | 'input' | 'mean'
    """
    for idx in blocks_to_patch:
        if idx >= len(dit.transformer_blocks):
            print(f"Warning: Block index {idx} exceeds available blocks ({len(dit.transformer_blocks)})")
            continue

        block = dit.transformer_blocks[idx]
        sub_module = getattr(block, part, None)

        if sub_module is None:
            print(f"Warning: Part '{part}' not found in block {idx}")
            continue

        # Store original forward method
        if not hasattr(sub_module, '_original_forward'):
            sub_module._original_forward = sub_module.forward

        def create_ablated_forward(module, ablation_mode):
            def ablated_forward(x, *args, **kwargs):
                if ablation_mode == 'zero':
                    return torch.zeros_like(x)
                elif ablation_mode == 'input':
                    return x
                elif ablation_mode == 'mean':
                    return x.mean(dim=-1, keepdim=True).expand_as(x)
                else:
                    return module._original_forward(x, *args, **kwargs)
            return ablated_forward

        # Apply the ablation
        sub_module.forward = create_ablated_forward(sub_module, mode)
        print(f"Ablated block {idx}, part '{part}' with mode '{mode}'")

def restore_dit_part(dit, blocks_to_restore, part='ff'):
    """
    Restore the original forward method for specified blocks and parts.
    """
    for idx in blocks_to_restore:
        if idx >= len(dit.transformer_blocks):
            continue

        block = dit.transformer_blocks[idx]
        sub_module = getattr(block, part, None)

        if sub_module is not None and hasattr(sub_module, '_original_forward'):
            sub_module.forward = sub_module._original_forward
            delattr(sub_module, '_original_forward')
            print(f"Restored block {idx}, part '{part}'")

# Load COCO data
with open(COCO_ANN_PATH, "r") as f:
    coco = json.load(f)

id2filename = {img['id']: img['file_name'] for img in coco['images']}
data = []
for ann in coco['annotations']:
    img_id = ann['image_id']
    caption = ann['caption']
    img_path = os.path.join(COCO_IMG_DIR, id2filename[img_id])
    if os.path.exists(img_path):
        data.append((img_path, caption))

print(f"Loaded {len(data)} (image, caption) pairs from COCO val2017.")

# Pick random samples with fixed seed
random.seed(42)  # Set seed for reproducibility
chosen = random.sample(data, N)

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-512x512",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.to(device)

# Load CLIP
clip_model, preprocess = clip.load("ViT-B/32", device=device)

output_size = (299, 299)
os.makedirs("./pixart_samples/real", exist_ok=True)
os.makedirs("./pixart_samples/fake", exist_ok=True)
os.makedirs("./pixart_samples/ablated", exist_ok=True)

# Generate original images
print("Generating original images...")
for i, (img_path, prompt) in enumerate(tqdm(chosen, desc="Generating original")):
    gt_img = Image.open(img_path).convert("RGB").resize(output_size, Image.LANCZOS)
    gt_img.save(f"./pixart_samples/real/gt_{i+1}.jpg")

    gen_img = pipe(prompt).images[0].convert("RGB").resize(output_size, Image.LANCZOS)
    gen_img.save(f"./pixart_samples/fake/gen_{i+1}.jpg")

# Apply ablation
print(f"Total transformer blocks: {len(pipe.transformer.transformer_blocks)}")

# Example ablations - modify as needed:
# ablate_dit_part(pipe.transformer, blocks_to_patch=[0], part='ff', mode='zero')
# ablate_dit_part(pipe.transformer, blocks_to_patch=[1, 2], part='attn1', mode='input')
ablate_dit_part(pipe.transformer, blocks_to_patch=[0], part='ff', mode='mean')

# Generate ablated images
print("Generating ablated images...")
for i, (img_path, prompt) in enumerate(tqdm(chosen, desc="Generating ablated")):
    ablated_img = pipe(prompt).images[0].convert("RGB").resize(output_size, Image.LANCZOS)
    ablated_img.save(f"./pixart_samples/ablated/ablated_{i+1}.jpg")

# Calculate FID
print("Calculating FID scores...")
print("FID (Original vs Real):")
fid_orig = calculate_metrics(
    input1="./pixart_samples/real",
    input2="./pixart_samples/fake",
    cuda=torch.cuda.is_available(),
    isc=False, kid=False, fid=True, verbose=True,
)

print("FID (Ablated vs Real):")
fid_ablate = calculate_metrics(
    input1="./pixart_samples/real",
    input2="./pixart_samples/ablated",
    cuda=torch.cuda.is_available(),
    isc=False, kid=False, fid=True, verbose=True,
)

print(f"Original FID: {fid_orig['frechet_inception_distance']:.4f}")
print(f"Ablated FID: {fid_ablate['frechet_inception_distance']:.4f}")

# Calculate CLIP scores
def compute_clip_scores(img_dir, captions, prefix="gen"):
    scores = []
    for i, (_, prompt) in enumerate(captions):
        img = Image.open(f"{img_dir}/{prefix}_{i+1}.jpg").convert("RGB")
        img_tensor = preprocess(img).unsqueeze(0).to(device)
        text = clip.tokenize([prompt]).to(device)

        with torch.no_grad():
            image_features = clip_model.encode_image(img_tensor)
            text_features = clip_model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (image_features @ text_features.T).item()
        scores.append(similarity)
    return scores

print("Calculating CLIP scores...")
orig_clip_scores = compute_clip_scores("./pixart_samples/fake", chosen, "gen")
ablated_clip_scores = compute_clip_scores("./pixart_samples/ablated", chosen, "ablated")

print(f"Mean CLIP (Original): {sum(orig_clip_scores)/len(orig_clip_scores):.4f}")
print(f"Mean CLIP (Ablated): {sum(ablated_clip_scores)/len(ablated_clip_scores):.4f}")





