In [None]:
import click
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST
from tensorboardX import SummaryWriter
import uuid

from ptdec.dec import DEC
# from ptdec.model import train, predict
from ptsdae.sdae import StackedDenoisingAutoEncoder
import ptsdae.model as ae
from ptdec.utils import cluster_accuracy

from sklearn.cluster import KMeans

In [None]:
def mmap_fvecs(fname):
    x = np.memmap(fname, dtype='int32', mode='r')
    d = x[0]
    return x.view('float32').reshape(-1, d + 1)[:, 1:]

def mmap_bvecs(fname):
    x = np.memmap(fname, dtype='uint8', mode='r')
    d = x[:4].view('int32')[0]
    return x.reshape(-1, d + 4)[:, 4:]

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    # Wenqi: Format of ground truth (for 10000 query vectors):
    #   1000(topK), [1000 ids]
    #   1000(topK), [1000 ids]
    #        ...     ...
    #   1000(topK), [1000 ids]
    # 10000 rows in total, 10000 * 1001 elements, 10000 * 1001 * 4 bytes
    return a.reshape(-1, d + 1)[:, 1:].copy()

def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

In [None]:
dbname = 'SIFT1M'
num_vec_learn = int(1e4)

if dbname.startswith('SIFT'):
    # SIFT1M to SIFT1000M
    dbsize = int(dbname[4:-1])
    xb = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_base.bvecs')
    xq = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_query.bvecs')
    gt = ivecs_read('/mnt/scratch/wenqi/Faiss_experiments/bigann/gnd/idx_%dM.ivecs' % dbsize)

    N_VEC = int(dbsize * 1000 * 1000)

    # trim xb to correct size
    xb = xb[:dbsize * 1000 * 1000]

    # Wenqi: load xq to main memory and reshape
    xq = xq.astype('float32').copy()
#     xq = np.array(xq, dtype=np.float32)
    xb = xb.astype('float32').copy()
    gt = np.array(gt, dtype=np.int32)

    print("Vector shapes:")
    print("Base vector xb: ", xb.shape)
    print("Query vector xq: ", xq.shape)
    print("Ground truth gt: ", gt.shape)
else:
    print('unknown dataset', dbname, file=sys.stderr)
    sys.exit(1)

dim = xb.shape[1] # should be 128
nq = xq.shape[0]

# Normalize all to 0~1
xb = xb / 256
xq = xq / 256
xt = xb[:num_vec_learn]

In [None]:
writer = SummaryWriter()  # create the TensorBoard object
# callback function to call during training, uses writer from the scope

def training_callback(epoch, lr, loss, validation_loss):
    writer.add_scalars(
        "data/autoencoder",
        {"lr": lr, "loss": loss, "validation_loss": validation_loss,},
        epoch,
    )

In [None]:
cuda=False
batch_size=256
pretrain_epochs=100
finetune_epochs=100
testing_mode=False # whether to run in testing mode (default False).

In [None]:
# DNN_arch = [128, 100, 100, 64, 32]
DNN_arch = [128, 500, 500, 2000, 32]

autoencoder = StackedDenoisingAutoEncoder(
    DNN_arch, final_activation=None
)
if cuda:
    autoencoder.cuda()

In [None]:
# from torch.utils.data import Dataset
# https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset
# A blog about torch dataset: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel

class SIFTDataset(Dataset):
    "Characterizes a dataset for PyTorch"
    def __init__(self, vectors):
        'Initialization'
        # vectors: 2D vectors dim0: num_vec dim1: vec_dim
        self.vectors = vectors 

    def __len__(self):
        'Denotes the total number of samples'
        return self.vectors.shape[0]

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        return self.vectors[index]

In [None]:
# data = torch.Tensor((1000, 128))
ds_train = SIFTDataset(
    torch.Tensor(xt))
ds_val = SIFTDataset(
    torch.Tensor(xb[num_vec_learn:2*num_vec_learn]))

In [None]:
# AE: takes a torch dataset as input
#   SDAE model: https://github.com/vlukiyanov/pt-sdae/blob/master/ptsdae/model.py
#   Torch dataset manual: https://pytorch.org/docs/stable/data.html
print("Pretraining stage.")
ae.pretrain(
    ds_train,
    autoencoder,
    cuda=cuda,
    validation=ds_val,
    epochs=pretrain_epochs,
    batch_size=batch_size,
    optimizer=lambda model: SGD(model.parameters(), lr=0.1, momentum=0.9),
    scheduler=lambda x: StepLR(x, 100, gamma=0.1),
    corruption=0.2,
)

