In [9]:
from transformers import CLIPTokenizer,CLIPTextModel
import re
import numpy as np
import torch
import torch.nn as nn
import argparse


from PIL import Image, ImageFont, ImageDraw, ImageOps


In [10]:
device=torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [18]:
class TextConditioner(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer=CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14')
        self.tokenizer=CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')

        
        self.transformer.eval()
        for param in self.transformer.parameters():
            param.requires_grad=False


    def forward(self, prompt_list):
        batch_encoding = self.tokenizer(prompt_list, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        text_embedding = self.transformer(batch_encoding["input_ids"].cuda())
        return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768  /  1, 768


In [19]:
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')

In [39]:
def get_width(font_path, text):
    """
    This function calculates the width of the text.
    
    Args:
        font_path (str): user prompt.
        text (str): user prompt.
    """
    font = ImageFont.truetype(font_path, 24)
    width = font.getlength(text)
    return width

def get_key_words(text: str):
    """
    This function detect keywords (enclosed by quotes) from user prompts. The keywords are used to guide the layout generation.
    
    Args:
        text (str): user prompt.
    """

    words = []
    text = text
    matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
    
    if matches:
        for match in matches:
            # words.append(match.split())
            words.append(match)
            
    if len(words) >= 8:
        return []
    
    # print(words)
    
    return words

In [45]:
def process_caption(font_path,caption,keywords):
    # remove punctuations. please remove this statement if you want to paint punctuations
    caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption) 

    caption_words=tokenizer([caption],truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    # print(caption_words)

    caption_words_ids = caption_words['input_ids'] # (1, 77)  (tokens)
    # print(caption_words_ids)
    length = caption_words['length'] # (1, )
    print(length) #sent_len+2 for sos and eos

    # convert id back to words
    words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist()) 
    # print(words)
    words = [i.replace('</w>', '') for i in words]
    # print(words)
    words_valid = words[:int(length)]  #since padding of max_length so after eot tag still eot till 77 will be added
    print(words_valid) 



    # split the caption into words and convert them to lower case
    caption_split = caption.split() 
    caption_split=[i.lower() for i in caption_split]
    print(caption_split)

    start_dic={}  # get the start index of each word
    state_list=[] # 0: start, 1: middle, 2: special token
    word_match_list=[]  # the index of the word in the caption
    current_caption_index=0
    current_match=''

    for i in range(length):
        
        # use first and last token as special tokens
        if i==0 or i==length-1:
            state_list.append(2) 
            word_match_list.append(127)
            continue
            
        if current_match=='':
            state_list.append(0)
            start_dic[current_caption_index]=i

        else:
            state_list.append(1)

        current_match+=words_valid[i]
        word_match_list.append(current_caption_index)
        if current_match==caption_split[current_caption_index]:
            current_match=''
            current_caption_index+=1

        
    
    print(state_list) 
    print(word_match_list)

    while len(state_list)<77:
        state_list.append(127)
    
    while len(word_match_list)<77:
        word_match_list.append(127)

    # print(state_list) 
    # print(word_match_list)


    length_list=[]
    width_list=[]

    for i in range(len(word_match_list)):
        if word_match_list[i]==127:
            length_list.append(0)
            # width_list.append(0)
        else:
            length_list.append(len(caption.split()[word_match_list[i]]))  #storing the lenght of the word
            width_list.append(get_width(font_path,caption.split()[word_match_list[i]])) #for 

    
    length_list = torch.Tensor(length_list).long() # (77, ) with torch.int64
    width_list = torch.Tensor(width_list).long() # (77, )


    boxes=[]
    duplicate_dict={} #some words may appear more than once
    
    # store the box coordinates and state of each token
    info_array = np.zeros((77,5)) # (77, 5)
    

    for keyword in keywords:
        keyword = keyword.lower()
        if keyword in caption_split:
            if keyword not in duplicate_dict:
                duplicate_dict[keyword] = caption_split.index(keyword) #get the index of the keyword in the sentence
                index = caption_split.index(keyword)
            else:
                if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
                    index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
                    duplicate_dict[keyword] = index
                else:
                    continue
                
            index = caption_split.index(keyword) 
            index = start_dic[index] 
            info_array[index][0] = 1 #index denote the position of the keyword in the words list

            box = [0,0,0,0] 
            boxes.append(list(box))
            info_array[index][1:] = box
    
    boxes_length=len(boxes)
    if boxes_length>8:  #if keywords are more than 8
        boxes=boxes[:8]
    while len(boxes)<8:
        boxes.append([0,0,0,0])


    return caption,length_list,width_list,torch.from_numpy(info_array),words,torch.Tensor(state_list).long(),torch.Tensor(word_match_list).long(),torch.Tensor(boxes),boxes_length
    

        






In [56]:
def get_layout_from_prompt():

    font_path='/home/adi_techbuddy/Desktop/python/repos/brush-your-text/controlnet_util/Textgen/English/Roboto/static/Roboto-Regular.ttf'
    caption="I love 'nlp' and 'ai'"
    keywords=get_key_words(caption)
    print("The following words to be displayed in image were detected",keywords)

    caption,length_list,width_list,target,words,state_list,word_match_list,boxes,boxes_length=process_caption(font_path,caption,keywords)

    
    print(target.shape)
    target=target.cuda().unsqueeze(0)
    print(target.shape)



In [57]:
get_layout_from_prompt()

The following words to be displayed in image were detected ['nlp', 'ai']
tensor([7])
['<|startoftext|>', 'i', 'love', 'nlp', 'and', 'ai', '<|endoftext|>']
['i', 'love', 'nlp', 'and', 'ai']
[2, 0, 0, 0, 0, 0, 2]
[127, 0, 1, 2, 3, 4, 127]
torch.Size([77, 5])
torch.Size([1, 77, 5])
