In [1]:
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import pickle
import torch
import clip
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer, util
import numpy as np
import gc

In [3]:
# Choose computation device
device = "cuda:0" if torch.cuda.is_available() else "cpu" 

# Load pre-trained CLIP model
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model.cuda().eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [4]:
path_data = '/'  # Base path for images

# Load the CSV file
data_unfiltered = pd.read_csv("/".join([path_data, "images_v2.csv"]))

In [5]:
len(data_unfiltered)

38479

## Remove subjective topics

In [6]:
subjective_topics = ['Favorite home decorations', 'Favourite item in kitchen', 'Favourite sports clubs', 'How the most loved item is used', 'icons', 'Idols', 'Latest furniture bought', ' looking over the shoulder', 'Most loved item', 'Most loved toy', 'Most played songs on the radio', 'Music idol', 'Next big thing you are planning to buy', 'Playing with most loved toy', 'Thing I dream about having', 'Things I wish I had', 'Using most loved item', 'Youth culture', 'What I wish I could buy']

In [7]:
unique_topics = data_unfiltered['topics'].unique()

In [8]:
Subjective_topic_list = []

for topic in unique_topics:
    topic_ = topic.replace('/ ', ', ')
    topic_split = topic_.split(',')
    for split in topic_split:
        if split in subjective_topics:
            Subjective_topic_list.append(topic)

In [9]:
data = data_unfiltered[~data_unfiltered['topics'].isin(Subjective_topic_list)]
data = data[data.id != '5d4befbfcf0b3a0f3f353c2e'] #remove rows with corrupted image 
display(data)

Unnamed: 0,id,country.name,country.id,region.id,type,imageRelPath,topics,place,income
1,5d4bf31ccf0b3a0f3f359814,Burundi,bi,af,image,assets/5d4bf31ccf0b3a0f3f359814/5d4bf31ccf0b3a...,Family snapshots,butoyi,26.994581
2,5d4bf31ccf0b3a0f3f35982a,Burundi,bi,af,image,assets/5d4bf31ccf0b3a0f3f35982a/5d4bf31ccf0b3a...,Cutlery,butoyi,26.994581
3,5d4bf31ccf0b3a0f3f35982e,Burundi,bi,af,image,assets/5d4bf31ccf0b3a0f3f35982e/5d4bf31ccf0b3a...,Family,butoyi,26.994581
4,5d4bf31ccf0b3a0f3f359830,Burundi,bi,af,image,assets/5d4bf31ccf0b3a0f3f359830/5d4bf31ccf0b3a...,Place where eating dinner,butoyi,26.994581
5,5d4bf31dcf0b3a0f3f35983c,Burundi,bi,af,image,assets/5d4bf31dcf0b3a0f3f35983c/5d4bf31dcf0b3a...,Plate of food,butoyi,26.994581
...,...,...,...,...,...,...,...,...,...
38474,5ec4f5513f62767d97a47324,France,fr,eu,image,assets/5ec4f5513f62767d97a47324/5ec4f5513f6276...,Bed,larriere,19671.000000
38475,5ec4f5513f62767d97a47325,France,fr,eu,image,assets/5ec4f5513f62767d97a47325/5ec4f5513f6276...,Bathroom/Toilet,larriere,19671.000000
38476,5ec4f5523f62767d97a47327,France,fr,eu,image,assets/5ec4f5523f62767d97a47327/5ec4f5523f6276...,Armchair,larriere,19671.000000
38477,5ec4f5523f62767d97a47328,France,fr,eu,image,assets/5ec4f5523f62767d97a47328/5ec4f5523f6276...,Armchair,larriere,19671.000000


In [10]:
len(data)

36753

In [11]:
labels=["poor", "lower-middle", "upper-middle", "rich"]

#list_of_topic2prompt_dict[1]
data["quartile"] = pd.qcut(data["income"], q=[0, 0.25, 0.5, 0.75, 1],    
                         labels=labels)

## Split topics

In [12]:
from collections import Counter
list_topics = list(data['topics'])


separate_topics = [t.lower().strip() for topic in list_topics for t in topic.split(",") ]


print(len(separate_topics))
#print(Counter(separate_topics))
set_topics = list(set(separate_topics))
print(len(set_topics))

45691
270


