In [24]:
%pip install Levenshtein
%pip install transformers
%pip install accelerate
%pip install diffusers




In [25]:
# IMPORTS

import psutil
import json
import os
import random
import time
import torch
import uuid

from datetime import datetime
from Levenshtein import distance
from jinja2 import Template
from PIL import Image


import os
import re
import torch
from IPython.display import display
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
from diffusers import AutoencoderKL


In [27]:
# VARIOUS HELPER FUNCTIONS

### SIMILARITY CALCULATIONS FOR TRACK NAMES

SIMILARITY_DISTANCE = 3
def are_tracks_similar(tracks):
    for i in range(len(tracks)):
        for j in range(i + 1, len(tracks)):
            if distance(tracks[i]['name'], tracks[j]['name']) < SIMILARITY_DISTANCE:
                return True
    return False

def similarity_groups(tracks):
    found_group = [False] * len(tracks)
    groups = []

    for i in range(len(tracks)):
        if found_group[i]:
            continue

        found_group[i] = True
        groups.append([tracks[i]])

        for j in range(i + 1, len(tracks)):
            if found_group[j]:
                continue
            if distance(tracks[i]['name'], tracks[j]['name']) < SIMILARITY_DISTANCE:
                found_group[j] = True

    return groups

### OUTPUT FILE ID NAME
def get_run_id(run_in):
    # model = run_in['model_id'].replace('/', '-')
    return f"{run_in['album_id']}_{run_in['positive-prompt']}_{run_in['negative-prompt']}_{run_in['inference_steps']}_{run_in['guidance_scale']}_{run_in['batch_size']}"


In [28]:
# GET COMPUTER SPECS

platform_info = {}

platform_info['physical_cpu_cores'] = psutil.cpu_count(logical=False)
platform_info['total_cpu_cores'] = psutil.cpu_count(logical=True)

def get_available_device():
    """Helper method to find best possible hardware to run
    Returns:
        torch.device used to run experiments.
        str representation of backend.
    """
    # Check if CUDA is available
    if torch.cuda.is_available():
        return torch.device("cuda"), "cuda"

    # Check if ROCm is available
    if torch.version.hip is not None and torch.backends.mps.is_available():
        return torch.device("rocm"), "rocm"

    # Check if MPS (Apple Silicon) is available
    if torch.backends.mps.is_available():
        return torch.device('cpu'), "mps"

    # Fall back to CPU
    return torch.device("cpu"), "cpu"

# Check device info
device, backend = get_available_device()

# Check for GPU-specific details if CUDA or ROCm is available
if device.type == "cuda":
    cuda_device_count = torch.cuda.device_count()
    cuda_device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.cuda
elif device.type == "rocm":
    cuda_device_count = torch.cuda.device_count()
    cuda_device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.hip
else:
    cuda_device_count = 0
    cuda_device_name = "N/A"
    cuda_version = "N/A"

platform_info['device'] = device.type
platform_info['backend'] = backend
platform_info['cuda_device_count'] = cuda_device_count
platform_info['cuda_device_name'] = cuda_device_name
platform_info['cuda_version'] = cuda_version

# print(json.dumps(platform_info, indent=4))

In [31]:
import os
print(os.getcwd())

from google.colab import files
uploaded = files.upload()

/content


In [32]:
# GET ALBUM DATA
file_id = "" # if need for a specific album, put the file name here

if file_id == "":
    album_files = os.listdir('/content/MSGAI/input/')
    random_album_file = random.choice(album_files)
else:
    random_album_file = f'{file_id}.json'

with open(f'MSGAI/input/{random_album_file}', 'r') as file:
    album_data = json.load(file)

In [33]:
# POSITIVE PROMPT TEMPLATES

pos_prompt_templates = {}

pos_prompt_templates['1-long'] =  Template("""\
Album cover for this album:
Album name : {{ album.name }}
Artist{% if album.artists|length > 1 %}s{% endif %} : {{ album.artists | join(', ') }}
Release Date : {{ album.date }}
Label : {{ album.label }}
Tracks:
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}
""")

pos_prompt_templates['2-only-tracks'] =  Template("""\
Album cover for these tracks:
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}
""")

pos_prompt_templates['3-long-with-track-similarity'] =  Template("""\
Album cover for this album:
Album name : {{ album.name }}
Artist{% if album.artists|length > 1 %}s{% endif %} : {{ album.artists | join(', ') }}
Release Date : {{ album.date }}
Label : {{ album.label }}

{% if are_tracks_similar(album.tracks) %} Track format : {% for track in similarity_groups(album.tracks) %}- {{ track.name }}\n{% endfor %}
{% else %} Tracks:
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}{% endif %}
""")


In [34]:
# NEGATIVE PROMPT TEMPLATES

neg_prompts = {}
neg_prompts['1-no-text'] = "text"

In [35]:
# OTHER PARAMETERS

INFERENCE_STEPS = [20, 100] # the lower the faster but loses in quality
GUIDANCE_SCALE = [5, 10] # the higher the more it follows the prompt BUT loses in creativity
BATCH_SIZE = [1, 2] # Number of images to generate in parallel
#MODELS = ['sd-legacy/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-2']

