In [None]:
!pip install diffusers torch requests pillow tqdm
!pip install pycocotools
!pip install torch-fidelity
!pip install openai-clip
!wget http://images.cocodataset.org/zips/val2017.zip
!unzip -q val2017.zip
!mkdir -p annotations
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip -q annotations_trainval2017.zip -d annotations

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import os
import random
import requests
import torch
from PIL import Image
from tqdm import tqdm
from torch_fidelity import calculate_metrics
from diffusers import PixArtAlphaPipeline
import clip
# COCO paths (set your own path if not in current dir)
COCO_ANN_PATH = "./annotations/captions_val2017.json"
COCO_IMG_DIR = "./val2017"  # Folder containing val2017/*.jpg

N = 5  # Number of random samples to generate

# 1. Download captions file if not exists
if not os.path.exists(COCO_ANN_PATH):
    url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    print("Downloading COCO annotations...")
    r = requests.get(url, stream=True)
    with open("annotations_trainval2017.zip", "wb") as f:
        for chunk in r.iter_content(chunk_size=8192):
            f.write(chunk)
    import zipfile
    with zipfile.ZipFile("annotations_trainval2017.zip", "r") as zip_ref:
        zip_ref.extractall(".")
    os.remove("annotations_trainval2017.zip")

# 2. Load COCO captions
import json
with open(COCO_ANN_PATH, "r") as f:
    coco = json.load(f)

# Build (img_path, caption) pairs
id2filename = {img['id']: img['file_name'] for img in coco['images']}
data = []
for ann in coco['annotations']:
    img_id = ann['image_id']
    caption = ann['caption']
    img_path = os.path.join(COCO_IMG_DIR, id2filename[img_id])
    if os.path.exists(img_path):
        data.append((img_path, caption))
print(f"Loaded {len(data)} (image, caption) pairs from COCO val2017.")

# 3. Pick N random samples
chosen = random.sample(data, N)

def ablate_dit_part(dit, blocks_to_patch, part='ff', mode='zero'):
    """
    Patch a given part ('attn1', 'attn2', 'ff') of DiT blocks in PixArt-α (diffusers).
    part: 'attn1' (self-attn), 'attn2' (cross-attn), 'ff' (MLP)
    mode: 'zero' | 'input' | 'mean'
    """
    for idx, block in enumerate(dit.transformer_blocks):
        if idx not in blocks_to_patch:
            continue
        sub = getattr(block, part)
        orig_forward = sub.forward

        def ablated_forward(self, x, *args, **kwargs):
            if mode == 'zero':
                return torch.zeros_like(x)
            elif mode == 'input':
                return x
            elif mode == 'mean':
                return x.mean(dim=-1, keepdim=True).expand_as(x)
            return orig_forward(x, *args, **kwargs)
        sub.forward = ablated_forward.__get__(sub, sub.__class__)

# -- Setup --
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-512x512",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.to(device)
# print(dir(pipe.transformer))
# block = pipe.transformer.transformer_blocks[0]
# print(block)
# print(dir(block))

# Load CLIP
clip_model, preprocess = clip.load("ViT-B/32", device=device)

output_size = (299, 299)  # For FID/CLIP
os.makedirs("./pixart_samples/real", exist_ok=True)
os.makedirs("./pixart_samples/fake", exist_ok=True)
os.makedirs("./pixart_samples/ablated", exist_ok=True)

# -- Normal Generation (Baseline) --
for i, (img_path, prompt) in enumerate(tqdm(chosen, desc="Generating original")):
    gt_img = Image.open(img_path).convert("RGB").resize(output_size, Image.LANCZOS)
    gt_img.save(f"./pixart_samples/real/gt_{i+1}.jpg")
    gen_img = pipe(prompt).images[0].convert("RGB").resize(output_size, Image.LANCZOS)
    gen_img.save(f"./pixart_samples/fake/gen_{i+1}.jpg")



# Shut down (mean-ablate) layer 1 and 2
# patch_dit_layers(pipe.transformer, layers_to_ablate=[0], mode='input')
# ablate_dit_part(pipe.transformer, blocks_to_patch=[5,8], part='attn1', mode='input')
# ablate_dit_part(pipe.transformer, blocks_to_patch=[11], part='ff', mode='zero')
ablate_dit_part(pipe.transformer, blocks_to_patch=list(range(12)), part='attn2', mode='input')
print(len(pipe.transformer.transformer_blocks))
for i, (img_path, prompt) in enumerate(tqdm(chosen, desc="Generating ablated")):
    ablated_img = pipe(prompt).images[0].convert("RGB").resize(output_size, Image.LANCZOS)
    ablated_img.save(f"./pixart_samples/ablated/ablated_{i+1}.jpg")

# -- FID Calculation --
print("FID (Original vs Real):")
fid_orig = calculate_metrics(
    input1="./pixart_samples/real",
    input2="./pixart_samples/fake",
    cuda=torch.cuda.is_available(),
    isc=False, kid=False, fid=True, verbose=True,
)
print("FID (Ablated vs Real):")
fid_ablate = calculate_metrics(
    input1="./pixart_samples/real",
    input2="./pixart_samples/ablated",
    cuda=torch.cuda.is_available(),
    isc=False, kid=False, fid=True, verbose=True,
)
print(f"Original FID: {fid_orig['frechet_inception_distance']}")
print(f"Ablated FID: {fid_ablate['frechet_inception_distance']}")