In [None]:
print("Training stage.")
ae_optimizer = SGD(params=autoencoder.parameters(), lr=0.1, momentum=0.9)
ae.train(
    ds_train,
    autoencoder,
    cuda=cuda,
    validation=ds_val,
    epochs=finetune_epochs,
    batch_size=batch_size,
    optimizer=ae_optimizer,
    scheduler=StepLR(ae_optimizer, 100, gamma=0.1),
    corruption=0.2,
    update_callback=training_callback,
)

In [None]:
# Save the SAE model
import pickle
import os

def save_obj(obj, dirc, name):
    # note use "dir/" in dirc
    with open(os.path.join(dirc, name + '.pkl'), 'wb') as f:
        pickle.dump(obj, f, protocol=4) # for py37,pickle.HIGHEST_PROTOCOL=4

def load_obj(dirc, name):
    with open(os.path.join(dirc, name + '.pkl'), 'rb') as f:
        return pickle.load(f)

In [None]:
fdir = './models/'
if not os.path.exists(fdir): os.mkdir(fdir)
DNN_arch_name = ''
for i in DNN_arch: DNN_arch_name += '{}_'.format(i)
file_name = 'SAE_' + DNN_arch_name + 'epoch_{}_{}'.format(pretrain_epochs, finetune_epochs)
print('file_name', file_name)

save_obj(model, fdir, file_name)

In [None]:
from torch.utils.data.dataloader import DataLoader, default_collate
from typing import Tuple, Callable, Optional, Union
from tqdm import tqdm
from ptdec.utils import target_distribution, cluster_accuracy

def train(
    dataset: torch.utils.data.Dataset,
    model: torch.nn.Module,
    epochs: int,
    batch_size: int,
    optimizer: torch.optim.Optimizer,
    stopping_delta: Optional[float] = None,
    collate_fn=default_collate,
    cuda: bool = True,
    sampler: Optional[torch.utils.data.sampler.Sampler] = None,
    silent: bool = False,
    update_freq: int = 10,
    evaluate_batch_size: int = 1024,
    update_callback: Optional[Callable[[float, float], None]] = None,
    epoch_callback: Optional[Callable[[int, torch.nn.Module], None]] = None,
) -> None:
    """
    Train the DEC model given a dataset, a model instance and various configuration parameters.
    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to train
    :param epochs: number of training epochs
    :param batch_size: size of the batch to train with
    :param optimizer: instance of optimizer to use
    :param stopping_delta: label delta as a proportion to use for stopping, None to disable, default None
    :param collate_fn: function to merge a list of samples into mini-batch
    :param cuda: whether to use CUDA, defaults to True
    :param sampler: optional sampler to use in the DataLoader, defaults to None
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param update_freq: frequency of batches with which to update counter, None disables, default 10
    :param evaluate_batch_size: batch size for evaluation stage, default 1024
    :param update_callback: optional function of accuracy and loss to update, default None
    :param epoch_callback: optional function of epoch and model, default None
    :return: None
    """
    static_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        pin_memory=False,
        sampler=sampler,
        shuffle=False,
    )
    train_dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        sampler=sampler,
        shuffle=True,
    )
    data_iterator = tqdm(
        static_dataloader,
        leave=True,
        unit="batch",
        postfix={
            "epo": -1,
#             "acc": "%.4f" % 0.0,
            "lss": "%.8f" % 0.0,
            "dlb": "%.4f" % -1,
        },
        disable=silent,
    )
    kmeans = KMeans(n_clusters=model.cluster_number, n_init=20)
    model.train()
    features = []
    actual = []
    # form initial cluster centres
    for index, batch in enumerate(data_iterator):
        if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
            batch, value = batch  # if we have a prediction label, separate it to actual
            actual.append(value)
        if cuda:
            batch = batch.cuda(non_blocking=True)
        features.append(model.encoder(batch).detach().cpu())
    predicted = kmeans.fit_predict(torch.cat(features).numpy())
    predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
