In [None]:

import torch

from at2vec import ctx, sp2id, EncoderDecoder
# 配置
if __name__ == '__main__':
    torch.autograd.set_grad_enabled(False)
    
    # 训练集参数
    ctx.min_x = -35.00002
    ctx.max_x = 44.999763
    ctx.min_y = 110.00003
    ctx.max_y = 119.999954
    ctx.min_ts = 1200083742000
    ctx.max_ts = 1249975176000
    ctx.num_ts_grids = (ctx.max_ts - ctx.min_ts) // ctx.ts_gap + 1
    ctx.x_gap, ctx.y_gap = ((ctx.max_x - ctx.min_x) / ctx.num_x_grids,
                            (ctx.max_y - ctx.min_y) / ctx.num_y_grids)
    ctx.num_sp_grids = sp2id(ctx.max_x, ctx.max_y,
                             ctx.min_x, ctx.min_y,
                             ctx.max_x, ctx.max_y,
                             ctx.x_gap, ctx.y_gap)
    
    ctx.test_tr_path = 'data/geolife-speed-4'
    ctx.query_tr_path = 'data/geolife-speed-4-queries'
    ctx.keywords_path = 'data/keywords.txt'
    ctx.logging_path = 'data/test-time.log'
    ctx.test_batch_size = 128

In [None]:
import pandas as pd


class ChunkedTestDataset(torch.utils.data.IterableDataset):
    def __init__(self, path: str, tr_len: int, raw2tr):
        self.path = path
        self.tr_len = tr_len
        self.raw2tr = raw2tr
        
    def __iter__(self):
        with pd.read_csv(self.path, sep="\t",
                         chunksize=ctx.test_batch_size * ctx.complete_tr_len, 
                         usecols=[0, 1, 2, 3, 4], header=None) as reader:
            for data in reader:
                raws = []
                indexes = []
                length = data.shape[0]
                for i in range(0, length, self.tr_len):
                    raws.append(data.iloc[i:i + self.tr_len])
                    indexes.append(data.iloc[i, 0])
                with Pool() as p:
                    vectors = p.map(self.raw2tr, raws)
                for index, vector in zip(indexes, vectors):
                    yield index, vector

In [None]:
import torch
from torch.multiprocessing import Pool, set_start_method

# if __name__ == '__main__':
#     set_start_method('spawn')

def f(data):
    return raw2tr(data[1]).cpu()

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, bare_dataset, raw2tr):
        print(len(bare_dataset))
        count = 0
        for _ in iter(bare_dataset):
            count += 1
        print(count)
        with Pool() as p:
            self.vectors = p.map(f, bare_dataset)
    
    def __len__(self):
        return len(self.vectors)
    
    def __getitem__(self, index):
        """
        Returns:
            (index, tr)
        """
        tr = self.vectors[index]
        return index, tr

In [None]:

from gensim.models import Word2Vec
from functools import partial
from at2vec import get_mat, PretrainModel, BareDataset, TrajectoryDataset, EncoderDecoder
import torch


if __name__ == '__main__':
    # 准备模型与数据
    sp_model = PretrainModel(ctx.num_sp_grids, ctx.sp_len, torch.device('cpu'))
    sp_model.load_state_dict(torch.load(ctx.sp_pretrain_model_path)['model'])
    ts_model = PretrainModel(ctx.num_ts_grids, ctx.ts_len, torch.device('cpu'))
    ts_model.load_state_dict(torch.load(ctx.ts_pretrain_model_path)['model'])
    sm_model = Word2Vec.load(ctx.sm_pretrain_model_path)
    
    model = EncoderDecoder(ctx.sampled_tr_len, ctx.complete_tr_len, ctx.pt_len, ctx.hidden_len,
                           ctx.num_sp_grids, ctx.num_ts_grids, len(sm_model.wv), ctx.device)
    state = torch.load(ctx.at2vec_model_path)
    model.load_state_dict(state['model'])
    
    raw2tr = partial(get_mat, sp_model=sp_model, ts_model=ts_model, sm_model=sm_model)

