In [22]:
import clip 
import os 
import pandas as pd
import json
from PIL import Image
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [41]:
test_data_path = '../vizviz/vqa/val_df_thresh_3.csv'
train_data_path = '../vizviz/vqa/train_df_thresh_3.csv'
image_dir = '../vizviz/vqa/train'

train_df = pd.read_csv(train_data_path)
test_df = pd.read_csv(test_data_path)
test_df['image_path'] = test_df['image'].apply(lambda x : os.path.abspath(os.path.join(image_dir, x)))

In [15]:
answer_candidates = train_df[train_df.answerable == 1].final_answer.unique().tolist()
label2id = {ans:idx for idx,ans in enumerate(answer_candidates)}
id2label = {v:k for k,v in label2id.items()}

with open('../vizviz/vqa/label2id.json','w') as f:
    json.dump(label2id,f)

with open('../vizviz/vqa/id2label.json','w') as f:
    json.dump(id2label,f)


In [23]:
len(id2label)

821

In [34]:
class ModelConfig:
    clip_model_name = "RN50"
    id2label = '../vizviz/vqa/id2label.json'
    model_path = '../runs/vqa_clip_model/1/clip_vqa_0.632963594994311.pth'
    hidden_dim = 2048
    input_dim = 2048
    output_dim = 821

In [113]:
from network import VQAModelV1

class VQAModule:
    CLIP_MODELS = [
            "RN50",
            "RN101",
            "RN50x4",
            "RN50x16",
            "RN50x64",
            "ViT-B/32",
            "ViT-B/16",
            "ViT-L/14",
            "ViT-L/14@336px",
        ]
    def __init__(self, model_config, device="cpu"):
        self.model_config = model_config
        self.device=device 
        self.clip_model, self.preprocess = self.load_clip(model_name=self.model_config.clip_model_name, device=self.device)
        print("Clip Loaded")

        self.id2label = json.load(open(model_config.id2label))
        self.vqa_model = torch.load(model_config.model_path).to(self.device)
        self.vqa_model.eval()
        
    def load_clip(self, model_name="RN50", device="cpu"):
        assert model_name in self.CLIP_MODELS, f"clip models available {self.CLIP_MODELS}"
        clip_model, preprocess = clip.load(model_name, device=device)
        return clip_model, preprocess

    def predict(self, question : str, image_path : str):
        assert os.path.exists(image_path), f"{image_path} does not exists"
        image = self.preprocess(Image.open(image_path).convert("RGB")).unsqueeze(0).to(self.device)
        question = clip.tokenize(question).to(self.device)
        with torch.no_grad():
            image_features = self.clip_model.encode_image(image)
            text_features = self.clip_model.encode_text(question)
            vqa_input = torch.cat((image_features, text_features), 1).to(torch.float32)
            outputs = self.vqa_model(vqa_input)
            _, index = outputs.max(1)
        pred_index = int(index.detach().cpu())
        answer = self.id2label[str(pred_index)]
        return answer

In [114]:
vqa = VQAModule(model_config=ModelConfig)

Clip Loaded