In [36]:
# CREATE RUN PARAMETERS

runs = []

is_similar = are_tracks_similar(album_data['tracks'])

for pos_key, template in pos_prompt_templates.items():
    run_input = {
        'computer_specs': platform_info,
        'album_id': album_data['id'],
    }

    # Skip prompt if the tracks are not similar
    if pos_key == '3-long-with-track-similarity' and not is_similar:
        print(f"Skipping prompt {pos_key} as tracks are not similar.")
        continue

    run_input['positive-prompt'] = pos_key

    for neg_key, neg_prompt in neg_prompts.items():
        run_input['negative-prompt'] = neg_key

        for step in INFERENCE_STEPS:
            run_input['inference_steps'] = step

            for scale in GUIDANCE_SCALE:
                run_input['guidance_scale'] = scale

                for batch in BATCH_SIZE:
                    run_input['batch_size'] = batch

                    #for model in MODELS:
                    #run_input['model_id'] = model
                    runs.append(run_input.copy())

print(f"Total runs: {len(runs)}")

Skipping prompt 3-long-with-track-similarity as tracks are not similar.
Total runs: 16


In [37]:
# #LOAD MODEL

# Stable Diffusion model: https://huggingface.co/stabilityai/stable-diffusion-2
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", torch_dtype=torch.float32)
pipe = pipe.to("cuda")  # Use "cpu" if CUDA is not available
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

#to improve quality
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32).to("cuda")
pipe.vae = vae

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [23]:
# RUNS

curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")

output_dir = f'output/{curr_time}'
os.makedirs(f'{output_dir}/images', exist_ok=True)
os.makedirs(f'{output_dir}/runs', exist_ok=True)

for run_input in runs:
    # get ids for file naming
    img_id = uuid.uuid4().hex[:8]
    run_id = get_run_id(run_input)

    positive_prompt = pos_prompt_templates[run_input['positive-prompt']].render(album=album_data, are_tracks_similar=are_tracks_similar, similarity_groups=similarity_groups)
    negative_prompt = neg_prompts[run_input['negative-prompt']]

    #model_id = run_input['model_id']
    inference_steps = run_input['inference_steps']
    guidance_scale = run_input['guidance_scale']
    batch_size = run_input['batch_size']

    start_time = time.time()

    # TODO 1 : run the model here (remove temp images)
    imgs = []
    #for i in range(batch_size):
    #    imgs.append(Image.new('RGB', (256, 256), color = (73, 109, 137)))

    imgs.append(pipe(
                prompt= positive_prompt,
                negative_prompt= negative_prompt,
                guidance_scale = guidance_scale,
                num_inference_steps = inference_steps,
                batch_size = batch_size,
            ).images[0]
    )

    end_time = time.time()

    # TODO 2 : choose the measures to be returned
    run_info = run_input
    run_info['execution_time'] = end_time - start_time
    run_info['resolution'] = 'N/A'
    run_info['colour_quality'] = 'N/A'
    run_info['ssim'] = 'N/A'
    run_info['clip'] = 'N/A'
    run_info['image_id'] = img_id

    # Save the image(s)
    for i, img in enumerate(imgs):
        img.save(f"{output_dir}/images/{img_id}_{i}.png")

    # Save the run info
    with open(f"{output_dir}/runs/{run_id}.json", 'w') as f:
        json.dump(run_info, f, indent=4)


In [79]:
#Imports & quality measurements

import os
import uuid
import json
import time
import queue
import torch
from datetime import datetime
from PIL import Image
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms

# Function to measure color quality (mean color of the image)
def measure_color_quality(image):
    img_array = np.array(image)
    mean_color = img_array.mean(axis=(0, 1))  # Mean color across all pixels
    return np.mean(mean_color)  # Return mean of the RGB channels

# Function to check image resolution
def get_resolution(image):
    return image.size  # returns (width, height)