#     if actual: 
#         actual = torch.cat(actual).long()
#         _, accuracy = cluster_accuracy(predicted, actual.cpu().numpy())
    cluster_centers = torch.tensor(
        kmeans.cluster_centers_, dtype=torch.float, requires_grad=True
    )
    if cuda:
        cluster_centers = cluster_centers.cuda(non_blocking=True)
    with torch.no_grad():
        # initialise the cluster centers
        model.state_dict()["assignment.cluster_centers"].copy_(cluster_centers)
    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    for epoch in range(epochs):
        features = []
        data_iterator = tqdm(
            train_dataloader,
            leave=True,
            unit="batch",
            postfix={
                "epo": epoch,
#                 "acc": "%.4f" % (accuracy or 0.0),
                "lss": "%.8f" % 0.0,
                "dlb": "%.4f" % (delta_label or 0.0),
            },
            disable=silent,
        )
        model.train()
        for index, batch in enumerate(data_iterator):
            if (isinstance(batch, tuple) or isinstance(batch, list)) and len(
                batch
            ) == 2:
                batch, _ = batch  # if we have a prediction label, strip it away
            if cuda:
                batch = batch.cuda(non_blocking=True)
            output = model(batch)
            target = target_distribution(output).detach()
            loss = loss_function(output.log(), target) / output.shape[0]
#             print('output.log()', output.log())
#             print('target', target)
#             print('loss_function(output.log(), target)', loss_function(output.log(), target))
#             print('loss', loss)
            data_iterator.set_postfix(
                epo=epoch,
#                 acc="%.4f" % (accuracy or 0.0),
                lss="%.8f" % float(loss.item()),
                dlb="%.4f" % (delta_label or 0.0),
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step(closure=None)
            features.append(model.encoder(batch).detach().cpu())
            if update_freq is not None and index % update_freq == 0:
                loss_value = float(loss.item())
                data_iterator.set_postfix(
                    epo=epoch,
#                     acc="%.4f" % (accuracy or 0.0),
                    lss="%.8f" % loss_value,
                    dlb="%.4f" % (delta_label or 0.0),
                )
#                 if update_callback is not None:
#                     update_callback(accuracy, loss_value, delta_label)
        predicted = predict(
            dataset,
            model,
            batch_size=evaluate_batch_size,
            collate_fn=collate_fn,
            silent=True,
            return_actual=False,
            cuda=cuda,
        )
#         print('predicted', predicted)

#         predicted, actual = predict(
#             dataset,
#             model,
#             batch_size=evaluate_batch_size,
#             collate_fn=collate_fn,
#             silent=True,
#             return_actual=True,
#             cuda=cuda,
#         )
        delta_label = (
            float((predicted != predicted_previous).float().sum().item())
            / predicted_previous.shape[0]
        )
        if stopping_delta is not None and delta_label < stopping_delta:
            print(
                'Early stopping as label delta "%1.5f" less than "%1.5f".'
                % (delta_label, stopping_delta)
            )
            break
        predicted_previous = predicted
#         _, accuracy = cluster_accuracy(predicted.cpu().numpy(), actual.cpu().numpy())
        data_iterator.set_postfix(
            epo=epoch,
#             acc="%.4f" % (accuracy or 0.0),
            lss="%.8f" % 0.0,
            dlb="%.4f" % (delta_label or 0.0),
        )
        if epoch_callback is not None:
            epoch_callback(epoch, model)


def predict(
    dataset: torch.utils.data.Dataset,
    model: torch.nn.Module,
    batch_size: int = 1024,
    collate_fn=default_collate,
    cuda: bool = True,
    silent: bool = False,
    return_actual: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
    """
    Predict clusters for a dataset given a DEC model instance and various configuration parameters.
    :param dataset: instance of Dataset to use for training
    :param model: instance of DEC model to predict
    :param batch_size: size of the batch to predict with, default 1024
    :param collate_fn: function to merge a list of samples into mini-batch
    :param cuda: whether CUDA is used, defaults to True
    :param silent: set to True to prevent printing out summary statistics, defaults to False
    :param return_actual: return actual values, if present in the Dataset
    :return: tuple of prediction and actual if return_actual is True otherwise prediction
    """
    dataloader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False
    )
    data_iterator = tqdm(dataloader, leave=True, unit="batch", disable=silent,)
    features = []
    actual = []
    model.eval()
    for batch in data_iterator:
        if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
            batch, value = batch  # unpack if we have a prediction label
            if return_actual:
                actual.append(value)
        elif return_actual:
            raise ValueError(
                "Dataset has no actual value to unpack, but return_actual is set."
            )
        if cuda:
            batch = batch.cuda(non_blocking=True)
        features.append(
            model(batch).detach().cpu()
        )  # move to the CPU to prevent out of memory on the GPU
    if return_actual:
        return torch.cat(features).max(1)[1], torch.cat(actual).long()
    else:
        return torch.cat(features).max(1)[1]

In [None]:
print("Note: should pretrain & train the models for enough epochs (e.g., 100), " \
      "otherwise this model can predict all vectors to a same centroid which lead to zero gradients")

print("DEC stage.")
model = DEC(cluster_number=32, hidden_dimension=32, encoder=autoencoder.encoder)
if cuda:
    model.cuda()
    
    
# Learning rate: somehow the MNIST dataset has values from 0~5.x; we only have 0~1
# MNIST learning rate 0.01
# Our, set to 0.001 and try? 
# dec_optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
dec_optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data)

