***usage***
- 查看eval的结果
- 根据输入的图片（路径），产生输出

In [25]:
from PIL import Image
import torch
import os
import json
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
# from modi

In [27]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [22]:
def displayEvalResult(dataset, id, ref_path, hyp_path):
    """
    :brief: read eval results by ref_path,hyp_path
    :param dataset: dataset name
    :param id :image index
    :param ref_path: reference path
    :param hyp_path: answer path
    """
    # with open(os.join(dataset,'TEST',"0",str(id)), )
    img1 = Image.open(os.path.join('dataset', dataset,'TEST',"imgs", "0",str(id)+".png"))
    img2 = Image.open(os.path.join('dataset', dataset,'TEST',"imgs","1",str(id)+".png"))
    img3 = Image.open(os.path.join('dataset', dataset,'TEST',"imgs","2",str(id)+".png"))
    with open(ref_path, 'r', encoding='utf-8') as f:
        ref = json.load(f)
    with open(hyp_path, 'r', encoding='utf-8') as f:
        hyp = json.load(f)
    img1.show(f"{id}: before")
    img2.show(f"{id}: during")
    img3.show(f"{id}: after")
    print(f"reference:")
    for i in "".join(ref[str(id)]).split("<sep>"):
        print(i)
    print("###########")   
    print(f"answers:")
    for i in "".join(hyp[str(id)]).split("<sep>"):
        print(i)
    print("&&&&&&&&&&&")


In [38]:
#extract feature
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
])
# resnet50 = models.resnet50(pretrained=True)
# resnet50.eval()

def read_img(img_path):
    img = Image.open(img_path)
    img = img.convert('RGB')# read png...
    img = transform_to_tensor(img).to(device)
    img = img.unsqueeze(0)
    return img
    

In [31]:
def beam_search_decoder(decoder, memory, word_map, rev_word_map, beam_size, max_len=120):
    k = beam_size
    vocab_size = len(word_map)
    
    start_token = word_map['<start>']
    end_token = word_map['<end>']
    
    # Initialize sequences with the start token
    sequences = [[start_token]]
    scores = torch.zeros(k, 1).to(device)
    
    for _ in range(max_len):
        all_candidates = []
        
        for i in range(len(sequences)):
            seq = sequences[i]
            score = scores[i]
            
            if seq[-1] == end_token:
                all_candidates.append((seq, score))
                continue
            
            tgt = torch.tensor(seq).unsqueeze(1).to(device)
            tgt_length = tgt.size(0)
            
            mask = (torch.triu(torch.ones(tgt_length, tgt_length)) == 1).transpose(0, 1)
            mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
            mask = mask.to(device)
            
            tgt_embedding = decoder.vocab_embedding(tgt)
            tgt_embedding = decoder.position_encoding(tgt_embedding)
            pred = decoder.transformer(tgt_embedding, memory, tgt_mask=mask)
            pred = decoder.wdc(pred)
            pred = pred[-1, 0, :]
            
            topk_probs, topk_indices = torch.topk(pred, k, dim=-1)
            
            for j in range(k):
                candidate = [seq + [topk_indices[j].item()], score - torch.log(topk_probs[j])]
                all_candidates.append(candidate)
        
        ordered = sorted(all_candidates, key=lambda x: x[1])
        sequences, scores = zip(*ordered[:k])
    
    best_sequence = sequences[0]
    return best_sequence


In [42]:
def get_key(dict_, value):
  return [k for k, v in dict_.items() if v == value]


In [32]:
def load_wordmap(word_map_path):
    with open(word_map_path, 'r') as f:
        word_map = json.load(f)
    return word_map

