# Input variables

In [None]:
# Select the dataset:
DATA_PATH = "/kaggle/input/dog-breed-images/pug"
# DATA_PATH = "/kaggle/input/dog-breed-images/golden_retriever"
# DATA_PATH = "/kaggle/input/calico-25"
# DATA_PATH = "/kaggle/input/calico-cat"

#for calico, uncomment the following line
#animal = 'cat'
animal = 'dog'

# Initialize Stable Diffusion, CLIP

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import torchvision.transforms as T
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPTokenizer, CLIPModel
from diffusers import StableDiffusionPipeline
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

In [None]:
# Const
RANDOM_SEED = 42
SAVE_PATH = "/kaggle/working/projection_model.pt"
new_token = "[V]"

In [None]:
# Set random seed for reproducibility
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

device = "cuda" if torch.cuda.is_available() else "cpu" 

In [None]:
# Initialize models with proper memory management
@torch.no_grad()
def initialize_models():
    dream_model = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
    ).to(device)
    
    clip_model = CLIPModel.from_pretrained(
        "openai/clip-vit-large-patch14",
        torch_dtype=torch.float16
    ).to(device)
    
    clip_processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", use_fast=True)

    return dream_model, clip_model, clip_processor

dream_model, clip_model, clip_processor = initialize_models()

In [None]:
# Get embeddings and tokenizers
dream_embeddings = dream_model.text_encoder.get_input_embeddings()
clip_embeddings = clip_model.text_model.embeddings.token_embedding
dream_tokenizer = dream_model.tokenizer
clip_tokenizer = clip_processor.tokenizer

In [None]:
def GetEmb(model, x):
    if model == "dream":
        # Assuming x is a token or list of tokens
        token_ids = dream_tokenizer.convert_tokens_to_ids(x)
        embeddings = dream_embeddings.weight[token_ids]

    elif isinstance(x, str):
        # Text case for CLIP
        inputs = clip_tokenizer(x, return_tensors="pt", padding=True, truncation=True)
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            embeddings = clip_model.get_text_features(**inputs)

    else:
        # Image case for CLIP
        inputs = clip_processor(images=x, return_tensors="pt").to(device)
        
        with torch.no_grad():
            embeddings = clip_model.get_image_features(**inputs)

    return embeddings

# Train Projection Map

In [None]:
N_EPOCHS = 2000

In [None]:
# Enhanced projection model with layer norm
class EmbeddingProjection(nn.Module):
    def __init__(self, input_dim=clip_embeddings.embedding_dim, output_dim=dream_embeddings.embedding_dim, hidden_dim = 1024):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim, bias=True)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
        
projection_model = EmbeddingProjection().to(device)

In [None]:
@torch.no_grad()
def get_common_embeddings():
    # Find common words more efficiently
    dream_tokens = set(dream_tokenizer.get_vocab().keys())
    clip_tokens = set(clip_tokenizer.get_vocab().keys())
    common_tokens = list(dream_tokens.intersection(clip_tokens))  # Convert to list for indexing
    
    # Pre-allocate tensors for better memory management
    num_tokens = len(common_tokens)
    dream_embs = torch.zeros(num_tokens, dream_embeddings.embedding_dim, device=device)
    clip_embs = torch.zeros(num_tokens, clip_embeddings.embedding_dim, device=device)
    
    for idx, token in enumerate(common_tokens):
        dream_embs[idx] = GetEmb("dream", token)
        clip_embs[idx] = GetEmb("clip", token)

    print(f"Number of common words: {num_tokens} \nNumber of words in diffusion model: {len(dream_tokens)} \nNumber of words in CLIP model: {len(clip_tokens)}")
    print(f"Difference between embeddings {torch.norm(dream_embs-clip_embs)}")
    
    return clip_embs, dream_embs

clip_emb, dream_emb = get_common_embeddings()

In [None]:
#Add token to dream tokenizer
dream_tokenizer.add_tokens(new_token)
dream_model.text_encoder.resize_token_embeddings(len(dream_tokenizer), mean_resizing=False)
new_token_id = dream_tokenizer.convert_tokens_to_ids(new_token)

