# Inference

This notebook enables to predict the closest images of a sketch. 

In [18]:
# Imports
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

plt.rcParams["figure.figsize"] = (12,12)
plt.rcParams['axes.titlesize'] = 15

In [2]:
from src.data.loader_factory import load_data
from src.data.utils import default_image_loader
from src.models.encoder import EncoderCNN
from src.models.utils import load_checkpoint
from src.models.test import get_test_data
from src.models.metrics import get_similarity

In [3]:
class Args:
    dataset = "sketchy_extend"
    data_path = "../io/data/raw"
    emb_size = 256
    grl_lambda = 0.5
    nopretrain = False
    epochs = 1000
    batch_size = 10
    seed = 42
    load = None
    early_stop = 20
    ngpu = 1
    prefetch = 2
    log = "../io/models/"
    log_interval = 20
    attn = True
    plot = False
    cuda = False

args = Args()

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f562c969ad0>

## Preprocess embeddings to df

In [51]:
import pandas as pd

In [10]:
BEST_CHECKPOINT = '../io/models/1_run-batch_size_10/checkpoint.pth'

In [11]:
def get_model(args, best_checkpoint):
    im_net = EncoderCNN(out_size=args.emb_size, attention=True)
    sk_net = EncoderCNN(out_size=args.emb_size, attention=True)
    
    #checkpoint = load_checkpoint(best_checkpoint)
    checkpoint = torch.load(best_checkpoint, map_location='cpu')
    
    im_net.load_state_dict(checkpoint['im_state'])
    sk_net.load_state_dict(checkpoint['sk_state'])

    if args.cuda and args.ngpu > 1:
        print('\t* Data Parallel **NOT TESTED**')
        im_net = nn.DataParallel(im_net, device_ids=list(range(args.ngpu)))
        sk_net = nn.DataParallel(sk_net, device_ids=list(range(args.ngpu)))

    if args.cuda:
        print('\t* CUDA')
        im_net, sk_net = im_net.cuda(), sk_net.cuda()

    return im_net, sk_net

In [None]:
im_net, sk_net = get_model(args, model_path)

In [None]:
_, [_, _], [_, test_im_data], dict_class = load_data(args, transform)
print("Length Image: {}".format(len(test_im_data)))
print("Classes: {}".format(test_sk_data.get_class_dict()))

test_im_loader = DataLoader(test_im_data, batch_size=1)
images_fnames, images_embeddings, images_classes = get_test_data(test_im_loader, im_net, args)

Length Image: 10453
Classes: ['bat', 'cabin', 'cow', 'dolphin', 'door', 'giraffe', 'helicopter', 'mouse', 'pear', 'raccoon', 'rhinoceros', 'saw', 'scissors', 'seagull', 'skyscraper', 'songbird', 'sword', 'tree', 'wheelchair', 'windmill', 'window']
0 images processed on 10453
200 images processed on 10453
400 images processed on 10453
600 images processed on 10453
800 images processed on 10453
1000 images processed on 10453
1200 images processed on 10453
1400 images processed on 10453
1600 images processed on 10453
1800 images processed on 10453
2000 images processed on 10453
2200 images processed on 10453
2400 images processed on 10453
2600 images processed on 10453
2800 images processed on 10453
3000 images processed on 10453
3200 images processed on 10453
3400 images processed on 10453
3600 images processed on 10453
3800 images processed on 10453
4000 images processed on 10453
4200 images processed on 10453
4400 images processed on 10453
4600 images processed on 10453
4800 images pro

In [None]:
df = pd.DataFrame(data=[images_fnames, images_embeddings, images_classes]).T
df.columns=['fnames', 'embeddings', 'classes']

In [None]:
df.to_csv('../io/data/processed/images_embeddings.csv', sep=' ', header=True)

## Inference

In [62]:
def get_test_data(data_loader, model, args):
    fnames = []
    for i, (image, fname, target) in enumerate(data_loader):
        if i%500 == 0:
            print(f'{i} images processed on {len(data_loader)}')
        # Process
        out_features, _ = model(image)

        # Filename of the images for qualitative
        fnames.append(fname)

        if i == 0:
            embeddings = out_features.detach().numpy()
            classes = target.detach().numpy()
        else:
            embeddings = np.concatenate((embeddings, out_features.detach().numpy()), axis=0)
            classes = np.concatenate((classes, target.detach().numpy()), axis=0)

    return fnames, embeddings, classes

In [None]:
class Inference():
    
    def __init__(self, model_path, embedding_path):
        
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.loader = default_image_loader
        
        self.im_net, self.sk_net = get_model(args, model_path)
        
        df = pd.read_csv(embedding_path, sep=' ', header=True)
        self.images_fnames = df['fnames'].values
        self.images_embeddings = df['embeddings'].values
        self.images_classes = df['classes'].values
        
        
    def inference_sketch(self, sketch_fname, plot=True):
        ''' For now just process a sketch but TODO decide how to proceed later'''
        
        sketch = self.transform(self.loader(sketch_fname)).unsqueeze(0) # unsqueeze because 1 sketch (no batch)
        sketch_embedding, _ = self.sk_net(sketch)
        self.get_closest_images(sketch_embedding)
        
        if plot:
            self.plot_closest(sketch_fname)
        
    def get_closest_images(self, sketch_embedding):
        '''
        Based on a sketch embedding, retrieve the index of the closest images
        '''
        
        similarity = get_similarity(sketch_embedding, self.images_embeddings)
        arg_sorted_sim = (-similarity).argsort()
        
        self.sorted_fnames = [self.images_fnames[i][0] for i in arg_sorted_sim[0]]
        
    def plot_closest(self, sketch_fname):
        fig, axes = plt.subplots(1, NUM_CLOSEST + 1)

        sk = mpimg.imread(sketch_fname)
        axes[0].imshow(sk)
        axes[0].set(title='Sketch')

        for i in range(1, NUM_CLOSEST + 1):
            im = mpimg.imread(self.sorted_fnames[i-1])
            axes[i].imshow(im)
            axes[i].set(title='Closest image ' + str(i))

        plt.subplots_adjust(wspace=0.25, hspace=-0.35)

In [None]:
inference = Inference(BEST_CHECKPOINT)

# Results

In [None]:
sketch_fname = '../io/data/raw/Sketchy/sketch/tx_000000000000/bat/n02139199_1332-1.png'
inference.inference_sketch(sketch_fname, plot=True)

In [None]:
sketch_fname = '../io/data/raw/Sketchy/sketch/tx_000000000000/door/n03222176_681-1.png'
inference.inference_sketch(sketch_fname, plot=True)

In [None]:
sketch_fname = '../io/data/raw/Sketchy/sketch/tx_000000000000/giraffe/n02439033_67-1.png'
inference.inference_sketch(sketch_fname, plot=True)

In [None]:
sketch_fname = '../io/data/raw/Sketchy/sketch/tx_000000000000skyscraper/n04233124_498-1.png'
inference.inference_sketch(sketch_fname, plot=True)

In [None]:
sketch_fname = '../io/data/raw/Sketchy/sketch/tx_000000000000/wheelchair/n04576002_150-2.png'
inference.inference_sketch(sketch_fname, plot=True)