In [12]:
import os
import random
import json

import sys 
sys.path.append(".")
sys.path.append("..")

from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

from transformers import CLIPProcessor, CLIPTextModel

%load_ext autoreload 
%autoreload 2

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [10]:
def solve_for_W(A, B):
    A_pseudo_inv = np.linalg.pinv(A)
    
    W = np.dot(A_pseudo_inv, B)
    return W

In [13]:
model_name = "openai/clip-vit-large-patch14"  
processor = CLIPProcessor.from_pretrained(model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name)



In [24]:
pairs = [
    ["salt", "pepper"],
    ["knife", "fork"],
    ["pen", "paper"],
    ["phone", "charger"],
    ["milk", "cookies"],
    ["shoes", "socks"],
    ["tea", "sugar"],
    ["computer", "mouse"],
]

In [22]:
valid_pairs = [["book", "bookmark"], ["coffee", "creamer"]]

In [25]:
prompts = ['a ' + s1 + ' and ' + 'a ' + s2 for s1, s2 in pairs]

In [35]:
infer_prompt = ['a ' + s1 + ' and ' + 'a ' + s2 for s1, s2 in valid_pairs]

In [2]:
prompts = ['a ' + s1 + ' is riding on a ' + s2 for s1 in ['man', 'woman'] for s2 in ['bike', 'motorcycle', 'horse']]

In [3]:
prompts

['a man is riding on a bike',
 'a man is riding on a motorcycle',
 'a man is riding on a horse',
 'a woman is riding on a bike',
 'a woman is riding on a motorcycle',
 'a woman is riding on a horse']

In [6]:
indice_dic = {'s1': 2, 's2': 7, 'r':4}
token_dic = {name: [prompt.split()[pos-1] for prompt in prompts] for name, pos in indice_dic.items()}

refer_prompt = [f"a {s1} {r} on a {s2}" for s1, r, s2 in zip(token_dic['s1'], token_dic['r'], token_dic['s2'])]
print(refer_prompt)

['a man riding on a bike', 'a man riding on a motorcycle', 'a man riding on a horse', 'a woman riding on a bike', 'a woman riding on a motorcycle', 'a woman riding on a horse']


In [7]:
token_dic

{'s1': ['man', 'man', 'man', 'woman', 'woman', 'woman'],
 's2': ['bike', 'motorcycle', 'horse', 'bike', 'motorcycle', 'horse'],
 'r': ['riding', 'riding', 'riding', 'riding', 'riding', 'riding']}

In [14]:
def extract_embedding(prompt, token_dic):
    inputs = processor(
        prompt,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    embeds = text_encoder(inputs.input_ids).last_hidden_state

    eot = embeds[:, len(prompt[0].split())+1]

    inputs = processor(
        token_dic['s1'],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    s1 = text_encoder(inputs.input_ids).last_hidden_state[:, 1]

    inputs = processor(
        token_dic['r'],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    r = text_encoder(inputs.input_ids).last_hidden_state[:, 1]

    inputs = processor(
        token_dic['s2'],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    s2 = text_encoder(inputs.input_ids).last_hidden_state[:, 1]

    return [s1, r, s2, eot]

In [15]:
refer = extract_embedding(refer_prompt, token_dic)

In [16]:
a = torch.linalg.lstsq(refer[0]+refer[1]+refer[2], refer[3]).solution

In [17]:
infer_prompt = ['a man is riding on a dog']

In [12]:
infer_prompt = ['a ' + s1 + ' is riding on a ' + s2 for s1 in ['woman'] for s2 in ['bike', 'motorcycle', 'horse']]

In [18]:
infer_prompt

['a man is riding on a dog']

In [25]:
token_dic

{'s1': ['man', 'man', 'man', 'woman', 'woman', 'woman'],
 's2': ['bike', 'motorcycle', 'horse', 'bike', 'motorcycle', 'horse'],
 'r': ['riding', 'riding', 'riding', 'riding', 'riding', 'riding']}

In [27]:
infer_dic = {'s1': ['man'], 's2': ['dog'], 'r': ['riding']}

In [28]:
infer = extract_embedding(infer_prompt, infer_dic)

In [29]:
infer_eot = torch.matmul(infer[0]+infer[1]+infer[2], a)

In [30]:
infer_eot.shape

torch.Size([1, 768])

In [171]:
torch.save(infer_eot, 'ride_eot.pt')

In [45]:
inputs = processor(
    refer_prompt,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)

embeds = text_encoder(inputs.input_ids).last_hidden_state

In [44]:
refer_prompt

['a salt and a pepper',
 'a knife and a fork',
 'a pen and a paper',
 'a phone and a charger',
 'a milk and a cookies',
 'a shoes and a socks',
 'a tea and a sugar',
 'a computer and a mouse']

In [46]:
t = embeds[:, 6]

In [51]:
torch.nn.functional.cosine_similarity(t, infer_eot)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

In [44]:
def get_subject_embedding(token):
    inputs = processor(
        token,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    subject_embedds = text_encoder(inputs.input_ids).last_hidden_state[:, 1]
    return subject_embedds

In [166]:
tmp = get_subject_embedding(['man', 'riding', 'dog'])

In [167]:
torch.save(tmp, 'sub_embeds.pt')

In [158]:
tmp.shape

torch.Size([3, 768])

In [1]:
import json

In [8]:
inside_subjects = [
    "House", "Car", "Box", "Bag", "Room", "Envelope", "Drawer", "Cabinet", "Closet", 
    "Safe", "Jar", "Bottle", "Cup", "Bowl", "Fridge", "Oven", "Microwave", "Pantry", 
    "Suitcase", "Briefcase", "Backpack", "Pocket", "Wallet", "Purse", "Garage", 
    "Warehouse", "Building", "Tent", "Apartment", "Office", "Library", "Classroom", 
    "Laboratory", "Hospital", "Clinic", "Church", "Temple", "Mosque", "Stadium", 
    "Gym", "Arena", "Courtroom", "Prison", "Jail", "Cell", "Dungeon", "Castle", 
    "Basement", "Attic", "Shed", "Factory", "Workshop", "Barn", "Stable", "Aquarium", 
    "Cage", "Nest", "Burrow", "Cave", "Tunnel", "Subway", "Elevator", "Escalator", 
    "Hallway", "Corridor", "Lighthouse", "Observatory", "Theater", "Cinema", 
    "Museum", "Exhibit", "Gallery", "Restaurant", "Café", "Bar", "Club", "Hotel", 
    "Resort", "Lobby", "Reception", "Chamber", "Vault", "Bank", "Mine", "Crater", 
    "Volcano", "Cocoon", "Shell", "Egg", "Womb", "Heart", "Mind", "Soul", "Book", 
    "Computer", "Phone", "Tablet", "Cloud", "Network", "Internet", "Database", 
    "File", "Folder"
]

In [9]:
with open('../inside.json', 'w') as fn:
    json.dump(inside_subjects, fn)