## 配置

In [1]:
import os
import torch


class Context:
    """配置项"""
    def __init__(self):
        self.min_x = None
        self.max_x = None
        self.min_y = None
        self.max_y = None
        self.min_ts = None
        self.max_ts = None
    
ctx = Context()

In [2]:
def init_ctx(self: Context):
    # 数据文件夹
    data_dir = 'tdrive-data'

    def get_dir(path: str):
        return os.path.join(data_dir, path)

    # 数据文件
    self.sampled_tr_path = get_dir('tdrive-r-train-ps-40')
    self.complete_tr_path = get_dir('tdrive-r-train-ps-50')
    # self.test_sampled_tr_path = get_dir('brinkhoff-stability-head')
    # self.test_complete_tr_path = get_dir('geolife-r-train-ps-50')

    # 数据特征
    self.num_x_grids = 200
    self.num_y_grids = 200
    self.ts_gap = 10 * 60 * 1000  # 10分钟  # don't forget to change this
    # self.ts_gap = 1
    self.sampled_tr_len = 40
    self.complete_tr_len = 50

    # 时空数据表示矩阵
    self.ts_rep_dict_path = get_dir('ts_rep_dict.pkl')
    self.sp_rep_dict_path = get_dir('sp_rep_dict.pkl')

    # 模型文件
    self.ts_pretrain_model_path = get_dir('ts_pretrain_model.pt')
    self.sp_pretrain_model_path = get_dir('sp_pretrain_model.pt')
    self.sm_pretrain_model_path = get_dir('semantic2vec.model')
    self.bare_dataset_path = get_dir('bare_dataset.pt')
    self.dataset_path = get_dir('dataset.pt')
    self.at2vec_model_path = get_dir('at2vec_model.pt')
    self.at2vec_rep_path = get_dir('at2vec_rep_path.pt')

    # 其他配置
    self.sp_len = 100  # 空间表示向量长度
    self.ts_len = 100  # 时间表示向量长度
    self.sm_len = 100  # 语义表示向量长度
    self.pt_len = self.sp_len + self.ts_len + self.sm_len  # 轨迹点表示向量长度
    self.hidden_len = 256  # 最终生成的轨迹点向量长度
    self.k = 10  # KNN个数
    self.batch_size = 32
    self.device = torch.device('cuda')
    
    self.alpha = 1
    self.beta = 1
    self.gamma = 1


init_ctx(ctx)

## 轨迹数据读取

In [3]:
import torch
import pandas as pd
from tqdm.notebook import tqdm

def cover_none(f, x, y):
    # x and y may be None
    if x is None and y is None:
        return None
    if x is None:
        return y
    if y is None:
        return x
    return f(x, y)


# 读取轨迹文件
class BareDataset(torch.utils.data.Dataset):
    def __init__(self, sampled_tr_path="", complete_tr_path="", ntrs=None, update_ctx=True, *, ctx=None):
        if not (sampled_tr_path or complete_tr_path):
            sampled_tr_path = ctx.sampled_tr_path
            complete_tr_path = ctx.complete_tr_path

        def read_data(path: str, nrows: int):
            if not path:
                return None, None
            data = pd.read_csv(path, sep="\t", nrows=nrows,
                               usecols=[0, 1, 2, 3, 4], header=None)
            idx = dict()  # tid -> [start_index, len]
            current_tid = -1
            print('total lines: ' + str(data.shape[0]))
            # 假设每条轨迹在文件中是连续的
            for i, point in tqdm(data.iterrows(), total=data.shape[0], disable=True):
                tid, ts, x, y, semantics = (point.iloc[i] for i in range(5))
                if current_tid != tid:
                    idx[tid] = [i, 0]
                    current_tid = tid
                idx[tid][1] += 1
                if update_ctx:
                    ctx.min_x = cover_none(min, x, ctx.min_x)
                    ctx.max_x = cover_none(max, x, ctx.max_x)
                    ctx.min_y = cover_none(min, y, ctx.min_y)
                    ctx.max_y = cover_none(max, y, ctx.max_y)
                    ctx.min_ts = cover_none(min, ts, ctx.min_ts)
                    ctx.max_ts = cover_none(max, ts, ctx.max_ts)
            return data, idx

        nrows = None if ntrs is None else ntrs * ctx.sampled_tr_len
        # 读取采样轨迹数据
        self.sampled_data, self.sampled_data_idx = read_data(sampled_tr_path, nrows=nrows)

        nrows = None if ntrs is None else ntrs * ctx.complete_tr_len
        # 读取原始轨迹数据
        self.complete_data, self.complete_data_idx = read_data(complete_tr_path, nrows=nrows)

        if self.sampled_data_idx is not None and self.complete_data_idx is not None:
            assert len(self.sampled_data_idx) == len(self.complete_data_idx)

    def __len__(self):
        if self.complete_data_idx is not None:
            return len(self.complete_data_idx)
        return len(self.sampled_data_idx)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
            
        def get_from(data, data_idx, index):
            if data is None:
                return None
            start_idx, length = data_idx[index]
            return data.iloc[start_idx:start_idx+length]

        return (get_from(self.sampled_data, self.sampled_data_idx, index),
                get_from(self.complete_data, self.complete_data_idx, index))

In [4]:
if __name__ == '__main__':
    bare_dataset = BareDataset(ctx.sampled_tr_path, ctx.complete_tr_path, ctx=ctx)
    print(f'dataset len: {len(bare_dataset)}')
    # torch.save(bare_dataset, ctx.bare_dataset_path)
    print(ctx.__dict__)

