In [None]:
# Install dependencies.
!rm -r sample_data
!pip install -qq --upgrade transformers accelerate git+https://github.com/TimothyAlexisVass/diffusers.git

In [None]:
# Set the details for your model here:
import torch

from diffusers import StableDiffusionXLPipeline, AutoencoderKL

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
base = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
)

_ = base.to("cuda")

tokenizer = base.tokenizer            # cpu
tokenizer_2 = base.tokenizer_2        # cpu
scheduler = base.scheduler            # cpu
text_encoder = base.text_encoder      # cuda
text_encoder_2 = base.text_encoder_2  # cuda
unet = base.unet                      # cuda

torch.cuda.empty_cache()

In [None]:
print(unet.parameters().__next__().device)
print(scheduler.sigmas)
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
import os
import random
import torch
import diffusers
import numpy as np

# Define a function to apply Gaussian blur to a tensor
def latents_filter(latent_tensor, timestep, filter_type, kernel_size, sigma, filter_strength):
    num_channels = latent_tensor.shape[1]
    filtered_latents = []

    # Redefine filter_strength based on filter_strength and timestep
    timestep_factor = filter_strength * (1.0 - (timestep * 0.001))

    for i in range(num_channels):
        # Create a Gaussian kernel for blurring
        kernel = np.fromfunction(
            lambda x, y: (1/ (2 * np.pi * sigma ** 2)) * np.exp(-((x - kernel_size//2)**2 + (y - kernel_size//2)**2) / (2 * sigma**2)),
            (kernel_size, kernel_size)
        )
        kernel = torch.FloatTensor(kernel).to("cuda")
        kernel = kernel / kernel.sum()

        kernel = kernel.view(1, 1, kernel_size, kernel_size).repeat(1, 1, 1, 1).to(latent_tensor.dtype)
        latent_channel = latent_tensor[:, i:i+1, :, :]

        filtered_channel = torch.nn.functional.conv2d(latent_channel, kernel, padding=kernel_size//2)

        if filter_type == 'sharpen':
            # Apply smoothing/sharpening by subtracting the filtered image from the original image
            filter_strength = 0.00025 * timestep_factor
            filtered_channel = latent_channel + filter_strength * (latent_channel - filtered_channel)
        elif filter_type == 'amplify':
            # Apply dampen/amplify by applying the filtered image to the original image
            filter_strength = 0.002 * timestep_factor
            if filter_strength > 0: filter_strength/=3
            filtered_channel = latent_channel + filter_strength * filtered_channel
        elif filter_type == 'enhance':
            # Apply edge enhancement with a Laplacian kernel
            filter_strength = 0.00003 * timestep_factor
            laplacian_kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=latent_tensor.dtype).view(1, 1, 3, 3).to("cuda")
            edges = torch.nn.functional.conv2d(latent_channel, laplacian_kernel, padding=1)
            filtered_channel = latent_channel - filter_strength * edges

        filtered_latents.append(filtered_channel)

    return torch.cat(filtered_latents, dim=1)


prompt = f"Starwars jedi robot in indigo robe with golden details, looking happy with red lightsaber in hand, daylight lush nature park background. Intricate details even to the smallest particle, extreme detail of the enviroment, sharp portrait, well lit, interesting outfit, beautiful shadows, bright, photoquality, ultra realistic, masterpiece, 8k"
negative_prompt = "bokeh, blur, ugly, old, boring, photoshopped, tired, wrinkles, scar, gray hair, big forehead, crosseyed, dumb, stupid, cockeyed, disfigured, blurry, assymetrical, unrealistic, grayscale, bad anatomy, unnatural irises, no pupils, blurry eyes, dark eyes, extra limbs, deformed, disfigured eyes, out of frame, no irises, assymetrical face, broken fingers, extra fingers, disfigured hands"
num_inference_steps = 20

def callback(index, timestep, latents):
  latents = latents_filter(latents, timestep, 'enhance', kernel_size=5, sigma=1.0, filter_strength=-100.0)
  return {"latents": latents}

parameters = {
    "prompt": prompt,
    "negative_prompt": negative_prompt,
    "num_inference_steps": num_inference_steps,
    "output_type": "latent",
    "num_images_per_prompt": 1,
    "guidance_scale": 8,
    "callback": callback
}

height = 1024
width = 1024
# latents = torch.randn((parameters["num_images_per_prompt"], base.unet.config.in_channels, height // 8, width // 8), generator=parameters["generator"]).to("cuda") * base.scheduler.init_noise_sigma

# latents = base.prepare_latents(parameters["num_images_per_prompt"], unet.config.in_channels, height, width, None, base.device, parameters["generator"])

def normalize_tensor(tensor):
  min_val = torch.min(tensor)
  return (tensor - min_val) / (torch.max(tensor) - min_val)

@torch.no_grad()
def decode_latents(latents, saturation=50, contrast=50, brightness=50, normalize=False):
    scaling = 4.444 + saturation / 16

    samples = vae.decode(latents * scaling).sample
    if normalize:
        samples = normalize_tensor(samples)
    else:
        samples = samples.mul(contrast/100).add(brightness/100).clamp(0, 1)
    return base.numpy_to_pil(samples.permute(0, 2, 3, 1).cpu().numpy())

wow = base(**parameters, generator=torch.manual_seed(2222)).images
image = decode_latents(wow, normalize=False)[0]
display(image)


In [None]:
!zip -r png_files.zip *.png
!rm -r *.png