In [13]:
# map each topic to list of corresponding images
dict_topic2img = {}
for list_topics, image_path in zip(data['topics'], data['imageRelPath']):
    for topic in list_topics.split(","):
        topic = topic.lower().strip()
        if topic not in dict_topic2img:
            dict_topic2img[topic] = set() #### here a set was used in place of list to avoid duplicate where keyword appears twice in a topic
        dict_topic2img[topic].add(image_path)

print(len(dict_topic2img))

270


In [14]:
ground_truth_counts = [len(dict_topic2img[i]) for i in dict_topic2img]
print(sum(ground_truth_counts))

45615


In [15]:
topics = dict_topic2img.keys()

In [16]:
with open("mapping_topic2function.pkl", 'rb') as f:
    topic2function = pickle.load(f)

In [17]:
dict_function2imgGT = {}

for topic in topics:
    key = topic2function[topic]
    dict_function2imgGT[key] = dict_topic2img[topic]
print(len(dict_function2imgGT))

270


In [18]:
dict_img2topic, dict_img2country, dict_img2incomelevel, dict_country2region, dict_img2income = {}, {}, {}, {}, {}
for list_topics, image_path, country, incomelevel, region, income in zip(data['topics'], data['imageRelPath'], data['country.name'], data['quartile'], data['region.id'], data['income']):
    dict_img2topic[image_path] = list_topics
    dict_img2country[image_path] = country
    dict_img2incomelevel[image_path] = incomelevel
    dict_country2region[country] = region
    dict_img2income[image_path] = income
    

In [19]:
#dict_topic2img

In [20]:
image_list = []
for imgs in dict_topic2img.values():
    image_list += imgs
len(image_list)

45615

In [21]:
topic_list = []
for t, imgs in dict_topic2img.items():
    topics = [t] * len(imgs)
    topic_list += topics
len(topic_list)

45615

In [22]:
function_list = [topic2function[f] for f in topic_list]
len(function_list)

45615

In [24]:
country_list = [dict_img2country[c] for c in image_list]
income_list = [dict_img2incomelevel[i] for i in image_list]
region_list = [dict_country2region[cntry] for cntry in country_list]
income = [dict_img2income[inc] for inc in image_list]

functions = [' '.join(function.split(' ')[5 :]) for function in function_list]
function_topic = [f"the {clean_id} represents {function}" for clean_id, function in zip(topic_list, functions)]


In [25]:
data_sep = pd.DataFrame()
data_sep['topics'] = topic_list
data_sep['image ids'] = image_list
data_sep['generated functions'] = function_list
data_sep['income level'] = income_list
data_sep['income'] = income
data_sep['country'] = country_list
data_sep['continent'] = region_list
data_sep['function_topic'] = function_topic

In [26]:
data_sep.to_csv('one_image_many_topics.csv')

In [27]:
data = data_sep

In [28]:
path_images = [path_data + s for s in list(data['image ids'])]
image_ids1 = ['/'.join(s.split("/")[8:]) for s in path_images]

In [29]:
image_ids1[0].split("/")[1]

'5ec4fb77f0611d7ddd742855'

In [30]:
len(path_images)

45615

In [55]:
image_path_list = []
imgs_corrupted = []
for image_path in tqdm(path_images):
        path_to_img = image_path
        if Path(path_to_img).is_file():
            image_path_list.append(path_to_img)
        else:
            imgs_corrupted.append(path_to_img)
len(image_path_list)

100%|██████████| 45615/45615 [00:48<00:00, 934.57it/s] 


45615

In [56]:
def load_image(url_or_path):
    if url_or_path.startswith("http://") or url_or_path.startswith("https://"):
        return Image.open(requests.get(url_or_path, stream=True).raw)
    else:
        return Image.open(url_or_path).convert("RGB")

In [57]:
count = 5000
end = len(image_path_list) // count + 1
print(end)

list_img_emb = []
image_names_list = []
for k in tqdm(range(0, end)):
    print(k)                    
    imgs_corrupted = []
    images_prep = []
    k_images = image_path_list[count*k:count*(k+1)]
    print(len(k_images))
    

    for image_path in k_images:
        if Path(image_path).is_file():
            image = Image.open(image_path).convert("RGB")
            images_prep.append(preprocess(image))
        
            
    images_prep = torch.stack(images_prep).to("cuda")
    torch.cuda.empty_cache()
    
    with torch.no_grad():
        image_features = model.encode_image(images_prep).float()
        image_features /= image_features.norm(dim=-1, keepdim=True)
        torch.cuda.empty_cache()
    list_img_emb.append(image_features)
    
    image_names = ['/'.join(image_path.split("/")[8:])  for image_path in image_path_list]
    image_names_list.append(image_names)
    