total lines: 2000000
total lines: 2500000
dataset len: 50000
{'min_x': 116.15005, 'max_x': 116.59997, 'min_y': 39.75, 'max_y': 40.09967, 'min_ts': 1201930247000, 'max_ts': 1202463545000, 'sampled_tr_path': 'tdrive-data/tdrive-r-train-ps-40', 'complete_tr_path': 'tdrive-data/tdrive-r-train-ps-50', 'test_complete_tr_path': 'tdrive-data/geolife-r-train-ps-50', 'num_x_grids': 200, 'num_y_grids': 200, 'ts_gap': 600000, 'sampled_tr_len': 40, 'complete_tr_len': 50, 'ts_rep_dict_path': 'tdrive-data/ts_rep_dict.pkl', 'sp_rep_dict_path': 'tdrive-data/sp_rep_dict.pkl', 'ts_pretrain_model_path': 'tdrive-data/ts_pretrain_model.pt', 'sp_pretrain_model_path': 'tdrive-data/sp_pretrain_model.pt', 'sm_pretrain_model_path': 'tdrive-data/semantic2vec.model', 'bare_dataset_path': 'tdrive-data/bare_dataset.pt', 'dataset_path': 'tdrive-data/dataset.pt', 'at2vec_model_path': 'tdrive-data/at2vec_model.pt', 'at2vec_rep_path': 'tdrive-data/at2vec_rep_path.pt', 'sp_len': 100, 'ts_len': 100, 'sm_len': 100, 'pt_len

In [5]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


def visualize_data(data: pd.DataFrame, length):
    """绘制图像"""
    x = data.iloc[:, 2]
    y = data.iloc[:, 3]
    ts = data.iloc[:, 1]
    labels = [x for x in range(len(ts))]
    size = len(data)
    i = 0
    with tqdm(total=size, disable=True) as pbar:
        while i < size:
            plt.plot(x.iloc[i:i+length-1], y.iloc[i:i+length-1],
                     x.iloc[i+1:i+length], y.iloc[i+1:i+length])
            i += length
            pbar.update(length)
    # plt.figure()
    # plt.scatter(ts, labels)

# visualize_data(bare_dataset.complete_data, ctx.complete_tr_len)

In [6]:
# For geolife


## 时空语义信息预处理

### 处理时间标记

In [7]:
from torch import nn
import torch
import torch.nn.functional as F


class PretrainModel(nn.Module):
    def __init__(self, vocab_size, embed_size, device):
        """
        Args: 
          vocab_size: 总的单元个数
          embed_size: 时间表示向量长度
        """
        super(PretrainModel, self).__init__()

        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.device = device

        # Skip-gram参考：https://youtu.be/c2qIe74NN6A?t=3152
        self.ctx_embed = nn.Embedding(
            self.vocab_size, self.embed_size, max_norm=1)
        nn.init.xavier_normal_(self.ctx_embed.weight)
        self.tgt_embed = nn.Linear(self.embed_size, self.vocab_size)
        nn.init.xavier_normal_(self.tgt_embed.weight)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, context: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
          context:  中心词 (batch_size), 0 <= value < vocab_size
          targets: 上下文词 (batch_size, k), 0 <= value < vocab_size

        len: 时间/空间向量维度
        k: K近邻个数

        Returns: 损失值
        """
        batch_size = context.shape[0]
        # vector: (batch_size, embed_size)
        vector = self.ctx_embed(context).to(self.device)
        # print(f'vector: {vector.shape}, {vector.dtype}')

        # Z: (batch_size, vocab_size)
        Z = self.tgt_embed(vector).to(self.device)
        # print(f'Z: {Z.shape}, {Z.dtype}')

        # Y_hat: (batch_size, vocab_size)
        Y_hat = self.log_softmax(Z).to(self.device)
        # print(f'Y_hat: {Y_hat.shape}, {Y_hat.dtype}')

        return self.criterion(Y_hat, targets)

    def criterion(self, output, labels):
        """
        Args:
            output: (batch_size, vocab_size)
            labels: (batch_size, k)
        Returns: loss
        """
        k = labels.shape[1]
        # loss: (batch_size, )
        loss = torch.stack([self.loss_fn(output, label)
                           for label in torch.unbind(labels, dim=1)])
        return loss

    def embed(self, index):
        with torch.no_grad():
            if isinstance(index, int):
                index = torch.tensor(index).to(self.device)
            return self.ctx_embed(index)

In [8]:
def ts2id(ts: int, min_ts: int, max_ts: int, ts_gap: int):
    ts = min(max_ts, ts)
    ts = max(min_ts, ts)
    return (ts - min_ts) // ts_gap


def ts_k_nearest(k: int, max_tsid: int):
    """返回一个生成K近邻的函数。"""
    def f(tsid: int) -> 'list[int]':
        # tsid in [0, max_tsid)
        nearest = list()
        l = tsid - 1
        r = tsid + 1
        while True:
            if len(nearest) >= k:
                break
            if l >= 0:
                nearest.append(l)
                l -= 1
            if len(nearest) >= k:
                break
            if r < max_tsid:
                nearest.append(r)
                r += 1
            if l < 0 and r >= max_tsid:
                break
        return nearest

    return f


ts_k_nearest(4, 10)(5)  # [4, 6, 3, 7]

[4, 6, 3, 7]

In [9]:
# DataLoader for timestamps data
class SpaceTimestampDataset(torch.utils.data.Dataset):
    def __init__(self, vocab_size, k_nearest, device):
        self.vocab_size = vocab_size
        self.k_nearest = k_nearest
        self.device = device

    def __len__(self):
        return self.vocab_size

    def __getitem__(self, index):
        return (torch.tensor(index).to(self.device),
                torch.tensor(self.k_nearest(index)).to(self.device))

In [10]:
from d2l import torch as d2l
import torch
from torch.utils.data import DataLoader
import random
from tqdm import tqdm, trange

def validate(model: PretrainModel, k_nearest, num_tests=2) -> float:
    """验证训练出来的表示向量是否能够反映邻近关系"""
    with torch.no_grad():
        count = 0.0
        vocab_size = model.vocab_size
        for _ in trange(num_tests):
            index = random.randrange(vocab_size)
            neighbors = k_nearest(index)
            id_dist = []
            index_vec = model.embed(index)
            for j in range(vocab_size):
                j_vec = model.embed(j)
                id_dist.append((torch.dist(index_vec, j_vec).item(), j))
            id_dist.sort()
            # print(id_dist[:20])
            assert id_dist[0][1] == index
            for k in range(1, min(vocab_size, 2 * len(neighbors) + 1)):
                if id_dist[k][1] in neighbors:
                    count += 1 / len(neighbors)
        return count / num_tests


def pretrain(dataloader: DataLoader, num_epochs: int, lr: float, vocab_size: int, embed_size: int,
             k_nearest, ctx: Context, /, state=None, draw=True, save_path=None):
    model = PretrainModel(vocab_size=vocab_size,
                          embed_size=embed_size,
                          device=ctx.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if state is None:
        model.to(ctx.device)
        start_epoch = 0
    else:
        model.load_state_dict(state['model'])
        model.to(ctx.device)
        # optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']

    if draw:
        board = d2l.ProgressBoard(xlabel="epoch",
                                  ylabel="loss",
                                  xlim=[start_epoch, start_epoch + num_epochs],
                                  figsize=(5, 3))

    # xlim=[0, num_epochs])
    num_rounds = len(dataloader)
    # pbar = tqdm(total=num_epochs)
    # pbar1 = tqdm(total=num_rounds)
    for epoch in range(start_epoch, start_epoch + num_epochs):
        # pbar1.reset()
        if not draw:
            print(f"========= Epoch {epoch} ========")

        total_loss = 0
        numel = 0
        for i, (context, target) in enumerate(dataloader):
            optimizer.zero_grad()
            loss = model(context, target)
            loss_sum = loss.sum()
            loss_sum.backward()
            optimizer.step()
            with torch.no_grad():
                total_loss += loss_sum.item()
                numel += loss.shape[0]
                #if not draw:
                #    print(
                #        f"epoch {i+1:4d}/{num_rounds:4d}: loss = {loss_sum / loss.shape[0]:.4f}")
            # pbar1.update(1)

        with torch.no_grad():
            if draw:
                board.draw(epoch, total_loss / numel,
                           'loss', every_n=num_epochs // 100)
            else:
                print(f"total loss: {total_loss / numel:.4f}\n")

            if (epoch + 1) % 5 == 0 or epoch + 1 == num_epochs:
                # 每100次计算一次accuracy
                # val_count = validate(model, k_nearest)
                # print(f'accuracy: {val_count}')
                # if draw:
                #     board.draw(epoch, val_count, 'accuracy')
                # 第一次不需要保存
                state = {'epoch': epoch + 1,
                         'model': model.state_dict(),
                         'optimizer': optimizer.state_dict()}
                torch.save(state, save_path)
                if not draw: print('saved')
        # pbar.update(1)
    # pbar.close()
    # pbar1.close()

    print(f'final loss: {total_loss / numel}')

    return model

In [11]:
import gc
import torch

if __name__ == '__main__':
    # 每个时间切片的长度（单位：ms)
    # ctx.ts_gap = 600000
    ctx.num_ts_grids = (ctx.max_ts - ctx.min_ts) // ctx.ts_gap + 1

In [12]:
# if __name__ == '__main__':
#     ts_model = PretrainModel(vocab_size=ctx.num_ts_grids,
#                                   embed_size=ctx.ts_len,
#                                   device=ctx.device)
#     ts_model.load_state_dict(torch.load(ctx.ts_pretrain_model_path)['model'])
#     ts_model.to(ctx.device)
#     ts_model.embed(5)

### 处理空间标记

In [13]:
import math


def pair2spid(x_id: int, y_id: int, num_x_grids: int):
    return y_id * num_x_grids + x_id


def spid2pair(id: int, num_x_grids: int):
    return id % num_x_grids, id // num_x_grids


def sp2id(x: float, y: float,
          min_x: float, min_y: float,
          max_x: float, max_y: float,
          x_gap: float, y_gap: float):
    """
    (x, y)坐标转换为空间网格令牌值。假设max_x和max_y不能取到。

    Returns:
        令牌值, (x轴编号, y轴编号)
    """
    x, y = max(min_x, x), max(min_y, y)
    x, y = min(max_x, x), min(max_y, y)
    num_x_grids = int(math.ceil((max_x - min_x) / x_gap))
    x_grid, y_grid = (int(math.floor((x - min_x) / x_gap)),
                      int(math.floor((y - min_y) / y_gap)))
    return pair2spid(x_grid, y_grid, num_x_grids)


if __name__ == '__main__':
    spid = sp2id(6, 4.5, 0, 0, 16, 9, 4, 3)
    pair = spid2pair(spid, 4)
    spid, pair

In [14]:
def sp_k_nearest(k: int, max_x_id: int, max_y_id: int):
    """返回一个生成K近邻的函数。"""
    assert 0 < k < max_x_id * max_y_id

    def iswithin(x_id, y_id):
        return 0 <= x_id < max_x_id and 0 <= y_id < max_y_id

    def f(spid: int) -> 'list[int]':
        # 使用Manhattan距离
        # x_id in [0, max_x_id)
        # y_id in [0, max_y_id)
        x_id, y_id = spid2pair(spid, max_x_id)

        nearest = list()
        dist = 1
        while True:
            delta_x_abs = dist - dist // 2
            while delta_x_abs <= dist:
                delta_x = delta_x_abs
                delta_y = dist - delta_x

                diagonal = delta_x == delta_y or delta_y == 0

                for _ in range(4):
                    new_x_id = x_id + delta_x
                    new_y_id = y_id + delta_y
                    if iswithin(new_x_id, new_y_id):
                        nearest.append(pair2spid(new_x_id, new_y_id, max_x_id))
                        if len(nearest) >= k:
                            return nearest
                    if not diagonal:
                        new_x_id = x_id + delta_y
                        new_y_id = y_id + delta_x
                        if iswithin(new_x_id, new_y_id):
                            nearest.append(
                                pair2spid(new_x_id, new_y_id, max_x_id))
                            if len(nearest) >= k:
                                return nearest
                    delta_x, delta_y = -delta_y, delta_x  # 顺时针旋转90°

                delta_x_abs += 1
            dist += 1

    return f

if __name__ == '__main__':
    sp_k_nearest(20, 7, 7)(24)

In [15]:
import gc
import torch

if __name__ == '__main__':
    # 每个时间切片的长度（单位：ms)
    ctx.x_gap, ctx.y_gap = ((ctx.max_x - ctx.min_x) / ctx.num_x_grids,
                            (ctx.max_y - ctx.min_y) / ctx.num_y_grids)
    ctx.num_sp_grids = sp2id(ctx.max_x, ctx.max_y,
                             ctx.min_x, ctx.min_y,
                             ctx.max_x, ctx.max_y,
                             ctx.x_gap, ctx.y_gap)
    ctx.num_sp_grids

In [16]:
# if __name__ == '__main__':
#     sp_model = PretrainModel(vocab_size=ctx.num_sp_grids,
#                                   embed_size=ctx.sp_len,
#                                   device=ctx.device)
#     sp_model.load_state_dict(torch.load(ctx.sp_pretrain_model_path)['model'])
#     sp_model.to(ctx.device)
#     sp_model.embed(5)

### 处理语义标记

In [17]:
# from gensim.models import Word2Vec

# if __name__ == '__main__':
#     word2vec = Word2Vec.load(ctx.sm_pretrain_model_path)

### 轨迹特征融合

In [18]:
import torch
from gensim.models import Word2Vec
from functools import partial


def get_mat(tr, sp_model, ts_model, sm_model):
    get_spid = partial(sp2id, min_x=ctx.min_x, max_x=ctx.max_x, min_y=ctx.min_y, max_y=ctx.max_y,
                           x_gap=ctx.x_gap, y_gap=ctx.y_gap)
    get_tsid = partial(ts2id, min_ts=ctx.min_ts,
                           max_ts=ctx.max_ts, ts_gap=ctx.ts_gap)
    ts_col, all_cols, sm_col = (tr.iloc[:, 1],
                                tr.iloc[:, 2:4],
                                tr.iloc[:, 4])
    sp_vec = torch.stack([sp_model.embed(get_spid(al.iloc[0], al.iloc[1]))
                          for (_, al) in all_cols.iterrows()], dim=0)
    ts_vec = torch.stack([ts_model.embed(get_tsid(ts))
                         for (_, ts) in ts_col.iteritems()], dim=0)
    # semantics are more complicated
    vec_set = []
    for _, sm in sm_col.iteritems():
        # For each trajectory point
        # keyword list of this point
        kws = sm.replace(' ', '-').split(',')
        # 所有关键词向量取平均并归一化，作为该点语义向量
        avg_vec = torch.from_numpy(sm_model.wv.get_mean_vector(
            kws, pre_normalize=True, post_normalize=True))
        vec_set.append(avg_vec)
    sm_vec = torch.stack(vec_set, dim=0)
    # returns: (tr_len, sp_len)
    return torch.cat((sp_vec, ts_vec, sm_vec), dim=1)


class TrajectoryDataset(torch.utils.data.Dataset):
    def __init__(self,
                 bare_dataset: BareDataset,
                 sp_model: PretrainModel,
                 ts_model: PretrainModel,
                 sm_model: Word2Vec,
                 ctx: Context):
        self.bare_dataset = bare_dataset
        length = len(bare_dataset)
        self.length = length
        
        dev = ctx.device
        if self.bare_dataset.sampled_data is not None:
            self.sampled_tr = torch.zeros(
                length, ctx.sampled_tr_len, ctx.pt_len)
            for i in range(length):
                if i % 10000 == 0:
                    print(f'processing sampled trajectory #{i}')
                bare_tr, _ = self.bare_dataset[i]
                self.sampled_tr[i, :] = get_mat(bare_tr, sp_model, ts_model, sm_model).cpu()
        else:
            self.sampled_tr = None

        if self.bare_dataset.complete_data is not None:
            self.complete_tr = torch.zeros(
                length, ctx.complete_tr_len, ctx.pt_len)
            for i in range(length):
                if i % 10000 == 0:
                    print(f'processing full trajectory #{i}')
                _, bare_tr = self.bare_dataset[i]
                self.complete_tr[i, :] = get_mat(bare_tr, sp_model, ts_model, sm_model).cpu()
        else:
            self.complete_tr = None

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        sampled = self.sampled_tr[index] if self.sampled_tr is not None else None
        complete = self.complete_tr[index] if self.complete_tr is not None else None
        return index, sampled, complete

In [19]:
# 得到轨迹向量表示
from gensim.models import Word2Vec

if __name__ == '__main__':
    sp_model = PretrainModel(ctx.num_sp_grids, ctx.sp_len, torch.device('cpu'))
    sp_model.load_state_dict(torch.load(ctx.sp_pretrain_model_path)['model'])
    ts_model = PretrainModel(ctx.num_ts_grids, ctx.ts_len, torch.device('cpu'))
    ts_model.load_state_dict(torch.load(ctx.ts_pretrain_model_path)['model'])
    sm_model = Word2Vec.load(ctx.sm_pretrain_model_path)

    # dataset = TrajectoryDataset(bare_dataset, sp_model, ts_model, sm_model, ctx)
    # torch.save(dataset, ctx.dataset_path)

In [20]:
if __name__ == '__main__':
    dataset = torch.load(ctx.dataset_path)
    # print(dataset[0])

## 模型定义

### 注意力机制

In [21]:
import torch.nn as nn


class Attention(nn.Module):
    """注意力机制"""

    def __init__(self, hidden_len, vec_len, atten_len):
        """
        Args:
            hidden_len
            vec_len
            atten_len

        Point-level:   N_e=hidden_len, N=vec_len=atten_len
        Feature-level: N_e=hidden_len, T=vec_len=atten_len
        Encoder-level: N_d=hidden_len, N_e=vec_len, N_p=atten_len
        """
        super().__init__()
        self.vec_len = vec_len
        self.hidden_len = hidden_len
        # 对(h_{t-1}, q_{t-1}^e)施加的线性变换
        # h是encoder hidden state，q是encoder cell state
        self.W1 = nn.Linear(2 * hidden_len, atten_len)
        # 对E(·,j)施加的线性变换
        # E(·,j)是第j个轨迹点的表示向量
        self.W2 = nn.Linear(vec_len, atten_len)
        self.v1 = nn.Linear(atten_len, 1)
        # 不需要b1，因为Linear层含bias

    def forward(self, src, h, q, t):
        """
        Args: （Point|Feature|Encoder attention)
            src: 采样轨迹|采样轨迹|encoder hidden state
                 (batch_size, sampled_tr_len, pt_len->vec_len)
            h: t-1时刻encoder hidden state|encoder hidden state|decoder hidden state
            q: t-1时刻encoder cell state|encoder cell state|decoder cell state
            t: 当前时刻（即当前是轨迹中的第几个点，从0开始计数）

        Returns: (batch_size, atten_len)
        """
        # part_1: (batch_size, atten_len)
        part_1 = self.W1(torch.cat((h, q), dim=1))
        # part_2: (batch_size, sampled_tr_len, atten_len)
        part_2 = self.W2(src)
        # a: (batch_size, sampled_tr_len)
        a = self.v1(torch.tanh(part_1.unsqueeze(1) + part_2)).squeeze()
        if len(a.shape) == 1:
            a = a.unsqueeze(0)
        a = torch.exp(a)
        a = a / torch.pow(a.sum(dim=1, keepdim=True), 1 / (t+1))
        return a

### 编码器—解码器模型

In [22]:
import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self, sampled_tr_len, pt_len, hidden_len, device):
        """
        Args:
            sampled_tr_len: 采样轨迹长度（轨迹点个数）
            pt_len: 轨迹点特征个数（表示向量长度）
            hidden_len: 隐藏层大小
        """
        super().__init__()
        self.sampled_tr_len = sampled_tr_len
        self.pt_len = pt_len
        self.hidden_len = hidden_len
        self.device = device

        self.lstm_cell = nn.LSTMCell(pt_len, hidden_len)
        self.point_attn = Attention(hidden_len, pt_len, pt_len)
        self.feature_attn = Attention(
            hidden_len, sampled_tr_len, sampled_tr_len)

    def forward(self, sampled_tr):
        """
        Args:
            sampled_tr: 采样轨迹 (batch_size, sampled_tr_len, pt_len)
        Returns: 
            (enc_hiddens, enc_cell)
            enc_hiddens: 编码器所有hidden state (batch_size, sampled_tr_len, hidden_len)
            enc_cell:    编码器最后时刻cell state (batch_size, hidden_len)
        """
        batch_size = sampled_tr.shape[0]
        enc_hiddens = torch.zeros(
            (batch_size, self.sampled_tr_len, self.hidden_len)).to(self.device)
        enc_hidden = torch.zeros((batch_size, self.hidden_len)).to(self.device)
        enc_cell = torch.zeros((batch_size, self.hidden_len)).to(self.device)
        for t in range(self.sampled_tr_len):
            # point_attn_vec: (batch_size, sampled_tr_len, 1)
            point_attn_vec = self.point_attn(
                sampled_tr, enc_hidden, enc_cell, t).unsqueeze(2)
            # point_attn_matrix: (batch_size, ptr_len, sampled_tr_len)
            point_attn_matrix = (sampled_tr * point_attn_vec).transpose(1, 2)
            # feature_attn_vec: (batch_size, ptr_len, 1)
            feature_attn_vec = self.feature_attn(
                point_attn_matrix, enc_hidden, enc_cell, t).unsqueeze(2)
            # feature_attn_matrix: (batch_size, sampled_tr_len, pt_len)
            feature_attn_matrix = (point_attn_matrix *
                                   feature_attn_vec).transpose(1, 2)
            # enc_hidden, enc_cell: (batch_size, hidden_len)
            enc_hidden, enc_cell = self.lstm_cell(
                feature_attn_matrix[:, t], (enc_hidden, enc_cell))
            enc_hiddens[:, t] = enc_hidden
        return enc_hiddens, enc_cell

    def get_rep_vector(self, tr):
        """
        Args:
            tr: 轨迹 ([batch_size,] *tr_len, pt_len)
        Returns: 
            最终表示 ([batch_size,] hidden_len)
        """
        if len(tr.shape) == 2:
            tr = tr.unsqueeze(0)
            
        batch_size, tr_len, _ = tr.shape
        enc_hidden = torch.zeros((batch_size, self.hidden_len)).to(self.device)
        enc_cell = torch.zeros((batch_size, self.hidden_len)).to(self.device)
        for t in range(tr_len):
            # point_attn_vec: (batch_size, tr_len, 1)
            point_attn_vec = self.point_attn(
                tr, enc_hidden, enc_cell, t).unsqueeze(2)
            # point_attn_matrix: (batch_size, tr_len, pt_len)
            point_attn_matrix = (tr * point_attn_vec)
            # enc_hidden, enc_cell: (batch_size, hidden_len)
            enc_hidden, enc_cell = self.lstm_cell(
                point_attn_matrix[:, t], (enc_hidden, enc_cell))
        return enc_hidden.squeeze()

if __name__ == '__main__':
    encoder = Encoder(3, 4, 5, torch.device('cpu'))
    sampled_tr = torch.ones((2, 3, 4))
    h, c = encoder(sampled_tr)
    h, c

In [23]:
class Decoder(nn.Module):
    def __init__(self, complete_tr_len, pt_len, hidden_len, device):
        """
        Args:
            pt_len: 轨迹点特征个数（表示向量长度）
            hidden_len: 隐藏层大小
        """
        super().__init__()
        self.complete_tr_len = complete_tr_len
        self.pt_len = pt_len
        self.hidden_len = hidden_len
        self.device = device

        self.lstm_cell = nn.LSTMCell(pt_len + hidden_len, hidden_len)
        self.enc_attn = Attention(hidden_len, hidden_len, hidden_len)

    def forward(self, complete_tr, enc_hiddens, enc_cell):
        """
        Args:
            complete_tr: 完整轨迹 (batch_size, complete_tr_len, pt_len)
            enc_hiddens: 编码器所有隐藏层 (batch_size, sampled_tr_len, hidden_len)
            enc_cell:    编码器最后时刻cell state (batch_size, hidden_len)
        Returns:
            解码器所有隐藏层 (batch_size, complete_tr_len, hidden_len)
        """
        batch_size = complete_tr.shape[0]
        # dec_hiddens: (batch_size, complete_tr_len, hidden_len)
        dec_hiddens = torch.zeros(
            batch_size, self.complete_tr_len, self.hidden_len).to(self.device)
        # dec_hidden: (batch_size, hidden_len)
        dec_hidden = enc_hiddens[:, -1, :]
        # print('initial dec_hidden:', dec_hidden)
        # dec_cell
        dec_cell = enc_cell

        for t in range(self.complete_tr_len):
            # enc_attn_vec: (batch_size, 1, sampled_tr_len)
            enc_attn_vec = self.enc_attn(
                enc_hiddens, dec_hidden, dec_cell, t).unsqueeze(1)
            # c_t: (batch_size, hidden_len)
            c_t = torch.bmm(enc_attn_vec, enc_hiddens).squeeze()
            # pt_and_c_t: (batch_size, pt_size + hidden_len)
            pt_and_c_t = torch.cat((complete_tr[:, t], c_t), dim=1)
            dec_hidden, dec_cell = self.lstm_cell(
                pt_and_c_t, (dec_hidden, dec_cell))
        
            # print('dec_hidden:', dec_hidden)
            dec_hiddens[:, t] = dec_hidden
            # TODO: c?

        return dec_hiddens

if __name__ == '__main__':
    decoder = Decoder(6, 4, 5, torch.device('cpu'))
    complete_tr = torch.ones(2, 6, 4)
    o, _ = decoder(complete_tr, h, c)
    o

In [24]:
class EncoderDecoder(nn.Module):
    def __init__(self, sampled_tr_len, complete_tr_len, pt_len, hidden_len,
                 num_sp_grids, num_ts_grids, num_keywords, device):
        """
        Args:
            sampled_tr_len: 采样轨迹长度
            complete_tr_len: 完整轨迹长度
            pt_len: 轨迹点特征数（向量长度）
            hidden_len: 隐藏层维度
            num_sp_grids: 空间网格个数
            num_ts_grids: 时间单位个数
            num_keywords: 语义词汇个数
        """
        super().__init__()
        self.encoder = Encoder(sampled_tr_len, pt_len, hidden_len, device)
        self.decoder = Decoder(complete_tr_len, pt_len, hidden_len, device)
        self.sp_dense = nn.Linear(hidden_len, num_sp_grids)
        self.ts_dense = nn.Linear(hidden_len, num_ts_grids)
        self.sm_dense = nn.Linear(hidden_len, num_keywords)
        self.to(device)

    def forward(self, sampled_tr, complete_tr):
        """
        Args:
            sampled_tr:  采样轨迹 (batch_size, sampled_tr_len,  pt_len)
            complete_tr: 完整轨迹 (batch_size, complete_tr_len, pt_len)
        Returns: (_, _, _), _
            sp_prediction 空间预测 (batch_size, complete_tr_len, num_sp_grids)
            ts_prediction 时间预测 (batch_size, complete_tr_len, num_ts_grids)
            sm_prediction 语义预测 (batch_size, complete_tr_len, num_keywords)
            c             最终表示 (batch_size, hidden_len)
        """
        enc_hiddens, enc_cell = self.encoder(sampled_tr)
        dec_hiddens = self.decoder(complete_tr, enc_hiddens, enc_cell)
        sp_prediction = self.sp_dense(dec_hiddens)
        ts_prediction = self.ts_dense(dec_hiddens)
        sm_prediction = self.sm_dense(dec_hiddens)
        c = enc_hiddens[:, -1]
        return (sp_prediction, ts_prediction, sm_prediction), c
    
    def get_rep_vector(self, tr):
        """
        Args:
            tr: 采样轨迹 (batch_size, sampled_tr_len, pt_len)
        Returns:
            最终表示 (batch_size, hidden_len)
        """
        with torch.no_grad():
            return self.encoder.get_rep_vector(tr)
    
if __name__ == '__main__':
    enc_dec = EncoderDecoder(3, 6, 4, 5, 10, 20, 30, torch.device('cpu'))
    sampled_tr = torch.ones(2, 3, 4)
    complete_tr = torch.ones(2, 6, 4)
    (p1, p2, p3), c = enc_dec(sampled_tr, complete_tr)
    p1.shape, p2.shape, p3.shape, c.shape

### STA损失函数

In [25]:
from gensim.models import Word2Vec
import numpy as np


def get_sp_knn(ctx: Context, k: int):
    """
    返回一个函数，该函数输入为x、y坐标值，输出为k个邻近的空间网格令牌值。
    """
    knn_func = sp_k_nearest(k, ctx.num_x_grids, ctx.num_y_grids)

    def f(x: float, y: float) -> 'list[int]':
        spid = sp2id(x, y, ctx.min_x, ctx.min_y, ctx.max_x,
                     ctx.max_y, ctx.x_gap, ctx.y_gap)
        return [spid] + knn_func(spid)

    return f


def get_ts_knn(ctx: Context, k: int):
    """
    返回一个函数，该函数输入为时间戳值，输出为k个邻近的时间单元令牌值。
    """
    knn_func = ts_k_nearest(k, ctx.num_ts_grids)

    def f(ts: int) -> 'list[int]':
        tsid = ts2id(ts, ctx.min_ts, ctx.max_ts, ctx.ts_gap)
        return [tsid] + knn_func(tsid)

    return f


def get_sm_knn(model: Word2Vec, ctx: Context, k: int):
    """
    返回一个函数，该函数输入为某个轨迹点的语义向量（即所有关键词向量的平均），
    输出与该语义向量最相邻的k个关键词。
    """
    def f(vec: torch.Tensor) -> 'list[str]':
        vec_cpu = vec.cpu()
        knn = model.wv.similar_by_vector(vec_cpu.detach().numpy(), topn=k)
        return [key for (key, similarity) in knn]

    return f

In [26]:
from torch.cuda import memory_allocated

def inspect_memory():
    pass
    # print(f'memory allocated: {memory_allocated() / 1024 / 1024:.2f} MB')  

In [27]:
from gensim.models import Word2Vec
import torch.nn.functional as F
import torch

def criterion(sp_prediction: torch.Tensor,
              ts_prediction: torch.Tensor,
              sm_prediction: torch.Tensor,
              true_tr: torch.Tensor,
              index: torch.Tensor,
              k: int,
              alpha: float,
              beta: float,
              gamma: float,
              ts_start: int,
              sm_start: int,
              pt_len: int,
              sp_knn,
              ts_knn,
              sm_knn,
              sp_model: PretrainModel,
              ts_model: PretrainModel,
              sm_model: Word2Vec,
              bare_dataset: BareDataset,
              device):
    """
    STA损失函数。

    Args:
        sp_prediction: 空间预测 (batch_size, complete_tr_len, num_sp_grids)
        ts_prediction: 时间预测 (batch_size, complete_tr_len, num_ts_grids)
        sm_prediction: 语义预测 (batch_size, complete_tr_len, num_keywords)
        true_tr:  完整轨迹 (batch_size, complete_tr_len, pt_len)
        index: 轨迹编号 (batch_size, )
        k: KNN个数
        alpha: 空间部分权重
        beta:  时间部分权重
        gamma: 语义部分权重
        ts_start: 时间信息开始于 (100)
        sm_start: 语义信息开始于 (200)
        pt_len:  总长度 (300)
        sp_knn: 输入x/y，返回空间网格k近邻spid+自身spid
        ts_knn: 输入ts，返回时间分段k近邻tsid+自身tsid
        sm_knn: 输入语义tensor array (sm_len,)，返回k近邻(list[str])
        sp_model: 
        ts_model:
        sm_model:
        bare_dataset: BareDataset 
        device: 
    Returns:
        损失值 (batch_size, )
    """
    inspect_memory()
    batch_size, complete_tr_len, _ = sp_prediction.shape
    cross_entropy = torch.nn.CrossEntropyLoss(reduction='none')
    # loss_pts: list of (batch_size, ), len=complete_tr_len
    loss_pts = []
    for t in range(complete_tr_len):
        inspect_memory()
        # 准备第t个轨迹点
        # true_pt: (batch_size, 1, pt_len)
        true_pt = true_tr[:, t:t+1]
        # raw: pd.DataFrame (batch_size, [tid, ts, x, y, kw])
        raw = []
        for i in range(batch_size):
            pt_raw = bare_dataset[index[i].item()][1].iloc[t]
            raw.append(pt_raw)

        # 空间（此部分K近邻包含自己，k=k+1）
        predicted_sp = sp_prediction[:, t].unsqueeze(1)  # (batch_size, 1, num_sp_grids)
        true_sp = true_pt[:, :, :ts_start]  # (batch_size, 1, sp_len)
        # knn_sp: 目标轨迹点KNN
        knn_label_set = []
        knn_vec_set = []
        for i in range(batch_size):
            x, y = raw[i].iloc[2], raw[i].iloc[3]
            labels = sp_knn(x, y)  # list of spid
            knn_label_set.append(labels)
            # knn_vec_i: (k, sp_len)
            knn_vec_i = torch.stack([sp_model.embed(spid).to(device) 
                                     for spid in labels], dim=0)
            knn_vec_set.append(knn_vec_i)
        # knn_sp_labels: (batch_size, k)
        knn_sp_labels = torch.tensor(knn_label_set, device=device)
        # knn_sp_vec: (batch_size, k, sp_len)
        knn_sp_vec = torch.stack(knn_vec_set, dim=0)
        # w_sp: (batch_size, k)
        w_sp = -torch.linalg.vector_norm(knn_sp_vec - true_sp, dim=2) / alpha
        w_sp = F.softmax(w_sp, dim=1)
        # ce_sp: (batch_size, k) -- cross entropy works on axis 1
        ce_sp = cross_entropy(predicted_sp.expand(-1, knn_sp_labels.shape[1], -1).permute(0, 2, 1), 
                              knn_sp_labels)
        # loss_sp_pt: (batch_size, )
        loss_sp_pt = torch.mean(w_sp * ce_sp, dim=1)
        
        # 时间（此部分K近邻包含自己，k=k+1）
        predicted_ts = ts_prediction[:, t].unsqueeze(1)  # (batch_size, 1, num_ts_grids)
        true_ts = true_pt[:, :, :ts_start]  # (batch_size, 1, ts_len)
        # knn_ts: 目标轨迹点KNN
        knn_label_set = []
        knn_vec_set = []
        for i in range(batch_size):
            ts = raw[i].iloc[1]
            labels = ts_knn(ts)  # list of tsid
            knn_label_set.append(labels)
            # knn_vec_i: (k, ts_len)
            knn_vec_i = torch.stack([ts_model.embed(int(tsid)).to(device) 
                                     for tsid in labels], dim=0)
            knn_vec_set.append(knn_vec_i)
        # knn_ts_labels: (batch_size, k)
        knn_ts_labels = torch.tensor(knn_label_set, device=device)
        # knn_ts_vec: (batch_size, k, ts_len)
        knn_ts_vec = torch.stack(knn_vec_set, dim=0)
        # w_ts: (batch_size, k)
        w_ts = -torch.linalg.vector_norm(knn_ts_vec - true_ts, dim=2) / alpha
        w_ts = F.softmax(w_ts, dim=1)
        # ce_ts: (batch_size, k) -- cross entropy works on axis 1
        ce_ts = cross_entropy(predicted_ts.expand(-1, knn_ts_labels.shape[1], -1).permute(0, 2, 1), 
                              knn_ts_labels) 
        # loss_ts_pt: (batch_size, )
        loss_ts_pt = torch.mean(w_ts * ce_ts, dim=1)
        
        # 语义
        predicted_sm = sm_prediction[:, t].unsqueeze(1)  # (batch_size, 1, num_keywords)
        true_sm = true_pt[:, :, :ts_start]  # (batch_size, 1, sm_len)
        # knn_sm: 目标轨迹点KNN
        knn_label_set = []
        knn_vec_set = []
        for i in range(batch_size):
            # true_sm_vec: (sm_len, )
            true_sm_vec = true_sm[i].squeeze()
            knn_keywords = sm_knn(true_sm_vec)
            labels = [sm_model.wv.get_index(kw) for kw in knn_keywords]
            knn_label_set.append(labels)
            # knn_vec_i: (k, sm_len)
            knn_vec_i = torch.stack([torch.tensor(sm_model.wv.get_vector(kw), device=device)
                                     for kw in knn_keywords], dim=0)
            knn_vec_set.append(knn_vec_i)
        # knn_sm_labels: (batch_size, k)
        knn_sm_labels = torch.tensor(knn_label_set, device=device)
        # knn_sm_vec: (batch_size, k, sm_len)
        knn_sm_vec = torch.stack(knn_vec_set, dim=0)
        # w_sm: (batch_size, k)
        w_sm = -F.cosine_similarity(knn_sm_vec, true_sm, dim=2) / gamma
        w_sm = F.softmax(w_sm, dim=1)
        # ce_sm: (batch_size, k) -- cross entropy works on axis 1
        ce_sm = cross_entropy(predicted_sm.expand(-1, knn_sm_labels.shape[1], -1).permute(0, 2, 1), 
                              knn_sm_labels) 
        # loss_sm_pt: (batch_size, )
        loss_sm_pt = torch.sum(w_sm * ce_sm, dim=1)
        
        loss_pts.append(loss_sp_pt + loss_ts_pt + loss_sm_pt)

    return torch.stack(loss_pts, dim=1).sum(dim=1)


In [28]:
if __name__ == '__main__':
    _, _, true_tr = dataset[5]
    true_tr = true_tr.unsqueeze(0).cuda()
    # 轨迹长度：8
    sp_prediction = torch.rand(1, 50, ctx.num_sp_grids, dtype=torch.float32).cuda()
    ts_prediction = torch.rand(1, 50, ctx.num_ts_grids, dtype=torch.float32).cuda()
    sm_prediction = torch.ones(1, 50, len(sm_model.wv), dtype=torch.float32).cuda()

    sp_knn_func = get_sp_knn(ctx, ctx.k)
    ts_knn_func = get_ts_knn(ctx, ctx.k)
    sm_knn_func = get_sm_knn(sm_model, ctx, ctx.k)

    criterion(sp_prediction, ts_prediction, sm_prediction, true_tr, torch.Tensor([5]), 
              ctx.k, 1, 1, 1, 100, 200, 300,
              sp_knn_func, ts_knn_func, sm_knn_func,
              sp_model, ts_model, sm_model,
              bare_dataset, torch.device('cuda'))

## 模型训练

In [29]:
from d2l import torch as d2l
import torch
from torch.utils.data import DataLoader
import random
from functools import partial
from tqdm.notebook import tqdm

# def validate(model: PretrainModel, k_nearest, num_tests=100) -> float:
#     """验证训练出来的表示向量是否能够反映邻近关系"""
#     with torch.no_grad():
#         count = 0.0
#         vocab_size = model.vocab_size
#         for _ in range(num_tests):
#             index = random.randrange(vocab_size)
#             neighbors = k_nearest(index)
#             id_dist = []
#             index_vec = model.embed(index)
#             for j in range(vocab_size):
#                 j_vec = model.embed(j)
#                 id_dist.append((torch.dist(index_vec, j_vec).item(), j))
#             id_dist.sort()
#             # print(id_dist[:20])
#             assert id_dist[0][1] == index
#             for k in range(1, min(vocab_size, 2 * len(neighbors) + 1)):
#                 if id_dist[k][1] in neighbors:
#                     count += 1 / len(neighbors)
#         return count / num_tests


def train(model: EncoderDecoder, dataloader: DataLoader,
          sp_model: PretrainModel, ts_model: PretrainModel, sm_model: PretrainModel,
          bare_dataset: BareDataset,
          num_epochs: int, lr: float, ctx: Context, /, state=None, draw=True,
          save_path=None):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if state is None:
        start_epoch = 0
    else:
        model.load_state_dict(state['model'])
        # optimizer.load_state_dict(state['optimizer'])
        start_epoch = state['epoch']

    if draw:
        board = d2l.ProgressBoard(xlabel="epoch",
                                  ylabel="loss",
                                  xlim=[start_epoch, start_epoch + num_epochs],
                                  figsize=(5, 3))
        
    sp_knn_func = get_sp_knn(ctx, ctx.k)
    ts_knn_func = get_ts_knn(ctx, ctx.k)
    sm_knn_func = get_sm_knn(sm_model, ctx, ctx.k)
    loss_fn = partial(criterion, k=ctx.k, alpha=ctx.alpha, beta=ctx.beta, gamma=ctx.gamma,
                      ts_start=ctx.sp_len, sm_start=ctx.sp_len+ctx.ts_len, pt_len=ctx.pt_len,
                      sp_knn=sp_knn_func, ts_knn=ts_knn_func, sm_knn=sm_knn_func,
                      sp_model=sp_model, ts_model=ts_model, sm_model=sm_model,
                      bare_dataset=bare_dataset, device=ctx.device)
                      
    # xlim=[0, num_epochs])
    num_rounds = len(dataloader)
    pbar = tqdm(total=num_epochs * num_rounds)
    for epoch in range(num_epochs):
        save = True  # 是否保存数据？
        
        if not draw:
            print(f"========= Epoch {epoch} =========")

        total_loss = 0
        numel = 0
        for i, (index, sampled, complete) in enumerate(dataloader):
            optimizer.zero_grad()
            index = index.to(ctx.device)
            sampled = sampled.to(ctx.device)
            complete = complete.to(ctx.device)
            (sp_pred, ts_pred, sm_pred), c = model(sampled, complete)
            loss = loss_fn(sp_pred, ts_pred, sm_pred, complete, index)
            # print(f'after loss reserved: {torch.cuda.memory_reserved()}')
            loss_sum = loss.sum()
            loss_sum.backward()
            optimizer.step()

            total_loss += loss_sum.item()
            numel += loss.shape[0]
            if not draw and i % 20 == 0:
                print(
                    f"loss = {loss_sum / loss.shape[0]:.4f} [{i+1:4d}/{num_rounds:4d}]")
            pbar.update(1)
            
        with torch.no_grad():
            if draw:
                board.draw(epoch, total_loss / numel,
                           'loss', every_n=num_epochs // 100)
            else:
                print(f"total loss: {total_loss / numel:.4f}\n")

            if save:
                # 每100次计算一次accuracy
                # val_count = validate(model, k_nearest)
                # print(f'accuracy: {val_count}')
                if draw:
                    board.draw(epoch, val_count, 'accuracy')

                state = {'epoch': epoch + 1,
                         'model': model.state_dict(),
                         'optimizer': optimizer.state_dict()}
                torch.save(state, save_path)
    pbar.close()

    print(f'final loss: {total_loss / numel}')

    return model

In [30]:
if __name__ == '__main__':
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=ctx.batch_size, shuffle=True)

In [None]:
import gc
from datetime import datetime

if __name__ == '__main__':
    start = datetime.now()
    model = EncoderDecoder(ctx.sampled_tr_len, ctx.complete_tr_len, ctx.pt_len, ctx.hidden_len,
                          ctx.num_sp_grids, ctx.num_ts_grids, len(sm_model.wv), ctx.device)

    # state = torch.load(ctx.at2vec_model_path)

    gc.collect()
    torch.cuda.empty_cache()
    state = None
    # state = torch.load(ctx.at2vec_model_path)
    train(model, dataloader, sp_model, ts_model, sm_model, bare_dataset, 2, 0.001,
          ctx, state=state, draw=False, 
          save_path=ctx.at2vec_model_path)
    end = datetime.now()
    print(f'training time: {end - start}')

  0%|          | 0/3126 [00:00<?, ?it/s]

loss = 367.1755 [   1/1563]
loss = 358.8768 [  21/1563]


## 测试准确度

In [None]:
import torch
import random
import heapq

def test_accuracy(model: EncoderDecoder, test_dataset, device):
    """
    测试准确度。
    
    Args:
        test_dataloader
    """
    # no grad
    K = 50
    queue = []  # list[distance, index]

    i = random.randrange(len(test_dataset))
    chosen_index, _, chosen_tr = test_dataset[i]
    chosen_vec = model.get_rep_vector(chosen_tr.to(device))
    
    dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=128)
   
    # print(f'chosen: {chosen_index} {chosen_vec.shape}')
    for indexes, _, trs in dataloader:
        vecs = model.get_rep_vector(trs.to(device))
        for i in range(indexes.shape[0]):
            heapq.heappush(queue, (-torch.dist(vecs[i], chosen_vec), int(indexes[i])))
            if len(queue) > K:
                heapq.heappop(queue)
    queue.sort(reverse=True)
    hit_count = 0
    min_accepted_idx = chosen_index // 50 * 50
    max_accepted_idx = min_accepted_idx + 49
    for _, idx in queue:
        print(f"{idx:7d} ", end="")
        if min_accepted_idx <= idx <= max_accepted_idx:
            hit_count += 1
    print()
    for dist, _ in queue:
        print(f"{-dist:.5f} ", end="")
    print()
    print(f'#{chosen_index}: {hit_count} / 50', end='\n\n')
    return hit_count / 50

In [None]:
if __name__ == '__main__':
    model = EncoderDecoder(ctx.sampled_tr_len, ctx.complete_tr_len, ctx.pt_len, ctx.hidden_len,
                           ctx.num_sp_grids, ctx.num_ts_grids, len(sm_model.wv), ctx.device)

    state = torch.load(ctx.at2vec_model_path)
    model.load_state_dict(state['model'])

In [None]:
if __name__ == '__main__':
    bare_dataset = BareDataset(sampled_tr_path=None,
                               complete_tr_path=ctx.test_complete_tr_path,
                               update_ctx=False, ctx=ctx)
    test_dataset = TrajectoryDataset(bare_dataset, sp_model, ts_model, sm_model, ctx)

In [None]:
if __name__ == '__main__':
    accuracy = 0
    test_k = 50
    for i in range(test_k):
        accuracy += test_accuracy(model, test_dataset, ctx.device)
    accuracy / test_k