In [6]:
import numpy as np
import cv2
np.random.seed(0)

class GaussianBlur(object):
    # Implements Gaussian blur as described in the SimCLR paper
    def __init__(self, kernel_size, min=0.1, max=2.0):
        self.min = min
        self.max = max
        # kernel size is set to be 10% of the image height/width
        self.kernel_size = kernel_size

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < 0.5:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample

In [7]:
import math
import os
import statistics as stat
import sys

class Logger(object):

    def __init__(self, log_path, on=True):
        self.log_path = log_path
        self.on = False

        if self.on:
            while os.path.isfile(self.log_path):
                self.log_path += '+'

    def log(self, string, newline=True):
        if self.on:
            with open(self.log_path, 'a') as logf:
                logf.write(string)
                if newline: logf.write('\n')

        sys.stdout.write(string)
        if newline: sys.stdout.write('\n')
        sys.stdout.flush()

    def log_perfs(self, perfs):
        valid_perfs = [perf for perf in perfs if not math.isinf(perf)]
        best_perf = max(valid_perfs)
        self.log('-' * 89)
        self.log('%d perfs: %s' % (len(perfs), str(perfs)))
        self.log('perf max: %g' % best_perf)
        self.log('perf min: %g' % min(valid_perfs))
        self.log('perf avg: %g' % stat.mean(valid_perfs))
        self.log('perf std: %g' % (stat.stdev(valid_perfs)
                                     if len(valid_perfs) > 1 else 0.0))
        self.log('(excluded %d out of %d runs that produced -inf)' %
                 (len(perfs) - len(valid_perfs), len(perfs)))
        self.log('-' * 89)

In [8]:
from torch.autograd import Variable
import numpy as np


def compress(train, test, encode_discrete, device):
    retrievalB = list([])
    retrievalL = list([])
    for batch_step, (data, target) in enumerate(train):
        var_data = Variable(data.to(device))
        code = encode_discrete(var_data)
        retrievalB.extend(code.cpu().data.numpy())
        retrievalL.extend(target)

    queryB = list([])
    queryL = list([])
    for batch_step, (data, target) in enumerate(test):
        var_data = Variable(data.to(device))
        code = encode_discrete(var_data)
        queryB.extend(code.cpu().data.numpy())
        queryL.extend(target)

    retrievalB = np.array(retrievalB)
    retrievalL = np.stack(retrievalL)

    queryB = np.array(queryB)
    queryL = np.stack(queryL)
    return retrievalB, retrievalL, queryB, queryL


def calculate_hamming(B1, B2):
    """
    :param B1:  vector [n]
    :param B2:  vector [r*n]
    :return: hamming distance [r]
    """
    q = B2.shape[1] # max inner product value
    distH = 0.5 * (q - np.dot(B1, B2.transpose()))
    return distH


def calculate_top_map(qB, rB, queryL, retrievalL, topk):
    """
    :param qB: {-1,+1}^{mxq} query bits
    :param rB: {-1,+1}^{nxq} retrieval bits
    :param queryL: {0,1}^{mxl} query label
    :param retrievalL: {0,1}^{nxl} retrieval label
    :param topk:
    :return:
    """
    num_query = queryL.shape[0]
    topkmap = 0
    for iter in range(num_query):
        gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
        hamm = calculate_hamming(qB[iter, :], rB)
        ind = np.argsort(hamm)
        gnd = gnd[ind] # reorder gnd

        tgnd = gnd[0:topk]
        tsum = int(np.sum(tgnd))
        if tsum == 0:
            continue
        count = np.linspace(1, tsum, tsum)

        tindex = np.asarray(np.where(tgnd == 1)) + 1.0
        topkmap_ = np.mean(count / (tindex))
        # print(topkmap_)
        topkmap = topkmap + topkmap_
    topkmap = topkmap / num_query
    return topkmap


In [9]:
import numpy as np
from PIL import Image
from torchvision import transforms
import torchvision.datasets as dsets
from torch.utils.data import Dataset, DataLoader

# from utils.gaussian_blur import GaussianBlur