In [45]:
def answer(img1, img2, img3, ckeckpoint, word_map_path, beam_size, decode_twice=False):
    checkpoint = torch.load(ckeckpoint, map_location='cuda:0')
    encoder = checkpoint['encoder']
    encoder = encoder.to(device)
    encoder.eval()

    decoder = checkpoint['decoder']
    decoder = decoder.to(device)
    decoder.eval()

    wordmap = load_wordmap(word_map_path=word_map_path)
    rev_word_map = {v: k for k, v in wordmap.items()}

    img1 = read_img(img_path=img1)
    img2 = read_img(img_path=img2)
    img3 = read_img(img_path=img3)
    memory = encoder(img1, img2, img3)
    if decode_twice:
        assert type(memory)==tuple
        memory1, memory2 = memory
        seq1 = beam_search_decoder(decoder=decoder, memory=memory1, word_map=wordmap, rev_word_map=rev_word_map, beam_size=beam_size, max_len=120)
        seq2 = beam_search_decoder(decoder=decoder, memory=memory2, word_map=wordmap, rev_word_map=rev_word_map, beam_size=beam_size, max_len=120)
        hyp = [w for w in seq1 if w not in {wordmap['<start>'], wordmap['<end>'], wordmap['<pad>']}]
        hyp.extend([w for w in seq2 if w not in {wordmap['<start>'], wordmap['<end>'], wordmap['<pad>']}])
        
    else:
        if type(memory)==tuple:
            memory1, memory2 = memory
            memory = (memory1+memory2)/2
        seq = beam_search_decoder(decoder=decoder, memory=memory, word_map=wordmap, rev_word_map=rev_word_map, beam_size=beam_size, max_len=120)
        hyp = [w for w in seq if w not in {wordmap['<start>'], wordmap['<end>'], wordmap['<pad>']}]
    
    line_hypo=""
    for word_idx in hyp:
        word = get_key(wordmap, word_idx)
        line_hypo += word[0] + " "

    print("#######")
    print("Answer:")
    hyp = "".join(line_hypo.split("<sep>"))
    print(hyp)
    return hyp


AdvanceMCCFormers-S 例

In [23]:
dataset = "CLEVR"
id = 0
eval_root = "./eval_results/CLEVR/AdvanceMCCFormers-S"
ref_path, hyp_path = eval_root+"/advance_CLEVRhat_decode_twice_epoch9decode_twice_gts.json", eval_root+"/advance_CLEVRhat_decode_twice_epoch9decode_twice_res.json"
displayEvalResult(dataset=dataset, id=id, ref_path=ref_path, hyp_path=hyp_path)

reference:
There is no longer a small gray rubber sphere . 
 The large red rubber cube has disappeared . 
 A new large green metal cylinder is visible . 
 The large cyan metal sphere has been moved . 
 The small blue metal sphere is in a different location . 
 Someone removed the large brown rubber cube . 
###########
answer:
The large purple rubber sphere is missing . 
 The large purple rubber sphere is missing . 
 The small purple rubber cylinder is missing . A new large purple metal sphere is visible . 
 The large purple metal sphere was moved from its original location . 
 The large purple metal sphere was moved from its original original location . 
 A new large purple metal cylinder is visible . there . A different location . 
 A large purple metal cylinder . 
&&&&&&&&&&&


Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)


Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)


AdvanceMCCFormers-D例

In [46]:
img1 = "./dataset/CLEVR/TEST/imgs/0/0.png"
img2 = "./dataset/CLEVR/TEST/imgs/1/0.png"
img3 = "./dataset/CLEVR/TEST/imgs/2/0.png"
# Image.open(img1).show("before")
# Image.open(img2).show("during")
# Image.open(img3).show("after")

checkpoint = "./result/AdvanceMCCFormers-D/CLEVRhat/checkpoint_epoch_9first_advance_CLEVRhat_use_2_imgs.pth.tar" #使用的是采用高阶训练的模型
word_map_path="./dataset/CLEVRhat/wordmap/wordmap.json"
beam_size = 3
answer(img1=img1, img2=img2, img3=img3, ckeckpoint=checkpoint, word_map_path=word_map_path, beam_size=beam_size, decode_twice=True)


#######
Answer:
The small green rubber sphere is missing .  The large green rubber sphere is missing .  The small green rubber sphere is missing . The large green rubber sphere is missing .  The large green metal sphere is missing .  The large green metal sphere is missing . 


'The small green rubber sphere is missing .  The large green rubber sphere is missing .  The small green rubber sphere is missing . The large green rubber sphere is missing .  The large green metal sphere is missing .  The large green metal sphere is missing . '