In [None]:
def get_memory_usage(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    return param_size  # in bytes

print(get_memory_usage(sp_model), get_memory_usage(ts_model), sm_model.estimate_memory())
print(get_memory_usage(model))

In [None]:
if __name__ == '__main__':
    torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
def get_accuracy(index: int, results):
    hit_count = 0
    min_accepted_idx = index // 50 * 50
    max_accepted_idx = min_accepted_idx + 49
    for _, idx in results:
        if min_accepted_idx <= idx <= max_accepted_idx:
            hit_count += 1
    return hit_count / len(results)

In [None]:
import logging
import random
import heapq
from datetime import datetime
import time
import torch
from tqdm import tqdm
import math

if __name__ == '__main__':
    test_dataset = ChunkedTestDataset(ctx.test_tr_path, ctx.complete_tr_len, raw2tr)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=ctx.test_batch_size)
    
    num_queries = 8
    # dataset_range = [4096] * num_queries
    dataset_range = [1876550, 1876550, 1876550, 1876550, 469150, 938300, 1407450, 1876550]
    max_dataset_range = max(dataset_range)
    ks = [1, 5, 10, 20, 50, 50, 50, 50]
    
    query_dataset = ChunkedTestDataset(ctx.query_tr_path, ctx.complete_tr_len, raw2tr)
    
    logging.basicConfig(filename=ctx.logging_path, format='%(message)s', level=logging.INFO)
    logging.info(str(datetime.now()))

    chosen_indexes = []
    queues = [[]] * num_queries
    chosen_vecs = []

    query_it = iter(query_dataset)
    for test_num in range(num_queries):
        chosen_idx, chosen_tr = next(query_it)
        chosen_indexes.append(int(chosen_idx))
        chosen_tr = chosen_tr.to(ctx.device)
        # chosen_vec: (1, hidden_len)
        chosen_vec = model.get_rep_vector(chosen_tr).unsqueeze(0)
        chosen_vecs.append(chosen_vec)
        
    time_move = 0.0
    time_search = [0.0] * num_queries  # seconds
    
    batch_count = math.ceil(max_dataset_range/ctx.test_batch_size)
    start_time = datetime.now()
    for i, (indexes, trs) in enumerate(test_dataloader):
        min_index = int(torch.min(indexes))
        if min_index >= max_dataset_range:
            break
            
        start = time.time()
        trs = trs.to(ctx.device)
        end = time.time()
        time_move += end - start
        # print(f'batch {i}')
        
        for j, chosen_vec in enumerate(chosen_vecs):
            if (min_index >= dataset_range[j]):
                continue
            k = ks[j]
            # print(f'query {j}')
            start = time.time()
            # indexes: (batch_size)
            # trs: (batch_size, tr_len, pt_len)
            # vecs: (batch_size, hidden_len)
            vecs = model.get_rep_vector(trs)
            if len(vecs.shape) == 1:
                dists = torch.dist(chosen_vec.squeeze(), vecs.squeeze()).unsqueeze(0)
            else:
                dists = torch.cdist(chosen_vec, vecs).squeeze()
            for x in range(indexes.shape[0]):
                heapq.heappush(queues[j], (-dists[x].item(), indexes[x].item()))
                if len(queues[j]) > k:
                    heapq.heappop(queues[j])       
            end = time.time()
            time_search[j] += end - start
        elapsed_time = datetime.now() - start_time
        remaining_time = elapsed_time / ((i + 1) * ctx.test_batch_size) * (max_dataset_range - (i + 1) * ctx.test_batch_size)
        print(f'Batch {i+1}/{batch_count}, {str(elapsed_time).split(".")[0]} < {str(remaining_time).split(".")[0]}')
            
    print('Results:')
    for j, queue in enumerate(queues):
        print(f'Query #{chosen_indexes[j]}')
        results = sorted([(-k, v) for (k, v) in queue])
        print(f'accuracy: {get_accuracy(chosen_indexes[j], results)}')
    print(f'time move: {time_move}')
    print(f'time search: {time_search}')