# Find hard triplets 

In [None]:
from vectorrvnn.interfaces import *
from vectorrvnn.trainutils import *
from vectorrvnn.data import *
from vectorrvnn.utils import *
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from copy import deepcopy
from sklearn.manifold import TSNE

# Load our model
opts = Options().parse(testing=[
    '--checkpoints_dir', '../results',
    '--dataroot', '../data/All',
    '--embedding_size', '512',
    '--hidden_size', '1024',
    '--encoder_layers', '2',
    '--heads', '2',
    '--load_ckpt', 'obb-6/best_0-756-12-23-2021-18-13-36.pth',
    '--modelcls', 'OBBNet',
    '--name', 'test',
    '--phase', 'test',
])

In [None]:
model = buildModel(opts)
data = buildData(opts)

## What are the features?

In [None]:
def plotVisImage (im) : 
    nodeType = ['ref', 'plus', 'minus']
    print(nodeType)
    ims = []
    for ntype, im_ in zip(nodeType, im) :
        numpyIm = im_.detach().cpu().numpy()
        numpyIm = np.transpose(numpyIm, (1, 2, 0))
        ims.append(numpyIm)
    
    fig, ax = plt.subplots(1, 1, dpi=150)
    ax.imshow(np.concatenate(ims, 1))
    plt.show()
        
trainData, valData, trainDataLoader, valDataLoader = data

for trainBatch in trainDataLoader : 
    break
    
for valBatch in valDataLoader :
    break
    
tripletviscallback = TripletVisCallback()

trainTripletImage = tripletviscallback.visualized_image(trainBatch, dict(mask=None), False)
valTripletImage   = tripletviscallback.visualized_image(valBatch  , dict(mask=None), True)

print("Plotting train triplet")
plotVisImage(trainTripletImage)
print("Plotting val triplet")
plotVisImage(valTripletImage)

## What are hard examples among validation triplets?

In [None]:
interface = TripletInterface(opts, model, trainData, valData)

rets = [] 
model.eval()
valDataLoader.reset()
with torch.no_grad() :
    for batch in tqdm(valDataLoader) :
        ret = model(**batch)
        mask = ret['dminus'] > ret['dplus']
        nHard = int(mask.sum())
        if nHard > 0 : 
            for i in range(nHard) :
                hardTriplet = tripletviscallback.visualized_image(
                    batch, 
                    dict(mask=mask), 
                    False,
                    i=i
                )
                plotVisImage(hardTriplet)
        rets.append(ret)
        
print('loss = ', avg(map(lambda r : r['loss'], rets)))
print('% hard triplets = ', avg(map(lambda r : r['hardpct'], rets)))