In [None]:
import numpy as np
from base64 import b64encode
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, PNDMScheduler, DDPMScheduler
from huggingface_hub import notebook_login
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import logging
import os
from huggingface_hub import HfFolder
import random
import json
import torch.nn.functional as F
import re
import time
from IPython.display import HTML

fixed_seed = 1
torch.manual_seed(fixed_seed)
np.random.seed(fixed_seed)
# Check if the HUGGING_FACE_TOKEN environment variable is set
token = os.getenv("HUGGING_FACE_TOKEN")
if token:
    # If the token is found in the environment, use it to authenticate
    # This saves the token for use in the current session
    HfFolder.save_token(token)
else:
    # If the token is not found, prompt for manual login
    from huggingface_hub import notebook_login
    notebook_login()

# Set device
torch_device = "cuda:0" if torch.cuda.is_available(
) else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device:
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

print(f"Using device: {torch_device}")

In [None]:
dtype = torch.float32

In [None]:
model_path = "runwayml/stable-diffusion-v1-5"

vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")

# unet_ema_path = "../data/300000-step-model-OCT/checkpoint-116000/" # TODO change back to this
unet_ema_path = "/store/CIA/scfc3/diffusion/bloodimage/saved_models/500000-step-model-pumas-VAE/checkpoint-21000"
unet = UNet2DConditionModel.from_pretrained(
    unet_ema_path, subfolder="unet_ema")

scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")

vae = vae.to(torch_device, dtype=dtype)
unet = unet.to(torch_device, dtype=dtype)

In [None]:
def convert_dim(caption, target_shape=(77, 768)):
    caption = np.array(caption)
    final_array = np.zeros(target_shape)
    # Determine the length of the caption to be inserted
    assert target_shape[1] >= len(caption)
    caption_len = len(caption)
    # Insert the caption into the beginning of the final_array
    for i in range(target_shape[0]):
        final_array[i, 0:caption_len] = caption
    return final_array


selected_features = np.array([92, 111, 50, 3, 91, 67, 37, 8, 90, 120, 54, 56, 21, 61, 75, 29, 80, 12, 95, 118, 73, 94, 101, 20, 48, 99, 104, 13, 59, 52, 106, 79, 4, 86, 93, 85, 72, 32, 87, 35, 47, 113, 40, 53, 36, 55, 122, 22, 5, 2, 88, 77, 26, 15, 7, 108, 58, 28, 39, 128, 126, 25,
                             103, 65, 105, 34, 18, 69, 27, 43, 64, 123, 38, 78, 17, 121, 42, 49, 33, 66, 57, 6, 24, 112, 10, 115, 68, 45, 11, 51, 41, 97, 70, 102, 114, 89, 71, 44, 110, 109, 62, 31, 124, 16, 1, 74, 9, 119, 14, 83, 117, 76, 60, 46, 23, 84, 98, 82, 100, 107, 81, 125, 127, 30, 19, 96, 63, 116])
# subtract 1 from all elements in selected_features (julia -> python)
selected_features -= 1

In [None]:
prompt_path = "../data/generate_images/generate_json_latent"
files = os.listdir(prompt_path)  # Filter out only .json files
json_files = [file for file in files if file.endswith(
    '.json')]  # Pick a random .json file
random_json_file = random.choice(json_files)
full_path = os.path.join(prompt_path, random_json_file)
with open(full_path, 'r') as f:
    prompt = json.load(f)

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)


set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents = (latents * scheduler.init_noise_sigma)

with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS
    real_input = prompt
    real_input_mat = convert_dim(real_input)
    real_input_mat = torch.tensor(real_input_mat)
    real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
    input_mat = real_input_mat.unsqueeze(0)

    for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        latent_model_input = scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

        latents = scheduler.step(noise_pred, t, latents).prev_sample

# scale and decode the image latents with vae
latents = 1 / vae.config.scaling_factor * latents
with torch.no_grad():
    latents = latents.to(dtype=dtype)
    image = vae.decode(latents).sample

# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image).convert("L")
              for image in images]  # convert("L")->to gray scale
image = pil_images[0]
image = image.convert("L")
display(image)

In [None]:
json_file_path = "../saved_eta_and_lv_data/json/lvs_matrix_100k.json"

# Load the prompt vectors from the JSON file
with open(json_file_path, 'r') as file:
    prompts_data = json.load(file)

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 15