# -- CLIP Score Calculation --
def compute_clip_scores(img_dir, captions):
    scores = []
    for i, (_, prompt) in enumerate(captions):
        img = Image.open(f"{img_dir}/gen_{i+1}.jpg").convert("RGB")
        img_tensor = preprocess(img).unsqueeze(0).to(device)
        text = clip.tokenize([prompt]).to(device)
        with torch.no_grad():
            image_features = clip_model.encode_image(img_tensor)
            text_features = clip_model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (image_features @ text_features.T).item()
        scores.append(similarity)
    return scores

print("CLIP Score (Original):")
orig_clip_scores = compute_clip_scores("./pixart_samples/fake", chosen)
print("CLIP Score (Ablated):")
def compute_clip_scores_ablated(img_dir, captions):
    scores = []
    for i, (_, prompt) in enumerate(captions):
        img = Image.open(f"{img_dir}/ablated_{i+1}.jpg").convert("RGB")
        img_tensor = preprocess(img).unsqueeze(0).to(device)
        text = clip.tokenize([prompt]).to(device)
        with torch.no_grad():
            image_features = clip_model.encode_image(img_tensor)
            text_features = clip_model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (image_features @ text_features.T).item()
        scores.append(similarity)
    return scores
ablated_clip_scores = compute_clip_scores_ablated("./pixart_samples/ablated", chosen)

print(f"Mean CLIP (Original): {sum(orig_clip_scores)/len(orig_clip_scores):.4f}")
print(f"Mean CLIP (Ablated): {sum(ablated_clip_scores)/len(ablated_clip_scores):.4f}")


