In [2]:
!git clone https://github.com/CornerSiow/zero-shot-image-captioning.git
!pip install 'git+https://github.com/facebookresearch/detectron2.git'

Cloning into 'zero-shot-image-captioning'...
remote: Enumerating objects: 180, done.[K
remote: Counting objects: 100% (91/91), done.[K
remote: Compressing objects: 100% (91/91), done.[K
remote: Total 180 (delta 43), reused 0 (delta 0), pack-reused 89[K
Receiving objects: 100% (180/180), 76.90 MiB | 19.38 MiB/s, done.
Resolving deltas: 100% (82/82), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/detectron2.git
  Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-fdld7odi
  Running command git clone -q https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-fdld7odi
Collecting yacs>=0.1.8
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting fvcore<0.1.6,>=0.1.5
  Downloading fvcore-0.1.5.post20220512.tar.gz (50 kB)
[K     |████████████████████████████████| 50 kB 5.1 MB/s 
[?25hCollecting iopath<0.1.10,>=0.1.7
  Downloading 

In [3]:
!cp "zero-shot-image-captioning/code/Vocabulary.py" "Vocabulary.py"
!cp "zero-shot-image-captioning/code/DecoderLSTM.py" "DecoderLSTM.py"
!cp "zero-shot-image-captioning/code/Places365.py" "Places365.py"

In [4]:
import torch
from DecoderLSTM import DecoderLSTM
from Vocabulary import Vocabulary
import pickle
from Places365 import Places365
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
import cv2
import os
import numpy as np

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
place365 = Places365()
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
cfg.MODEL.DEVICE = str(device)
detectron2 = DefaultPredictor(cfg)
with open('zero-shot-image-captioning/data/filtered_symbolic.pickle', 'rb') as handle:
    filtered_symbolic = pickle.load(handle)  
vocab = Vocabulary()
vocab.loadFile("zero-shot-image-captioning/data/vocab.pickle")
vocab_size = len(vocab)
embed_size = len(filtered_symbolic)
hidden_size = 256 

model_final_cafdb1.pkl: 261MB [00:12, 21.5MB/s]                           


In [6]:
decoder = DecoderLSTM(embed_size, hidden_size, vocab_size)
decoder.load_state_dict(torch.load('zero-shot-image-captioning/data/lstm_decoder.pkl'))
decoder.to(device)
decoder.eval()

DecoderLSTM(
  (embedding): Embedding(30, 8)
  (lstm): LSTM(8, 256, bias=False, batch_first=True)
  (linear): Linear(in_features=256, out_features=30, bias=True)
)

In [7]:
def getImageSymbolics(img_path, threshold = 0.6):

    result = place365.pred(img_path)
   
    im = cv2.imread(img_path)        
    panoptic_seg, segments_info = detectron2(im)["panoptic_seg"]
    meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])

    stuffClass = meta.stuff_classes
    thingClass = meta.thing_classes

    stuffResult = []
    thingResult = []
    areaSize = 0
    
    for v in segments_info:        
        if v['isthing']:
            thingResult.append([thingClass[v['category_id']].replace("-", " ").replace("/", " ").replace("_", " "), v['score']])    
        else:
            areaSize = max(areaSize, v['area'])
            stuffResult.append([stuffClass[v['category_id']].replace("-", " ").replace("/", " ").replace("_", " "), v['area']])

    
    for v in thingResult:
            if v[0] not in result:
                result[v[0]] = 0
            result[v[0]] += v[1]
    
    
    for v in stuffResult:     
            if v[0] not in result:
                result[v[0]] = 0
            result[v[0]] += v[1]/areaSize
    
    symbolics = {}
    for r in result:
        if result[r] > threshold:
            symbolics[r] = result[r]
 
    return symbolics

In [8]:
results = []
for filename in os.scandir("zero-shot-image-captioning/img_test"):
  if filename.path.endswith('jpg'):
    symbolics = getImageSymbolics(filename.path)
    inputs = np.zeros(len(filtered_symbolic))

    for s in symbolics:
        if s.replace(" ", "_") in filtered_symbolic:
          inputs[filtered_symbolic.index(s.replace(" ", "_"))] = symbolics[s]
    
    inputs = torch.from_numpy(inputs).float()
    inputs = inputs.unsqueeze(0).unsqueeze(0).float().to(device)
    output = decoder.sample(inputs)    
    sentence = vocab.clean_sentence(output)        
    print(filename.name, sentence)
    results.append({
        'name' : filename.name,
        'path' : filename.path,
        'caption': sentence
    })

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


washing_2.jpg  a person washes his face in the sink
washing_1.jpg  a person washes his face in the sink
eating_2.jpg  a person eats a banana in front of a laptop
eating_1.jpg  a person eats a banana in front of a laptop
working_2.jpg  a person using a laptop in the office
bus_2.jpg  someone is waiting at the bus stop
cycling_1.jpg  a person riding a bike on a clear sky
working_1.jpg  a person using a laptop in the office
cycling_2.jpg  a person riding a bike on a clear sky
bus_1.jpg  someone is waiting at the bus stop


In [None]:
from google.colab.patches import cv2_imshow
for v in results:
  img = cv2.imread(v['path'])
  cv2_imshow(img)
  break