In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
import torch

# Check if CUDA (GPU) is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name())
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")


Using GPU: NVIDIA GeForce GTX 1650 Ti


In [3]:
import torch
# Create a tensor on the CPU
tensor = torch.randn((3, 3))
#Move the tensor to the GPU
tensor = tensor.to('cuda')

In [None]:
# #@title Setup
import os, subprocess

def setup():
    install_cmds = [
        ['pip', 'install', 'gradio'],
        ['pip', 'install', 'open_clip_torch'],
        ['pip', 'install', 'clip-interrogator'],
    ]
    for cmd in install_cmds:
        print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))

setup()


caption_model_name = 'blip-large' #@param ["blip-base", "blip-large", "git-large-coco"]
clip_model_name = 'ViT-L-14/openai' #@param ["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"]

import gradio as gr
from clip_interrogator import Config, Interrogator

config = Config()
config.clip_model_name = clip_model_name
config.caption_model_name = caption_model_name
ci = Interrogator(config)

def image_analysis(image):
    image = image.convert('RGB')
    image_features = ci.image_to_features(image)

    top_mediums = ci.mediums.rank(image_features, 5)
    top_artists = ci.artists.rank(image_features, 5)
    top_movements = ci.movements.rank(image_features, 5)
    top_trendings = ci.trendings.rank(image_features, 5)
    top_flavors = ci.flavors.rank(image_features, 5)

    medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
    artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
    movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
    trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
    flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
    
    return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks

def image_to_prompt(image, mode):
    ci.config.chunk_size = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
    ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == "ViT-L-14/openai" else 1024
    image = image.convert('RGB')
    if mode == 'best':
        return ci.interrogate(image)
    elif mode == 'classic':
        return ci.interrogate_classic(image)
    elif mode == 'fast':
        return ci.interrogate_fast(image)
    elif mode == 'negative':
        return ci.interrogate_negative(image)

In [None]:
#@title Batch process a folder of images 📁 -> 📝

#@markdown This will generate prompts for every image in a folder and either save results 
#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.
#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)
#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)

import csv
import os
from IPython.display import clear_output, display
from PIL import Image
from tqdm import tqdm

folder_path = "TAD66K dataset/TAD66K/" #@param {type:"string"}
prompt_mode = 'fast' #@param ["best","fast","classic","negative"]
output_mode = 'desc.csv' #@param ["desc.csv","rename"]
max_filename_len = 128 #@param {type:"integer"}


def sanitize_for_filename(prompt: str, max_len: int) -> str:
    name = "".join(c for c in prompt if (c.isalnum() or c in ",._-! "))
    name = name.strip()[:(max_len-4)] # extra space for extension
    return name

ci.config.quiet = True

files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []
prompts = []
for idx, file in enumerate(tqdm(files, desc='Generating prompts')):
    if idx > 0 and idx % 100 == 0:
        clear_output(wait=True)

    image = Image.open(os.path.join(folder_path, file)).convert('RGB')
    prompt = image_to_prompt(image, prompt_mode)
    prompts.append(prompt)

    print(prompt)
    thumb = image.copy()
    thumb.thumbnail([256, 256])
    display(thumb)

    if output_mode == 'rename':
        name = sanitize_for_filename(prompt, max_filename_len)
        ext = os.path.splitext(file)[1]
        filename = name + ext
        idx = 1
        while os.path.exists(os.path.join(folder_path, filename)):
            print(f'File {filename} already exists, trying {idx+1}...')
            filename = f"{name}_{idx}{ext}"
            idx += 1
        os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))

if len(prompts):
    if output_mode == 'desc.csv':
        csv_path = os.path.join(folder_path, 'desc.csv')
        with open(csv_path, 'w', encoding='utf-8', newline='') as f:
            w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
            w.writerow(['image', 'prompt'])
            for file, prompt in zip(files, prompts):
                w.writerow([file, prompt])

        print(f"\n\n\n\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!")
    else:
        print(f"\n\n\n\nGenerated {len(prompts)} prompts and renamed your files, enjoy!")
else:
    print(f"Sorry, I couldn't find any images in {folder_path}")

In [None]:
# Debug print
print("Number of Prompts:", len(prompts))

