In [36]:
from transformers import CLIPTokenizer, CLIPTextModel, GPT2Tokenizer
import torch
import tokenizations

# Initialize the CLIPTokenizer and CLIPTextModel
clip_tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
clip_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder")
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

clip_embeddings = clip_model.get_input_embeddings().weight
gpt_vocab_size = gpt_tokenizer.vocab_size


In [40]:
def get_gpt2_logits(prompt):
    # Tokenize the prompt
    tokens = gpt_tokenizer.encode(prompt, add_special_tokens=False)

    # Create a tensor to hold the one-hot encodings
    # Shape: [sequence_length, vocab_size]
    one_hot_encodings = torch.zeros((len(tokens), gpt_vocab_size))

    # Fill the tensor with one-hot encodings
    for i, token_id in enumerate(tokens):
        one_hot_encodings[i, token_id] = 1

    return one_hot_encodings

def recover_text_from_one_hot(one_hot_encodings, tokenizer):
    # Get the token IDs from the one-hot encodings
    token_ids = one_hot_encodings.argmax(dim=-1).tolist()

    # Decode the token IDs to text
    text = tokenizer.decode(token_ids)
    return text

def convert_gpt2_to_clip_onehots(gpt2_onehots, transformation_matrix):
    # Assuming transformation_matrix is a sparse tensor
    # Perform sparse matrix multiplication
    print(transformation_matrix.shape)
    print(gpt2_onehots.shape)
    return torch.sparse.mm(gpt2_onehots, transformation_matrix).t()  # Transpose inputs to match dimensions



def create_sparse_transformation_matrix(tokens_gpt2, tokens_clip, a2b, gpt2_vocab_size, clip_vocab_size):
    # Prepare indices and values for the sparse matrix
    indices = []
    values = []
    
    for gpt2_idx, alignments in enumerate(a2b):
        gpt2_token_id = gpt_tokenizer.convert_tokens_to_ids(tokens_gpt2[gpt2_idx])
        for clip_idx in alignments:
            clip_token_id = clip_tokenizer.convert_tokens_to_ids(tokens_clip[clip_idx])
            indices.append([gpt2_token_id, clip_token_id])
            values.append(1)  # We set the value to 1 to denote alignment
    
    # Convert lists to tensors
    indices = torch.LongTensor(indices).t()  # Transpose to fit COO format
    values = torch.FloatTensor(values)
    
    # Create sparse tensor
    transformation_matrix = torch.sparse.FloatTensor(indices, values, torch.Size([gpt2_vocab_size, clip_vocab_size]))
    return transformation_matrix


def tokenize_and_align(prompt):
    # Tokenize using GPT-2
    tokens_gpt2 = gpt_tokenizer.tokenize(prompt)
    # Tokenize using Stable Diffusion's CLIPTokenizer
    tokens_clip = clip_tokenizer.tokenize(prompt)

    # Get token alignments
    a2b, b2a = tokenizations.get_alignments(tokens_gpt2, tokens_clip)
    
    return tokens_gpt2, tokens_clip, a2b, b2a

def one_hot_to_embeddings(one_hot_encodings, embeddings):
    """
    Convert one-hot encodings to embeddings by matrix multiplication.
    one_hot_encodings: [sequence_length, vocab_size]
    embeddings: [vocab_size, embedding_dim]
    Returns:
    Tensor of shape [sequence_length, embedding_dim]
    """
    return torch.matmul(one_hot_encodings, embeddings)


prompt = "Hello world."

print("Prompt:", prompt)
# Get one-hot encodings of GPT-2 tokens
one_hot_gpt2 = get_gpt2_logits(prompt)
# Tokenize and get alignments
tokens_gpt2, tokens_clip, a2b, b2a = tokenize_and_align(prompt)
print("GPT-2 Tokens:", tokens_gpt2)
print("Stable Diffusion Tokens:", tokens_clip)
print("GPT-2 to Stable Diffusion Alignments:", a2b)

# Create the transformation matrix
gpt2_vocab_size = len(gpt_tokenizer.get_vocab())
clip_vocab_size = len(clip_tokenizer.get_vocab())
transformation_matrix = create_sparse_transformation_matrix(tokens_gpt2, tokens_clip, a2b, gpt2_vocab_size, clip_vocab_size)

# Get GPT-2 one-hot encodings
one_hot_gpt2 = get_gpt2_logits(prompt)

# Convert GPT-2 one-hots to Stable Diffusion one-hots
one_hot_clip = convert_gpt2_to_clip_onehots(one_hot_gpt2, transformation_matrix)
print("Converted One-Hot Encodings Shape:", one_hot_clip.shape)

diffusion_embeddings = one_hot_to_embeddings(one_hot_clip, clip_embeddings)
print("CLIP Embeddings Shape:", clip_embeddings.shape)

# recover the text from the one-hot encodings
recovered_text = recover_text_from_one_hot(one_hot_clip, clip_tokenizer)
print("Recovered Text:", recovered_text)

Prompt: Hello world.
GPT-2 Tokens: ['Hello', 'Ġworld', '.']
Stable Diffusion Tokens: ['hello</w>', 'world</w>', '.</w>']
GPT-2 to Stable Diffusion Alignments: [[0], [0, 1], [2]]
torch.Size([50257, 49408])
torch.Size([3, 50257])


RuntimeError: addmm: Argument #3 (dense): Expected dim 0 size 49408, got 50257