Loaded 25014 (image, caption) pairs from COCO val2017.


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Some weights of the model checkpoint at /root/.cache/huggingface/hub/models--PixArt-alpha--PixArt-XL-2-512x512/snapshots/50f702106901db6d0f8b67eb88e814c56ded2692/transformer were not used when initializing PixArtTransformer2DModel: 
 ['caption_projection.y_embedding']


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Generating original:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating original:  20%|██        | 1/5 [00:01<00:04,  1.06s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating original:  40%|████      | 2/5 [00:02<00:03,  1.05s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating original:  60%|██████    | 3/5 [00:03<00:02,  1.05s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating original:  80%|████████  | 4/5 [00:04<00:01,  1.05s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating original: 100%|██████████| 5/5 [00:05<00:00,  1.05s/it]


28


Generating ablated:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  images = (images * 255).round().astype("uint8")
Generating ablated:  20%|██        | 1/5 [00:00<00:03,  1.09it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating ablated:  40%|████      | 2/5 [00:01<00:02,  1.09it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating ablated:  60%|██████    | 3/5 [00:02<00:01,  1.09it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating ablated:  80%|████████  | 4/5 [00:03<00:00,  1.09it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Generating ablated: 100%|██████████| 5/5 [00:04<00:00,  1.10it/s]
Creating feature extractor "inception-v3-compat" with features ['2048']


FID (Original vs Real):


Extracting statistics from input 1
Looking for samples non-recursivelty in "./pixart_samples/real" with extensions png,jpg,jpeg
Found 5 samples, some are lossy-compressed - this may affect metrics
Processing samples
Extracting statistics from input 2
Looking for samples non-recursivelty in "./pixart_samples/fake" with extensions png,jpg,jpeg
Found 5 samples, some are lossy-compressed - this may affect metrics
Processing samples
Frechet Inception Distance: 308.2989888639742
Creating feature extractor "inception-v3-compat" with features ['2048']


FID (Ablated vs Real):


Extracting statistics from input 1
Looking for samples non-recursivelty in "./pixart_samples/real" with extensions png,jpg,jpeg
Found 5 samples, some are lossy-compressed - this may affect metrics
Processing samples
Extracting statistics from input 2
Looking for samples non-recursivelty in "./pixart_samples/ablated" with extensions png,jpg,jpeg
Found 5 samples, some are lossy-compressed - this may affect metrics
Processing samples
  arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
Frechet Inception Distance: 748.0256444859742


Original FID: 308.2989888639742
Ablated FID: 748.0256444859742
CLIP Score (Original):
CLIP Score (Ablated):
Mean CLIP (Original): 0.3205
Mean CLIP (Ablated): 0.2127


In [None]:
import os
import random
import requests
import torch
from PIL import Image
from tqdm import tqdm
from torch_fidelity import calculate_metrics
from diffusers import PixArtAlphaPipeline
import clip
# COCO paths (set your own path if not in current dir)
COCO_ANN_PATH = "./annotations/captions_val2017.json"
COCO_IMG_DIR = "./val2017"  # Folder containing val2017/*.jpg

def patch_dit_layers(dit, layers_to_ablate, mode='zero'):
    """
    Patch the forward method of the DiT's transformer layers in diffusers PixArt-α.
    """
    for idx, block in enumerate(dit.transformer_blocks):
        orig_forward = block.forward

        def ablated_forward(self, x, *args, **kwargs):
            if idx in layers_to_ablate:
                if mode == 'zero':
                    return torch.zeros_like(x)
                elif mode == 'input':
                    return x
                elif mode == 'mean':
                    # Mean over (batch, spatial, channel)
                    return x.mean(dim=[1, 2], keepdim=True).expand_as(x)
            return orig_forward(x, *args, **kwargs)
        block.forward = ablated_forward.__get__(block, block.__class__)
N = 5  # Number of random samples to generate

# 1. Download captions file if not exists
if not os.path.exists(COCO_ANN_PATH):
    url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    print("Downloading COCO annotations...")
    r = requests.get(url, stream=True)
    with open("annotations_trainval2017.zip", "wb") as f:
        for chunk in r.iter_content(chunk_size=8192):
            f.write(chunk)
    import zipfile
    with zipfile.ZipFile("annotations_trainval2017.zip", "r") as zip_ref:
        zip_ref.extractall(".")
    os.remove("annotations_trainval2017.zip")

# 2. Load COCO captions
import json
with open(COCO_ANN_PATH, "r") as f:
    coco = json.load(f)

# Build (img_path, caption) pairs
id2filename = {img['id']: img['file_name'] for img in coco['images']}
data = []
for ann in coco['annotations']:
    img_id = ann['image_id']
    caption = ann['caption']
    img_path = os.path.join(COCO_IMG_DIR, id2filename[img_id])
    if os.path.exists(img_path):
        data.append((img_path, caption))
print(f"Loaded {len(data)} (image, caption) pairs from COCO val2017.")

# 3. Pick N random samples
chosen = random.sample(data, N)

# 4. Load PixArt-α
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = PixArtAlphaPipeline.from_pretrained(
    "PixArt-alpha/PixArt-XL-2-512x512",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.to(device)

# 5. Generate images
os.makedirs("./pixart_samples/real", exist_ok=True)
os.makedirs("./pixart_samples/fake", exist_ok=True)

output_size = (299, 299)  # for FID

for i, (img_path, prompt) in enumerate(tqdm(chosen, desc="Generating")):
    # Ground truth
    gt_img = Image.open(img_path).convert("RGB")
    gt_img = gt_img.resize(output_size, Image.LANCZOS)
    gt_img.save(f"./pixart_samples/real/gt_{i+1}.jpg")

    # Generated
    gen_img = pipe(prompt).images[0].convert("RGB")
    gen_img = gen_img.resize(output_size, Image.LANCZOS)
    gen_img.save(f"./pixart_samples/fake/gen_{i+1}.jpg")
    print(f"Saved: gt_{i+1}.jpg (ground truth), gen_{i+1}.jpg (generated)")

for folder in ["./pixart_samples/real", "./pixart_samples/fake"]:
    for fname in os.listdir(folder):
        fpath = os.path.join(folder, fname)
        if os.path.getsize(fpath) == 0:
            print(f"Removing zero-byte file: {fpath}")
            os.remove(fpath)

-- FID Calculation --

metrics = calculate_metrics(
    input1="./pixart_samples/real",
    input2="./pixart_samples/fake",
    cuda=torch.cuda.is_available(),
    isc=False,
    kid=False,
    fid=True,
    verbose=True,
)
print(f"FID between generated and ground-truth: {metrics['frechet_inception_distance']}")





In [None]:
import clip

# Load CLIP
model, preprocess = clip.load("ViT-B/32", device=device)

clip_scores = []
for i, (img_path, prompt) in enumerate(chosen):
    # Load generated image
    gen_img = Image.open(f"./pixart_samples/fake/gen_{i+1}.jpg").convert("RGB")
    gen_img = preprocess(gen_img).unsqueeze(0).to(device)
    # Encode prompt
    text = clip.tokenize([prompt]).to(device)

    # CLIP similarity
    with torch.no_grad():
        image_features = model.encode_image(gen_img)
        text_features = model.encode_text(text)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (image_features @ text_features.T).item()  # Cosine similarity

    print(f"Sample {i+1}: CLIP score = {similarity:.4f} | Caption: {prompt}")
    clip_scores.append(similarity)

print(f"\nMean CLIP score over {len(clip_scores)} samples: {sum(clip_scores)/len(clip_scores):.4f}")

100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 35.9MiB/s]


Sample 1: CLIP score = 0.2825 | Caption: A plate of some food on a table.
Sample 2: CLIP score = 0.3010 | Caption: A plate of donuts with a person in the background.
Sample 3: CLIP score = 0.3057 | Caption: A picture of a small-sized kitchen with wood cabinets.
Sample 4: CLIP score = 0.3071 | Caption: a couple of boats sitting on top of a body of water.
Sample 5: CLIP score = 0.3213 | Caption: A very small boy on the beach with a disc.

Mean CLIP score over 5 samples: 0.3035
