In [None]:
import numpy as np
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from huggingface_hub import notebook_login
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import logging
import os
from huggingface_hub import HfFolder
import random
import json
import torch.nn.functional as F
import re
import time

fixed_seed = 1
torch.manual_seed(fixed_seed)
np.random.seed(fixed_seed)
# Check if the HUGGING_FACE_TOKEN environment variable is set
token = os.getenv("HUGGING_FACE_TOKEN")
if token:
    # If the token is found in the environment, use it to authenticate
    # This saves the token for use in the current session
    HfFolder.save_token(token)
else:
    # If the token is not found, prompt for manual login
    from huggingface_hub import notebook_login
    notebook_login()

# Set device
torch_device = "cuda:0" if torch.cuda.is_available(
) else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device:
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

dtype = torch.float32
print(f"Using device: {torch_device}")

In [None]:
model_path = "runwayml/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")

SD_path = "../../OCT-Longitudinal/saved_models/SD_OCT"
unet = UNet2DConditionModel.from_pretrained(SD_path, subfolder="unet")

scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")

vae = vae.to(torch_device, dtype=dtype)
unet = unet.to(torch_device, dtype=dtype)

In [None]:
def convert_dim(caption, target_shape=(77, 768)):
    caption = np.array(caption)
    final_array = np.zeros(target_shape)
    # Determine the length of the caption to be inserted
    assert target_shape[1] >= len(caption)
    caption_len = len(caption)
    # Insert the caption into the beginning of the final_array
    for i in range(target_shape[0]):
        final_array[i, 0:caption_len] = caption
    return final_array


In [None]:
# Create an image from a random latent space

prompt = np.random.randn(128)
height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1

def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)

set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents = (latents * scheduler.init_noise_sigma)

with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS
    real_input = prompt
    real_input_mat = convert_dim(real_input)
    real_input_mat = torch.tensor(real_input_mat)
    real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
    input_mat = real_input_mat.unsqueeze(0)

    for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        latent_model_input = scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

        latents = scheduler.step(noise_pred, t, latents).prev_sample

# scale and decode the image latents with vae
latents = 1 / vae.config.scaling_factor * latents
with torch.no_grad():
    latents = latents.to(dtype=dtype)
    image = vae.decode(latents).sample

# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image).convert("L")
              for image in images]  # convert("L")->to gray scale
image = pil_images[0]
image = image.convert("L")
display(image)

In [None]:
# Use the latent vector from an image in the test set to generate an image with the diffusion model and compare it with the original image.

# Path to the directories
image_dir = "../../OCT-Longitudinal/test_images"
json_dir = "../../OCT-Longitudinal/128_dim_latent_space/test_set"

# List all JPEG files in the directory
jpeg_files = [f for f in os.listdir(image_dir) if f.endswith('.jpeg')]
# Randomly select one JPEG file
selected_jpeg = random.choice(jpeg_files)
# Identify the corresponding JSON file
json_filename = selected_jpeg.replace('.jpeg', '.json')
json_path = os.path.join(json_dir, json_filename)

# Read the vector from the JSON file
print(f"Reading the vector from the JSON file: {json_path}")
with open(json_path, 'r') as file:
    prompt = json.load(file)

# Load and display the original image
original_image_path = os.path.join(image_dir, selected_jpeg)
original_image = Image.open(original_image_path)

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1

def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)

set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents = (latents * scheduler.init_noise_sigma)

# Loop
with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

    real_input = prompt
    real_input_mat = convert_dim(real_input)
    real_input_mat = torch.tensor(real_input_mat)
    real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
    input_mat = real_input_mat.unsqueeze(0)

    for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
        latent_model_input = scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

        latents = scheduler.step(noise_pred, t, latents).prev_sample  #NOTE: not doing classifier free here!

# scale and decode the image latents with vae
latents = 1 / vae.config.scaling_factor * latents
with torch.no_grad():
    latents = latents.to(dtype=dtype)
    image = vae.decode(latents).sample

# Display
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
image = pil_images[0]
image = image.convert("L")
# concatenate the images
concatenated = Image.new('L', (2 * width, height))
# resize original image to with and height
original_image = original_image.resize((width, height))
concatenated.paste(original_image, (0, 0))
concatenated.paste(image, (width, 0))
display(concatenated)
save_dir = "data/generate_images/diffusion_compare_rel_synth"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Find the highest numbered file in the directory
existing_files = [f for f in os.listdir(save_dir) if f.endswith('.jpeg')]
# Filter out files that do not start with a digit before finding the highest number
existing_files_filtered = [f for f in existing_files if f.split('.')[0].isdigit()]
if existing_files_filtered:
    highest_num = max([int(f.split('.')[0]) for f in existing_files_filtered])
