In [None]:
!pip install transformers diffusers lpips accelerate
!pip install torch torchvision
!pip install diffusers[torch]
!pip install tqdm pillow matplotlib ipython numpy

In [None]:
import os
import gc
import torch
from torch import autocast
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from diffusers.models.attention import BasicTransformerBlock
import torch.nn.utils.prune as prune
from accelerate import cpu_offload_with_hook
from lpips import LPIPS
from PIL import Image
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
from transformers.utils import logging
logging.set_verbosity_info()


In [None]:
use_amp = torch.cuda.is_available()
if use_amp:
    print("AMP enabled.")

In [None]:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {torch_device}")

from transformers.utils import logging
logging.set_verbosity_info()

In [None]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=True)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True)
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)


In [None]:
unet.enable_gradient_checkpointing()

In [None]:
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

In [None]:
if torch_device == "cpu":
    cpu_offload_with_hook(vae, execution_device=torch.device("cpu"))

In [None]:
def prune_unet_attention(unet_model):
    for name, module in unet_model.named_modules():
        if isinstance(module, BasicTransformerBlock):
            for param_name, _ in module.named_parameters():
                try:
                    prune.random_unstructured(module, name=param_name, amount=0.2)
                except:
                    pass

prune_unet_attention(unet)

# Optional attention map hook
attention_maps = []
def save_attention_hook(module, input, output):
    if hasattr(output, 'attn_probs'):
        attention_maps.append(output.attn_probs.detach().cpu())

for name, module in unet.named_modules():
    if isinstance(module, BasicTransformerBlock):
        module.register_forward_hook(save_attention_hook)
        print(f"Hooked attention at layer: {name}")
        break


In [None]:
prompt = [
    "A post-apocalyptic cityscape with crumbling, dilapidated skyscrapers overtaken by nature. "
    "Vines and massive trees grow through shattered windows and collapsed roofs. The streets are cracked, "
    "filled with roots and overgrowth. No signs of humans, just nature reclaiming the ruins. Moody lighting, "
    "overcast skies, high detail, ultra-realistic, cinematic, 4K, concept art style."
]
height = 512
width = 768
num_inference_steps = 50
guidance_scale = 7.5
generator = torch.manual_seed(4)
batch_size = 1

# High-resolution mode
if height > 512 or width > 768:
    print("High-res mode activated.")
    num_inference_steps = int(num_inference_steps * 1.5)

text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

In [None]:
scheduler.set_timesteps(num_inference_steps)
latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8), generator=generator).to(torch_device)
latents = latents * scheduler.sigmas[0]

In [None]:
autocast_context = autocast("cuda") if use_amp else torch.no_grad()
with autocast_context:
    for t in tqdm(scheduler.timesteps):
        latent_model_input = torch.cat([latents] * 2)
        sigma = scheduler.sigmas[(scheduler.timesteps == t).nonzero(as_tuple=True)[0].item()]
        latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents)["prev_sample"]

In [None]:
 tile_size=32
 def tile_latents(latents):
    _, c, h, w = latents.shape
    if h % tile_size != 0 or w % tile_size != 0:
         raise ValueError(f"Latent size ({h}, {w}) not divisible by tile size {tile_size}")
    tiles = []
    for i in range(0, h, tile_size):
        for j in range(0, w, tile_size):
            tiles.append(latents[:, :, i:i+tile_size, j:j+tile_size])
    return tiles

tiles = tile_latents(latents)
print(f"Tiled into {len(tiles)} latent chunks with tile size {tile_size}")

In [None]:
latents = 1 / 0.18215 * latents
with torch.no_grad():
    decoded_output = vae.decode(latents)

image_tensor = decoded_output.sample  # may need to update depending on your model
image = (image_tensor / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(img) for img in images]

In [None]:
lpips_fn = LPIPS(net='vgg')
def calc_lpips(image1, image2):
    tensor1 = tfms.ToTensor()(image1).unsqueeze(0)
    tensor2 = tfms.ToTensor()(image2).unsqueeze(0)
    return lpips_fn(tensor1, tensor2).item()

lpips_score = calc_lpips(pil_images[0], pil_images[0])
print("LPIPS score (identity):", lpips_score)

In [None]:
def print_memory_stats():
    if torch.cuda.is_available():
        print(f"\n[Memory] Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"[Memory] Reserved : {torch.cuda.memory_reserved() / 1e9:.2f} GB")

print_memory_stats()
gc.collect()
torch.cuda.empty_cache()

# Show Image
plt.figure(figsize=(8, 6))
plt.imshow(pil_images[0])
plt.axis("off")
plt.title("Generated Image")
plt.show()