Data loading section

In [1]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
path = "./data/taxi_log_2008_by_id"
os.chdir(path)
files = os.listdir()

total_files = len(files[1:])
target_files = int(total_files * 0.3)

gps_data=pd.read_csv(files[0],names=['taxi_id','time','latitude','longitude'])

for i,file in enumerate(tqdm(files[1:], desc="Processing files")):
    tmp = pd.read_csv(file, names=['taxi_id', 'time', 'longitude', 'latitude'])
    if not tmp.empty:
        gps_data = pd.concat([gps_data, tmp])
    if i + 1 >= target_files:
        break
 
gps_data


Processing files:  30%|██▉       | 3105/10356 [03:18<07:44, 15.61it/s] 


Unnamed: 0,taxi_id,time,latitude,longitude
0,1766,2008-02-02 13:47:24,116.42342,39.83735
1,1766,2008-02-02 13:57:25,116.42343,39.83725
2,1766,2008-02-02 14:07:24,116.42339,39.83720
3,1766,2008-02-02 14:17:24,116.42334,39.83726
4,1766,2008-02-02 14:27:24,116.42342,39.83728
...,...,...,...,...
679,9091,2008-02-08 10:40:38,40.12872,116.63859
680,9091,2008-02-08 10:40:38,40.12872,116.63859
681,9091,2008-02-08 10:40:38,40.12872,116.63859
682,9091,2008-02-08 10:40:38,40.12872,116.63859


In [2]:
gps_data1=gps_data.sort_values(by=['taxi_id','time'],ignore_index=True)
gps_data1.drop_duplicates(inplace=True,ignore_index=True)
gps_data1['time'] = pd.to_datetime(gps_data1['time'])
gps_data1

Unnamed: 0,taxi_id,time,latitude,longitude
0,2,2008-02-02 13:33:52,39.88781,116.36422
1,2,2008-02-02 13:37:16,39.88782,116.37481
2,2,2008-02-02 13:38:53,39.88791,116.37677
3,2,2008-02-02 13:42:18,39.88795,116.38033
4,2,2008-02-02 13:43:55,39.89014,116.39392
...,...,...,...,...
4782221,10356,2008-02-07 22:10:49,40.21196,116.24457
4782222,10356,2008-02-07 22:15:51,40.21237,116.25047
4782223,10356,2008-02-07 22:20:53,40.22385,116.23035
4782224,10356,2008-02-07 22:24:45,40.22432,116.23075


In [3]:
gps_data1=gps_data1[(gps_data1['latitude']>39.4)&
                  (gps_data1['latitude']<41.6)&
                  (gps_data1['longitude']>115.7)&
                  (gps_data1['longitude']<117.4)]
gps_data1

Unnamed: 0,taxi_id,time,latitude,longitude
0,2,2008-02-02 13:33:52,39.88781,116.36422
1,2,2008-02-02 13:37:16,39.88782,116.37481
2,2,2008-02-02 13:38:53,39.88791,116.37677
3,2,2008-02-02 13:42:18,39.88795,116.38033
4,2,2008-02-02 13:43:55,39.89014,116.39392
...,...,...,...,...
4782221,10356,2008-02-07 22:10:49,40.21196,116.24457
4782222,10356,2008-02-07 22:15:51,40.21237,116.25047
4782223,10356,2008-02-07 22:20:53,40.22385,116.23035
4782224,10356,2008-02-07 22:24:45,40.22432,116.23075


Data pre-processing section

In [4]:
window_size = 3

total_samples = 0
for _, group in gps_data1.groupby('taxi_id'):
    n_samples = len(group) - window_size
    if n_samples > 0:
        total_samples += n_samples

X = np.zeros((total_samples, window_size, 3))
y = np.zeros((total_samples, 2))

sample_idx = 0