selected_prompts = [prompts_data[str(i)] for i in range(batch_size)]
prompts = [torch.tensor(convert_dim(prompt_vector)).to(
    torch_device, dtype=dtype) for prompt_vector in selected_prompts]

# Stack processed prompts into a batch
input_mat = torch.stack(prompts, dim=0)


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)


set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents = latents * scheduler.init_noise_sigma

# Loop
with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

    for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        latent_model_input = scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t,
                              encoder_hidden_states=input_mat).sample

        # not doing classifier free here!
        latents = scheduler.step(noise_pred, t, latents).prev_sample

# scale and decode the image latents with vae
latents = 1 / vae.config.scaling_factor * latents
with torch.no_grad():
    latents = latents.to(dtype=dtype)
    image = vae.decode(latents).sample

# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
image = pil_images[0]
for i in range(batch_size):
    display(pil_images[i].convert("L"))

In [None]:
for _ in range(4):
    # Path to the directories
    image_dir = "../data/data_resized/bm3d_496_512_test"
    json_dir = "../data/generate_images/generate_json_latent_test"

    # List all JPEG files in the directory
    jpeg_files = [f for f in os.listdir(image_dir) if f.endswith('.jpeg')]
    # Randomly select one JPEG file
    selected_jpeg = random.choice(jpeg_files)
    # Identify the corresponding JSON file
    json_filename = selected_jpeg.replace('.jpeg', '.json')
    json_path = os.path.join(json_dir, json_filename)

    # Read the vector from the JSON file
    with open(json_path, 'r') as file:
        prompt = json.load(file)  # Assuming the JSON structure directly contains the vector

    # Load and display the original image
    original_image_path = os.path.join(image_dir, selected_jpeg)
    original_image = Image.open(original_image_path)


    height = 512
    width = 512
    num_inference_steps = 50
    seed_nr = np.random.randint(10000)
    print(f"seed_nr {seed_nr}")
    generator = torch.manual_seed(seed_nr)
    batch_size = 1

    def set_timesteps(scheduler, num_inference_steps):
        scheduler.set_timesteps(num_inference_steps)

    set_timesteps(scheduler, num_inference_steps)

    # Prep latents
    latents = torch.randn(
        (batch_size, unet.config.in_channels, height // 8, width // 8),
        generator=generator,
    ).to(torch_device)
    latents = (latents * scheduler.init_noise_sigma)

    # Loop
    with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

        real_input = prompt
        real_input_mat = convert_dim(real_input)
        real_input_mat = torch.tensor(real_input_mat)
        real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
        input_mat = real_input_mat.unsqueeze(0)

        for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
            latent_model_input = scheduler.scale_model_input(latents, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

            latents = scheduler.step(noise_pred, t, latents).prev_sample  #TODO not doing classifier free here!

    # scale and decode the image latents with vae
    latents = 1 / vae.config.scaling_factor * latents
    with torch.no_grad():
        latents = latents.to(dtype=dtype)
        image = vae.decode(latents).sample

    # Display
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    image = pil_images[0]
    image = image.convert("L")
    # concatenate the images
    concatenated = Image.new('L', (2 * width, height))
    # resize original image to with and height
    original_image = original_image.resize((width, height))
    concatenated.paste(original_image, (0, 0))
    concatenated.paste(image, (width, 0))
    # display(concatenated)
    save_dir = "/store/CIA/scfc3/diffusion/TemporalRetinaVAE/data/generate_images/diffusion_compare_rel_synth"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Find the highest numbered file in the directory
    existing_files = [f for f in os.listdir(save_dir) if f.endswith('.jpeg')]
    # Filter out files that do not start with a digit before finding the highest number
    existing_files_filtered = [f for f in existing_files if f.split('.')[0].isdigit()]
    if existing_files_filtered:
        highest_num = max([int(f.split('.')[0]) for f in existing_files_filtered])
    else:
        highest_num = 0

    # File name for the new concatenated image
    new_file_name = f"{highest_num + 1}.jpeg"
    save_path = os.path.join(save_dir, new_file_name)

    # Save the concatenated image
    concatenated.save(save_path)
    print(f"Concatenated image saved as: {save_path}")

In [None]:
# Interpolate between two prompts
prompt_1 = [-0.010155955,-0.004592131,-0.38701367,-0.022426253,-0.00024568333,0.03570405,-0.025474727,-0.003766098,0.00015272689,-0.009874655,0.008434392,-0.012125211,-0.0005047297,0.03926266,-0.005911234,0.012074554,0.025061749,0.016118947,-0.025295403,0.0043429234,-0.019303171,-0.008965155,0.0048098783,-0.022155657,-0.018439487,0.012438098,-0.008230246,-0.011369854,0.013975238,-0.024733193,-0.0024420747,-0.007300651,-0.010927623,0.009804951,0.013800517,-0.010402694,-0.0024602013,0.010395709,-0.026068788,0.007702348,0.014202778,0.0013273212,0.010732115,0.011755251,-0.0026727943,-0.00014252868,-0.00964251,-0.01090727,-0.0086603165,0.91866994,0.0036692163,0.0012229425,0.012757576,-0.009995481,-0.019993355,0.01013756,0.024875883,0.008574486,-0.000510592,-0.013572663,0.0047257543,0.033691775,-0.004811735,-0.0038559795,-0.023267845,0.005499406,-0.19547385,0.021643838,0.007663427,0.01097489,0.02083284,0.0013352409,0.01342792,0.012294747,0.008467486,-0.014881456,0.0070831156,0.004481094,0.030578267,0.010763423,0.022135999,0.004610107,0.0061101215,-0.023164,-0.02054793,-0.01389838,0.012618901,0.0015867227,-0.0007768469,-0.025492975,-2.8981638,1.124453,0.0047108997,-0.0034027295,-0.012984973,-0.0029615848,0.015046431,0.0046694623,-0.009886998,-0.01942618,-0.009058119,0.014815218,0.0002879626,-0.0055417474,-0.0027186424,0.0059638117,0.0034492463,0.010449398,0.021700718,-0.0009963056,0.5486783,-0.014873964,-0.005197687,-0.00982392,0.015218205,0.0066417474,0.007279752,-0.0034043728,0.009700814,-0.014037486,-0.0048284503,-0.009108011,-0.0073809493,0.009162379,-0.024887985,0.003875365,0.0010573778,0.021624224]
prompt_2 = [-0.026637983,-0.0049919384,-0.72735846,0.0013894774,1.663604e-6,-0.0037689628,-0.0007240372,0.011690616,0.011068916,0.035094958,0.0075723547,-0.020591881,0.014244843,0.015679605,0.0059791747,0.023565397,-2.31124e-5,0.007694515,-0.010229966,0.010043709,0.011308225,0.01125709,-0.00085738534,-0.0023815772,-0.0067716558,0.013232091,-0.020934608,0.0052047134,0.0031938986,-0.004571068,0.019595314,-0.014215741,0.017283715,0.0058534164,0.013672844,-0.008371335,-0.0011384739,-0.005467668,0.0069078547,-0.022255898,0.023867454,-0.008623882,-0.011428493,0.016150393,-0.010343215,0.009841408,0.0035812396,-0.0038042641,0.00046871044,-0.34486362,-0.04114393,-0.014105927,-0.006903317,0.005046356,0.008912971,0.024923684,-0.03211934,0.008048945,-0.0031226752,0.021200834,0.027791256,-0.012126264,-0.013014689,-0.010026319,-0.011552734,0.0023576792,0.6625984,-0.0037716748,-0.017448092,-0.003082925,0.0055787405,0.013155896,0.012592762,0.013704545,0.014140189,-0.034019794,0.008365242,0.0046707354,0.0045412,0.012962765,0.008882722,-0.0053416304,0.018359562,-0.013084315,0.00395564,-0.013893956,0.019868493,0.009645246,0.0083075445,0.01042995,3.8250873,-0.53500366,-0.0055012787,-0.00024795998,-0.025108527,0.0001758025,0.016005095,-0.014115073,-0.022973824,0.0007388084,-0.009637654,-0.0043257056,0.01392425,-0.007446825,-0.008244691,-0.01408944,-6.8385154e-5,-0.021611854,0.0012386662,-0.0153120095,1.1141413,-0.016341472,-0.0043792315,0.011762388,0.012778547,-0.021212585,-0.038239643,0.013913666,0.027715262,0.013375203,-0.00016425946,-0.04146558,0.00089889695,0.0057990993,0.010341369,0.02737344,-0.00036986242,0.014797182]

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1

def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)

set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents_orig = (latents * scheduler.init_noise_sigma)

num_images = 5
for k in range(num_images):
    latents = latents_orig.clone()

    scale = k / (num_images - 1)
    prompt = (1-scale) * np.array(prompt_1) + scale * np.array(prompt_2)
    # Loop
    with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

        real_input = prompt
        real_input_mat = convert_dim(real_input)
        real_input_mat = torch.tensor(real_input_mat)
        real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
        input_mat = real_input_mat.unsqueeze(0)

        for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
            latent_model_input = scheduler.scale_model_input(latents, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

            latents = scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / vae.config.scaling_factor * latents
    with torch.no_grad():
        latents = latents.to(dtype=dtype)
        image = vae.decode(latents).sample

    # Display
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    image = pil_images[0]
    image = image.convert("L")
    display(image)

In [None]:
# create image grid
x = 0
base_dir = f"../data/generate_images/diffusion_images/{x}"
os.makedirs(base_dir, exist_ok=True)

selected_features = np.array(   [92,111,50,3,91,67,37,8,90,120,54,56,21,61,75,29,80,12,95,118,73,94,101,20,48,99,104,13,59,52,106,79,4,86,93,85,72,32,87,35,47,113,40,53,36,55,122,22,5,2,88,77,26,15,7,108,58,28,39,128,126,25,103,65,105,34,18,69,27,43,64,123,38,78,17,121,42,49,33,66,57,6,24,112,10,115,68,45,11,51,41,97,70,102,114,89,71,44,110,109,62,31,124,16,1,74,9,119,14,83,117,76,60,46,23,84,98,82,100,107,81,125,127,30,19,96,63,116]
)
# subtract 1 from all elements in selected_features (julia -> python)
selected_features -= 1

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)


set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents_orig = latents * scheduler.init_noise_sigma

save_names = ["verticle_position", "tilt", "luminance", "curvature", "morphing"]
cols = 7
rows = 5
max_val = 4
for j in range(rows):
    save_name = save_names[j]
    for i in range(cols):
        prompt = np.zeros(128)
        prompt[selected_features[0]] = 3.0
        prompt[selected_features[j + 1]] = -max_val + i * 2 * max_val / (cols -1)
        latents = latents_orig.clone()

        # Loop
        with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

            real_input = prompt
            real_input_mat = convert_dim(real_input)
            real_input_mat = torch.tensor(real_input_mat)
            real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
            input_mat = real_input_mat.unsqueeze(0)

            for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
                latent_model_input = scheduler.scale_model_input(latents, t)

                # predict the noise residual
                with torch.no_grad():
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

                latents = scheduler.step(noise_pred, t, latents).prev_sample

        # scale and decode the image latents with vae
        latents = 1 / vae.config.scaling_factor * latents
        with torch.no_grad():
            latents = latents.to(dtype=dtype)
            image = vae.decode(latents).sample

        # Display
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        image = pil_images[0]
        # to Gray scale
        image = image.convert("L")
        # display(image)
        image_path = f"{base_dir}/{save_name}_{i}.png"
        image.save(image_path)

In [None]:
img_nr = 0
base_dir = f"../data/generate_images/diffusion_images"

# Define the names used for saving the images
# save_names = ["verticle_position", "tilt", "luminance", "curvature", "morphing"]
save_names = ["verticle_position", "tilt", "luminance", "curvature"]

# Initialize a list to hold the loaded images
loaded_images = []

# Load the images
for j, save_name in enumerate(save_names):
    for i in range(cols):
        image_path = f"{base_dir}/{img_nr}/{save_name}_{i}.png"
        try:
            with Image.open(image_path) as img:
                loaded_images.append(img.copy())
        except FileNotFoundError:
            print(f"File not found: {image_path}")
            continue

# Assuming all images are of the same size
img_width, img_height = loaded_images[0].size

# Create a new image with a size to hold all the images in a 4x7 grid
grid_image = Image.new("L", (img_width * cols, img_height * len(save_names)))

# Paste the images into the grid
for j, img in enumerate(loaded_images):
    x = j % cols * img_width
    y = j // cols * img_height
    grid_image.paste(img, (x, y))

# Display the concatenated image
# display(grid_image)

# Optionally, save the grid image
grid_image_path = f"{base_dir}/concatenated_grid_{img_nr}.png"
grid_image.save(grid_image_path)
print(f"Concatenated grid image saved at: {grid_image_path}")

In [None]:
def pil_to_latent(input_im):
    # if input image is gray sclae convert it to 3 channels
    if input_im.mode != 'RGB':
        input_im = input_im.convert('RGB')
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(transforms.ToTensor()(input_im).unsqueeze(
            0).to(torch_device, dtype=dtype) * 2 - 1)  # Note scaling
    return vae.config.scaling_factor * latent.latent_dist.sample()


def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / vae.config.scaling_factor) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    # to gray scalse
    pil_images = [image.convert("L") for image in pil_images]
    return pil_images

