In [1]:
# load clip model from lavis library
!pip install salesforce-lavis -U



In [2]:
import torch
import numpy as np
import random
from PIL import Image
from tqdm.notebook import tqdm
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
import torch.nn.functional as F
from torch import nn

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [3]:
# Make reproducible code
GLOBAL_SEED = 10

np.random.seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)
torch.use_deterministic_algorithms(True)
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# **Load Dataset**

In [5]:
!pip install huggingface_hub -q

In [7]:
from huggingface_hub import login
from google.colab import userdata

access_token = userdata.get('HF_TOKEN_ALL')
login(token = access_token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [9]:
USERNAME = userdata.get('HUGGINGFACE_USERNAME')
ACCESS_TOKEN = access_token

In [10]:
ds_test = 'IllusionAnimals_test'
local = '/content/'
hf_path_test = f'VQA-Illusion/{ds_test}'
hf_path_model = 'IllusionAnimals_CLIP'
hf_path_weights = 'CLIP_IllusionAnimals_train'

In [11]:
%cd {local}
!git clone 'https://{USERNAME}:{ACCESS_TOKEN}@huggingface.co/datasets/{hf_path_test}'

/content
Cloning into 'IllusionAnimals_test'...
remote: Enumerating objects: 5027, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Total 5027 (delta 4), reused 4 (delta 4), pack-reused 5022 (from 1)[K
Receiving objects: 100% (5027/5027), 748.19 KiB | 3.67 MiB/s, done.
Resolving deltas: 100% (5/5), done.
Updating files: 100% (5004/5004), done.
Filtering content: 100% (5001/5001), 931.75 MiB | 5.21 MiB/s, done.


In [12]:
import pandas as pd

df = pd.read_csv(f'{local + ds_test}/df_data.csv')
df

Unnamed: 0,image_name,Pprompt,Nprompt,illusion_strength,label
0,IllusionAnimals_1,A raging river flowing through a dense jungle ...,low quality,2.5,cat
1,IllusionAnimals_2,A starry night sky over a tranquil lake,low quality,2.5,cat
2,IllusionAnimals_3,Jaguar (Panthera Onca) patrolling South Americ...,low quality,2.5,cat
3,IllusionAnimals_4,"Flowing lava illuminates cavern walls, ancient...",low quality,2.5,cat
4,IllusionAnimals_5,"Dense forest canopy, sunlight filters through,...",low quality,2.5,cat
...,...,...,...,...,...
995,IllusionAnimals_996,Saharan sandstorm turning daytime into darkness,low quality,2.5,rooster
996,IllusionAnimals_997,Urban city with skyscrapers and traffic,low quality,2.5,rooster
997,IllusionAnimals_998,A serene meadow with wildflowers and butterflies,low quality,2.5,rooster
998,IllusionAnimals_999,Elk bugling in crisp autumn air,low quality,2.5,rooster


# **Load Model**

In [13]:
def load_model(model_path, device):
  loaded_model, loaded_vis_processors, loaded_text_processors = load_model_and_preprocess("clip_feature_extractor", "ViT-B-32", is_eval=True, device = device)
  fine_tuned_weights = torch.load(model_path)
  loaded_model.load_state_dict(fine_tuned_weights)
  return loaded_model, loaded_vis_processors, loaded_text_processors

In [14]:
%cd {local}
!git clone 'https://{USERNAME}:{ACCESS_TOKEN}@huggingface.co/VQA-Illusion/{hf_path_model}'

/content
Cloning into 'IllusionAnimals_CLIP'...
remote: Enumerating objects: 10, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)[K
Unpacking objects: 100% (10/10), 3.97 KiB | 1.98 MiB/s, done.


In [15]:
model, vis_processors, text_processors = load_model(f"/content/{hf_path_model}/{hf_path_weights}.pth", device)

100%|████████████████████████████████████████| 354M/354M [00:01<00:00, 185MiB/s]
  fine_tuned_weights = torch.load(model_path)


In [None]:
# model, vis_processors, text_processors = load_model_and_preprocess("clip_feature_extractor", "ViT-B-32", is_eval=True, device = device)

100%|████████████████████████████████████████| 354M/354M [00:01<00:00, 203MiB/s]


# **Inference**

In [16]:
labels = [
    "illusion animal cat",
    "illusion animal dog",
    "illusion animal pigeon",
    "illusion animal butterfly",
    "illusion animal elephant",
    "illusion animal horse",
    "illusion animal deer",
    "illusion animal snake",
    "illusion animal fish",
    "illusion animal rooster",
    "no illusion animal"
]

raw_labels = ['cat', 'dog', 'pigeon', 'butterfly', 'elephant', 'horse', 'deer', 'snake', 'fish', 'rooster']

In [17]:
labels = [text_processors["eval"](label) for label in labels]
raw_labels = [text_processors["eval"](rlabel) for rlabel in raw_labels]

In [18]:
df["raw_answer"] = None
df["ill_answer"] = None
df["illless_answer"] = None
df["ill_filter_answer"] = None
df["illless_filter_answer"] = None

In [19]:
def inference(img, labels, model, vis_processors, device):
    image = vis_processors["eval"](img).unsqueeze(0).to(device)
    sample = {"image": image, "text_input": labels}
    clip_features = model.extract_features(sample)
    image_features = clip_features.image_embeds_proj
    text_features = clip_features.text_embeds_proj
    sims = (image_features @ text_features.t())[0] / 0.01
    probs = torch.nn.Softmax(dim=0)(sims).tolist()
    max_index = probs.index(max(probs))
    max_label = labels[max_index]
    return max_label

In [20]:
%cd '/content/IllusionAnimals_test'

/content/IllusionAnimals_test


In [21]:
for index, row in tqdm(df.iterrows(), total=len(df)):
    raw_image = Image.open(f"./raw_images/{row['image_name']}.jpg").convert("RGB")
    ill_image = Image.open(f"./ill_images/{row['image_name']}.jpg").convert("RGB")
    ill_less_image = Image.open(f"./illusionless_images/{row['image_name']}.jpg").convert("RGB")
    ill_filtered_image = Image.open(f"./illusion_images_filtered/{row['image_name']}.jpg").convert("RGB")
    ill_less_filtered_image = Image.open(f"./illusionless_images_filtered/{row['image_name']}.jpg").convert("RGB")

    # RAW
    df.loc[index, "raw_answer"] = inference(raw_image, raw_labels, model, vis_processors, device)

    # Illusion
    df.loc[index, "ill_answer"] = inference(ill_image, labels, model, vis_processors, device)

    # Illusionless
    df.loc[index, "illless_answer"] = inference(ill_less_image, labels, model, vis_processors, device)

    # IllusionFilter
    df.loc[index, "ill_filter_answer"] = inference(ill_filtered_image, labels, model, vis_processors, device)

    # IllusionlessFilter
    df.loc[index, "illless_filter_answer"] = inference(ill_less_filtered_image, labels, model, vis_processors, device)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [22]:
df

Unnamed: 0,image_name,Pprompt,Nprompt,illusion_strength,label,raw_answer,ill_answer,illless_answer,ill_filter_answer,illless_filter_answer
0,IllusionAnimals_1,A raging river flowing through a dense jungle ...,low quality,2.5,cat,cat,illusion animal cat,no illusion animal,illusion animal cat,illusion animal cat
1,IllusionAnimals_2,A starry night sky over a tranquil lake,low quality,2.5,cat,cat,illusion animal cat,no illusion animal,illusion animal cat,no illusion animal
2,IllusionAnimals_3,Jaguar (Panthera Onca) patrolling South Americ...,low quality,2.5,cat,cat,illusion animal cat,no illusion animal,illusion animal cat,illusion animal cat
3,IllusionAnimals_4,"Flowing lava illuminates cavern walls, ancient...",low quality,2.5,cat,cat,illusion animal cat,no illusion animal,illusion animal cat,illusion animal cat
4,IllusionAnimals_5,"Dense forest canopy, sunlight filters through,...",low quality,2.5,cat,cat,illusion animal cat,no illusion animal,illusion animal cat,no illusion animal
...,...,...,...,...,...,...,...,...,...,...
995,IllusionAnimals_996,Saharan sandstorm turning daytime into darkness,low quality,2.5,rooster,rooster,illusion animal rooster,illusion animal snake,illusion animal rooster,illusion animal snake
996,IllusionAnimals_997,Urban city with skyscrapers and traffic,low quality,2.5,rooster,rooster,illusion animal rooster,no illusion animal,illusion animal rooster,no illusion animal
997,IllusionAnimals_998,A serene meadow with wildflowers and butterflies,low quality,2.5,rooster,rooster,illusion animal rooster,no illusion animal,illusion animal rooster,no illusion animal
998,IllusionAnimals_999,Elk bugling in crisp autumn air,low quality,2.5,rooster,rooster,illusion animal rooster,illusion animal deer,illusion animal rooster,illusion animal deer


In [23]:
df.to_csv(f"/content/IllusionAnimals_CLIP_inference.csv", index=False)