In [None]:
train(
    dataset=ds_train,
    model=model,
    epochs=100,
    batch_size=256,
    optimizer=dec_optimizer,
    stopping_delta=0.000001,
    cuda=cuda,
)

In [None]:
ds_all = ds_train = SIFTDataset(
    torch.Tensor(xb))

In [None]:
# Observe unbalance factor on the training set

partition_IDs_train = predict(
    ds_train, model, batch_size=1024, silent=True, return_actual=False, cuda=cuda
).cpu().numpy() 

# Create a mapping: partition ID -> {list of vector IDs}

num_partition = 32

partition_id_vec_id_list_train = dict()
for i in range(num_partition):
    partition_id_vec_id_list_train[i] = []


for i in range(num_vec_learn):
    partition_ID = int(partition_IDs[i])
    partition_id_vec_id_list_train[partition_ID].append(i)
    
for i in range(num_partition):
    print('items in partition ', i, len(partition_id_vec_id_list_train[i]), 'average =', int(num_vec_learn/num_partition))

In [None]:
predicted = predict(
    ds_all, model, batch_size=1024, silent=True, return_actual=False, cuda=cuda
)
partition_IDs = predicted.cpu().numpy() 

In [None]:
partition_IDs

In [None]:
# Create a mapping: partition ID -> {list of vector IDs}

num_partition = 32

partition_id_vec_id_list_1M = dict()
for i in range(num_partition):
    partition_id_vec_id_list_1M[i] = []


for i in range(int(1e6)):
    partition_ID = int(partition_IDs[i])
    partition_id_vec_id_list_1M[partition_ID].append(i)
    
for i in range(num_partition):
    print('items in partition ', i, len(partition_id_vec_id_list_1M[i]), 'average =', int(1e6/num_partition))

In [None]:
import heapq

def scan_partition(query_vec, partition_id_list, vector_set):
    """
    query_vec = (128, )
    partition_id_list = (N_num_vec, )
    vector_set = 1M dataset (1M, 128)
    """
    min_dist = 1e10
    min_dist_ID = None
    for vec_id in partition_id_list:
        dataset_vec = vector_set[vec_id]
        dist = np.linalg.norm(query_vec - dataset_vec)
        if dist <= min_dist:
            min_dist = dist
            min_dist_ID = vec_id
            
    return min_dist_ID

In [None]:
nearest_neighbors = []

# N = 1000
N = 100
#### Wenqi: here had a bug: previously xb, now xq
ds_xq = SIFTDataset(torch.Tensor(xq))

query_partition = predict(
    ds_xq, model, batch_size=1024, silent=True, return_actual=False, cuda=cuda
).cpu().numpy()

for i in range(N):
    partition_id = int(query_partition[i])
    nearest_neighbor_ID = scan_partition(xq[i], partition_id_vec_id_list_1M[partition_id], xb)
    nearest_neighbors.append(nearest_neighbor_ID)
    print(i, nearest_neighbor_ID)

In [None]:
correct_count = 0
for i in range(N):
    if nearest_neighbors[i] == gt[i][0]:
        correct_count += 1
        
print(correct_count, 'recall@1 = ', correct_count / N)