In [1]:
import sys
sys.path.append('../')
#  Torch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

# Python imports
import numpy as np
import tqdm
import torchvision.models as tmodels
from tqdm import tqdm
import os
from os.path import join as ospj
import itertools
import glob
import random

#Local imports
from data import dataset as dset
from models.common import Evaluator
from models.image_extractor import get_image_extractor
from models.manifold_methods import RedWine, LabelEmbedPlus, AttributeOperator
from models.modular_methods import GatedGeneralNN
from models.symnet import Symnet
from utils.utils import save_args, UnNormalizer, load_args
from utils.config_model import configure_model
from flags import parser
from PIL import Image
import matplotlib.pyplot as plt
import importlib
import easydict

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# args, unknown = parser.parse_known_args()

In [2]:
path = 'logs/co-cge-ow/mitstates/'
yml = path+'mit.yml'
ck = path+'ckpt_best_auc.t7'

In [3]:
# args format

import easydict

args = easydict.EasyDict({
    
    'config': 'configs/co-cge/mit_anno.yml',
    'dataset': 'mitstates',
    'data_dir': 'mit-states',
    'logpath': path,           
    'splitname': 'compositional-split-natural',
    'cv_dir': 'logs/',
    'name': path,
    'load': None,
    'image_extractor': 'resnet18',
    'norm_family': 'imagenet',
    'num_negs': 1,
    'pair_dropout': 0.0,
    'test_set': 'val',
    'clean_only': False,
    'subset': False,
    'open_world': True,
    'test_batch_size': 128,
    'cpu_eval': True,

    'model': 'graphfull',
    'emb_dim': 512,
    'nlayers': 2,
    'nmods': 24,
    'embed_rank': 64,
    'bias': 1e3,
    'update_features': True,
    'freeze_features': False,
    # 'use_feature': True,
    'emb_init': 'ft+w2v',
    'clf_init': False,
    'static_inp': False,
    'composition': 'mlp_add',
    # 'relu': True,
    'dropout': True,
    'norm': True,
    'train_only': False,
    'train_triplet_loss': False,


    # Evaluation
    'fast_eval': True,
    'closed_eval': None,
    
    # Model parameters
    'train_only': False,

    #CGE
    'graph': False,
    'graph_init': None,
    'gcn_type': 'gcn',

    # Forward
    'eval_type': 'dist',

    # Primitive-based loss
    'lambda_aux': 0.0,

    # AoP
    'lambda_inv': 0.0,
    'lambda_comm': 0.0,
    'lambda_ant': 0.0,


    # SymNet
    'lambda_trip': 0,
    'lambda_sym': 0,
    'lambda_axiom': 0,
    'lambda_cls_attr': 0,
    'lambda_cls_obj': 0,

    # CompCos (for the margin, see below)
    'cosine_scale': 50,
    'epoch_max_margin': 100,
    'update_feasibility_every': 1,
    'hard_masking': False,
    'threshold': None,
    'threshold_trials': 50,

    # Graph methods
    'graph_init': None,
    'gcn_type': 'gcn',
    'gr_emb': 'd600',
    'cosine_classifier': True,
    'feasibility_adjacency': True,

    # Hyperparameters
    'topk': 3,
    'margin': 1.0,
    'workers': 8,
    'batch_size': 128,
    'lr': 5e-5,
    'lrg': 1e-3,
    'wd': 5e-5,
    'save_every': 10000,
    'eval_val_every': 1,
    'max_epochs': 200,
    'fc_emb': '768,1024',
    'gr_emb': 'd4096',
    'fast_eval': True
})


### Run one of the cells to load the dataset you want to run test for and move to the next section

In [4]:
best_ut = yml    # logs/unmatch/cgqa_2/cgqa_unmatch_2.yml
load_args(best_ut,args)
args.graph_init = args.graph_init
args.load = ck

### Loading arguments and dataset

In [5]:
from flags import DATA_FOLDER