In [None]:
image_dir = "../data/data_resized/bm3d_496_512_train"
json_dir = "../data/generate_images/generate_json_latent"

# List all JPEG files in the directory
jpeg_files = [f for f in os.listdir(image_dir) if f.endswith('.jpeg')]
# Randomly select one JPEG file
selected_jpeg = random.choice(jpeg_files)
# Identify the corresponding JSON file
json_filename = selected_jpeg.replace('.jpeg', '.json')
json_path = os.path.join(json_dir, json_filename)

# Read the vector from the JSON file
with open(json_path, 'r') as file:
    # Assuming the JSON structure directly contains the vector
    prompt = json.load(file)

# Load and display the original image
original_image_path = os.path.join(image_dir, selected_jpeg)
input_image = Image.open(original_image_path)
# resize to height and width
input_image = input_image.resize((width, height))
input_image

In [None]:
# Encode to the latent space
encoded = pil_to_latent(input_image)
encoded.shape

In [None]:
# Let's visualize the four channels of this latent representation:
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(encoded[0][c].cpu(), cmap="Greys")

This 4x64x64 tensor captures lots of information about the image, hopefully enough that when we feed it through the decoder we get back something very close to our input image:

In [None]:
# Decode this latent representation back into an image
decoded = latents_to_pil(encoded)[0]
decoded