else:
    highest_num = 0

# File name for the new concatenated image
new_file_name = f"{highest_num + 1}.jpeg"
save_path = os.path.join(save_dir, new_file_name)

# Save the concatenated image
concatenated.save(save_path)
print(f"Concatenated image saved as: {save_path}")

In [None]:
# create image grid
x = 0
base_dir = f"../data/generate_images/diffusion_images/{x}"
os.makedirs(base_dir, exist_ok=True)

selected_features = np.array(   [92,111,50,3,91,67,37,8,90,120,54,56,21,61,75,29,80,12,95,118,73,94,101,20,48,99,104,13,59,52,106,79,4,86,93,85,72,32,87,35,47,113,40,53,36,55,122,22,5,2,88,77,26,15,7,108,58,28,39,128,126,25,103,65,105,34,18,69,27,43,64,123,38,78,17,121,42,49,33,66,57,6,24,112,10,115,68,45,11,51,41,97,70,102,114,89,71,44,110,109,62,31,124,16,1,74,9,119,14,83,117,76,60,46,23,84,98,82,100,107,81,125,127,30,19,96,63,116]
)
# subtract 1 from all elements in selected_features (julia -> python)
selected_features -= 1

height = 512
width = 512
num_inference_steps = 50
seed_nr = np.random.randint(10000)
print(f"seed_nr {seed_nr}")
generator = torch.manual_seed(seed_nr)
batch_size = 1


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)


set_timesteps(scheduler, num_inference_steps)

# Prep latents
latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
).to(torch_device)
latents_orig = latents * scheduler.init_noise_sigma

save_names = ["verticle_position", "tilt", "luminance", "curvature"]
cols = 7
rows = len(save_names)
max_val = 4
for j in range(rows):
    save_name = save_names[j]
    for i in range(cols):
        prompt = np.zeros(128)
        prompt[selected_features[0]] = 3.0
        prompt[selected_features[j + 1]] = -max_val + i * 2 * max_val / (cols -1)
        latents = latents_orig.clone()

        # Loop
        with autocast("cuda"):  # will fallback to CPU if no CUDA; no autocast for MPS

            real_input = prompt
            real_input_mat = convert_dim(real_input)
            real_input_mat = torch.tensor(real_input_mat)
            real_input_mat = real_input_mat.to(torch_device, dtype=dtype)
            input_mat = real_input_mat.unsqueeze(0)

            for _, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
                latent_model_input = scheduler.scale_model_input(latents, t)

                # predict the noise residual
                with torch.no_grad():
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_mat).sample

                latents = scheduler.step(noise_pred, t, latents).prev_sample

        # scale and decode the image latents with vae
        latents = 1 / vae.config.scaling_factor * latents
        with torch.no_grad():
            latents = latents.to(dtype=dtype)
            image = vae.decode(latents).sample

        # Display
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        image = pil_images[0]
        # to Gray scale
        image = image.convert("L")
        # display(image)
        image_path = f"{base_dir}/{save_name}_{i}.png"
        image.save(image_path)

In [None]:
img_nr = 0
base_dir = f"../data/generate_images/diffusion_images"

# Define the names used for saving the images
save_names = ["verticle_position", "tilt", "luminance", "curvature"]

# Initialize a list to hold the loaded images
loaded_images = []

# Load the images
for j, save_name in enumerate(save_names):
    for i in range(cols):
        image_path = f"{base_dir}/{img_nr}/{save_name}_{i}.png"
        try:
            with Image.open(image_path) as img:
                loaded_images.append(img.copy())
        except FileNotFoundError:
            print(f"File not found: {image_path}")
            continue

# Assuming all images are of the same size
img_width, img_height = loaded_images[0].size

# Create a new image with a size to hold all the images in a 4x7 grid
grid_image = Image.new("L", (img_width * cols, img_height * len(save_names)))

# Paste the images into the grid
for j, img in enumerate(loaded_images):
    x = j % cols * img_width
    y = j // cols * img_height
    grid_image.paste(img, (x, y))

# Display the concatenated image
display(grid_image)

# Optionally, save the grid image
grid_image_path = f"{base_dir}/concatenated_grid_{img_nr}.png"
grid_image.save(grid_image_path)
print(f"Concatenated grid image saved at: {grid_image_path}")