In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.autograd import Variable
from dataset import ReadImages
from utils import *
from model.nn_utils import *
from model.siamese import *
from model.custom_modules import *
from os import path
import cv2

In [9]:
trainSetPath = 'data/pre_proc/CLICIDE_video_384'
testSetPath = 'data/pre_proc/CLICIDE_video_384/test'
mean_std_file = 'data/CLICIDE_224sq_train_ms.txt'
num_classes = 464
cnn_model = models.resnet152
scale_size = 320
feature_size = (7, 7)
out_size = 2048 # 1000 # 7 * 7 * 2048  # 6 * 6 * 256
siam2_k = 6
weights_file = 'data/20170429-150806-986029_best_siam.pth.tar'

In [3]:
trainSetFull = ReadImages.readImageswithPattern(
    trainSetPath, lambda im: im.split('/')[-1].split('-')[0])
testSetFull = ReadImages.readImageswithPattern(
    testSetPath, lambda im: im.split('/')[-1].split('-')[0])

listLabel = [t[1] for t in trainSetFull if 'wall' not in t[1]]
labels = list(set(listLabel))
print(len(trainSetFull), len(testSetFull), len(labels))
print(labels)

(3245, 177, 464)
['29J', '29K', '29H', '29I', '29F', '29G', '29D', '29E', '29B', '29C', '29A', '34D', '34E', '34G', '34B', '34C', '34L', '34H', '34I', '34J', '34K', '6H', '3S', '3R', '3Q', '3P', '43C', '3T', '3K', '3J', '3I', '3H', '3O', '3N', '3M', '3L', '3C', '3B', '3A', '3G', '3F', '3E', '3D', '10J', '10H', '10I', '10B', '10C', '10A', '10F', '10G', '10D', '10E', '44C', '44A', '44E', '44D', '27A', '27B', '27C', '27D', '27E', '27F', '27G', '27H', '27I', '27J', '27K', '27L', '9A', '9C', '9B', '9E', '9D', '8J', '8K', '8H', '5J2', '8L', '33M', '33L', '33I', '33H', '33K', '33J', '33E', '33D', '33F', '33A', '33C', '33B', '8H2', '14F', '14G', '14D', '14E', '14B', '14C', '14L', '14M', '14J', '14K', '14H', '14I', '23L', '23M', '23N', '23O', '23H', '23J', '23K', '23D', '23E', '23F', '23G', '23A', '23B', '23C', '37I', '37K', '37J', '37L', '37A', '37B', '37E', '37D', '37G', '37F', '40G', '40F', '40E', '40C', '40B', '40A', '2D', '2E', '2F', '2G', '2A', '2B', '2C', '2H', '2I', '2J', '2K', '13C', '

In [4]:
m, s = readMeanStd(mean_std_file)
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(m, s)])

trainSet, testSet, trainNames, testNames = [], [], [], []
for im, lab in trainSetFull:
    if lab in labels:
        im_out = trans(imread_rgb(im))
        trainSet.append((im_out, lab))
        trainNames.append(im)
for im, lab in testSetFull:
    if lab in labels:
        im_out = trans(imread_rgb(im))
        testSet.append((im_out, lab))
        testNames.append(im)
print(len(trainSet))
print(len(testSet))

3245
165


In [10]:
tuning_model = TuneClassif(cnn_model(), num_classes, feature_size)
tuning_model.load_state_dict(torch.load('data/finetune_classif/cli_best_resnet152_classif_finetuned.pth.tar'))
classif_net = TuneClassifSub(tuning_model, num_classes, feature_size)
feature_net = FeatureNet(classif_net, feature_size, classify=True)
siam_net = Siamese2(feature_net, siam2_k, out_size, feature_size)
siam_net.load_state_dict(torch.load(weights_file))
siam_net = siam_net.eval().cuda(0)

In [11]:
def get_embeddings(net, dataset, out_size):
    def batch(last, i, is_final, batch):
        embeddings = last
        test_in = batch[0][0].unsqueeze(0).cuda(0)
        out = net(Variable(test_in, volatile=True))
        embeddings[i] = out.data[0]
        return embeddings
    init = torch.Tensor(len(dataset), out_size)
    return fold_batches(batch, init, dataset, 1)

In [12]:
embeddings_test = get_embeddings(siam_net, testSet, out_size)
embeddings_train = get_embeddings(siam_net, trainSet, out_size)
sim = torch.mm(embeddings_test, embeddings_train.t())
print(sim.size())

torch.Size([165, 3245])


In [13]:
max_sim, max_idx = sim.max(1)
max_label = []
for i in range(sim.size(0)):
    # get label from ref set which obtained highest score
    max_label.append(trainSet[max_idx[i, 0]][1])
correct = sum(testLabel == max_label[j] for j, (_, testLabel) in enumerate(testSet))
print('Correct: {0}/{1}'.format(correct, len(testSet)))
count = 0
for j, (_, test_label) in enumerate(testSet):
    if max_label[j] == test_label:
        continue
    count += 1
    print('Incorrect {0}: test im {1}, label {2} -> train im {3}, label {4}'.format(count, testNames[j], test_label, trainNames[max_idx[j, 0]], max_label[j]))
    print('Avg prec: {0}'.format(avg_precision(sim, j, testSet, trainSet)))

Correct: 156/165
Incorrect 1: test im data/pre_proc/CLICIDE_video_384/test/11C-0351.JPG, label 11C -> train im data/pre_proc/CLICIDE_video_384/11F-7.JPG, label 11F
Avg prec: 0.00229784975425
Incorrect 2: test im data/pre_proc/CLICIDE_video_384/test/11C-0436.JPG, label 11C -> train im data/pre_proc/CLICIDE_video_384/10J-1.JPG, label 10J
Avg prec: 0.00822340006581
Incorrect 3: test im data/pre_proc/CLICIDE_video_384/test/26J-1245.JPG, label 26J -> train im data/pre_proc/CLICIDE_video_384/43D-0.JPG, label 43D
Avg prec: 0.000796974303149
Incorrect 4: test im data/pre_proc/CLICIDE_video_384/test/26J-1247.JPG, label 26J -> train im data/pre_proc/CLICIDE_video_384/26B-2.JPG, label 26B
Avg prec: 0.220404411765
Incorrect 5: test im data/pre_proc/CLICIDE_video_384/test/30A-1257.JPG, label 30A -> train im data/pre_proc/CLICIDE_video_384/30J-6.JPG, label 30J
Avg prec: 0.100261003283
Incorrect 6: test im data/pre_proc/CLICIDE_video_384/test/36E-1265.JPG, label 36E -> train im data/pre_proc/CLICIDE_