In [None]:
# Setting the number of sampling steps:
set_timesteps(scheduler, 50)
print(f"len(scheduler.timesteps) = {len(scheduler.timesteps)}")

In [None]:
# See these in terms of the original 1000 steps used for training:
print(scheduler.timesteps)

In [None]:
noise = torch.randn_like(encoded)  # Random noise
sampling_step = 35
encoded_and_noised = scheduler.add_noise(
    encoded, noise, timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))
latents_to_pil(encoded_and_noised.float())[0]  # Display

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(encoded_and_noised[0][c].cpu(), cmap="Greys")

In [None]:
height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)


set_timesteps(scheduler, num_inference_steps)
with open(json_path, 'r') as file:
    # Assuming the JSON structure directly contains the vector
    prompt = json.load(file)
# prompt[selected_features[2]] = 3.0
print("test")
start_step = 0
noise = torch.randn((batch_size, unet.config.in_channels, height //
                    8, width // 8), generator=generator).to(torch_device)
latents = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([
                              scheduler.timesteps[start_step]]))
latents = latents.to(torch_device).float()
# latents = (latents * scheduler.init_noise_sigma)

# Loop
with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS
    real_input_mat = convert_dim(prompt)
    real_input_mat = torch.tensor(real_input_mat)
    real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
    input_mat = real_input_mat.unsqueeze(0)

    for k, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        if k >= start_step:
            latent_model_input = scheduler.scale_model_input(latents, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t,
                                  encoder_hidden_states=input_mat).sample

            latents = scheduler.step(noise_pred, t, latents).prev_sample

# scale and decode the image latents with vae
latents_scale = 1 / vae.config.scaling_factor * latents
with torch.no_grad():
    latents_scale = latents_scale.to(dtype=dtype)
    image = vae.decode(latents_scale).sample
    # to gray scale

# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
image = pil_images[0]
# to gray scale
image = image.convert("L")
# display(image)
# concat image with input_image and display
concatenated = Image.new('L', (2 * width, height))
input_image = input_image.resize((width, height))
concatenated.paste(input_image, (0, 0))
concatenated.paste(image, (width, 0))
display(concatenated)
# take the mse loss IN the latent space of the original and the generated image. print a new decoded image that is the difference between the two

In [None]:
# print the mean and the std of the latents and the encoded latents
print(f"mean latents: {latents.mean()}, std latents: {latents.std()}")
print(f"mean encoded: {encoded.mean()}, std encoded: {encoded.std()}")

# Ensure both latent representations are on the same device and dtype
original_latents = encoded.to(torch_device, dtype=dtype)
generated_latents = latents.to(torch_device, dtype=dtype)

# 1. Compute the MSE in the Latent Space
mse_loss = F.mse_loss(original_latents, generated_latents, reduction='mean')
print(f"MSE in Latent Space: {mse_loss.item()}")

# 2. Create a New Image from the Difference in the Latent Space
# Here, the new latent representation is the absolute difference between the original and generated latents
new_latent_representation = torch.abs(original_latents - generated_latents)

# 3. Decode and Display the New Image
# Assuming `latents_to_pil` is a function that can decode a latent representation back to an image
# and `vae` is your VAE model used for decoding
with torch.no_grad():
    new_latent_representation = 1 / vae.config.scaling_factor * new_latent_representation
    # Adjust scaling if necessary, depending on your VAE model's requirements
    new_latent_representation = new_latent_representation.to(dtype=dtype)
    # Adjust this line based on your VAE's decode method
    new_image_tensor = vae.decode(new_latent_representation).sample

# Convert tensor to PIL Image for display
new_image_tensor = (new_image_tensor / 2 + 0.5).clamp(0,
                                                      1)  # Normalize if necessary
new_image_tensor = new_image_tensor.detach().cpu().permute(0, 2, 3, 1).numpy()
new_images = (new_image_tensor * 255).round().astype("uint8")
new_pil_images = [Image.fromarray(img) for img in new_images]
new_image = new_pil_images[0]
# to gray scale
# new_image = new_image.convert("L")

# Display the new image
display(new_image)

In [None]:
def pil_to_tensor(pil_img):
    arr = np.array(pil_img)
    return torch.tensor(arr).float() / 255.


original_image_tensor = pil_to_tensor(input_image).to(
    torch_device, dtype=dtype).unsqueeze(0)
generated_image_tensor = pil_to_tensor(image).to(
    torch_device, dtype=dtype).unsqueeze(0)

# Now, proceed with the MSE computation
mse_loss_pixel = F.mse_loss(original_image_tensor,
                            generated_image_tensor, reduction='mean')
print(f"MSE in Pixel Space: {mse_loss_pixel.item()}")

pixel_difference = torch.abs(original_image_tensor - generated_image_tensor)

# Convert the difference back to a PIL image for display
pixel_difference_image = pixel_difference.squeeze().cpu()
pixel_difference_image = (pixel_difference_image.numpy()
                          * 255).round().astype("uint8")
new_diff_image = Image.fromarray(
    pixel_difference_image, 'L' if pixel_difference_image.ndim == 2 else 'RGB')

# Display the new image
display(new_diff_image)

In [None]:
prompt = [-0.026637983,-0.0049919384,-0.72735846,0.0013894774,1.663604e-6,-0.0037689628,-0.0007240372,0.011690616,0.011068916,0.035094958,0.0075723547,-0.020591881,0.014244843,0.015679605,0.0059791747,0.023565397,-2.31124e-5,0.007694515,-0.010229966,0.010043709,0.011308225,0.01125709,-0.00085738534,-0.0023815772,-0.0067716558,0.013232091,-0.020934608,0.0052047134,0.0031938986,-0.004571068,0.019595314,-0.014215741,0.017283715,0.0058534164,0.013672844,-0.008371335,-0.0011384739,-0.005467668,0.0069078547,-0.022255898,0.023867454,-0.008623882,-0.011428493,0.016150393,-0.010343215,0.009841408,0.0035812396,-0.0038042641,0.00046871044,-0.34486362,-0.04114393,-0.014105927,-0.006903317,0.005046356,0.008912971,0.024923684,-0.03211934,0.008048945,-0.0031226752,0.021200834,0.027791256,-0.012126264,-0.013014689,-0.010026319,-0.011552734,0.0023576792,0.6625984,-0.0037716748,-0.017448092,-0.003082925,0.0055787405,0.013155896,0.012592762,0.013704545,0.014140189,-0.034019794,0.008365242,0.0046707354,0.0045412,0.012962765,0.008882722,-0.0053416304,0.018359562,-0.013084315,0.00395564,-0.013893956,0.019868493,0.009645246,0.0083075445,0.01042995,3.8250873,-0.53500366,-0.0055012787,-0.00024795998,-0.025108527,0.0001758025,0.016005095,-0.014115073,-0.022973824,0.0007388084,-0.009637654,-0.0043257056,0.01392425,-0.007446825,-0.008244691,-0.01408944,-6.8385154e-5,-0.021611854,0.0012386662,-0.0153120095,1.1141413,-0.016341472,-0.0043792315,0.011762388,0.012778547,-0.021212585,-0.038239643,0.013913666,0.027715262,0.013375203,-0.00016425946,-0.04146558,0.00089889695,0.0057990993,0.010341369,0.02737344,-0.00036986242,0.014797182]

height = 512
width = 512
num_inference_steps = 20 # 200
seed_nr = np.random.randint(10000)
seed_nr = 9139
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1

# Make a folder to store results
!rm -rf steps/
!mkdir -p steps/


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)