print("Image encoding completed..")

10


  0%|          | 0/10 [00:00<?, ?it/s]

0
5000


 10%|█         | 1/10 [21:03<3:09:32, 1263.65s/it]

1
5000


 20%|██        | 2/10 [44:32<2:59:51, 1348.93s/it]

2
5000


 30%|███       | 3/10 [1:07:19<2:38:20, 1357.27s/it]

3
5000


 40%|████      | 4/10 [1:31:04<2:18:23, 1383.93s/it]

4
5000


 50%|█████     | 5/10 [1:38:36<1:27:20, 1048.08s/it]

5
5000


 60%|██████    | 6/10 [2:01:17<1:16:57, 1154.39s/it]

6
5000


 70%|███████   | 7/10 [2:23:22<1:00:30, 1210.10s/it]

7
5000


 80%|████████  | 8/10 [2:43:39<40:24, 1212.36s/it]  

8
5000


 90%|█████████ | 9/10 [3:00:13<19:04, 1144.05s/it]

9
615


100%|██████████| 10/10 [3:01:31<00:00, 1089.16s/it]

Image encoding completed..





In [58]:
img_features = torch.cat(list_img_emb, dim=0)
img_features.shape

torch.Size([45615, 512])

In [71]:
img_features.shape

torch.Size([45615, 512])

In [59]:
# f = open("sep_clip_img_embedding.pkl","wb")

# # write the python object (dict) to pickle file
# pickle.dump(img_features,f)

# # close file
# f.close()

### Text and Image Sim

In [31]:
with open("sep_clip_img_embedding.pkl","rb") as f:
    img_embedding = pickle.load(f)
img_embedding.shape

torch.Size([45615, 512])

In [32]:
img_features = img_embedding

In [33]:
gc.collect()
torch.cuda.empty_cache()

#### Change prompts here

In [34]:
text = data['function_topic'].to_list()


In [35]:
len(text)

45615

In [36]:
count = 5000
end = len(text) // count + 1
print(end)
score_list = []

for k in tqdm(range(10)):
    k_img_features = img_features[count*k:count*(k+1)]
    k_text = text[count*k:count*(k+1)]
    
    with torch.no_grad():
        text_tokens = clip.tokenize(k_text).cuda()
        text_features = model.encode_text(text_tokens).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = text_features.cpu().numpy() @ k_img_features.cpu().numpy().T # dot product of image and text features
    similarity_diag = similarity.diagonal()
    
    score_list += list(similarity_diag)


10


100%|██████████| 10/10 [00:13<00:00,  1.31s/it]


In [37]:
len(data)

45615

In [38]:
len(score_list)

45615

In [39]:
exp_1_results = data

In [40]:
exp_1_results["CLIP score"] = score_list

#exp_1_results = exp_1_results.drop(['Unnamed: 0'], axis=1)

In [41]:
exp_1_results.to_csv('exp_1_sep_function_topic_results.csv')

In [42]:
exp_1_results.head()

Unnamed: 0,topics,image ids,generated functions,income level,income,country,continent,function_topic,CLIP score
0,family snapshots,assets/5ec4fb77f0611d7ddd742855/5ec4fb77f0611d...,This is a photo of memories of shared moments.,upper-middle,1394.0,Colombia,am,the family snapshots represents memories of sh...,0.251733
1,family snapshots,assets/5d4beb77cf0b3a0f3f34c470/5d4beb77cf0b3a...,This is a photo of memories of shared moments.,rich,3267.97326,Kenya,af,the family snapshots represents memories of sh...,0.266196
2,family snapshots,assets/5ec4f7f1f0611d7ddd740a82/5ec4f7f1f0611d...,This is a photo of memories of shared moments.,poor,96.0,Nepal,as,the family snapshots represents memories of sh...,0.20418
3,family snapshots,assets/5d4beaffcf0b3a0f3f34b7bc/5d4beaffcf0b3a...,This is a photo of memories of shared moments.,poor,80.084998,India,as,the family snapshots represents memories of sh...,0.303029
4,family snapshots,assets/5d4bf4b0cf0b3a0f3f35c062/5d4bf4b0cf0b3a...,This is a photo of memories of shared moments.,rich,2944.177918,Vietnam,as,the family snapshots represents memories of sh...,0.272154
