In [14]:
import warnings
warnings.filterwarnings('ignore')

import clip 
import numpy as np 
import os
import pandas as pd
from PIL import Image
import torch

In [4]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [24]:
%%writefile ../src/clip_vqa/extract_features.py
import warnings
warnings.filterwarnings('ignore')

import clip 
import numpy as np 
import os
import pandas as pd
from PIL import Image
import torch
import argparse

def load_clip(model_name='RN50', device='cpu'):
    clip_model, preprocess = clip.load(model_name, device=device)
    print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}")
    print(f"Input resolution: {clip_model.visual.input_resolution}")
    print(f"Context length: {clip_model.context_length}")
    print(f"Vocab size: {clip_model.vocab_size}")
    return clip_model, preprocess

def extract_features(image_path, question, clip_model, transform, device='cpu'):
    image = Image.open(image_path).convert('RGB')
    if transform is not None:
        image = transform(image).unsqueeze(0).to(device)
    question = clip.tokenize(question).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        text_features = clip_model.encode_text(question)
    return image_features.detach().cpu(), text_features.detach().cpu()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path",type=str)
    parser.add_argument("--feat-save-dir",type=str)
    parser.add_argument("--clip-model", type=str, default='RN50')
    parser.add_argument("--device", type=str, default='cuda:0')
    parser.add_argument("--save-path",type=str)

    args = parser.parse_args()

    assert os.path.exists(args.data_path)

    feat_save_dir = args.feat_save_dir
    device=args.device
    os.makedirs(feat_save_dir, exist_ok=True)
    clip_model_name = args.clip_model
    save_path = args.save_path

    clip_model, preprocess = load_clip(model_name='RN50', device=device)

    data = pd.read_json(args.data_path)

    img_feat_list = []
    text_feat_list = []

    total = len(data)
    done = 0
    for idx,row in data.iterrows():
        image_path = row['image_path']
        question = row['question']
        text_feat_name = row['image_name'].split('.')[0] + '_text.pt'
        img_feat_name = row['image_name'].split('.')[0] + '_img.pt'
        
        img_feat, text_feat = extract_features(image_path=image_path,
                                       question=question,
                                       clip_model=clip_model,
                                       transform=preprocess,
                                       device=device
                                    )
        img_feat_path = os.path.join(feat_save_dir,img_feat_name)
        text_feat_path = os.path.join(feat_save_dir,text_feat_name)
        torch.save(img_feat, img_feat_path)
        torch.save(text_feat, text_feat_path)

        img_feat_list.append(img_feat_path)
        text_feat_list.append(text_feat_path)
        done+=1

        if done % 500 == 0:
            print(f"Total={total} Done={done}")

    data['img_feat'] = img_feat_list
    data['text_feat'] = text_feat_list

    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    data.to_json(save_path)


Overwriting ../src/clip_vqa/extract_features.py


In [9]:
clip_model, preprocess = load_clip(device='cuda:0')

Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


In [12]:
train_path = '../vizviz/vqa/train_df.json'
train_data = pd.read_json(train_path)
train_data.head(2)

Unnamed: 0,image,question,answers,answer_type,answerable,final_answer,image_path
0,VizWiz_train_00000000.jpg,What's the name of this product?,"[{'answer_confidence': 'yes', 'answer': 'basil...",other,1,basil leaves,/nfs/home/scg1143/MLDS/Quarter3/DeepLearning/P...
1,VizWiz_train_00000001.jpg,Can you tell me what is in this can please?,"[{'answer_confidence': 'yes', 'answer': 'soda'...",other,1,coca cola,/nfs/home/scg1143/MLDS/Quarter3/DeepLearning/P...


In [15]:
def extract_features(image_path, question, clip_model, transform, device='cpu'):
    image = Image.open(image_path).convert('RGB')
    if transform is not None:
        image = transform(image).unsqueeze(0).to(device)
    question = clip.tokenize(question).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        text_features = clip_model.encode_text(question)
    return image_features.detach().cpu(), text_features.detach().cpu()

In [16]:
data = train_data.iloc[0]
img_feat, text_feat = extract_features(image_path=data.image_path,
                                       question=data.question,
                                       clip_model=clip_model,
                                       transform=preprocess,
                                       device='cuda:0'
                                    )
                                    

In [22]:
torch.save(img_feat,'sample.pt')

In [23]:
torch.save(text_feat,'sample_text.pt')