set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents = (latents * scheduler.init_noise_sigma)

# Loop
with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

    real_input = prompt
    real_input_mat = convert_dim(real_input)
    real_input_mat = torch.tensor(real_input_mat)
    real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
    input_mat = real_input_mat.unsqueeze(0)

    for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        latent_model_input = scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

        latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
        latents = scheduler.step(noise_pred, t, latents).prev_sample

        # # FOR DUAL IMAGES
        # # To PIL Images
        # im_t0 = latents_to_pil(latents_x0)[0]
        # im_next = latents_to_pil(latents)[0]
        # # Combine the two images and save for later viewing
        # im = Image.new('RGB', (1024, 512))
        # im.paste(im_next, (0, 0))
        # im.paste(im_t0, (512, 0))
        # im.save(f'steps/{i:04}.png')


        im_t0 = latents_to_pil(latents)[0]
        # im_t0 = latents_to_pil(latents_x0)[0]
        im = Image.new('RGB', (512, 512))
        im.paste(im_t0, (0, 0))
        im.save(f'steps/{i:04}.png')


im

In [None]:
# # import imageio
# import imageio.v2 as imageio

# import os

# # Define the directory where your images are stored
# image_directory = 'steps'
# image_files = [os.path.join(image_directory, img) for img in sorted(os.listdir(image_directory)) if img.endswith('.png')]