# Load the CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# Function to compute CLIP score
def compute_clip_score(image, text):
    # Resize the image to 224x224 if it is not already in that size
    if image.size != (224, 224):
        image = image.resize((224, 224))  # Resize to 224x224

    # Preprocess the image and text
    inputs = processor(text=text, images=image, return_tensors="pt", padding=True)

    # Get the image and text embeddings from CLIP model
    with torch.no_grad():
        outputs = model(**inputs)

    # Compute cosine similarity between image and text embeddings
    image_embeddings = outputs.image_embeds
    text_embeddings = outputs.text_embeds

    # Normalize the embeddings
    image_embeddings = image_embeddings / image_embeddings.norm(p=2, dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(p=2, dim=-1, keepdim=True)

    # Calculate cosine similarity (dot product)
    similarity = (image_embeddings @ text_embeddings.T).squeeze()

    return similarity.item()

In [92]:
#Handling of multiple requests

# Set the number of total runs
total_runs = 3  # Specify the number of runs you want to repeat (can be adjusted)
trade_off = 2 # when does the trade-off start between quality and latency
total_execution_time = 0
total_images_generated = 0
total_generation_time = 0

# Create the output directory with the current timestamp
curr_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f'output/{curr_time}'
os.makedirs(f'{output_dir}/images', exist_ok=True)
os.makedirs(f'{output_dir}/runs', exist_ok=True)

# Initialize the FIFO Queue
image_request_queue = queue.Queue()

# Simulate enqueuing requests
for _ in range(total_runs):
    image_request_queue.put(1)  # Enqueue one request for each total run
    request_time = time.time()  # Capture time when the request is added to the queue
    image_request_queue.put({'request_time': request_time})

# Outer loop for total runs
for run_num in range(total_runs):
    print(f"Running iteration {run_num + 1} of {total_runs}...")

    while not image_request_queue.empty():  # Use while loop to stop when queue is empty
        print(f'Queue Size = {image_request_queue.qsize()}')

        # Get the request and timestamp when it was enqueued
        request = image_request_queue.get()

        # Get IDs for file naming
        img_id = uuid.uuid4().hex[:8]
        run_id = get_run_id(run_input)

        # Generate the prompts
        positive_prompt = pos_prompt_templates[run_input['positive-prompt']].render(
            album=album_data,
            are_tracks_similar=are_tracks_similar,
            similarity_groups=similarity_groups
        )
        negative_prompt = neg_prompts[run_input['negative-prompt']]

        # Set the model parameters
        if image_request_queue.qsize() >= trade_off:
          inference_steps = 20
          height, width = 224, 224
        else:
          inference_steps = 40
          height, width = 768, 768
        print(f"Inference Steps: {inference_steps}")

        guidance_scale = run_input['guidance_scale']
        batch_size = run_input['batch_size']

        # Start timing the image generation
        start_time = time.time()

        # Run the model and generate the images
        imgs = []
        imgs.append(pipe(
            prompt=positive_prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=inference_steps,
            batch_size=batch_size,
            height=height,
            width=width
        ).images[0])

        # Save the image(s)
        for i, img in enumerate(imgs):
            img.save(f"{output_dir}/images/{img_id}_{run_num+1}_{i}.png")  # Add the run number to the image name

        # Save the run info
        with open(f"{output_dir}/runs/{run_id}_{run_num+1}.json", 'w') as f:
            json.dump(run_info, f, indent=4)

        # Remove the image from the queue
        image_request_queue.get()  # Dequeue one image request

        # End timing the image generation
        end_time = time.time()
        latency = end_time - start_time  # Latency (time used to generate the image)

        # Add the time for this image to the total generation time
        total_execution_time += latency
        total_images_generated += 1
        total_gen_time = end_time - request_time
        total_generation_time += total_gen_time

        # Measure resolution
        image = imgs[0]
        resolution = get_resolution(image)

        # Measure color quality
        color_quality = measure_color_quality(image)

        # Measure CLIP score
        if image.size != (224, 224):
            image = image.resize((224, 224))
        clip_score = compute_clip_score(image, positive_prompt)

        # Gather information about the run
        run_info = run_input
        run_info['execution_time'] = latency
        run_info['resolution'] = resolution
        run_info['colour_quality'] = color_quality
        run_info['image_id'] = img_id
        run_info['total_generation_time'] = total_gen_time

        # Print the results after dequeuing
        print(f"Run {run_num + 1} - Image {i}: Latency = {latency:.2f}s, Total Generation Time = {total_gen_time:.2f}s, "
              f"Resolution = {resolution}, Color Quality = {color_quality:.2f}, CLIP score: {clip_score}")

    print(f"Iteration {run_num + 1} completed.\n")

    # Calculate the average time per image generation
    average_time_per_image = total_execution_time / total_images_generated if total_images_generated > 0 else 0
    average_total_gen_time = total_generation_time / total_images_generated if total_images_generated > 0 else 0

print(f"Total generation time: {total_generation_time:.2f}s")
print(f"Average generation time per image: {average_total_gen_time:.2f}s")
print(f"Average execution time per image: {average_time_per_image:.2f}s")
print(f"All {total_runs} runs completed and saved.")


Running iteration 1 of 3...
Queue Size = 6
Inference Steps: 20


  0%|          | 0/20 [00:00<?, ?it/s]

Run 1 - Image 0: Latency = 1.22s, Total Generation Time = 1.22s, Resolution = (224, 224), Color Quality = 78.69, CLIP score: 0.23766645789146423
Queue Size = 4
Inference Steps: 20


  0%|          | 0/20 [00:00<?, ?it/s]

Run 1 - Image 0: Latency = 1.23s, Total Generation Time = 2.60s, Resolution = (224, 224), Color Quality = 111.62, CLIP score: 0.2503531277179718
Queue Size = 2
Inference Steps: 40


  0%|          | 0/40 [00:00<?, ?it/s]

Run 1 - Image 0: Latency = 21.43s, Total Generation Time = 24.18s, Resolution = (768, 768), Color Quality = 121.16, CLIP score: 0.30140429735183716
Iteration 1 completed.

Running iteration 2 of 3...
Iteration 2 completed.

Running iteration 3 of 3...
Iteration 3 completed.

Total generation time: 28.01s
Average generation time per image: 9.34s
Average execution time per image: 7.96s
All 3 runs completed and saved.
