In [None]:
from mmf.utils.build import build_processors
from omegaconf import OmegaConf

import urllib.request

import matplotlib.pyplot as plt
import torchvision.datasets.folder as tv_helpers

from mmf.common.sample import Sample, SampleList

In [None]:
# Build config, processors and model


# Build configuration
dataset_conf = OmegaConf.load('/content/configs/okvqa_colab.yaml')
model_conf = OmegaConf.load('/content/mmf_transformer_config.yaml')
experiment_conf = OmegaConf.load('/content/experiment_config.yaml')
extra_args = ["env.data_dir=/root/.cache/torch/mmf/data/"]
extra_args = OmegaConf.from_dotlist(extra_args)

conf = OmegaConf.merge(dataset_conf, model_conf, experiment_conf, extra_args)

conf.dataset_config.okvqa_colab.processors.answer_processor.params.vocab_file\
    ="/root/.cache/torch/mmf/data/datasets/" \
    + conf.dataset_config.okvqa_colab.processors.answer_processor.params.vocab_file



# Build processors
mmf_processors = build_processors(conf.dataset_config.okvqa_colab.processors)



# Build model
model = MMFTransformer(conf.model_config.mmf_transformer)
model.build()
model.init_losses()

state_dict = torch.load('okvqa_mmft.ckpt')
model.load_state_dict(state_dict["model"])
model.to("cuda")
model.eval()
print("Model Loaded Successfully!!")

In [None]:
def create_sample(image, text):
  # Create a Sample
  current_sample = Sample()

  # Preprocess the text to generate tokens
  processed_text = mmf_processors["text_processor"]({"text": text})
  current_sample.update(processed_text)
  
  # Load the image and run image preprocessors on it
  current_sample.image = mmf_processors["image_processor"](image)

  # Create a sample list
  sample_list = SampleList([current_sample])
  sample_list = sample_list.to("cuda")
  return sample_list

In [None]:
image_url = "http://images.cocodataset.org/train2017/000000444444.jpg" #@param {type:"string"}
question = "Which sport requires riding on the animal depicted?" #@param {type:"string"}
urllib.request.urlretrieve(image_url, "/content/local.jpg")
image = tv_helpers.default_loader("/content/local.jpg")
print("Image :: \n")
plt.imshow(image)
print("Question :: ", question)

output = model(create_sample(image, question))
output = torch.nn.functional.softmax(output["scores"], dim=1)
prob, indices = output.topk(1, dim=1)
answer = mmf_processors["answer_processor"].idx2word(indices[0][0])
print(answer)