In [1]:
import os, json
import faiss
import numpy as np
import torch
from torch.utils.data import Dataset
from path_utils import *
from mmap_dataset import MmapDataset
from knn.pq_wrapper import TorchPQCodec
from dataset import KNNDataset

In [2]:
ds = KNNDataset(np.zeros((1000, 3000)), [3, 2], "data", "test", k=128)

In [4]:
ds.neighbor_offsets.shape

(5, 128)

In [3]:
ds[0]

IndexError: index 5 is out of bounds for axis 0 with size 5

In [25]:
feats = ds.quantize_neighbor_feat[:2]
ds.tgt_quantizer.decode(torch.LongTensor(feats)).shape

torch.Size([2, 768])

In [3]:
class KNNDataset(Dataset):

    def __init__(self, subtokens, sizes, data_path, split, k):
        """
        Build a LM dataset with KNN info.

        Args:
            subtokens: a list of subtoken ids
            sizes: list of subtokens length
            data_path: path to the data directory
            split: train/valid/test
            k: number of neighbors
        """
        self.dtype = np.float16
        self.invalid_context = 512
        self.k = k
        self.subtokens = subtokens
        self.sizes = np.array(sizes)
        self.cum_sizes = np.cumsum(self.sizes)
        self.cum_sizes = np.insert(self.cum_sizes, 0, 0)
        self.num_tokens = sum(self.sizes)
        self.neighbor_info = json.load(open(os.path.join(dstore_path(data_path, "train"), "info.json")))
        self.info = json.load(open(os.path.join(dstore_path(data_path, split), "info.json")))
        self.train_tokens = self.neighbor_info["dstore_size"]
        self.quantize_neighbor_feat = np.load(quantized_feature_path(data_path, "train"))
        self.neighbor_offsets = MmapDataset(neighbor_path(data_dir=data_path, mode=split, k=self.k),
                                    dtype=np.int64, shape=(self.num_tokens, self.k), warmup=False)
        self.neighbor_tokens = MmapDataset(value_path(data_dir=data_path, mode="train"),
                                    dtype=np.int64,
                                    shape=(self.train_tokens, 2), warmup=False)
        quantiz_path = quantizer_path(data_dir=data_path, suffix="", norm=False)  
        self.tgt_quantizer = TorchPQCodec(index=faiss.read_index(quantiz_path))

        self.init_block()
        self.init_special_tokens()

    def init_block(self, block_size=512):
        """
        Set file index, file offset and sample length.
        Left to zhz.
        """
        self.file_index = np.arange(len(self.sizes))
        self.file_offset = np.zeros(len(self.sizes), dtype=np.int64)
        self.sample_length = np.ones(len(self.sizes), dtype=np.int64) * 512
        pass

    def init_special_tokens(self):
        self.pad = 0
        self.unk = 1
        self.eos = 2

    def double_idx_to_one(self, file_idx, offset):
        """
        Convert a double index to a single index.
        """
        return self.cum_sizes[file_idx] + offset

    def __len__(self):
        """
        Batch numbers.
        """
        return len(self.file_index)

In [41]:
val_file = "data/train_dstore/vals.npy"
dstore_size = 147600000
vals = np.memmap(val_file,
                dtype=np.int32,
                mode="r",
                shape=(dstore_size, 2))

In [42]:
vals[11000]

memmap([19960,   710], dtype=int32)