# Inference

This notebook enables to predict the closest images of a sketch. 

In [53]:
# Imports
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.utils.data import DataLoader

In [54]:
from src.data.loader_factory import load_data
from src.models.encoder import EncoderCNN
from src.models.utils import load_checkpoint
from src.models.test import get_test_data

In [55]:
class Args:
    dataset = "sketchy_extend"
    data_path = "../io/data/raw"
    emb_size = 256
    grl_lambda = 0.5
    nopretrain = False
    epochs = 1000
    batch_size = 10
    seed = 42
    load = None
    early_stop = 20
    ngpu = 1
    prefetch = 2
    log = "../io/models/"
    log_interval = 20
    attn = True
    plot = False
    cuda = True

args = Args()

## Load the model

In [56]:
BEST_CHECKPOINT = '../io/models/1_run-batch_size_10/checkpoint.pth'

In [57]:
def get_model(args, best_checkpoint):
    im_net = EncoderCNN(out_size=args.emb_size, attention=True)
    sk_net = EncoderCNN(out_size=args.emb_size, attention=True)
    
    checkpoint = load_checkpoint(best_checkpoint)
    im_net.load_state_dict(checkpoint['im_state'])
    sk_net.load_state_dict(checkpoint['sk_state'])
    return im_net, sk_net

In [58]:
im_net, sk_net = get_model(args, BEST_CHECKPOINT)

=> loading model '../io/models/1_run-batch_size_10/checkpoint.pth'
=> loaded model '../io/models/1_run-batch_size_10/checkpoint.pth' (epoch 42, map 0.2796034388821055)


## Get all images and compute their embeddings

In [59]:
transform = transforms.Compose([transforms.ToTensor()])
_, [_, _], [test_sk_data, test_im_data], dict_class = load_data(args, transform)

In [60]:
print("Length Sketch: {}".format(len(test_sk_data)))
print("Length Image: {}".format(len(test_im_data)))
print("Classes: {}".format(test_sk_data.get_class_dict()))
print("Num Classes: {}".format(len(test_sk_data.get_class_dict())))

Length Sketch: 12694
Length Image: 10453
Classes: ['bat', 'cabin', 'cow', 'dolphin', 'door', 'giraffe', 'helicopter', 'mouse', 'pear', 'raccoon', 'rhinoceros', 'saw', 'scissors', 'seagull', 'skyscraper', 'songbird', 'sword', 'tree', 'wheelchair', 'windmill', 'window']
Num Classes: 21


In [61]:
test_im_loader = DataLoader(test_im_data, batch_size=3 * args.batch_size, num_workers=args.prefetch, pin_memory=True)

In [62]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f6f6fc6ff10>

In [63]:
fnames, embeddings, classes = get_test_data(test_im_loader, im_net, args)

AttributeError: 'list' object has no attribute 'cuda'

In [68]:
fnames = []
for i, (image, fname, target) in enumerate(test_im_loader):
    print(i)
    print(image.shape)
    print(fname.shape)
    print(target)

    # Process
    out_features, _ = im_net(image)

    # Filename of the images for qualitative
    fnames.append(fname)

    if i == 0:
        embeddings = out_features.cpu().data.numpy()
        classes = target.cpu().data.numpy()
    else:
        embeddings = np.concatenate((embeddings, out_features.cpu().data.numpy()), axis=0)
        classes = np.concatenate((classes, target.cpu().data.numpy()), axis=0)
    
    break

0
torch.Size([30, 3, 224, 224])
torch.Size([30])
['../io/data/raw/Sketchy/extended_photo/bat/ext_274.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_137.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02142407_538.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_98.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02139199_9033.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_305.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02139199_14100.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_199.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_183.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02139199_15837.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02149420_28.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_202.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02139199_16702.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext_204.jpg', '../io/data/raw/Sketchy/extended_photo/bat/n02139199_14279.jpg', '../io/data/raw/Sketchy/extended_photo/bat/ext

AttributeError: 'list' object has no attribute 'cpu'

## Get sketch embedding

## Find closest images

## Plot results