for taxi_id, group in gps_data1.groupby('taxi_id'):
    n_samples = len(group) - window_size
    if n_samples <= 0:
        continue
    
    loc_data = group[['latitude', 'longitude']].values
    time_data = group['time'].values
    
    for i in range(n_samples):
        loc_features = loc_data[i:i+window_size]
        
        time_diff = (time_data[i:i+window_size] - time_data[i]).astype('timedelta64[ms]').astype(float) / 1000
        time_diff = time_diff.reshape(-1, 1)
        
        seq_input = np.hstack([loc_features, time_diff])
        
        target = loc_data[i+window_size]
        
        if sample_idx < total_samples:
            X[sample_idx] = seq_input
            y[sample_idx] = target
            sample_idx += 1
        else:
            print(f"warning: sample_idx={sample_idx} beyond {total_samples}")
            break

In [5]:
n = len(X)
train_X, train_y = X[:int(n*0.8)], y[:int(n*0.8)]
val_X, val_y = X[int(n*0.8):int(n*0.9)], y[int(n*0.8):int(n*0.9)]
test_X, test_y = X[int(n*0.9):], y[int(n*0.9):]

Data loader section

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np


class TaxiDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = TaxiDataset(train_X, train_y)
val_dataset = TaxiDataset(val_X, val_y)
test_dataset = TaxiDataset(test_X, test_y)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Model section

In [10]:
class TaxiTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward, dropout=0.1):
        super(TaxiTransformer, self).__init__()
        self.input_linear = nn.Linear(input_dim, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 3, d_model))

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                   dim_feedforward=dim_feedforward, dropout=dropout, 
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        self.output_layer = nn.Linear(d_model, 2)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.input_linear(x)  # [batch_size, window_size, d_model]
        x = self.norm(x)
        x = x + self.positional_encoding
        x = self.transformer_encoder(x)
        x_last = x[:, -1, :]  # [batch_size, d_model]
        output = self.output_layer(x_last)  # [batch_size, 2]
        return output

# Model hyperparameters
input_dim = 3
d_model = 64
nhead = 8
num_layers = 2
dim_feedforward = 128
model = TaxiTransformer(input_dim, d_model, nhead, num_layers, dim_feedforward)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = TaxiTransformer(input_dim, d_model, nhead, num_layers, dim_feedforward).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Using device: cuda


Training section

In [11]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for batch_X, batch_y in train_loader:
        # 将数据移动到 GPU
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)
        
        optimizer.zero_grad()
        preds = model(batch_X)
        loss = criterion(preds, batch_y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * batch_X.size(0)
    train_loss /= len(train_loader.dataset)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch_X, batch_y in val_loader:
            # 将验证数据移动到 GPU
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
            
            preds = model(batch_X)
            loss = criterion(preds, batch_y)
            val_loss += loss.item() * batch_X.size(0)
    val_loss /= len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")


Epoch 1/10, Train Loss: 32.0846, Val Loss: 0.0274
Epoch 2/10, Train Loss: 0.0234, Val Loss: 0.0225
Epoch 3/10, Train Loss: 0.0227, Val Loss: 0.0283
Epoch 4/10, Train Loss: 0.0227, Val Loss: 0.0305
Epoch 5/10, Train Loss: 0.0226, Val Loss: 0.0233
Epoch 6/10, Train Loss: 0.0232, Val Loss: 0.0228
Epoch 7/10, Train Loss: 0.0217, Val Loss: 0.0219
Epoch 8/10, Train Loss: 0.0218, Val Loss: 0.0304
Epoch 9/10, Train Loss: 0.0217, Val Loss: 0.0210
Epoch 10/10, Train Loss: 0.0213, Val Loss: 0.0254


Evaluation section

In [13]:
model.eval()
test_loss = 0.0
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)
        preds = model(batch_X)
        loss = criterion(preds, batch_y)
        test_loss += loss.item() * batch_X.size(0)
test_loss /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}")

Test Loss: 0.0221


Model export section

In [14]:
os.chdir("../../")
torch.save(model.state_dict(), "./taxi_transformer.pth")