# CLIP for extracting visual features
* Why CLIP?  
   Because setting up the original BUTD feature extracter is an absolute nightmare...  
   ... and we will corrupt the image features with KB embeddings anyway

## Setup

In [1]:
# CLIP
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
from IPython.display import clear_output 
clear_output()

In [2]:
# Stuff we may or may not need.
# Handling data
import glob
import json
from PIL import Image
import os
import random

# Modeling and training
import torch
import numpy
from transformers import AutoConfig, AutoTokenizer, GPT2PreTrainedModel, GPT2Model, AdamW, get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

from torch.utils.tensorboard import SummaryWriter
import tqdm

# Pretrained CLIP models
import clip

# Evaluation
import torchtext

import matplotlib.pyplot as plt

2021-10-06 07:37:08.896060: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(21)

encoder_name = "ViT-B/32"

print("Available CLIP variants: {}".format(clip.available_models()))
encoder, preprocess = clip.load(encoder_name, device=device)

batch_size = 16

Available CLIP variants: ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']


100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 225MiB/s]


In [4]:
scene_paths = glob.glob('../input/simmc-img/data/simmc2_scene_jsons_dstc10_public/public/*_scene.json')
print(len(scene_paths))
print(scene_paths[5])

1740
../input/simmc-img/data/simmc2_scene_jsons_dstc10_public/public/cloth_store_paul_5_4_scene.json


# Do the thing

In [5]:
def get_features(img):
    img = preprocess(img).to(device)
    encoder.eval()
    with torch.no_grad():
        image_features = encoder.encode_image(img.unsqueeze(0))
    return image_features

In [6]:
out = {} 
error = 0 # 3 corrupted images
for idx, scene_path in enumerate(scene_paths):
    
    with open(scene_path, 'r') as f:
        objects = json.load(f)['scenes'][0]['objects']
    
    # Retrive image path
    img_path = '../input/simmc-img/data/all_images/'+scene_path.split('/')[-1].replace('m_','').replace('_scene.json','.png')
    
    # Get the scene image
    scene_img = Image.open(img_path)
    
    # Get sub-images for each bbox
    try:
        imgs = {'scene': get_features(scene_img)}
    except:
        error += 1
        continue
    for obj_i, object in enumerate(objects):
        index = object['index']
        x,y,h,w = object['bbox']
        x = max(x,1) # Some bbox has width 0 ???
        y = max(y,1)
        w = max(w,1)
        h = max(h,1)
        
        obj_img = scene_img.crop((x,y,x+w,y+h))
        imgs[index] = get_features(obj_img)
    
    out[scene_path.split('/')[-1].replace('.json','')] = imgs
    
#     if idx > 10:
#         break

torch.save(out, './img_features.pt')
print(f"# ERROR: {error}")

# ERROR: 4


In [7]:
data = torch.load('img_features.pt')
img = data['cloth_store_2_11_11_scene'][2]
print(img.shape)
print(img)

torch.Size([1, 512])
tensor([[ 1.5833e-01,  8.3313e-02, -6.0449e-01,  8.7128e-03,  3.2990e-02,
         -2.2253e-01,  2.2510e-01,  7.2656e-01,  8.9307e-01,  9.6069e-02,
          5.0586e-01, -1.7017e-01,  6.3525e-01, -2.0654e-01, -3.5449e-01,
         -2.0459e-01,  3.8867e-01, -2.0679e-01,  2.7979e-01,  6.6223e-02,
          2.8516e-01,  1.3513e-01,  1.7236e-01, -3.8788e-02,  3.7817e-01,
          2.8857e-01, -9.7717e-02, -3.1592e-01, -3.9258e-01, -2.6074e-01,
          1.5228e-02,  1.2451e-01,  3.7524e-01, -9.9915e-02, -4.9097e-01,
         -4.4342e-02, -3.2812e-01,  2.1863e-01, -1.1124e-02, -5.2490e-01,
         -6.4746e-01, -5.0000e-01, -7.3576e-04, -3.3813e-01, -7.2266e-01,
          1.7539e+00,  2.8564e-01,  1.3257e-01, -1.1884e-01, -2.1204e-01,
          2.8833e-01, -9.4910e-02,  9.4177e-02, -2.8247e-01, -1.9714e-01,
          2.4609e-01, -1.3611e-01,  4.2381e-03,  1.2622e-01,  3.6279e-01,
          1.0283e+00,  2.4438e-01,  6.7139e-02, -6.5479e-01, -1.5613e-01,
         -9.9060e