if len(prompts):
    csv_path = os.path.join(output_mode)
    with open(csv_path, 'w', encoding='utf-8', newline='') as f:
        w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
        w.writerow(['image', 'prompt'])
        for file, prompt in zip(files[:20], prompts):
            w.writerow([file, prompt])

    print(f"\n\n\n\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!")
else:
    print(f"Sorry, I couldn't find any images in {folder_path}")

In [None]:
# Add this import statement
from sentence_transformers import SentenceTransformer

# Continue from where we left off

# Load the SentenceTransformer model
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Assuming prompts is a list containing the generated prompts
encoded_prompts = model.encode(prompts)

# Debug print to check the shape of encoded prompts
print("Shape of encoded prompts:", encoded_prompts.shape)

# Now you can use encoded_prompts for further processing or analysis

# Modify the prompt saving part to include filenames
csv_encoded_path = "result/encoded_prompts.csv"
with open(csv_encoded_path, 'w', encoding='utf-8', newline='') as f:
    w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
    w.writerow(['image', 'prompt', 'encoded_representation'])
    for file, prompt, encoded_prompt in zip(files, prompts, encoded_prompts):
        w.writerow([file, prompt, encoded_prompt.tolist()])


print(f"\nEncoded prompts saved to {csv_encoded_path}")

# Now you can continue with the rest of your code as needed.


In [None]:
# Add this import statement
import numpy as np

# Continue from where we left off

# Save the encoded prompts to a NumPy file
np.save("result/encoded_prompts.npy", encoded_prompts)

print("\nEncoded prompts saved to encoded_prompts.npy")

# Now you can continue with the rest of your code as needed.

In [None]:
# Example verification after loading
import pandas as pd

data = pd.read_csv(csv_encoded_path)
print("Data preview:", data.head())
print("Data consistency check, Number of entries:", len(data))

In [1]:
pip install sentence-transformers

^C
Note: you may need to restart the kernel to use updated packages.


In [6]:
import csv
from sentence_transformers import SentenceTransformer

# Load the SentenceTransformer model
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Assuming the CSV file path
input_csv_path = "color_expert_desc_test.csv"
output_csv_path = "expert_encoded_test.csv"

# Read the CSV file
files = []
prompts = []
with open(input_csv_path, 'r', encoding='utf-8') as f:
    reader = csv.reader(f)
    next(reader)  # Skip header
    for row in reader:
        files.append(row[0])
        prompts.append(row[1])

# Encode the prompts
encoded_prompts = model.encode(prompts)

# Debug print to check the shape of encoded prompts
print("Shape of encoded prompts:", encoded_prompts.shape)

# Modify the prompt saving part to include filenames and encoded representations
with open(output_csv_path, 'w', encoding='utf-8', newline='') as f:
    writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
    writer.writerow(['image', 'prompt', 'encoded_representation'])
    for file, prompt, encoded_prompt in zip(files, prompts, encoded_prompts):
        writer.writerow([file, prompt, encoded_prompt.tolist()])

print(f"\nEncoded prompts saved to {output_csv_path}")




Shape of encoded prompts: (500, 384)

Encoded prompts saved to expert_encoded_test.csv


In [7]:
import pandas as pd

df = pd.read_csv('expert_encoded_test.csv')
print(df)

                            image  \
0              a4501-DSC_0354.jpg   
1    a4502-Duggan_090116_4368.jpg   
2              a4503-kme_0411.jpg   
3              a4504-_DGW7893.jpg   
4              a4505-DSC_0086.jpg   
..                            ...   
495  a4996-Duggan_090426_7783.jpg   
496            a4997-kme_0558.jpg   
497  a4998-Duggan_080210_5246.jpg   
498            a4999-DSC_0035.jpg   
499            a5000-kme_0204.jpg   

                                                prompt  \
0    Dominant colors in the image are: rgb(93, 89, ...   
1    Dominant colors in the image are: rgb(62, 54, ...   
2    Dominant colors in the image are: rgb(80, 59, ...   
3    Dominant colors in the image are: rgb(176, 176...   
4    Dominant colors in the image are: rgb(59, 57, ...   
..                                                 ...   
495  Dominant colors in the image are: rgb(191, 188...   
496  Dominant colors in the image are: rgb(220, 226...   
497  Dominant colors in the image