In [1]:
import warnings
warnings.filterwarnings("ignore")

import os, sys, numpy as np, argparse, imp, datetime, time, pickle as pkl, random, json

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import pandas as pd

import torch, torch.nn as nn
import auxiliaries_nofaiss as aux

import evaluate as eval

from torchvision import transforms
import itertools

In [2]:
source_path = os.getcwd()+'/Datasets/online_products'
save_path = os.getcwd()+'/Training_Results/online_products'
k_vals = [1,10,100,1000]
batches_per_super_pair = 10
sampling = 'None'
pretrained = False
device = torch.device('cpu')

In [3]:
class ResNet50_mcn(nn.Module):
    """
    class definition for the ResNet50 model imported from MatConvNet
    """
    def __init__(self):
        super(ResNet50_mcn, self).__init__()

        
        self.meta = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225],
                     'imageSize': [224, 224]}

        self.features_0 = nn.Conv2d(3, 64, kernel_size=[7, 7], stride=(2, 2), padding=(3, 3), bias=False)
        self.features_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_2 = nn.ReLU(inplace=True)
        self.features_3 = nn.MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=1, dilation=1, ceil_mode=False)
        self.features_4_0_conv1 = nn.Conv2d(64, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_0_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_0_relu1 = nn.ReLU(inplace=True)
        self.features_4_0_conv2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_4_0_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_0_relu2 = nn.ReLU(inplace=True)
        self.features_4_0_conv3 = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_0_bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_0_downsample_0 = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_0_downsample_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_0_id_relu = nn.ReLU(inplace=True)
        self.features_4_1_conv1 = nn.Conv2d(256, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_1_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_1_relu1 = nn.ReLU(inplace=True)
        self.features_4_1_conv2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_4_1_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_1_relu2 = nn.ReLU(inplace=True)
        self.features_4_1_conv3 = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_1_bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_1_id_relu = nn.ReLU(inplace=True)
        self.features_4_2_conv1 = nn.Conv2d(256, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_2_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_2_relu1 = nn.ReLU(inplace=True)
        self.features_4_2_conv2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_4_2_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_2_relu2 = nn.ReLU(inplace=True)
        self.features_4_2_conv3 = nn.Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_4_2_bn3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_4_2_id_relu = nn.ReLU(inplace=True)
        self.features_5_0_conv1 = nn.Conv2d(256, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_0_bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_0_relu1 = nn.ReLU(inplace=True)
        self.features_5_0_conv2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(2, 2), padding=(1, 1), bias=False)
        self.features_5_0_bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_0_relu2 = nn.ReLU(inplace=True)
        self.features_5_0_conv3 = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_0_bn3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_0_downsample_0 = nn.Conv2d(256, 512, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.features_5_0_downsample_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_0_id_relu = nn.ReLU(inplace=True)
        self.features_5_1_conv1 = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_1_bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_1_relu1 = nn.ReLU(inplace=True)
        self.features_5_1_conv2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_5_1_bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_1_relu2 = nn.ReLU(inplace=True)
        self.features_5_1_conv3 = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_1_bn3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_1_id_relu = nn.ReLU(inplace=True)
        self.features_5_2_conv1 = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_2_bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_2_relu1 = nn.ReLU(inplace=True)
        self.features_5_2_conv2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_5_2_bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_2_relu2 = nn.ReLU(inplace=True)
        self.features_5_2_conv3 = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_2_bn3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_2_id_relu = nn.ReLU(inplace=True)
        self.features_5_3_conv1 = nn.Conv2d(512, 128, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_3_bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_3_relu1 = nn.ReLU(inplace=True)
        self.features_5_3_conv2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_5_3_bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_3_relu2 = nn.ReLU(inplace=True)
        self.features_5_3_conv3 = nn.Conv2d(128, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_5_3_bn3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_5_3_id_relu = nn.ReLU(inplace=True)
        self.features_6_0_conv1 = nn.Conv2d(512, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_0_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_0_relu1 = nn.ReLU(inplace=True)
        self.features_6_0_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(2, 2), padding=(1, 1), bias=False)
        self.features_6_0_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_0_relu2 = nn.ReLU(inplace=True)
        self.features_6_0_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_0_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_0_downsample_0 = nn.Conv2d(512, 1024, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.features_6_0_downsample_1 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_0_id_relu = nn.ReLU(inplace=True)
        self.features_6_1_conv1 = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_1_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_1_relu1 = nn.ReLU(inplace=True)
        self.features_6_1_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_6_1_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_1_relu2 = nn.ReLU(inplace=True)
        self.features_6_1_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_1_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_1_id_relu = nn.ReLU(inplace=True)
        self.features_6_2_conv1 = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_2_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_2_relu1 = nn.ReLU(inplace=True)
        self.features_6_2_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_6_2_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_2_relu2 = nn.ReLU(inplace=True)
        self.features_6_2_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_2_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_2_id_relu = nn.ReLU(inplace=True)
        self.features_6_3_conv1 = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_3_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_3_relu1 = nn.ReLU(inplace=True)
        self.features_6_3_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_6_3_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_3_relu2 = nn.ReLU(inplace=True)
        self.features_6_3_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_3_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_3_id_relu = nn.ReLU(inplace=True)
        self.features_6_4_conv1 = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_4_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_4_relu1 = nn.ReLU(inplace=True)
        self.features_6_4_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_6_4_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_4_relu2 = nn.ReLU(inplace=True)
        self.features_6_4_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_4_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_4_id_relu = nn.ReLU(inplace=True)
        self.features_6_5_conv1 = nn.Conv2d(1024, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_5_bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_5_relu1 = nn.ReLU(inplace=True)
        self.features_6_5_conv2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_6_5_bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_5_relu2 = nn.ReLU(inplace=True)
        self.features_6_5_conv3 = nn.Conv2d(256, 1024, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_6_5_bn3 = nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_6_5_id_relu = nn.ReLU(inplace=True)
        self.features_7_0_conv1 = nn.Conv2d(1024, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_0_bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_0_relu1 = nn.ReLU(inplace=True)
        self.features_7_0_conv2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(2, 2), padding=(1, 1), bias=False)
        self.features_7_0_bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_0_relu2 = nn.ReLU(inplace=True)
        self.features_7_0_conv3 = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_0_bn3 = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_0_downsample_0 = nn.Conv2d(1024, 2048, kernel_size=[1, 1], stride=(2, 2), bias=False)
        self.features_7_0_downsample_1 = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_0_id_relu = nn.ReLU(inplace=True)
        self.features_7_1_conv1 = nn.Conv2d(2048, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_1_bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_1_relu1 = nn.ReLU(inplace=True)
        self.features_7_1_conv2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_7_1_bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_1_relu2 = nn.ReLU(inplace=True)
        self.features_7_1_conv3 = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_1_bn3 = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_1_id_relu = nn.ReLU(inplace=True)
        self.features_7_2_conv1 = nn.Conv2d(2048, 512, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_2_bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_2_relu1 = nn.ReLU(inplace=True)
        self.features_7_2_conv2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
        self.features_7_2_bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_2_relu2 = nn.ReLU(inplace=True)
        self.features_7_2_conv3 = nn.Conv2d(512, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
        self.features_7_2_bn3 = nn.BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.features_7_2_id_relu = nn.ReLU(inplace=True)
        self.features_8 = nn.AvgPool2d(kernel_size=[7, 7], stride=[1, 1], padding=0)
        self.fc = nn.Linear(in_features=2048, out_features=512, bias=True)

    def forward(self, data):
        features_0 = self.features_0(data)
        features_1 = self.features_1(features_0)
        features_2 = self.features_2(features_1)
        features_3 = self.features_3(features_2)
        features_4_0_conv1 = self.features_4_0_conv1(features_3)
        features_4_0_bn1 = self.features_4_0_bn1(features_4_0_conv1)
        features_4_0_relu1 = self.features_4_0_relu1(features_4_0_bn1)
        features_4_0_conv2 = self.features_4_0_conv2(features_4_0_relu1)
        features_4_0_bn2 = self.features_4_0_bn2(features_4_0_conv2)
        features_4_0_relu2 = self.features_4_0_relu2(features_4_0_bn2)
        features_4_0_conv3 = self.features_4_0_conv3(features_4_0_relu2)
        features_4_0_bn3 = self.features_4_0_bn3(features_4_0_conv3)
        features_4_0_downsample_0 = self.features_4_0_downsample_0(features_3)
        features_4_0_downsample_1 = self.features_4_0_downsample_1(features_4_0_downsample_0)
        features_4_0_merge = torch.add(features_4_0_downsample_1, 1, features_4_0_bn3)
        features_4_0_id_relu = self.features_4_0_id_relu(features_4_0_merge)
        features_4_1_conv1 = self.features_4_1_conv1(features_4_0_id_relu)
        features_4_1_bn1 = self.features_4_1_bn1(features_4_1_conv1)
        features_4_1_relu1 = self.features_4_1_relu1(features_4_1_bn1)
        features_4_1_conv2 = self.features_4_1_conv2(features_4_1_relu1)
        features_4_1_bn2 = self.features_4_1_bn2(features_4_1_conv2)
        features_4_1_relu2 = self.features_4_1_relu2(features_4_1_bn2)
        features_4_1_conv3 = self.features_4_1_conv3(features_4_1_relu2)
        features_4_1_bn3 = self.features_4_1_bn3(features_4_1_conv3)
        features_4_1_merge = torch.add(features_4_0_id_relu, 1, features_4_1_bn3)
        features_4_1_id_relu = self.features_4_1_id_relu(features_4_1_merge)
        features_4_2_conv1 = self.features_4_2_conv1(features_4_1_id_relu)
        features_4_2_bn1 = self.features_4_2_bn1(features_4_2_conv1)
        features_4_2_relu1 = self.features_4_2_relu1(features_4_2_bn1)
        features_4_2_conv2 = self.features_4_2_conv2(features_4_2_relu1)
        features_4_2_bn2 = self.features_4_2_bn2(features_4_2_conv2)
        features_4_2_relu2 = self.features_4_2_relu2(features_4_2_bn2)
        features_4_2_conv3 = self.features_4_2_conv3(features_4_2_relu2)
        features_4_2_bn3 = self.features_4_2_bn3(features_4_2_conv3)
        features_4_2_merge = torch.add(features_4_1_id_relu, 1, features_4_2_bn3)
        features_4_2_id_relu = self.features_4_2_id_relu(features_4_2_merge)
        features_5_0_conv1 = self.features_5_0_conv1(features_4_2_id_relu)
        features_5_0_bn1 = self.features_5_0_bn1(features_5_0_conv1)
        features_5_0_relu1 = self.features_5_0_relu1(features_5_0_bn1)
        features_5_0_conv2 = self.features_5_0_conv2(features_5_0_relu1)
        features_5_0_bn2 = self.features_5_0_bn2(features_5_0_conv2)
        features_5_0_relu2 = self.features_5_0_relu2(features_5_0_bn2)
        features_5_0_conv3 = self.features_5_0_conv3(features_5_0_relu2)
        features_5_0_bn3 = self.features_5_0_bn3(features_5_0_conv3)
        features_5_0_downsample_0 = self.features_5_0_downsample_0(features_4_2_id_relu)
        features_5_0_downsample_1 = self.features_5_0_downsample_1(features_5_0_downsample_0)
        features_5_0_merge = torch.add(features_5_0_downsample_1, 1, features_5_0_bn3)
        features_5_0_id_relu = self.features_5_0_id_relu(features_5_0_merge)
        features_5_1_conv1 = self.features_5_1_conv1(features_5_0_id_relu)
        features_5_1_bn1 = self.features_5_1_bn1(features_5_1_conv1)
        features_5_1_relu1 = self.features_5_1_relu1(features_5_1_bn1)
        features_5_1_conv2 = self.features_5_1_conv2(features_5_1_relu1)
        features_5_1_bn2 = self.features_5_1_bn2(features_5_1_conv2)
        features_5_1_relu2 = self.features_5_1_relu2(features_5_1_bn2)
        features_5_1_conv3 = self.features_5_1_conv3(features_5_1_relu2)
        features_5_1_bn3 = self.features_5_1_bn3(features_5_1_conv3)
        features_5_1_merge = torch.add(features_5_0_id_relu, 1, features_5_1_bn3)
        features_5_1_id_relu = self.features_5_1_id_relu(features_5_1_merge)
        features_5_2_conv1 = self.features_5_2_conv1(features_5_1_id_relu)
        features_5_2_bn1 = self.features_5_2_bn1(features_5_2_conv1)
        features_5_2_relu1 = self.features_5_2_relu1(features_5_2_bn1)
        features_5_2_conv2 = self.features_5_2_conv2(features_5_2_relu1)
        features_5_2_bn2 = self.features_5_2_bn2(features_5_2_conv2)
        features_5_2_relu2 = self.features_5_2_relu2(features_5_2_bn2)
        features_5_2_conv3 = self.features_5_2_conv3(features_5_2_relu2)
        features_5_2_bn3 = self.features_5_2_bn3(features_5_2_conv3)
        features_5_2_merge = torch.add(features_5_1_id_relu, 1, features_5_2_bn3)
        features_5_2_id_relu = self.features_5_2_id_relu(features_5_2_merge)
        features_5_3_conv1 = self.features_5_3_conv1(features_5_2_id_relu)
        features_5_3_bn1 = self.features_5_3_bn1(features_5_3_conv1)
        features_5_3_relu1 = self.features_5_3_relu1(features_5_3_bn1)
        features_5_3_conv2 = self.features_5_3_conv2(features_5_3_relu1)
        features_5_3_bn2 = self.features_5_3_bn2(features_5_3_conv2)
        features_5_3_relu2 = self.features_5_3_relu2(features_5_3_bn2)
        features_5_3_conv3 = self.features_5_3_conv3(features_5_3_relu2)
        features_5_3_bn3 = self.features_5_3_bn3(features_5_3_conv3)
        features_5_3_merge = torch.add(features_5_2_id_relu, 1, features_5_3_bn3)
        features_5_3_id_relu = self.features_5_3_id_relu(features_5_3_merge)
        features_6_0_conv1 = self.features_6_0_conv1(features_5_3_id_relu)
        features_6_0_bn1 = self.features_6_0_bn1(features_6_0_conv1)
        features_6_0_relu1 = self.features_6_0_relu1(features_6_0_bn1)
        features_6_0_conv2 = self.features_6_0_conv2(features_6_0_relu1)
        features_6_0_bn2 = self.features_6_0_bn2(features_6_0_conv2)
        features_6_0_relu2 = self.features_6_0_relu2(features_6_0_bn2)
        features_6_0_conv3 = self.features_6_0_conv3(features_6_0_relu2)
        features_6_0_bn3 = self.features_6_0_bn3(features_6_0_conv3)
        features_6_0_downsample_0 = self.features_6_0_downsample_0(features_5_3_id_relu)
        features_6_0_downsample_1 = self.features_6_0_downsample_1(features_6_0_downsample_0)
        features_6_0_merge = torch.add(features_6_0_downsample_1, 1, features_6_0_bn3)
        features_6_0_id_relu = self.features_6_0_id_relu(features_6_0_merge)
        features_6_1_conv1 = self.features_6_1_conv1(features_6_0_id_relu)
        features_6_1_bn1 = self.features_6_1_bn1(features_6_1_conv1)
        features_6_1_relu1 = self.features_6_1_relu1(features_6_1_bn1)
        features_6_1_conv2 = self.features_6_1_conv2(features_6_1_relu1)
        features_6_1_bn2 = self.features_6_1_bn2(features_6_1_conv2)
        features_6_1_relu2 = self.features_6_1_relu2(features_6_1_bn2)
        features_6_1_conv3 = self.features_6_1_conv3(features_6_1_relu2)
        features_6_1_bn3 = self.features_6_1_bn3(features_6_1_conv3)
        features_6_1_merge = torch.add(features_6_0_id_relu, 1, features_6_1_bn3)
        features_6_1_id_relu = self.features_6_1_id_relu(features_6_1_merge)
        features_6_2_conv1 = self.features_6_2_conv1(features_6_1_id_relu)
        features_6_2_bn1 = self.features_6_2_bn1(features_6_2_conv1)
        features_6_2_relu1 = self.features_6_2_relu1(features_6_2_bn1)
        features_6_2_conv2 = self.features_6_2_conv2(features_6_2_relu1)
        features_6_2_bn2 = self.features_6_2_bn2(features_6_2_conv2)
        features_6_2_relu2 = self.features_6_2_relu2(features_6_2_bn2)
        features_6_2_conv3 = self.features_6_2_conv3(features_6_2_relu2)
        features_6_2_bn3 = self.features_6_2_bn3(features_6_2_conv3)
        features_6_2_merge = torch.add(features_6_1_id_relu, 1, features_6_2_bn3)
        features_6_2_id_relu = self.features_6_2_id_relu(features_6_2_merge)
        features_6_3_conv1 = self.features_6_3_conv1(features_6_2_id_relu)
        features_6_3_bn1 = self.features_6_3_bn1(features_6_3_conv1)
        features_6_3_relu1 = self.features_6_3_relu1(features_6_3_bn1)
        features_6_3_conv2 = self.features_6_3_conv2(features_6_3_relu1)
        features_6_3_bn2 = self.features_6_3_bn2(features_6_3_conv2)
        features_6_3_relu2 = self.features_6_3_relu2(features_6_3_bn2)
        features_6_3_conv3 = self.features_6_3_conv3(features_6_3_relu2)
        features_6_3_bn3 = self.features_6_3_bn3(features_6_3_conv3)
        features_6_3_merge = torch.add(features_6_2_id_relu, 1, features_6_3_bn3)
        features_6_3_id_relu = self.features_6_3_id_relu(features_6_3_merge)
        features_6_4_conv1 = self.features_6_4_conv1(features_6_3_id_relu)
        features_6_4_bn1 = self.features_6_4_bn1(features_6_4_conv1)
        features_6_4_relu1 = self.features_6_4_relu1(features_6_4_bn1)
        features_6_4_conv2 = self.features_6_4_conv2(features_6_4_relu1)
        features_6_4_bn2 = self.features_6_4_bn2(features_6_4_conv2)
        features_6_4_relu2 = self.features_6_4_relu2(features_6_4_bn2)
        features_6_4_conv3 = self.features_6_4_conv3(features_6_4_relu2)
        features_6_4_bn3 = self.features_6_4_bn3(features_6_4_conv3)
        features_6_4_merge = torch.add(features_6_3_id_relu, 1, features_6_4_bn3)
        features_6_4_id_relu = self.features_6_4_id_relu(features_6_4_merge)
        features_6_5_conv1 = self.features_6_5_conv1(features_6_4_id_relu)
        features_6_5_bn1 = self.features_6_5_bn1(features_6_5_conv1)
        features_6_5_relu1 = self.features_6_5_relu1(features_6_5_bn1)
        features_6_5_conv2 = self.features_6_5_conv2(features_6_5_relu1)
        features_6_5_bn2 = self.features_6_5_bn2(features_6_5_conv2)
        features_6_5_relu2 = self.features_6_5_relu2(features_6_5_bn2)
        features_6_5_conv3 = self.features_6_5_conv3(features_6_5_relu2)
        features_6_5_bn3 = self.features_6_5_bn3(features_6_5_conv3)
        features_6_5_merge = torch.add(features_6_4_id_relu, 1, features_6_5_bn3)
        features_6_5_id_relu = self.features_6_5_id_relu(features_6_5_merge)
        features_7_0_conv1 = self.features_7_0_conv1(features_6_5_id_relu)
        features_7_0_bn1 = self.features_7_0_bn1(features_7_0_conv1)
        features_7_0_relu1 = self.features_7_0_relu1(features_7_0_bn1)
        features_7_0_conv2 = self.features_7_0_conv2(features_7_0_relu1)
        features_7_0_bn2 = self.features_7_0_bn2(features_7_0_conv2)
        features_7_0_relu2 = self.features_7_0_relu2(features_7_0_bn2)
        features_7_0_conv3 = self.features_7_0_conv3(features_7_0_relu2)
        features_7_0_bn3 = self.features_7_0_bn3(features_7_0_conv3)
        features_7_0_downsample_0 = self.features_7_0_downsample_0(features_6_5_id_relu)
        features_7_0_downsample_1 = self.features_7_0_downsample_1(features_7_0_downsample_0)
        features_7_0_merge = torch.add(features_7_0_downsample_1, 1, features_7_0_bn3)
        features_7_0_id_relu = self.features_7_0_id_relu(features_7_0_merge)
        features_7_1_conv1 = self.features_7_1_conv1(features_7_0_id_relu)
        features_7_1_bn1 = self.features_7_1_bn1(features_7_1_conv1)
        features_7_1_relu1 = self.features_7_1_relu1(features_7_1_bn1)
        features_7_1_conv2 = self.features_7_1_conv2(features_7_1_relu1)
        features_7_1_bn2 = self.features_7_1_bn2(features_7_1_conv2)
        features_7_1_relu2 = self.features_7_1_relu2(features_7_1_bn2)
        features_7_1_conv3 = self.features_7_1_conv3(features_7_1_relu2)
        features_7_1_bn3 = self.features_7_1_bn3(features_7_1_conv3)
        features_7_1_merge = torch.add(features_7_0_id_relu, 1, features_7_1_bn3)
        features_7_1_id_relu = self.features_7_1_id_relu(features_7_1_merge)
        features_7_2_conv1 = self.features_7_2_conv1(features_7_1_id_relu)
        features_7_2_bn1 = self.features_7_2_bn1(features_7_2_conv1)
        features_7_2_relu1 = self.features_7_2_relu1(features_7_2_bn1)
        features_7_2_conv2 = self.features_7_2_conv2(features_7_2_relu1)
        features_7_2_bn2 = self.features_7_2_bn2(features_7_2_conv2)
        features_7_2_relu2 = self.features_7_2_relu2(features_7_2_bn2)
        features_7_2_conv3 = self.features_7_2_conv3(features_7_2_relu2)
        features_7_2_bn3 = self.features_7_2_bn3(features_7_2_conv3)
        features_7_2_merge = torch.add(features_7_1_id_relu, 1, features_7_2_bn3)
        features_7_2_id_relu = self.features_7_2_id_relu(features_7_2_merge)
        features_8 = self.features_8(features_7_2_id_relu)
        classifier_flatten = features_8.view(features_8.size(0), -1)
        logits = self.fc(classifier_flatten)

        #No Normalization is used if N-Pair Loss is the target criterion.
        return torch.nn.functional.normalize(logits, dim=-1)

    def load_pth(self, weights_path):
        if weights_path:
            state_dict = torch.load(weights_path)
            self.load_state_dict(state_dict)

    def to_optim(self, opt):
        return [{'params':self.parameters(),'lr':opt.lr,'weight_decay':opt.decay}]

In [4]:
model = ResNet50_mcn()

In [5]:
model.to(device)

ResNet50_mcn(
  (features_0): Conv2d(3, 64, kernel_size=[7, 7], stride=(2, 2), padding=(3, 3), bias=False)
  (features_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (features_2): ReLU(inplace=True)
  (features_3): MaxPool2d(kernel_size=[3, 3], stride=[2, 2], padding=1, dilation=1, ceil_mode=False)
  (features_4_0_conv1): Conv2d(64, 64, kernel_size=[1, 1], stride=(1, 1), bias=False)
  (features_4_0_bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (features_4_0_relu1): ReLU(inplace=True)
  (features_4_0_conv2): Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1), bias=False)
  (features_4_0_bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (features_4_0_relu2): ReLU(inplace=True)
  (features_4_0_conv3): Conv2d(64, 256, kernel_size=[1, 1], stride=(1, 1), bias=False)
  (features_4_0_bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [6]:
to_optim = [{'params': model.parameters(), 'lr': 0.00001, 'weight_decay': 0.0004}]

In [7]:
def give_dataloaders():
    """
    Args:
        dataset: string, name of dataset for which the dataloaders should be returned.
        opt:     argparse.Namespace, contains all training-specific parameters.
    Returns:
        dataloaders: dict of dataloaders for training, testing and evaluation on training.
    """
    #Dataset selection
    
    
    datasets = give_OnlineProducts_datasets(os.getcwd()+'/Datasets/online_products')
    

    #Move datasets to dataloaders.
    dataloaders = {}
    for key, dataset in datasets.items():
        if isinstance(dataset, SuperLabelTrainDataset) and key == 'training':
            # important: use a SequentialSampler
            # see reasoning in class definition of SuperLabelTrainDataset
            dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=112, 
                    num_workers=8, sampler=torch.utils.data.SequentialSampler(dataset), 
                    pin_memory=True, drop_last=False)
        else:
            is_val = dataset.is_validation
            dataloaders[key] = torch.utils.data.DataLoader(dataset, batch_size=112, 
                    num_workers=8, shuffle=not is_val, pin_memory=True, drop_last=not is_val)

    return dataloaders

In [8]:
def give_OnlineProducts_datasets(source_path):
    """
    This function generates a training, testing and evaluation dataloader for Metric Learning on the Online-Products dataset.
    For Metric Learning, training and test sets are provided by given text-files, Ebay_train.txt & Ebay_test.txt.
    So no random shuffling of classes.

    Args:
        opt: argparse.Namespace, contains all traininig-specific parameters.
    Returns:
        dict of PyTorch datasets for training, testing and evaluation.
    """
    image_sourcepath  = source_path+'/images'
    #Load text-files containing classes and imagepaths.
    training_files = pd.read_table(source_path+'/Info_Files/Ebay_train.txt', header=0, delimiter=' ')
    test_files     = pd.read_table(source_path+'/Info_Files/Ebay_test.txt', header=0, delimiter=' ')

    #Generate Conversion dict.
    conversion = {}
    for class_id, path in zip(training_files['class_id'],training_files['path']):
        conversion[class_id] = path.split('/')[0]
    for class_id, path in zip(test_files['class_id'],test_files['path']):
        conversion[class_id] = path.split('/')[0]

    #Generate image_dicts of shape {class_idx:[list of paths to images belong to this class] ...}
    train_image_dict, val_image_dict  = {},{}
    for key, img_path in zip(training_files['class_id'],training_files['path']):
        key = key-1
        if not key in train_image_dict.keys():
            train_image_dict[key] = []
        train_image_dict[key].append(image_sourcepath+'/'+img_path)

    for key, img_path in zip(test_files['class_id'],test_files['path']):
        key = key-1
        if not key in val_image_dict.keys():
            val_image_dict[key] = []
        val_image_dict[key].append(image_sourcepath+'/'+img_path)

    ### Uncomment this if super-labels should be used to generate resp.datasets
    # super_conversion = {}
    # for super_class_id, path in zip(training_files['super_class_id'],training_files['path']):
    #     conversion[super_class_id] = path.split('/')[0]
    # for key, img_path in zip(training_files['super_class_id'],training_files['path']):
    #     key = key-1
    #     if not key in super_train_image_dict.keys():
    #         super_train_image_dict[key] = []
    #     super_train_image_dict[key].append(image_sourcepath+'/'+img_path)
    # super_train_dataset = BaseTripletDataset(super_train_image_dict, opt, is_validation=True)
    # super_train_dataset.conversion = super_conversion

    
    super_dict = {}
    for cid, scid, path in zip(training_files['class_id'], training_files['super_class_id'], training_files['path']):
        cid  = cid - 1
        scid = scid - 1
        if not scid in super_dict.keys():
            super_dict[scid] = {}
        if not cid in super_dict[scid].keys():
            super_dict[scid][cid] = []
        super_dict[scid][cid].append(image_sourcepath+'/'+path)
        train_dataset = SuperLabelTrainDataset(super_dict)
    

    val_dataset   = BaseTripletDataset(val_image_dict, is_validation=True)
    eval_dataset  = BaseTripletDataset(train_image_dict, is_validation=True)

    train_dataset.conversion       = conversion
    val_dataset.conversion         = conversion
    eval_dataset.conversion        = conversion

    return {'training':train_dataset, 'testing':val_dataset, 'evaluation':eval_dataset}

In [9]:
from torch.utils.data import Dataset
class BaseTripletDataset(Dataset):
    """
    Dataset class to provide (augmented) correctly prepared training samples corresponding to standard DML literature.
    This includes normalizing to ImageNet-standards, and Random & Resized cropping of shapes 224 for ResNet50 and 227 for
    GoogLeNet during Training. During validation, only resizing to 256 or center cropping to 224/227 is performed.
    """
    def __init__(self, image_dict, samples_per_class=8, is_validation=False):
        """
        Dataset Init-Function.

        Args:
            image_dict:         dict, Dictionary of shape {class_idx:[list of paths to images belong to this class] ...} providing all the training paths and classes.
            opt:                argparse.Namespace, contains all training-specific parameters.
            samples_per_class:  Number of samples to draw from one class before moving to the next when filling the batch.
            is_validation:      If is true, dataset properties for validation/testing are used instead of ones for training.
        Returns:
            Nothing!
        """
        #Define length of dataset
        self.n_files     = np.sum([len(image_dict[key]) for key in image_dict.keys()])

        self.is_validation = is_validation

        
        self.image_dict  = image_dict

        self.avail_classes    = sorted(list(self.image_dict.keys()))

        #Convert image dictionary from classname:content to class_idx:content, because the initial indices are not necessarily from 0 - <n_classes>.
        self.image_dict    = {i:self.image_dict[key] for i,key in enumerate(self.avail_classes)}
        self.avail_classes = sorted(list(self.image_dict.keys()))

        #Init. properties that are used when filling up batches.
        if not self.is_validation:
            self.samples_per_class = samples_per_class
            #Select current class to sample images from up to <samples_per_class>
            self.current_class   = np.random.randint(len(self.avail_classes))
            self.classes_visited = [self.current_class, self.current_class]
            self.n_samples_drawn = 0

        #Data augmentation/processing methods.
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        transf_list = []
        if not self.is_validation:
            transf_list.extend([transforms.RandomResizedCrop(size=224)])
        else:
            transf_list.extend([transforms.Resize(256),
                                transforms.CenterCrop(224)])

        transf_list.extend([transforms.ToTensor(), normalize])
        self.transform = transforms.Compose(transf_list)

        #Convert Image-Dict to list of (image_path, image_class). Allows for easier direct sampling.
        self.image_list = [[(x,key) for x in self.image_dict[key]] for key in self.image_dict.keys()]
        self.image_list = [x for y in self.image_list for x in y]

        #Flag that denotes if dataset is called for the first time.
        self.is_init = True


    def ensure_3dim(self, img):
        """
        Function that ensures that the input img is three-dimensional.

        Args:
            img: PIL.Image, image which is to be checked for three-dimensionality (i.e. if some images are black-and-white in an otherwise coloured dataset).
        Returns:
            Checked PIL.Image img.
        """
        if len(img.size)==2:
            img = img.convert('RGB')
        return img


    def __getitem__(self, idx):
        """
        Args:
            idx: Sample idx for training sample
        Returns:
            tuple of form (sample_class, torch.Tensor() of input image)
        """
        if self.is_init:
            self.current_class = self.avail_classes[idx%len(self.avail_classes)]
            self.is_init = False

        if not self.is_validation:
            if self.samples_per_class==1:
                return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))

            if (self.samples_per_class == 0 and self.n_samples_drawn == len(self.image_dict[self.current_class])
                or self.n_samples_drawn == self.samples_per_class):
                #Once enough samples per class have been drawn, we choose another class to draw samples from.
                #Note that we ensure with self.classes_visited that no class is chosen if it had been chosen
                #previously or one before that.
                #NOTE: if self.samples_per_class is 0, then use all the images from current_class
                counter = copy.deepcopy(self.avail_classes)
                for prev_class in self.classes_visited:
                    if prev_class in counter: counter.remove(prev_class)

                self.current_class   = counter[idx%len(counter)]
                self.classes_visited = self.classes_visited[1:]+[self.current_class]
                self.n_samples_drawn = 0

            class_sample_idx = idx%len(self.image_dict[self.current_class])
            self.n_samples_drawn += 1

            out_img = self.transform(self.ensure_3dim(Image.open(self.image_dict[self.current_class][class_sample_idx])))
            return self.current_class,out_img
        else:
            return self.image_list[idx][-1], self.transform(self.ensure_3dim(Image.open(self.image_list[idx][0])))

    def __len__(self):
        return self.n_files

In [10]:
class SuperLabelTrainDataset(Dataset):
    """
    Dataset class to provide (augmented) correctly prepared training samples, utilizing
    super-label information to construct the batches.

    Each batch takes a pair of super-labels (s1,s2). Then, for each s{i}, sample half the batch
    from classes belonging to it.

    NOTE: 
        SuperLabelTrainDataset implements a custom reshuffle(), so it's important that DataLoader 
        does NOT do further randomization. This means it should use a SequentialSampler.
    TODO:
        support samples_per_class
    """
    def __init__(self, image_dict, super_pairs=None):
        """
        Args:
            image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of 
                        image paths having the same super-label and class label
        """
        self.batch_size = 112
        self.batches_per_super_pair = 5
        self.samples_per_class = 0

        # checks
        assert self.batch_size % 2 == 0, "opt.bs should be an even number"
        self.half_bs = self.batch_size // 2
        if self.samples_per_class > 0:
            assert self.half_bs % self.samples_per_class == 0, "opt.bs not a multiple of opt.samples_per_class"

        # provide avail_classes
        self.avail_classes = []
        for sid in image_dict.keys():
            self.avail_classes += list(image_dict[sid].keys())

        # Data augmentation/processing methods.
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        transf_list = []
        transf_list.extend([
            transforms.RandomResizedCrop(size=224)])
        transf_list.extend([transforms.ToTensor(), normalize])
        self.transform = transforms.Compose(transf_list)

        # for each super-label, store a list of lists:
        # super_image_lists[super0]: [
        #       [(class0, image0), (class0, image1), ...], 
        #       [(class1, image0), (class1, image1), ...], 
        #     ], 
        # ...
        self.super_image_lists = {}
        for sid in image_dict.keys():
            self.super_image_lists[sid] = []
            for cid in image_dict[sid].keys():
                cur_cid_list = list(itertools.product([cid], image_dict[sid][cid]))
                self.super_image_lists[sid].append(cur_cid_list)

        if super_pairs is None:
            self.super_pairs = list(itertools.combinations(image_dict.keys(), 2))
        else:
            self.super_pairs = super_pairs  # allow super_pairs to be supplied

        self.reshuffle()


    def ensure_3dim(self, img):
        if len(img.size) == 2:
            img = img.convert('RGB')
        return img


    def reshuffle(self):
        # for each super-label, concat all images into a long list:
        # super_images[super0]: [
        #       (class0_in_super0, image0), (class0_in_super0, image1), ...
        #       (class1_in_super0, image0), (class1_in_super0, image1), ...
        #       ...
        #     ] 
        # ...
        super_images, num_images, cur_pos = {}, {}, {}

        for sid in self.super_image_lists.keys():
            all_imgs_in_super = self.super_image_lists[sid]

            if self.samples_per_class > 0:
                chunks_list = []
                for cls_imgs in all_imgs_in_super:
                    random.shuffle(cls_imgs)
                    num = len(cls_imgs)
                    # take chunks of size `samples_per_class` and append to chunks_list
                    for c in range(math.ceil(num / self.samples_per_class)):
                        inds = [i % num for i in range(c*self.samples_per_class, (c+1)*self.samples_per_class)]
                        chunks_list.append([cls_imgs[i] for i in inds])
                # concat a "list of lists" into a long list
                random.shuffle(chunks_list)
                super_images[sid] = list(itertools.chain.from_iterable(chunks_list))
            else:
                for cls_imgs in all_imgs_in_super:
                    random.shuffle(cls_imgs)  # shuffle images in each class
                # concat a "list of lists" into a long list
                random.shuffle(all_imgs_in_super)
                super_images[sid] = list(itertools.chain.from_iterable(all_imgs_in_super))

            num_images[sid] = len(super_images[sid])
            cur_pos[sid] = 0

        # pre-compute all the batches
        # batches = [
        #   [(cid,img), (cid,img), ...],   # batch No.0
        #   [(cid,img), (cid,img), ...],   # batch No.1
        #   ...
        # ]
        self.batches = []

        # for each pair of super-labels, e.g. (bicycle, chair)
        for pair in self.super_pairs:
            s0, s1 = pair
            # sample `batches_per_super_pair` batches
            for b in range(self.batches_per_super_pair):
                # get half of the batch from each super-label
                ind0 = [(cur_pos[s0]+i) % num_images[s0] for i in range(self.half_bs)]
                ind1 = [(cur_pos[s1]+i) % num_images[s1] for i in range(self.half_bs)]
                cur_batch = [super_images[s0][i] for i in ind0] + [super_images[s1][i] for i in ind1]

                # move pointers and append to list
                cur_pos[s0] = (ind0[-1] + 1) % num_images[s0]
                cur_pos[s1] = (ind1[-1] + 1) % num_images[s1]
                self.batches.append(cur_batch)


    def __getitem__(self, idx):
        # we use SequentialSampler together with SuperLabelTrainDataset,
        # so idx==0 indicates the start of a new epoch
        if idx == 0:
            self.reshuffle()

        batch_idx    = idx // self.batch_size  # global batch index
        batch_offset = idx % self.batch_size   # offset from start of this batch
        batch_item   = self.batches[batch_idx][batch_offset]

        cls = batch_item[0]
        img = Image.open(batch_item[1])
        return cls, self.transform(self.ensure_3dim(img))


    def __len__(self):
        return len(self.batches) * self.batch_size


In [11]:
dataloaders      = give_dataloaders()

In [12]:
num_classes  = len(dataloaders['training'].dataset.avail_classes)

In [13]:
num_classes

11318

In [14]:
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

from PIL import Image

# FastAP loss proposed in CVPR'19 paper "Deep Metric Learning to Rank"
from FastAP_loss import FastAPLoss

In [15]:
histbins = 10
loss_params  = {'num_bins':histbins}
criterion    = FastAPLoss(**loss_params)

In [16]:
to_optim, criterion

([{'params': <generator object Module.parameters at 0x7fd16b9f34d0>,
   'lr': 1e-05,
   'weight_decay': 0.0004}],
 FastAPLoss())

In [17]:
optimizer    = torch.optim.Adam(to_optim)
scheduler    = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50], gamma=0.3)

In [18]:
def train_one_epoch(train_dataloader, model, optimizer, criterion, epoch):
    """
    This function is called every epoch to perform training of the network over one full
    (randomized) iteration of the dataset.

    Args:
        train_dataloader: torch.utils.data.DataLoader, returns (augmented) training data.
        model:            Network to train.
        optimizer:        Optimizer to use for training.
        criterion:        criterion to use during training.
        opt:              argparse.Namespace, Contains all relevant parameters.
        epoch:            int, Current epoch.

    Returns:
        Nothing!
    """
    loss_collect = []

    start = time.time()

    data_iterator = tqdm(train_dataloader, desc='Epoch {} Training...'.format(epoch))
    for i,(class_labels, input) in enumerate(data_iterator):
        #Compute embeddings for input batch.
        features  = model(input.to(device))
        #Compute loss.
        loss      = criterion(features, class_labels)

        #Ensure gradients are set to zero at beginning
        optimizer.zero_grad()
        #Compute gradients.
        loss.backward()

        

        #Update weights using comp. gradients.
        optimizer.step()

        #Store loss per iteration.
        loss_collect.append(loss.item())
        if i==len(train_dataloader)-1: data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect)))

    


In [19]:
_ = criterion.to(device)

In [20]:
_ = model.train()
train_one_epoch(dataloaders['training'], model, optimizer, criterion, 1)

Epoch (Train) 1: Mean Loss [0.8028]: 100%|██████████| 330/330 [10:48:44<00:00, 117.95s/it]


In [21]:
_ = model.eval()

In [23]:
eval_params = {'dataloader':dataloaders['testing'], 'model':model, 'epoch':1}

In [29]:
from sklearn import metrics
from sklearn.cluster import KMeans
from scipy.spatial.distance import squareform, pdist, cdist
def eval_metrics_one_dataset(model, test_dataloader, device, k_vals):
    """
    Compute evaluation metrics on test-dataset, e.g. NMI, F1 and Recall @ k.

    Args:
        model:              PyTorch network, network to compute evaluation metrics for.
        test_dataloader:    PyTorch Dataloader, dataloader for test dataset, should have no shuffling and correct processing.
        device:             torch.device, Device to run inference on.
        k_vals:             list of int, Recall values to compute
        
    Returns:
        F1 score (float), NMI score (float), recall_at_k (list of float), data embedding (np.ndarray)
    """

    _ = model.eval()
    n_classes = len(test_dataloader.dataset.avail_classes)

    with torch.no_grad():
        ### For all test images, extract features
        target_labels, feature_coll = [],[]
        final_iter = tqdm(test_dataloader, desc='Computing Evaluation Metrics...')
        image_paths= [x[0] for x in test_dataloader.dataset.image_list]
        for idx,inp in enumerate(final_iter):
            input_img,target = inp[-1], inp[0]
            target_labels.extend(target.numpy().tolist())
            out = model(input_img.to(device))
            feature_coll.extend(out.cpu().detach().numpy().tolist())

        target_labels = np.hstack(target_labels).reshape(-1,1)
        feature_coll  = np.vstack(feature_coll).astype('float32')


        ### Set Faiss CPU Cluster index
        # cpu_cluster_index = faiss.IndexFlatL2(feature_coll.shape[-1])
        # kmeans            = faiss.Clustering(feature_coll.shape[-1], n_classes)
        # kmeans.niter = 20
        # kmeans.min_points_per_centroid = 1
        # kmeans.max_points_per_centroid = 1000000000

        ### Train Kmeans
        # kmeans.train(feature_coll, cpu_cluster_index)
        # computed_centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, feature_coll.shape[-1])

        ### Assign feature points to clusters
        # faiss_search_index = faiss.IndexFlatL2(computed_centroids.shape[-1])
        # faiss_search_index.add(computed_centroids)
        # _, model_generated_cluster_labels = faiss_search_index.search(feature_coll, 1)

        kmeans = KMeans(n_clusters=n_classes, random_state=0).fit(feature_coll)
        model_generated_cluster_labels = kmeans.labels_
        computed_centroids = kmeans.cluster_centers_

        ### Compute NMI
        NMI = metrics.cluster.normalized_mutual_info_score(model_generated_cluster_labels.reshape(-1), target_labels.reshape(-1))


        ### Recover max(k_vals) nearest neighbours to use for recall computation
        # faiss_search_index  = faiss.IndexFlatL2(feature_coll.shape[-1])
        # faiss_search_index.add(feature_coll)
        # _, k_closest_points = faiss_search_index.search(feature_coll, int(np.max(k_vals)+1))

        k_closest_points  = squareform(pdist(feature_coll)).argsort(1)[:, :int(np.max(k_vals)+1)]
        k_closest_classes = target_labels.reshape(-1)[k_closest_points[:, 1:]]

        ### Compute Recall
        recall_all_k = []
        for k in k_vals:
            recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(target_labels)
            recall_all_k.append(recall_at_k)

        ### Compute F1 Score
        F1 = f1_score(model_generated_cluster_labels, target_labels, feature_coll, computed_centroids)

    return F1, NMI, recall_all_k, feature_coll

In [30]:
start = time.time()
image_paths = np.array(dataloaders['testing'].dataset.image_list)

In [None]:
with torch.no_grad():
        #Compute Metrics
        F1, NMI, recall_at_ks, feature_matrix_all = eval_metrics_one_dataset(model, dataloaders['testing'], 'cpu', [1,2,4,8])
        #Make printable summary string.
        result_str = ', '.join('@{0}: {1:.4f}'.format(k,rec) for k,rec in zip(opt.k_vals, recall_at_ks))
        result_str = 'Epoch (Test) {0}: NMI [{1:.4f}] | F1 [{2:.4f}] | Recall [{3}]'.format(epoch, NMI, F1, result_str)


Computing Evaluation Metrics...: 100%|██████████| 541/541 [5:43:09<00:00, 38.06s/it]  


In [None]:
print(result_str)