# # Define the output video file name
# output_video_file = 'out.mp4'

# # Create a writer object specifying the output file and framerate
# writer = imageio.get_writer(output_video_file, fps=20, codec='libx264', quality=10, pixelformat='yuv420p')

# # Iterate over image files, add them to the video
# for image_file in image_files:
#     image = imageio.imread(image_file)
#     writer.append_data(image)

# # Close the writer to finalize the video
# writer.close()

# # Display the video in a Jupyter notebook
# from IPython.display import Video

# Video(output_video_file, embed=True, width=512)


In [None]:
import imageio.v2 as imageio
import os

# Define the directory where your images are stored
image_directory = 'steps'
image_files = [os.path.join(image_directory, img) for img in sorted(os.listdir(image_directory)) if img.endswith('.png')]

# Define the output video file name
output_video_file = 'out.mp4'

# Create a writer object specifying the output file and framerate
writer = imageio.get_writer(output_video_file)

# Iterate over image files, add them to the video
for image_file in image_files:
    image = imageio.imread(image_file)
    writer.append_data(image)

# Close the writer to finalize the video
writer.close()

# Display the video in a Jupyter notebook
from IPython.display import Video

Video(output_video_file, embed=True, width=512)


In [None]:
import cv2
import os

# Define the directory where your images are stored
image_directory = 'steps'
image_files = [os.path.join(image_directory, img) for img in sorted(os.listdir(image_directory)) if img.endswith('.png')]

