In [1]:
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
# 加载GPT-2模型和分词器
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')



In [11]:
def relation_search(input_text, threshold=None, top_k=None):

    res = {}

    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    # 取最后一个词的logits
    last_token_logits = logits[0, -1, :]

    probabilities = torch.softmax(last_token_logits, dim=-1)

    if top_k:
        top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
        top_k_tokens = [tokenizer.decode([idx]) for idx in top_k_indices]
        for i in range(top_k):
            res[top_k_tokens[i]] = top_k_probs[i].item()
    
    if threshold:
        high_prob_indices = (probabilities > threshold).nonzero(as_tuple=True)[0]
        high_probabilities = probabilities[high_prob_indices]

        words_and_probs = [(tokenizer.decode([idx]), prob.item()) for idx, prob in zip(high_prob_indices, high_probabilities)]
        words_and_probs_sorted = sorted(words_and_probs, key=lambda x: x[1], reverse=True)

        for word, prob in words_and_probs_sorted:
            res[word] = prob
    return res

In [33]:
subjects = [
    "Mural",
    "Graffiti",
    "Logo",
    "Artwork",
    "Design",
    "Pattern",
    "Scene",
    "Stripes",
    "Message",
    "Picture",
    "Portrait",
    "Advertisement",
    "Symbol",
    "Flag",
    "Sign",
    "Illustration",
    "Map",
    "Image",
    "Drawing",
    "Landscape",
    "Animal",
    "Flower",
    "Character",
    "Slogan",
    "Banner",
    "Shape",
    "Stencil",
    "Letter",
    "Number",
    "Caricature",
    "Cartoon",
    "Texture",
    "Wave",
    "Cloud",
    "Mountain",
    "Tree",
    "Building",
    "Figure",
    "Silhouette",
    "Pattern",
    "Rainbow",
    "Planet",
    "Bird",
    "Sun",
    "Moon",
    "Star",
    "Quote",
    "Sketch",
    "Vehicle",
    "Boat",
    "Plane",
    "Arrow",
    "Fish",
    "Abstract",
    "Dragon",
    "Tiger",
    "Eagle",
    "Phoenix",
    "Butterfly",
    "Spiral",
    "Sunset",
    "River",
    "Bridge",
    "Cityscape",
    "Skyscraper",
    "Forest",
    "Dinosaur",
    "Elephant",
    "Tiger",
    "Lion",
    "Whale",
    "Shark",
    "Eagle",
    "Horse",
    "Tree",
    "Castle",
    "Galaxy",
    "Robot",
    "Spaceship",
    "Alien",
    "Wave",
    "Cloud",
    "Desert",
    "Volcano",
    "Island",
    "Ocean",
    "Waterfall",
    "Train",
    "Bicycle",
    "Helmet",
    "Sword",
    "Shield",
    "Armor",
    "Scroll",
    "Tapestry",
    "Banner",
    "Fresco",
    "Emblem",
    "Badge",
    "Monogram",
    "Seal"
]


In [34]:
import json

with open('painted.json', 'w') as fn:
    json.dump(subjects, fn)

In [24]:
# 输入文本
input_text = [f"a {sub} is painted on a" for sub in subjects]

In [18]:
input_text = ['a man is riding on a']

In [25]:
res = []
for text in input_text:
    tmp = relation_search(text, threshold=0.05)
    if len(tmp) != 0:
        tmp_text = [f'{text}{sub}' for sub in tmp.keys()]
        res.extend(tmp_text)

In [27]:
len(res)

44

In [32]:
res[0]

'a Mural is painted on a wall'

In [18]:
prompts = [input_text + s for s in res.keys()]

In [19]:
relation_search(input_text, threshold=0.1)

['a bird is playing with a bird',
 'a bird is playing with a ball',
 'a bird is playing with a toy',
 'a bird is playing with a stick',
 'a bird is playing with a dog']

In [9]:
input_text[0]

'a Mural is painted on a'

In [31]:
a = '5'
b = '4'

print(f'{a}{b}')

54


In [35]:
res

['a Mural is painted on a wall',
 'a Graffiti is painted on a wall',
 'a Logo is painted on a white',
 'a Scene is painted on a wall',
 'a Stripes is painted on a wall',
 'a Message is painted on a wall',
 'a Picture is painted on a wall',
 'a Advertisement is painted on a wall',
 'a Flag is painted on a flag',
 'a Sign is painted on a wall',
 'a Image is painted on a wall',
 'a Slogan is painted on a wall',
 'a Banner is painted on a wall',
 'a Letter is painted on a wall',
 'a Letter is painted on a piece',
 'a Number is painted on a piece',
 'a Building is painted on a wall',
 'a Figure is painted on a piece',
 'a Rainbow is painted on a white',
 'a Quote is painted on a piece',
 'a Quote is painted on a wall',
 'a Sketch is painted on a wall',
 'a Vehicle is painted on a white',
 'a Plane is painted on a white',
 'a Abstract is painted on a wall',
 'a River is painted on a wall',
 'a Skyscraper is painted on a wall',
 'a Elephant is painted on a wall',
 'a Lion is painted on a whit