class Data:
    def __init__(self, dataset):
        self.dataset = dataset
        self.load_datasets()

        # setup dataTransform
        color_jitter = transforms.ColorJitter(0.4,0.4,0.4,0.1)
        self.train_transforms = transforms.Compose([transforms.RandomResizedCrop(size = 224,scale=(0.5, 1.0)),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomApply([color_jitter], p = 0.7),
                                            transforms.RandomGrayscale(p  = 0.2),
                                            GaussianBlur(3),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                            ])
        self.test_transforms = transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.test_cifar10_transforms = transforms.Compose([
                                            transforms.Resize((224, 224)),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def load_datasets(self):
        raise NotImplementedError

    def get_loaders(self, batch_size, num_workers, shuffle_train=False,
                    get_test=True):
        train_dataset = MyTrainDataset(self.X_train, self.Y_train, self.train_transforms)

        if(self.dataset == 'cifar10'):
            val_dataset = MyTestDataset(self.X_val, self.Y_val, self.test_cifar10_transforms, self.dataset)
            test_dataset = MyTestDataset(self.X_test, self.Y_test, self.test_cifar10_transforms, self.dataset)
            database_dataset = MyTestDataset(self.X_database, self.Y_database, self.test_cifar10_transforms, self.dataset)
        else:
            val_dataset = MyTestDataset(self.X_val, self.Y_val, self.test_transforms, self.dataset)
            test_dataset = MyTestDataset(self.X_test, self.Y_test, self.test_transforms, self.dataset)
            database_dataset = MyTestDataset(self.X_database, self.Y_database, self.test_transforms, self.dataset)

        # DataLoader
        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                                                shuffle=shuffle_train,
                                                num_workers=num_workers)

        val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers)

        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers) if get_test else None

        database_loader = DataLoader(dataset=database_dataset, batch_size=batch_size,
                                                    shuffle=False,
                                                    num_workers=num_workers)

        return train_loader, val_loader, test_loader, database_loader

class LabeledData(Data):
    def __init__(self, dataset):
        super().__init__(dataset=dataset)

    def load_datasets(self):
        if(self.dataset == 'cifar10'):
            self.topK = 1000
            self.X_train, self.Y_train, self.X_val, self.Y_val, self.X_test, self.Y_test, self.X_database, self.Y_database = get_cifar()
        else:
            raise NotImplementedError("Please use the right dataset!")

class MyTrainDataset(Dataset):
    def __init__(self,data,labels, transform):
        self.data = data
        self.labels = labels
        self.transform  = transform
    def __getitem__(self, index):
        pilImg = Image.fromarray(self.data[index])
        imgi = self.transform(pilImg)
        imgj = self.transform(pilImg)
        return (imgi, imgj, self.labels[index])

    def __len__(self):
        return len(self.data)

class MyTestDataset(Dataset):
    def __init__(self,data,labels, transform,dataset):
        self.data = data
        self.labels = labels
        self.transform  = transform
        self.dataset = dataset
    def __getitem__(self, index):
        if self.dataset == 'cifar10':
            pilImg = Image.fromarray(self.data[index])
            return (self.transform(pilImg),self.labels[index])
        else:
            return (self.transform(self.data[index]),self.labels[index])

    def __len__(self):
        return len(self.data)

def get_cifar():
    # Dataset
    train_dataset = dsets.CIFAR10(root='./data/cifar10/',
                                train=True,
                                download=True)

    test_dataset = dsets.CIFAR10(root='./data/cifar10/',
                                train=False
                                )

    database_dataset = dsets.CIFAR10(root='./data/cifar10/',
                                    train=True
                                    )


    # train with 5000 images
    X = train_dataset.data
    L = np.array(train_dataset.targets)

    first = True
    for label in range(10):
        index = np.where(L == label)[0]
        N = index.shape[0]
        prem = np.random.permutation(N)
        index = index[prem]

        data = X[index[0:500]]
        labels = L[index[0: 500]]
        if first:
            Y_train = labels
            X_train = data
        else:
            Y_train = np.concatenate((Y_train, labels))
            X_train = np.concatenate((X_train, data))
        first = False

    Y_train = np.eye(10)[Y_train]


    idxs = list(range(len(test_dataset.data)))
    np.random.shuffle(idxs)
    test_data = np.array(test_dataset.data)
    test_tragets = np.array(test_dataset.targets)

    X_val = test_data[idxs[:5000]]
    Y_val = np.eye(10)[test_tragets[idxs[:5000]]]

    X_test = test_data[idxs[5000:]]
    Y_test = np.eye(10)[test_tragets[idxs[5000:]]]


    X_database = database_dataset.data
    Y_database = np.eye(10)[database_dataset.targets]

    return X_train, Y_train, X_val, Y_val, X_test, Y_test, X_database, Y_database



In [10]:
import math
import torch
import random
import pickle
import sklearn
import argparse
import numpy as np
import seaborn as sb
from PIL import Image
import torch.nn as nn
from copy import deepcopy
from datetime import timedelta
from matplotlib import gridspec
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.patheffects as pe
from collections import OrderedDict
from torch.autograd import Variable
from sklearn.datasets import load_digits
from timeit import default_timer as timer


# from utils.logger import Logger
# from utils.data import LabeledData
# from utils.evaluation import calculate_hamming
# from utils.evaluation import compress, calculate_top_map