# Define the output video file name
output_video_file = 'out.mp4'

# Assume the first image to get the size
frame = cv2.imread(image_files[0])
height, width, layers = frame.shape

# Create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec used to create the video
video = cv2.VideoWriter(output_video_file, fourcc, 20, (width, height))

# Iterate over image files, add them to the video
for image_file in image_files:
    img = cv2.imread(image_file)
    video.write(img)

# Close the video writer
video.release()

# To display the video in a Jupyter notebook, you can still use IPython display
from IPython.display import Video
Video(output_video_file, embed=True, width=512)


# Guidance


Extra control to this generation process?

At each step, we're going to use our model as before to predict the noise component of x. Then we'll use this to produce a predicted output image, and apply some loss function to this image.

This function can be anything, but let's demo with a super simple example. If we want images that have a lot of blue, we can craft a loss function that gives a high loss if pixels have a low blue component:

In [None]:
def blue_loss(images):
    # How far are the blue channel values to 0.9:
    error = torch.abs(images[:, 2] - 0.9).mean()  # [:,2] -> all images in batch, only the blue channel
    return error

During each update step, we find the gradient of the loss with respect to the current noisy latents, and tweak them in the direction that reduces this loss as well as performing the normal update step:

In [None]:
# prompt = "A campfire (oil on canvas)"  # @param
# height = 512  # default height of Stable Diffusion
# width = 512  # default width of Stable Diffusion
# num_inference_steps = 50  # @param           # Number of denoising steps
# guidance_scale = 8  # @param               # Scale for classifier-free guidance
# generator = torch.manual_seed(32)  # Seed generator to create the inital latent noise
# batch_size = 1
# blue_loss_scale = 200  # @param

