In [14]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

#const value
result_num = 1#numerical value represting people num
feature_num = 49

# Hyper-parameter
embedding_dim = 10
hidden_dim = 16
dense_dim = 32
n_layers = 1

#training config
epochs = 70
batch_size = 72
learning_rate=0.001

class TaxiDataset(Dataset):
    def __init__(self, x, y, loc, time):
        self.features = torch.from_numpy(x)
        self.labels = torch.from_numpy(y)
        self.locations = torch.from_numpy(loc)
        self.times = torch.from_numpy(time)
        
    def __getitem__(self,index):
        return self.features[index], self.labels[index], self.locations[index], self.times[index]
        
    def __len__(self):
        return len(self.labels)

class TaxiCNN(nn.Module):
    def __init__(
        self,
        
    ):
        super(TaxiCNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, out_channels=16, kernel_size=2, stride=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, out_channels=32, kernel_size=2, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(352, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

"""
class TaxiRNN(nn.Module):
    def __init__(
        self,
        sequence_size,
        embedding_dim=1,
        hidden_dim=100,
        dense_dim=32,
        max_norm=2,
        n_layers=1,
    ):
        super().__init__()
        self.embedding = nn.Embedding(
            sequence_size,
            embedding_dim,
            norm_type=2,
            max_norm=max_norm,
        )
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=n_layers)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, dense_dim)
            nn.ReLU()
            nn.Linear(dense_dim, result_num)
        )
        
    def __forward__(self, features, locations, times):
        embeds = self.embedding(features)
       
"""
#try to decrise the demension of the feature sequence
class TaxiRNN(nn.Module):
    def __init__(
        self,
        feature_size=feature_num,
        hidden_dim=1,
        dense_dim=16,
        n_layers=1,
    ):
        super(TaxiRNN, self).__init__()
        self.lstm = nn.LSTM(feature_size, hidden_dim, batch_first=True, num_layers=n_layers)
        self.gru = nn.GRU(feature_size, hidden_dim, batch_first=True, num_layers=n_layers)
        self.mlp = nn.Sequential(
            nn.Linear(8+2+1, dense_dim),
            nn.ReLU(),
            nn.Linear(dense_dim, result_num)
        )
        
    def forward(self, features, locations, times):
        #lstm to decrease the demension by reprsenting features with one feature
        rnn,(_,_) = self.lstm(features)
        #_,rnn = self.gru(features)
        squeeze = rnn.squeeze()
        times = times.unsqueeze(1)
        cat = torch.cat((squeeze, locations, times), axis=1)
        output = self.mlp(cat)
        return output
    
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#read data
train_data = np.load('./train.npz')
train_features = train_data['x']#features
train_labels = train_data['y']#labels
train_locations = train_data['locations']#locations
train_times = train_data['times']#times
val_data = np.load('./val.npz')
val_features = val_data['x']#features
val_labels = val_data['y']#labels
val_locations = val_data['locations']#locations
val_times = val_data['times']#times

# dataset
train_dataset = TaxiDataset(train_features, train_labels, train_locations, train_times)
val_dataset = TaxiDataset(val_features, val_labels, val_locations, val_times)
# dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

# initialize model
model = TaxiRNN().to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
model.train()
total_step = len(train_dataloader)
for epoch in range(epochs):
    for i, (features, labels, locations, times) in enumerate(train_dataloader):
        features = features.to(torch.float32).to(device)
        labels = labels.to(torch.float32).to(device)
        
        locations = locations.to(torch.float32).to(device)
        times = times.to(torch.float32).to(device)
        
        # Forward pass
        outputs = model(features, locations, times)
        loss = torch.sqrt(criterion(outputs, labels))
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 500 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, epochs, i+1, total_step, loss.item()))

# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
all_loss = []
with torch.no_grad():
    for features, labels, locations, times in val_dataloader:
        features = features.to(torch.float32).to(device)
        labels = labels.to(torch.float32).to(device)
        locations = locations.to(torch.float32).to(device)
        times = times.to(torch.float32).to(device)
        
        outputs = model(features, locations, times)
        loss = torch.sqrt(criterion(outputs, labels)).item()
        all_loss.append(loss)

    print(f'The Root Main square error is {np.mean(all_loss)}')

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

#with open('./time.txt',"a") as f:
#    for n in times:
#        f.write(str(n))
#    f.close()

Epoch [1/70], Step [500/1000], Loss: 11.7513
Epoch [1/70], Step [1000/1000], Loss: 9.0430
Epoch [2/70], Step [500/1000], Loss: 14.9804
Epoch [2/70], Step [1000/1000], Loss: 11.0783
Epoch [3/70], Step [500/1000], Loss: 7.1718
Epoch [3/70], Step [1000/1000], Loss: 10.7226
Epoch [4/70], Step [500/1000], Loss: 5.9076
Epoch [4/70], Step [1000/1000], Loss: 7.7332
Epoch [5/70], Step [500/1000], Loss: 6.3698
Epoch [5/70], Step [1000/1000], Loss: 7.1084
Epoch [6/70], Step [500/1000], Loss: 6.9797
Epoch [6/70], Step [1000/1000], Loss: 5.9480
Epoch [7/70], Step [500/1000], Loss: 6.2791
Epoch [7/70], Step [1000/1000], Loss: 7.3541
Epoch [8/70], Step [500/1000], Loss: 7.5475
Epoch [8/70], Step [1000/1000], Loss: 6.1868
Epoch [9/70], Step [500/1000], Loss: 7.8119
Epoch [9/70], Step [1000/1000], Loss: 7.5621
Epoch [10/70], Step [500/1000], Loss: 8.1741
Epoch [10/70], Step [1000/1000], Loss: 7.2734
Epoch [11/70], Step [500/1000], Loss: 6.4317
Epoch [11/70], Step [1000/1000], Loss: 4.6268
Epoch [12/70]