args.test_set = 'test'
testset = dset.CompositionDataset(
        root= os.path.join(DATA_FOLDER,args.data_dir),
        phase=args.test_set,
        split=args.splitname,
        model =args.image_extractor,
        subset=args.subset,
        return_images = True,
        update_features = args.update_features,
        open_world=args.open_world,
        # clean_only = args.clean_only
    )
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=128,
    shuffle=False,
    num_workers=args.workers)

print('Objs ', len(testset.objs), ' Attrs ', len(testset.attrs))

Using all pairs
Dataset loaded
Train pairs: 1262, Validation pairs: 600, Test Pairs: 800
Train images: 30328, Validation images: 10417, Test images: 12988
Objs  245  Attrs  115


In [6]:
image_extractor, model, optimizer = configure_model(args, testset)
evaluator = Evaluator(testset, model)



Fasttext Embeddings loaded, total embeddings: torch.Size([360, 300])
Word2Vec Embeddings loaded, total embeddings: torch.Size([360, 300])
Combined embeddings are  torch.Size([360, 600])
Learnable image_embeddings
Evaluating with test pairs




In [7]:
if args.load is not None:
    checkpoint = torch.load(args.load)
    if image_extractor:
        try:
            image_extractor.load_state_dict(checkpoint['image_extractor'])
            image_extractor.eval()
        except:
            print('No Image extractor in checkpoint')
    model.load_state_dict(checkpoint['net'])
    model.eval()
    print('Loaded model from ', args.load)
    print('Best AUC: ', checkpoint['AUC'])

Loaded model from  logs/co-cge-ow/mitstates_4/ckpt_best_auc.t7
Best AUC:  0.027498053712842557


In [8]:
closed = sorted(list(set(testset.train_pairs + testset.val_pairs + testset.test_pairs)))
seen_match = sorted(list(set(testset.train_pairs)))
unseen_match = sorted(list(set(testset.val_pairs + testset.test_pairs) - set(testset.train_pairs)))
unseen_unmatch = sorted(list(set(testset.pairs) - set(closed)))
open = sorted(list(set(testset.pairs)))

In [9]:
def unmatch_ratio(scores, exp, flag=False):
    
    result = scores[exp]
    attr = [evaluator.dset.attrs[result[0][idx,a]] for a in range(topk)]
    obj = [evaluator.dset.objs[result[1][idx,a]] for a in range(topk)]
    attr_gt, obj_gt = evaluator.dset.attrs[data[1][idx]], evaluator.dset.objs[data[2][idx]]
    
   
    prediction = ''
    seen_num=0
    unseen_num=0
    unmatch = 0
    for i, (a,o) in enumerate(zip(attr, obj)):
        p_space = 'closed' if (a, o) in closed else 'open'
        if (a,o) in seen_match:
            seen_num += 1
        elif (a,o) in unseen_match:
            unseen_num += 1
        else:
            unmatch += 1
            
    return seen_num, unseen_num, unmatch

In [10]:
total_candidate_num = 5 * len(testset)
seen_candidate_num = 0
unseen_candidate_num = 0
unmatch_candidate_num = 0

for index, data in enumerate(testloader):
    images = data[-1]
    data = [d.to(device) for d in data[:-1]]
    if image_extractor:
        data[0] = image_extractor(data[0])
    _,  predictions, _ = model(data)
    data = [d.to('cpu') for d in data]
    topk = 5
    results = evaluator.score_model(predictions, data[2], bias = 1000, topk=topk)

    printed_lsit = []
    for idx in range(len(images)):
        seen = bool(evaluator.seen_mask[data[3][idx]])

        sm, um, uu = unmatch_ratio(results, 'open')
        seen_candidate_num += sm
        unseen_candidate_num += um
        unmatch_candidate_num += uu

print("unmatch pair candidate ratio")

print("total ratio: ", str(seen_candidate_num/total_candidate_num), " / " , str(unseen_candidate_num/total_candidate_num), " / " ,str(unmatch_candidate_num/total_candidate_num))

unmatch pair candidate ratio
seen unmatch ratio:  0.7767031118587048  /  0.30571909167367534
unseen unmatch ratio:  0.7558718190386428  /  0.29585296889726675
total unmatch ratio:  0.7596858638743456  /  0.2976593778872806
