In [1]:
# load clip model from lavis library
!pip install salesforce-lavis -U



In [2]:
import torch
import numpy as np
import random
from PIL import Image
from tqdm.notebook import tqdm
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
import torch.nn.functional as F
from torch import nn

  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [3]:
# Make reproducible code
GLOBAL_SEED = 10

np.random.seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)
torch.use_deterministic_algorithms(True)
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# **Load Dataset**

In [5]:
!pip install huggingface_hub -q

In [6]:
from huggingface_hub import login
from google.colab import userdata

access_token = userdata.get('HF_TOKEN_ALL')
login(token = access_token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [7]:
USERNAME = userdata.get('HUGGINGFACE_USERNAME')
ACCESS_TOKEN = access_token

In [12]:
ds_test = 'MNIST_test'
local = '/content/'
hf_path_test = f'VQA-Illusion/{ds_test}'
hf_path_model = 'MNIST_CLIP'
hf_path_weights = 'CLIP_MNIST_train'

In [9]:
%cd {local}
!git clone 'https://{USERNAME}:{ACCESS_TOKEN}@huggingface.co/datasets/{hf_path_test}'

/content
Cloning into 'MNIST_test'...
remote: Enumerating objects: 5559, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 5559 (delta 0), reused 0 (delta 0), pack-reused 5558 (from 1)[K
Receiving objects: 100% (5559/5559), 814.07 KiB | 3.97 MiB/s, done.
Resolving deltas: 100% (1/1), done.
Updating files: 100% (5548/5548), done.
Filtering content: 100% (5545/5545), 952.21 MiB | 4.40 MiB/s, done.


In [10]:
import pandas as pd

df = pd.read_csv(f'{local + ds_test}/df_data.csv')
df

Unnamed: 0,image_name,Pprompt,Nprompt,illusion_strength,label
0,Mnist_1,A field of blooming sunflowers swaying in the ...,low quality,1.5,7
1,Mnist_2,A peaceful countryside scene with grazing shee...,low quality,1.5,2
2,Mnist_3,A tranquil pond with lily pads floating on the...,low quality,1.5,1
3,Mnist_4,A sunny vineyard with rows of ripe grapes,low quality,1.5,0
4,Mnist_5,A picturesque vineyard at sunset with the sky ...,low quality,1.5,4
...,...,...,...,...,...
1104,Mnist_1105,Misty jungle surrounded by vibrant flowers and...,low quality,1.5,8
1105,Mnist_1106,A forest with blooming flowers,low quality,1.5,0
1106,Mnist_1107,"Desolate desert landscape, shifting sands illu...",low quality,1.5,5
1107,Mnist_1108,A vast desert with a towering canyon in the di...,low quality,1.5,0


# **Load Model**

In [11]:
def load_model(model_path, device):
  loaded_model, loaded_vis_processors, loaded_text_processors = load_model_and_preprocess("clip_feature_extractor", "ViT-B-32", is_eval=True, device = device)
  fine_tuned_weights = torch.load(model_path)
  loaded_model.load_state_dict(fine_tuned_weights)
  return loaded_model, loaded_vis_processors, loaded_text_processors

In [13]:
%cd {local}
!git clone 'https://{USERNAME}:{ACCESS_TOKEN}@huggingface.co/VQA-Illusion/{hf_path_model}'

/content
Cloning into 'MNIST_CLIP'...
remote: Enumerating objects: 6, done.[K
remote: Total 6 (delta 0), reused 0 (delta 0), pack-reused 6 (from 1)[K
Unpacking objects: 100% (6/6), 2.11 KiB | 2.11 MiB/s, done.


In [14]:
model, vis_processors, text_processors = load_model(f"/content/{hf_path_model}/{hf_path_weights}.pth", device)

100%|████████████████████████████████████████| 354M/354M [00:01<00:00, 201MiB/s]
  fine_tuned_weights = torch.load(model_path)


In [None]:
# model, vis_processors, text_processors = load_model_and_preprocess("clip_feature_extractor", "ViT-B-32", is_eval=True, device = device)

100%|████████████████████████████████████████| 354M/354M [00:03<00:00, 110MiB/s]


# **Inference**

In [15]:
labels = [
    "illusion digit 0",
    "illusion digit 1",
    "illusion digit 2",
    "illusion digit 3",
    "illusion digit 4",
    "illusion digit 5",
    "illusion digit 6",
    "illusion digit 7",
    "illusion digit 8",
    "illusion digit 9",
    "no illusion digit"
]

raw_labels = [
    "digit 0",
    "digit 1",
    "digit 2",
    "digit 3",
    "digit 4",
    "digit 5",
    "digit 6",
    "digit 7",
    "digit 8",
    "digit 9",
]

In [16]:
labels = [text_processors["eval"](label) for label in labels]
raw_labels = [text_processors["eval"](rlabel) for rlabel in raw_labels]

In [17]:
df["raw_answer"] = None
df["ill_answer"] = None
df["illless_answer"] = None
df["ill_filter_answer"] = None
df["illless_filter_answer"] = None

In [18]:
def inference(img, labels, model, vis_processors, device):
    image = vis_processors["eval"](img).unsqueeze(0).to(device)
    sample = {"image": image, "text_input": labels}
    clip_features = model.extract_features(sample)
    image_features = clip_features.image_embeds_proj
    text_features = clip_features.text_embeds_proj
    sims = (image_features @ text_features.t())[0] / 0.01
    probs = torch.nn.Softmax(dim=0)(sims).tolist()
    max_index = probs.index(max(probs))
    max_label = labels[max_index]
    return max_label

In [19]:
%cd '/content/MNIST_test'

/content/MNIST_test


In [20]:
for index, row in tqdm(df.iterrows(), total=len(df)):
    raw_image = Image.open(f"./raw_images/{row['image_name']}.jpg").convert("RGB")
    ill_image = Image.open(f"./ill_images/{row['image_name']}.jpg").convert("RGB")
    ill_less_image = Image.open(f"./illusionless_images/{row['image_name']}.jpg").convert("RGB")
    ill_filtered_image = Image.open(f"./illusion_images_filtered/{row['image_name']}.jpg").convert("RGB")
    ill_less_filtered_image = Image.open(f"./illusionless_images_filtered/{row['image_name']}.jpg").convert("RGB")

    # RAW
    df.loc[index, "raw_answer"] = inference(raw_image, raw_labels, model, vis_processors, device)

    # Illusion
    df.loc[index, "ill_answer"] = inference(ill_image, labels, model, vis_processors, device)

    # Illusionless
    df.loc[index, "illless_answer"] = inference(ill_less_image, labels, model, vis_processors, device)

    # IllusionFilter
    df.loc[index, "ill_filter_answer"] = inference(ill_filtered_image, labels, model, vis_processors, device)

    # IllusionlessFilter
    df.loc[index, "illless_filter_answer"] = inference(ill_less_filtered_image, labels, model, vis_processors, device)

  0%|          | 0/1109 [00:00<?, ?it/s]

In [21]:
df

Unnamed: 0,image_name,Pprompt,Nprompt,illusion_strength,label,raw_answer,ill_answer,illless_answer,ill_filter_answer,illless_filter_answer
0,Mnist_1,A field of blooming sunflowers swaying in the ...,low quality,1.5,7,digit 7,illusion digit 7,no illusion digit,illusion digit 7,no illusion digit
1,Mnist_2,A peaceful countryside scene with grazing shee...,low quality,1.5,2,digit 2,illusion digit 2,no illusion digit,illusion digit 2,no illusion digit
2,Mnist_3,A tranquil pond with lily pads floating on the...,low quality,1.5,1,digit 1,illusion digit 1,no illusion digit,illusion digit 1,no illusion digit
3,Mnist_4,A sunny vineyard with rows of ripe grapes,low quality,1.5,0,digit 0,illusion digit 0,no illusion digit,illusion digit 0,no illusion digit
4,Mnist_5,A picturesque vineyard at sunset with the sky ...,low quality,1.5,4,digit 4,illusion digit 4,no illusion digit,illusion digit 4,no illusion digit
...,...,...,...,...,...,...,...,...,...,...
1104,Mnist_1105,Misty jungle surrounded by vibrant flowers and...,low quality,1.5,8,digit 8,illusion digit 8,no illusion digit,illusion digit 8,no illusion digit
1105,Mnist_1106,A forest with blooming flowers,low quality,1.5,0,digit 0,illusion digit 0,no illusion digit,illusion digit 0,no illusion digit
1106,Mnist_1107,"Desolate desert landscape, shifting sands illu...",low quality,1.5,5,digit 5,illusion digit 5,no illusion digit,illusion digit 5,no illusion digit
1107,Mnist_1108,A vast desert with a towering canyon in the di...,low quality,1.5,0,digit 0,illusion digit 0,no illusion digit,illusion digit 0,no illusion digit


In [22]:
df.to_csv(f"/content/MNIST_CLIP_inference.csv", index=False)