In [2]:
from transformers import CLIPTokenizer,CLIPTextModel
import re
import numpy as np
import torch
import torch.nn as nn
import argparse
import matplotlib.pyplot as plt


from PIL import Image, ImageFont, ImageDraw, ImageOps


In [35]:
class TextConditioner(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer=CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') #processes the tokenized text to extract embeddings
        self.tokenizer=CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14') #convert raw text into tokens(mapping of words)

        
        self.transformer.eval()

        for param in self.transformer.parameters():
            param.requires_grad=False


    def forward(self, prompt):
        print("The prompt is: ",prompt)
        batch_encoding = self.tokenizer(prompt, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        print("output from tokenizer:",batch_encoding)
        
        words = self.tokenizer.convert_ids_to_tokens(batch_encoding['input_ids'].view(-1).tolist()) 
        print("Getting back the words from the token: ",words)

        text_embedding = self.transformer(batch_encoding["input_ids"].cuda()) #the input_ids are token given to each word
        print("contextualized_text_embedding from transformer: ",text_embedding.last_hidden_state.shape) #each token has a 768 size 1d array embedding in the output 
        return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768 and  1, 77


In [40]:
text_encoder=TextConditioner().cuda().eval()

prompt="Hello World"
print(len(prompt.split(' ')))

text_embedding,mask=text_encoder(prompt)



2
The prompt is:  Hello World
output from tokenizer: {'input_ids': tensor([[49406,  3306,  1002, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]]), 'length': tensor([77])}
Getting