# # Prep text
# text_input = tokenizer(
#     [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
# )
# with torch.no_grad():
#     text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

# # And the uncond. input as before:
# max_length = text_input.input_ids.shape[-1]
# uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
# with torch.no_grad():
#     uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# # Prep Scheduler
# set_timesteps(scheduler, num_inference_steps)

# # Prep latents
# latents = torch.randn(
#     (batch_size, unet.in_channels, height // 8, width // 8),
#     generator=generator,
# )
# latents = latents.to(torch_device)
# latents = latents * scheduler.init_noise_sigma

# # Loop
# for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
#     # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
#     latent_model_input = torch.cat([latents] * 2)
#     latent_model_input = scheduler.scale_model_input(latent_model_input, t)

#     # predict the noise residual
#     with torch.no_grad():
#         noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

#     # perform CFG
#     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
#     noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

#     #### ADDITIONAL GUIDANCE ###
#     if i % 5 == 0:
#         # Requires grad on the latents
#         latents = latents.detach().requires_grad_()

#         # Get the predicted x0:
#         latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample

#         # Decode to image space
#         denoised_images = vae.decode((1 / vae.config.scaling_factor) * latents_x0).sample / 2 + 0.5  # range (0, 1)

#         # Calculate loss
#         loss = blue_loss(denoised_images) * blue_loss_scale

#         # Occasionally print it out
#         if i % 10 == 0:
#             print(i, "loss:", loss.item())

#         # Get gradient
#         cond_grad = torch.autograd.grad(loss, latents)[0]

#         # Modify the latents based on this gradient
#         # latents = latents.detach() - cond_grad * sigma**2 # TODO maybe need to replace this line with the right code...

#     # Now step with scheduler
#     latents = scheduler.step(noise_pred, t, latents).prev_sample


# latents_to_pil(latents)[0]

Tweak the scale (`blue_loss_scale`) - at low values, the image is mostly red and orange thanks to the prompt. At higher values, it is mostly bluish! Too high and we get a plain blue image.

Since this is slow, you'll notice that I only apply this loss once every 5 iterations - this was a suggestion from Jeremy and we left it in because for this demo it saves time and still works. For your own tests you may want to explore using a lower scale for the loss and applying it every iteration instead :)

NB: We should set latents requires_grad=True **before** we do the forward pass of the unet (removing `with torch.no_grad()`) if we want mode accurate gradients. BUT this requires a lot of extra memory. You'll see both approaches used depending on whose implementation you're looking at.

Guiding with classifier models can give you images of a specific class. Guiding with a model like CLIP can help better match a text prompt. Guiding with a style loss can help add a particular style. Guiding with some sort of perceptual loss can force it towards the overall look af a target image. And so on.

In [None]:
# get working directory
import os
os.getcwd()

# change wd to /home/scfc3/rds/rds-cbs31-cmih-covid19/user_files/scfc3_files/diffusion/bloodimage
os.chdir("/home/scfc3/rds/rds-cbs31-cmih-covid19/user_files/scfc3_files/diffusion/bloodimage")

In [None]:
!sh /home/scfc3/rds/rds-cbs31-cmih-covid19/user_files/scfc3_files/diffusion/bloodimage/train_script_small.sh