class Base_Model(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.load_data()

    def load_data(self):
        self.data = LabeledData(self.hparams.dataset)

    def get_hparams_grid(self):
        raise NotImplementedError

    def define_parameters(self):
        raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError

    def run_training_sessions(self):
        logger = Logger(self.hparams.model_path + '.log', on=True)
        val_perfs = []
        best_val_perf = float('-inf')
        start = timer()
        random.seed(self.hparams.seed)  # For reproducible random runs

        for run_num in range(1, self.hparams.num_runs + 1):
            state_dict, val_perf = self.run_training_session(run_num, logger)
            val_perfs.append(val_perf)

            if val_perf > best_val_perf:
                best_val_perf = val_perf
                logger.log('----New best {:8.4f}, saving'.format(val_perf))
                torch.save({'hparams': self.hparams,
                            'state_dict': state_dict}, self.hparams.model_path)

        logger.log('Time: %s' % str(timedelta(seconds=round(timer() - start))))
        self.load()
        if self.hparams.num_runs > 1:
            logger.log_perfs(val_perfs)
            logger.log('best hparams: ' + self.flag_hparams())

        val_perf, test_perf = self.run_test()
        logger.log('Val:  {:8.4f}'.format(val_perf))
        logger.log('Test: {:8.4f}'.format(test_perf))

    def run_training_session(self, run_num, logger):
        self.train()

        # Scramble hyperparameters if number of runs is greater than 1.
        if self.hparams.num_runs > 1:
            logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))
            for hparam, values in self.get_hparams_grid().items():
                assert hasattr(self.hparams, hparam)
                self.hparams.__dict__[hparam] = random.choice(values)

        random.seed(self.hparams.seed)
        torch.manual_seed(self.hparams.seed)

        self.define_parameters()

        # if encode_length is 16, then al least 80 epochs!
        if self.hparams.encode_length == 16:
            self.hparams.epochs = max(80, self.hparams.epochs)

        logger.log('hparams: %s' % self.flag_hparams())

        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        self.to(device)

        optimizer = self.configure_optimizers()
        train_loader, val_loader, _, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=True, get_test=False)
        best_val_perf = float('-inf')
        best_state_dict = None
        bad_epochs = 0

        try:
            for epoch in range(1, self.hparams.epochs + 1):
                forward_sum = {}
                num_steps = 0
                for batch_num, batch in enumerate(train_loader):
                    optimizer.zero_grad()

                    imgi, imgj, _ = batch
                    imgi = imgi.to(device)
                    imgj = imgj.to(device)

                    forward = self.forward(imgi, imgj, device)

                    for key in forward:
                        if key in forward_sum:
                            forward_sum[key] += forward[key]
                        else:
                            forward_sum[key] = forward[key]
                    num_steps += 1

                    if math.isnan(forward_sum['loss']):
                        logger.log('Stopping epoch because loss is NaN')
                        break

                    forward['loss'].backward()
                    optimizer.step()

                if math.isnan(forward_sum['loss']):
                    logger.log('Stopping training session because loss is NaN')
                    break

                logger.log('End of epoch {:3d}'.format(epoch), False)
                logger.log(' '.join([' | {:s} {:8.4f}'.format(
                    key, forward_sum[key] / num_steps)
                                     for key in forward_sum]), True)

                if epoch % self.hparams.validate_frequency == 0:
                    print('evaluating...')
                    val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
                    logger.log(' | val perf {:8.4f}'.format(val_perf), False)

                    if val_perf > best_val_perf:
                        best_val_perf = val_perf
                        bad_epochs = 0
                        logger.log('\t\t*Best model so far, deep copying*')
                        best_state_dict = deepcopy(self.state_dict())
                    else:
                        bad_epochs += 1
                        logger.log('\t\tBad epoch %d' % bad_epochs)

                    if bad_epochs > self.hparams.num_bad_epochs:
                        break

        except KeyboardInterrupt:
            logger.log('-' * 89)
            logger.log('Exiting from training early')

        return best_state_dict, best_val_perf

    def evaluate(self, database_loader, val_loader, topK, device):
        self.eval()
        with torch.no_grad():
            retrievalB, retrievalL, queryB, queryL = compress(database_loader, val_loader, self.encode_discrete, device)
            result = calculate_top_map(qB=queryB, rB=retrievalB, queryL=queryL, retrievalL=retrievalL, topk=topK)
        self.train()
        return result

    def load(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        checkpoint = torch.load(self.hparams.model_path) if self.hparams.cuda \
                     else torch.load(self.hparams.model_path,
                                     map_location=torch.device('cpu'))
        if checkpoint['hparams'].cuda and not self.hparams.cuda:
            checkpoint['hparams'].cuda = False
        self.hparams = checkpoint['hparams']
        self.define_parameters()
        self.load_state_dict(checkpoint['state_dict'])
        self.to(device)

    def run_test(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        _, val_loader, test_loader, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
        test_perf = self.evaluate(database_loader, test_loader, self.data.topK, device)
        return val_perf, test_perf

    def run_retrieval_case_study(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        query_idxs = [0,2,5]
        X_database = self.data.X_database
        X_test = self.data.X_test
        X_case = torch.cat([self.data.test_cifar10_transforms(Image.fromarray(self.data.X_test[i])).unsqueeze(0) for i in query_idxs], dim=0)
        _, val_loader, test_loader, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        # get hash codes
        self.eval()
        with torch.no_grad():
            retrievalB = list([])
            for batch_step, (data, target) in enumerate(database_loader):
                var_data = Variable(data.to(device))
                code = self.encode_discrete(var_data)
                retrievalB.extend(code.cpu().data.numpy())

            queryB = list([])
            var_data = Variable(X_case.to(device))
            code = self.encode_discrete(var_data)
            queryB.extend(code.cpu().data.numpy())

        retrievalB = np.array(retrievalB)
        queryB = np.array(queryB)

        # get top 10 index
        top10_idx_list = []
        for idx in range(queryB.shape[0]):
            hamm = calculate_hamming(queryB[idx, :], retrievalB)
            ind = list(np.argsort(hamm)[:10])
            top10_idx_list.append(ind)

        # plot results
        fig = plt.figure(0, figsize = (5,1.2))
        fig.clf()
        gs = gridspec.GridSpec(queryB.shape[0], 12)
        gs.update(wspace = 0.0001, hspace = 0.0001)
        for i in range(queryB.shape[0]):
            axes = plt.subplot(gs[i,0])
            axes.imshow(X_test[query_idxs[i]])
            axes.axis('off')

            for j in range(0, 10):
                axes = plt.subplot(gs[i, j+2])
                axes.imshow(X_database[top10_idx_list[i][j]])
                axes.axis('off')
        fig.savefig("retrieval_case_study_{:d}bits.pdf".format(self.hparams.encode_length), bbox_inches='tight', pad_inches=0.0)

    def hash_code_visualization(self):
        """
        cifar10 labels:
        0: Airplane 1: Automobile 2: Bird 3: Cat 4: Deer
        5: Dog 6: Frog 7: Horse 8: Ship 9: Truck
        """
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        _, _, test_loader, _ = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        retrievalB = list([])
        retrievalL = list([])
        for batch_step, (data, target) in enumerate(test_loader):
            var_data = Variable(data.to(device))
            code = self.encode_discrete(var_data)
            retrievalB.extend(code.cpu().data.numpy())
            retrievalL.extend(target.cpu().data.numpy())

        hash_codes = np.array(retrievalB)
        _, labels = np.where(np.array(retrievalL) == 1)
        labels_ticks = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

        # TSN
        mapper = TSNE(perplexity=30).fit_transform(hash_codes)

        plt.figure(figsize=(8, 8))
        plt.scatter(mapper[:,0], mapper[:,1], lw=0, s=20, c=labels.astype(np.int), cmap='Spectral')
        # cbar = plt.colorbar(boundaries=np.arange(11)-0.5, fraction=0.046, pad=0.04)
        # cbar.set_ticks(np.arange(10))
        # cbar.set_ticklabels(labels_ticks)

        # Add the labels for each digit.
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(mapper[labels == i, :], axis=0)
            txt = plt.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])

        plt.axis("off")
        plt.gcf().tight_layout()
        plt.savefig('Ours_hash_codes_visulization_{:d}bits.pdf'.format(self.hparams.encode_length), bbox_inches='tight', pad_inches=0.0)

    def flag_hparams(self):
        flags = '%s' % (self.hparams.model_path)
        for hparam in vars(self.hparams):
            val = getattr(self.hparams, hparam)
            if str(val) == 'False':
                continue
            elif str(val) == 'True':
                flags += ' --%s' % (hparam)
            elif str(hparam) in {'model_path', 'num_runs',
                                 'num_workers'}:
                continue
            else:
                flags += ' --%s %s' % (hparam, val)
        return flags

    @staticmethod
    def get_general_hparams_grid():
        grid = OrderedDict({
            'seed': list(range(100000)),
            'lr': [0.003, 0.001, 0.0003, 0.0001],
            'batch_size': [64, 128, 256],
            })
        return grid

    @staticmethod
    def get_general_argparser():
        parser = argparse.ArgumentParser()

        parser.add_argument('model_path', type=str)
        parser.add_argument('--train', action='store_true',
                            help='train a model?')
        parser.add_argument('-d', '--dataset', default = 'cifar10', type=str,
                            help='dataset [%(default)s]')
        parser.add_argument("-l","--encode_length", type = int, default=16,
                            help = "Number of bits of the hash code [%(default)d]")
        parser.add_argument("--lr", default = 1e-3, type = float,
                            help='initial learning rate [%(default)g]')
        parser.add_argument("--batch_size", default=64,type=int,
                            help='batch size [%(default)d]')
        parser.add_argument("-e","--epochs", default=60, type=int,
                            help='max number of epochs [%(default)d]')
        parser.add_argument('--cuda', action='store_true',
                            help='use CUDA?')
        parser.add_argument('--num_runs', type=int, default=1,
                            help='num random runs (not random if 1) '
                            '[%(default)d]')
        parser.add_argument('--num_bad_epochs', type=int, default=6,
                            help='num indulged bad epochs [%(default)d]')
        parser.add_argument('--validate_frequency', type=int, default=20,
                            help='validate every [%(default)d] epochs')
        parser.add_argument('--num_workers', type=int, default=8,
                            help='num dataloader workers [%(default)d]')
        parser.add_argument('--seed', type=int, default=8888,
                            help='random seed [%(default)d]')
        parser.add_argument('--device', type=int, default=0,
                            help='device of the gpu')


        return parser



In [11]:
import math
import torch
import random
import pickle
import sklearn
import argparse
import numpy as np
import seaborn as sb
from PIL import Image
import torch.nn as nn
from copy import deepcopy
from datetime import timedelta
from matplotlib import gridspec
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import matplotlib.patheffects as pe
from collections import OrderedDict
from torch.autograd import Variable
from sklearn.datasets import load_digits
from timeit import default_timer as timer


# from utils.logger import Logger
# from utils.data import LabeledData
# from utils.evaluation import calculate_hamming
# from utils.evaluation import compress, calculate_top_map

class Base_Model(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.load_data()

    def load_data(self):
        self.data = LabeledData(self.hparams.dataset)

    def get_hparams_grid(self):
        raise NotImplementedError

    def define_parameters(self):
        raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError

    def run_training_sessions(self):
        logger = Logger(self.hparams.model_path + '.log', on=True)
        val_perfs = []
        best_val_perf = float('-inf')
        start = timer()
        random.seed(self.hparams.seed)  # For reproducible random runs

        for run_num in range(1, self.hparams.num_runs + 1):
            state_dict, val_perf = self.run_training_session(run_num, logger)
            val_perfs.append(val_perf)

            if val_perf > best_val_perf:
                best_val_perf = val_perf
                logger.log('----New best {:8.4f}, saving'.format(val_perf))
                torch.save({'hparams': self.hparams,
                            'state_dict': state_dict}, self.hparams.model_path)

        logger.log('Time: %s' % str(timedelta(seconds=round(timer() - start))))
        self.load()
        if self.hparams.num_runs > 1:
            logger.log_perfs(val_perfs)
            logger.log('best hparams: ' + self.flag_hparams())

        val_perf, test_perf = self.run_test()
        logger.log('Val:  {:8.4f}'.format(val_perf))
        logger.log('Test: {:8.4f}'.format(test_perf))

    def run_training_session(self, run_num, logger):
        self.train()

        # Scramble hyperparameters if number of runs is greater than 1.
        if self.hparams.num_runs > 1:
            logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))
            for hparam, values in self.get_hparams_grid().items():
                assert hasattr(self.hparams, hparam)
                self.hparams.__dict__[hparam] = random.choice(values)

        random.seed(self.hparams.seed)
        torch.manual_seed(self.hparams.seed)

        self.define_parameters()

        # if encode_length is 16, then al least 80 epochs!
        if self.hparams.encode_length == 16:
            self.hparams.epochs = max(80, self.hparams.epochs)

        logger.log('hparams: %s' % self.flag_hparams())

        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        self.to(device)

        optimizer = self.configure_optimizers()
        train_loader, val_loader, _, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=True, get_test=False)
        best_val_perf = float('-inf')
        best_state_dict = None
        bad_epochs = 0

        try:
            for epoch in range(1, self.hparams.epochs + 1):
                forward_sum = {}
                num_steps = 0
                for batch_num, batch in enumerate(train_loader):
                    optimizer.zero_grad()

                    imgi, imgj, _ = batch
                    imgi = imgi.to(device)
                    imgj = imgj.to(device)

                    forward = self.forward(imgi, imgj, device)

                    for key in forward:
                        if key in forward_sum:
                            forward_sum[key] += forward[key]
                        else:
                            forward_sum[key] = forward[key]
                    num_steps += 1

                    if math.isnan(forward_sum['loss']):
                        logger.log('Stopping epoch because loss is NaN')
                        break

                    forward['loss'].backward()
                    optimizer.step()

                if math.isnan(forward_sum['loss']):
                    logger.log('Stopping training session because loss is NaN')
                    break

                logger.log('End of epoch {:3d}'.format(epoch), False)
                logger.log(' '.join([' | {:s} {:8.4f}'.format(
                    key, forward_sum[key] / num_steps)
                                     for key in forward_sum]), True)

                if epoch % self.hparams.validate_frequency == 0:
                    print('evaluating...')
                    val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
                    logger.log(' | val perf {:8.4f}'.format(val_perf), False)

                    if val_perf > best_val_perf:
                        best_val_perf = val_perf
                        bad_epochs = 0
                        logger.log('\t\t*Best model so far, deep copying*')
                        best_state_dict = deepcopy(self.state_dict())
                    else:
                        bad_epochs += 1
                        logger.log('\t\tBad epoch %d' % bad_epochs)

                    if bad_epochs > self.hparams.num_bad_epochs:
                        break

        except KeyboardInterrupt:
            logger.log('-' * 89)
            logger.log('Exiting from training early')

        return best_state_dict, best_val_perf

    def evaluate(self, database_loader, val_loader, topK, device):
        self.eval()
        with torch.no_grad():
            retrievalB, retrievalL, queryB, queryL = compress(database_loader, val_loader, self.encode_discrete, device)
            result = calculate_top_map(qB=queryB, rB=retrievalB, queryL=queryL, retrievalL=retrievalL, topk=topK)
        self.train()
        return result

    def load(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        checkpoint = torch.load(self.hparams.model_path) if self.hparams.cuda \
                     else torch.load(self.hparams.model_path,
                                     map_location=torch.device('cpu'))
        if checkpoint['hparams'].cuda and not self.hparams.cuda:
            checkpoint['hparams'].cuda = False
        self.hparams = checkpoint['hparams']
        self.define_parameters()
        self.load_state_dict(checkpoint['state_dict'])
        self.to(device)

    def run_test(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        _, val_loader, test_loader, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
        test_perf = self.evaluate(database_loader, test_loader, self.data.topK, device)
        return val_perf, test_perf

    def run_retrieval_case_study(self):
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        query_idxs = [0,2,5]
        X_database = self.data.X_database
        X_test = self.data.X_test
        X_case = torch.cat([self.data.test_cifar10_transforms(Image.fromarray(self.data.X_test[i])).unsqueeze(0) for i in query_idxs], dim=0)
        _, val_loader, test_loader, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        # get hash codes
        self.eval()
        with torch.no_grad():
            retrievalB = list([])
            for batch_step, (data, target) in enumerate(database_loader):
                var_data = Variable(data.to(device))
                code = self.encode_discrete(var_data)
                retrievalB.extend(code.cpu().data.numpy())

            queryB = list([])
            var_data = Variable(X_case.to(device))
            code = self.encode_discrete(var_data)
            queryB.extend(code.cpu().data.numpy())

        retrievalB = np.array(retrievalB)
        queryB = np.array(queryB)

        # get top 10 index
        top10_idx_list = []
        for idx in range(queryB.shape[0]):
            hamm = calculate_hamming(queryB[idx, :], retrievalB)
            ind = list(np.argsort(hamm)[:10])
            top10_idx_list.append(ind)

        # plot results
        fig = plt.figure(0, figsize = (5,1.2))
        fig.clf()
        gs = gridspec.GridSpec(queryB.shape[0], 12)
        gs.update(wspace = 0.0001, hspace = 0.0001)
        for i in range(queryB.shape[0]):
            axes = plt.subplot(gs[i,0])
            axes.imshow(X_test[query_idxs[i]])
            axes.axis('off')

            for j in range(0, 10):
                axes = plt.subplot(gs[i, j+2])
                axes.imshow(X_database[top10_idx_list[i][j]])
                axes.axis('off')
        fig.savefig("retrieval_case_study_{:d}bits.pdf".format(self.hparams.encode_length), bbox_inches='tight', pad_inches=0.0)

    def hash_code_visualization(self):
        """
        cifar10 labels:
        0: Airplane 1: Automobile 2: Bird 3: Cat 4: Deer
        5: Dog 6: Frog 7: Horse 8: Ship 9: Truck
        """
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        _, _, test_loader, _ = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=False, get_test=True)

        retrievalB = list([])
        retrievalL = list([])
        for batch_step, (data, target) in enumerate(test_loader):
            var_data = Variable(data.to(device))
            code = self.encode_discrete(var_data)
            retrievalB.extend(code.cpu().data.numpy())
            retrievalL.extend(target.cpu().data.numpy())

        hash_codes = np.array(retrievalB)
        _, labels = np.where(np.array(retrievalL) == 1)
        labels_ticks = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

        # TSN
        mapper = TSNE(perplexity=30).fit_transform(hash_codes)

        plt.figure(figsize=(8, 8))
        plt.scatter(mapper[:,0], mapper[:,1], lw=0, s=20, c=labels.astype(np.int), cmap='Spectral')
        # cbar = plt.colorbar(boundaries=np.arange(11)-0.5, fraction=0.046, pad=0.04)
        # cbar.set_ticks(np.arange(10))
        # cbar.set_ticklabels(labels_ticks)

        # Add the labels for each digit.
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(mapper[labels == i, :], axis=0)
            txt = plt.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([pe.Stroke(linewidth=5, foreground="w"), pe.Normal()])

        plt.axis("off")
        plt.gcf().tight_layout()
        plt.savefig('Ours_hash_codes_visulization_{:d}bits.pdf'.format(self.hparams.encode_length), bbox_inches='tight', pad_inches=0.0)

    def flag_hparams(self):
        flags = '%s' % (self.hparams.model_path)
        for hparam in vars(self.hparams):
            val = getattr(self.hparams, hparam)
            if str(val) == 'False':
                continue
            elif str(val) == 'True':
                flags += ' --%s' % (hparam)
            elif str(hparam) in {'model_path', 'num_runs',
                                 'num_workers'}:
                continue
            else:
                flags += ' --%s %s' % (hparam, val)
        return flags

    @staticmethod
    def get_general_hparams_grid():
        grid = OrderedDict({
            'seed': list(range(100000)),
            'lr': [0.003, 0.001, 0.0003, 0.0001],
            'batch_size': [64, 128, 256],
            })
        return grid

    @staticmethod
    def get_general_argparser():
        parser = argparse.ArgumentParser()

        parser.add_argument('model_path', type=str)
        parser.add_argument('--train', action='store_true',
                            help='train a model?')
        parser.add_argument('-d', '--dataset', default = 'cifar10', type=str,
                            help='dataset [%(default)s]')
        parser.add_argument("-l","--encode_length", type = int, default=16,
                            help = "Number of bits of the hash code [%(default)d]")
        parser.add_argument("--lr", default = 1e-3, type = float,
                            help='initial learning rate [%(default)g]')
        parser.add_argument("--batch_size", default=64,type=int,
                            help='batch size [%(default)d]')
        parser.add_argument("-e","--epochs", default=60, type=int,
                            help='max number of epochs [%(default)d]')
        parser.add_argument('--cuda', action='store_true',
                            help='use CUDA?')
        parser.add_argument('--num_runs', type=int, default=1,
                            help='num random runs (not random if 1) '
                            '[%(default)d]')
        parser.add_argument('--num_bad_epochs', type=int, default=6,
                            help='num indulged bad epochs [%(default)d]')
        parser.add_argument('--validate_frequency', type=int, default=20,
                            help='validate every [%(default)d] epochs')
        parser.add_argument('--num_workers', type=int, default=8,
                            help='num dataloader workers [%(default)d]')
        parser.add_argument('--seed', type=int, default=8888,
                            help='random seed [%(default)d]')
        parser.add_argument('--device', type=int, default=0,
                            help='device of the gpu')


        return parser


In [12]:
import torch
import argparse
import torchvision
import torch.nn as nn
from torch.autograd import Function

# from model.base_model import Base_Model

class CIBHash(Base_Model):
    def __init__(self, hparams):
        super().__init__(hparams=hparams)

    def define_parameters(self):
        self.vgg = torchvision.models.vgg16(pretrained=True)
        self.vgg.classifier = nn.Sequential(*list(self.vgg.classifier.children())[:6])
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.encoder = nn.Sequential(nn.Linear(4096, 1024),
                                       nn.ReLU(),
                                       nn.Linear(1024, self.hparams.encode_length),
                                      )

        self.criterion = NtXentLoss(self.hparams.batch_size, self.hparams.temperature)

    def forward(self, imgi, imgj, device):
        imgi = self.vgg.features(imgi)
        imgi = imgi.view(imgi.size(0), -1)
        imgi = self.vgg.classifier(imgi)
        prob_i = torch.sigmoid(self.encoder(imgi))
        z_i = hash_layer(prob_i - 0.5)

        imgj = self.vgg.features(imgj)
        imgj = imgj.view(imgj.size(0), -1)
        imgj = self.vgg.classifier(imgj)
        prob_j = torch.sigmoid(self.encoder(imgj))
        z_j = hash_layer(prob_j - 0.5)

        kl_loss = (self.compute_kl(prob_i, prob_j) + self.compute_kl(prob_j, prob_i)) / 2
        contra_loss = self.criterion(z_i, z_j, device)
        loss = contra_loss + self.hparams.weight * kl_loss

        return {'loss': loss, 'contra_loss': contra_loss, 'kl_loss': kl_loss}

    def encode_discrete(self, x):
        x = self.vgg.features(x)
        x = x.view(x.size(0), -1)
        x = self.vgg.classifier(x)

        prob = torch.sigmoid(self.encoder(x))
        z = hash_layer(prob - 0.5)

        return z

    def compute_kl(self, prob, prob_v):
        prob_v = prob_v.detach()
        # prob = prob.detach()

        kl = prob * (torch.log(prob + 1e-8) - torch.log(prob_v + 1e-8)) + (1 - prob) * (torch.log(1 - prob + 1e-8 ) - torch.log(1 - prob_v + 1e-8))
        kl = torch.mean(torch.sum(kl, axis = 1))
        return kl

    def configure_optimizers(self):
        return torch.optim.Adam([{'params': self.encoder.parameters()}], lr = self.hparams.lr)

    def get_hparams_grid(self):
        grid = Base_Model.get_general_hparams_grid()
        grid.update({
            'temperature': [0.2, 0.3, 0.4],
            'weight': [0.001, 0.005, 0.0005, 0.0001, 0.00005, 0.00001]
            })
        return grid

    @staticmethod
    def get_model_specific_argparser():
        parser = Base_Model.get_general_argparser()

        parser.add_argument("-t", "--temperature", default = 0.3, type = float,
                            help = "Temperature [%(default)d]",)
        parser.add_argument('-w',"--weight", default = 0.001, type=float,
                            help='weight of I(x,z) [%(default)f]')
        return parser


class hash(Function):
    @staticmethod
    def forward(ctx, input):
        # ctx.save_for_backward(input)
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        # input,  = ctx.saved_tensors
        # grad_output = grad_output.data

        return grad_output

def hash_layer(input):
    return hash.apply(input)

class NtXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NtXentLoss, self).__init__()
        #self.batch_size = batch_size
        self.temperature = temperature
        #self.device = device

        #self.mask = self.mask_correlated_samples(batch_size)
        self.similarityF = nn.CosineSimilarity(dim = 2)
        self.criterion = nn.CrossEntropyLoss(reduction = 'sum')


    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask


    def forward(self, z_i, z_j, device):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
        """
        batch_size = z_i.shape[0]
        N = 2 * batch_size

        z = torch.cat((z_i, z_j), dim=0)

        sim = self.similarityF(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
        #sim = 0.5 * (z_i.shape[1] - torch.tensordot(z.unsqueeze(1), z.T.unsqueeze(0), dims = 2)) / z_i.shape[1] / self.temperature

        sim_i_j = torch.diag(sim, batch_size )
        sim_j_i = torch.diag(sim, -batch_size )

        mask = self.mask_correlated_samples(batch_size)
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).view(N, 1)
        negative_samples = sim[mask].view(N, -1)

        labels = torch.zeros(N).to(device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss

In [None]:
args_list = ['cifar10', '--train', '--dataset', 'cifar10', '--encode_length', '16', '--cuda', '--device', '0', '--epochs', '200']

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('model_path', type=str, help='Path to the model or dataset name')
    parser.add_argument('--train', action='store_true', help='Train the model')
    parser.add_argument('--dataset', type=str, help='Name of the dataset', default='cifar10')
    parser.add_argument('--encode_length', type=int, help='Length of encoding', default=16)
    parser.add_argument('--cuda', action='store_true', help='Use CUDA for computation')
    parser.add_argument('--lr', type=float, help='Learning rate')
    parser.add_argument('--batch_size', type=int, help='Batch size')
    parser.add_argument('--epochs', type=int, help='Number of epochs')
    parser.add_argument('--num_runs', type=int, help='Number of runs')
    parser.add_argument('--num_bad_epochs', type=int, help='Number of bad epochs for early stopping')
    parser.add_argument('--validate_frequency', type=int, help='Frequency of validation')
    parser.add_argument('--num_workers', type=int, help='Number of workers')
    parser.add_argument('--seed', type=int, help='Random seed')
    parser.add_argument('--device', type=str, help='Device to use', default='0')
    parser.add_argument('--temperature', type=float, help='Temperature for sampling')
    parser.add_argument('--weight', type=float, help='Weight parameter')

    args = parser.parse_args(args_list)

    # Now you can use args to access the command-line arguments
    print(f"Model Path: {args.model_path}")
    if args.train:
        print("Training mode activated")
    print(f"Dataset: {args.dataset}")
    print(f"Encode Length: {args.encode_length}")
    if args.cuda:
        print("Using CUDA")

    # 用相同的args_list重新解析参数
    argparser = CIBHash.get_model_specific_argparser()
    hparams = argparser.parse_args(args_list)

    # 设置CUDA设备
    if hparams.cuda:
        torch.cuda.set_device(hparams.device)

    # 初始化模型
    model = CIBHash(hparams)

    # 根据参数决定是否训练模型
    if hparams.train:
        model.run_training_sessions()
    else:
        model.load()
        print('Loaded model with: %s' % model.flag_hparams())


if __name__ == '__main__':
    main()

Model Path: cifar10
Training mode activated
Dataset: cifar10
Encode Length: 16
Using CUDA
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 90476740.93it/s] 


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10/


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 175MB/s]


hparams: cifar10 --train --dataset cifar10 --encode_length 16 --lr 0.001 --batch_size 64 --epochs 200 --cuda --num_bad_epochs 6 --validate_frequency 20 --seed 8888 --device 0 --temperature 0.3 --weight 0.001




End of epoch   1 | loss   3.6298  | contra_loss   3.6266  | kl_loss   3.2314
End of epoch   2 | loss   3.4592  | contra_loss   3.4554  | kl_loss   3.7798
End of epoch   3 | loss   3.4102  | contra_loss   3.4067  | kl_loss   3.5199
End of epoch   4 | loss   3.3783  | contra_loss   3.3747  | kl_loss   3.6317
End of epoch   5 | loss   3.3549  | contra_loss   3.3515  | kl_loss   3.4163
End of epoch   6 | loss   3.3403  | contra_loss   3.3369  | kl_loss   3.4887
End of epoch   7 | loss   3.3181  | contra_loss   3.3146  | kl_loss   3.4518
End of epoch   8 | loss   3.2981  | contra_loss   3.2947  | kl_loss   3.4474
End of epoch   9 | loss   3.2762  | contra_loss   3.2728  | kl_loss   3.4128
End of epoch  10 | loss   3.3014  | contra_loss   3.2979  | kl_loss   3.5601
End of epoch  11 | loss   3.2925  | contra_loss   3.2891  | kl_loss   3.4637
End of epoch  12 | loss   3.2551  | contra_loss   3.2518  | kl_loss   3.3701
End of epoch  13 | loss   3.2611  | contra_loss   3.2578  | kl_loss   3.3550