In [None]:
# Use mixed precision for faster training
scaler = torch.amp.GradScaler(device)
criterion = nn.MSELoss() 
optimizer = optim.AdamW(projection_model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
# Train/val split with proper stratification
X_train, X_val, Y_train, Y_val = train_test_split(
    clip_emb, 
    dream_emb,
    test_size=0.2,
    random_state=RANDOM_SEED
)

In [None]:
# Training loop with early stopping
best_loss = float('inf')
early_stop_counter = 0

for epoch in range(N_EPOCHS+1):
    projection_model.train()
    optimizer.zero_grad()
    
    # Mixed precision training
    with torch.amp.autocast(device):
        outputs = projection_model(X_train)
        loss = criterion(outputs, Y_train)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    # Validation
    if epoch % 50 == 0:
        projection_model.eval()
        with torch.no_grad():
            val_outputs = projection_model(X_val)
            val_loss = criterion(val_outputs, Y_val)
            
        print(f"Epoch {epoch}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss.item():.4f}")
        
        # Early stopping
        if val_loss < best_loss:
            best_loss = val_loss
            early_stop_counter = 0
            torch.save(projection_model.state_dict(), SAVE_PATH)
        else:
            early_stop_counter += 1
            if early_stop_counter >= 5:
                print("Early stopping triggered")
                break

## Find Best Distance metric

In [None]:
def CosineSimilarity(emb1, emb2):
    cos_sim = F.cosine_similarity(emb1, emb2, dim=-1).item()
    return cos_sim

In [None]:
def EuclideanDistance(emb1, emb2):
    return torch.norm(emb1 - emb2, p=2).item()

In [None]:
def ManhattanDistance(emb1, emb2):
    distance = torch.sum(torch.abs(emb1 - emb2))
    return distance

In [None]:
# Function to sample embeddings close and far away from the original embedding
def sample_embeddings_for_manhattan(original_embedding, n_samples=9):
    sampled_embeddings = []
    embedding_dim = original_embedding.shape
    
    device = original_embedding.device
    original_embedding = original_embedding.to(torch.float32)
    factor = 0.000001
    
    for _ in range(10):
        sampled_embedding = original_embedding + factor*torch.randn(embedding_dim, device=device) 
        factor = factor + 0.001
        sampled_embeddings.append(sampled_embedding) 

    return sampled_embeddings

In [None]:
# Function to sample embeddings close and far away from the original embedding
def sample_embeddings_for_euclidean(original_embedding, n_samples=9):
    sampled_embeddings = []
    embedding_dim = original_embedding.shape
    
    device = original_embedding.device
    original_embedding = original_embedding.to(torch.float32)
    factor = 0.00001
    
    for _ in range(10):
        sampled_embedding = original_embedding + factor*torch.randn(embedding_dim, device=device) 
        factor = factor + 0.01
        sampled_embeddings.append(sampled_embedding) 

    return sampled_embeddings

In [None]:
def sample_embeddings_for_cos_sim(original_embedding, n_samples=9):
    sampled_embeddings = []
    embedding_dim = original_embedding.shape
    
    device = original_embedding.device
    original_embedding = original_embedding.to(torch.float32)
    factors = [0.001, 0.030, 0.040, 0.050, 0.06, 0.07, 0.080, 0.1, 0.2, 0.5, 0.8, 1.0, 1.1]
    # factors = [0.02, 0.045, 0.055, 0.065, 0.075]
    # factors = [0.048, 0.049, 0.050, 0.051, 0.052, 0.053]
    
    # samples with cosine similarity close to 1 
    for n in range(len(factors)):
        sampled_embedding = original_embedding + factors[n]*original_embedding.norm()*torch.randn(embedding_dim, device=device) 
        sampled_embeddings.append(sampled_embedding)

    # samples with cosine similarity close to -1 
    for n in range(len(factors)):
        # reverse the direction of the original embedding to get nearly opposite vectors
        sampled_embedding = -original_embedding + factors[n]*original_embedding.norm()*torch.randn(embedding_dim, device=device) 
        sampled_embeddings.append(sampled_embedding)
    
    # samples with cosine similarity close to 0 (nearly orthogonal)
    for _ in range(5):
        # Make the random vector orthogonal to the original embedding 
        random_vector = torch.randn(embedding_dim, device=device) 
        random_vector -= (random_vector.flatten().dot(original_embedding.flatten())) / (original_embedding.flatten().norm() ** 2) * original_embedding
        
        # Normalize the random vector 
        random_vector = random_vector / (random_vector.flatten().norm() ** 2)
        
        sampled_embedding = random_vector
        sampled_embeddings.append(sampled_embedding)

    return sampled_embeddings


In [None]:
# cosine similarity 

words = ["cat", "dog", "bottle"]

# generate images from sample embeddings and see if distance is correlated to image quality
for word in words:
    original_embedding = GetEmb("clip", word)
    sampled_embeddings = sample_embeddings_for_cos_sim(original_embedding)
    
    for sampled_embedding in sampled_embeddings:
        with torch.no_grad():
            projection_model.eval()
            dream_sampled_embedding = projection_model(sampled_embedding)
            dream_embeddings.weight.data[new_token_id] = dream_sampled_embedding.to(dream_embeddings.weight.dtype)
                
        prompt = "A photo of a [V]"
        image = dream_model(prompt).images[0]
        plt.imshow(image)
        plt.axis('off')  
        plt.show()

        distance = CosineSimilarity(original_embedding, sampled_embedding)
        print(f"Word: {word}, Distance: {distance}")


In [None]:
# Euclidean distance

words = ["cat", "dog", "bottle"]

# generate images from sample embeddings and see if distance is correlated to image quality
for word in words:
    original_embedding = GetEmb("clip", word)
    sampled_embeddings = sample_embeddings_for_euclidean(original_embedding)
    
    for sampled_embedding in sampled_embeddings:
        with torch.no_grad():
            projection_model.eval()
            dream_sampled_embedding = projection_model(sampled_embedding)
            dream_embeddings.weight.data[new_token_id] = dream_sampled_embedding.to(dream_embeddings.weight.dtype)
                
        prompt = "A photo of a [V]"
        image = dream_model(prompt).images[0]
        plt.imshow(image)
        plt.axis('off')  
        plt.show()

        distance = EuclideanDistance(original_embedding, sampled_embedding)
        print(f"Word: {word}, Distance: {distance}")

In [None]:
# Manhattan Distance

words = ["cat", "dog", "bottle"]

# generate images from sample embeddings and see if distance is correlated to image quality
for word in words:
    original_embedding = GetEmb("clip", word)
    sampled_embeddings = sample_embeddings_for_manhattan(original_embedding)
    
    for sampled_embedding in sampled_embeddings:
        with torch.no_grad():
            projection_model.eval()
            dream_sampled_embedding = projection_model(sampled_embedding)
            dream_embeddings.weight.data[new_token_id] = dream_sampled_embedding.to(dream_embeddings.weight.dtype)
                
        prompt = "A photo of a [V]"
        image = dream_model(prompt).images[0]
        plt.imshow(image)
        plt.axis('off')  
        plt.show()

        distance = ManhattanDistance(original_embedding, sampled_embedding)
        print(f"Word: {word}, Distance: {distance}")

In [None]:
best_metric = CosineSimilarity

## Find best collapse given distance metric

In [None]:
def get_input_embeddings(n = -1):
    # Load all image embeddings
    image_embeddings = []
    
    for filename in os.listdir(DATA_PATH)[0:n]:
        filepath = os.path.join(DATA_PATH, filename)
        image = Image.open(filepath).convert("RGB")
        
        emb = GetEmb("clip", image)
        image_embeddings.append(emb)
    
    # Stack embeddings into a tensor [n_images, embedding_dim]
    embeddings = torch.cat(image_embeddings, dim=0)
    return embeddings

In [None]:
# target is a word
def calculate_best_function(target):
    
    embeddings = get_input_embeddings()
    target_embedding = GetEmb("clip", target)

    # Define all pooling functions
    def mean(embeddings):
        return torch.mean(embeddings, dim=0)
    
    def median(embeddings):
        return torch.median(embeddings, dim=0).values
    
    def max_(embeddings):
        return torch.max(embeddings, dim=0).values
    
    def min_(embeddings):
        return torch.min(embeddings, dim=0).values
    
    def quantile(embeddings, p=0.75):
        return torch.quantile(embeddings.float(), p, dim=0)
    
    def trimmed_mean(embeddings, trim=0.1):
        k = int(trim * len(embeddings))
        trimmed = torch.sort(embeddings, dim=0).values[k:-k]
        return torch.mean(trimmed, dim=0)

    def std_(embeddings):
        return torch.std(embeddings, dim=0)
    
    funcs = [mean, median, max_, min_, quantile, trimmed_mean, std_]
    
    # Compare each pooled embedding to target
    best_score = -float('inf')
    best_func = None
    
    for func in funcs:
        pooled_embedding = func(embeddings)
        distance = best_metric(pooled_embedding, target_embedding)
        
        if  distance > best_score:
            best_score = distance
            best_func = func
    
    return best_func, best_score

In [None]:
best_func, best_score = calculate_best_function(animal)
print(best_func.__name__, best_score)

# Inference


In [None]:
# Load best model
projection_model.load_state_dict(torch.load(SAVE_PATH, weights_only=True))

In [None]:
# Transofrm multiple embeddings into 1
imgs_embeddings = get_input_embeddings()
desired_embedding = best_func(imgs_embeddings).to(dtype=torch.float32)

In [None]:
#project the data into
with torch.no_grad():
    projection_model.eval()
    desired_embedding = projection_model(desired_embedding)
    dream_embeddings.weight.data[new_token_id] = desired_embedding.to(dream_embeddings.weight.dtype)

In [None]:
prompt = "A photo of a [V]"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

prompt = f"A photo of a [V] {animal} swimming in the sea"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

prompt = f"A photo of a [V] {animal} in front of the Eiffel Tower"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

prompt = f"A photo of a [V] {animal} in Christmas apparel"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

prompt = f"A painting of a [V] {animal} in Van Gogh style"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

prompt = f"A photo of a [V] {animal} in front of Tower Bridge in London"
image = dream_model(prompt, num_inference_steps=100, guidance_scale=7.0).images[0]
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()