In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataloader import load_train, load_val
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)

In [None]:
LOOKBACK = 30
N_PREDICT = 10
N_FEATURES = 2    # LAT, LON

class model_lstm(nn.Module):
    def __init__(self,hidden_size,num_layers):
        super(model_lstm, self).__init__()

        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(
            input_size=N_FEATURES, # number of features per timestep
            hidden_size=hidden_size, # how many hidden units the LSTM should use
            num_layers=num_layers, # number of stacked LSTM layers
            batch_first=True     # input shape: (B, 30, 5) - batch is the first dimension
        )

        self.l_out = nn.Linear(
            in_features=hidden_size,
            out_features=N_PREDICT * N_FEATURES,
            bias=False
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)                    # out shape: (batch, LOOKBACK, hidden)
        last_output = out[:, -1, :]              # takes the output of the last time step. shape: (batch, hidden)
        y = self.l_out(last_output)              # shape: (batch, N_PREDICT * N_FEATURES)
        return y.view(-1, N_PREDICT, N_FEATURES) # reshape to (batch, N_PREDICT, N_FEATURES)

In [None]:
# Hyperparameters (tune these)
num_epochs = 10
lr = 1e-3
hidden_size = 512
num_layers = 5

In [None]:
batch_size = 2048

#load data
train_ds, scaler = load_train()
val_ds = load_val(scaler)
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4, shuffle=False)


In [None]:
net = model_lstm(hidden_size=hidden_size, num_layers=num_layers)
net.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

# track loss
training_loss, validation_loss = [], []
best_val_loss = float("inf")

for epoch in range(1, num_epochs + 1):
    # training
    net.train()
    epoch_train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = net(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * x.size(0)
    epoch_train_loss /= len(train_loader.dataset)

    # validation
    net.eval()
    epoch_val_loss = 0.0
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = net(x)
            loss = criterion(y_pred, y)
            epoch_val_loss += loss.item() * x.size(0)
    epoch_val_loss /= len(val_loader.dataset)
    
    # save best model
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(net.state_dict(), "best_lstm_model.pt")
    
    # track
    training_loss.append(epoch_train_loss)
    validation_loss.append(epoch_val_loss)

    #print loss every epoch
    print(f"Epoch {epoch}/{num_epochs}, Training Loss: {epoch_train_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}")

In [None]:
# plot training and validation loss
plt.figure()
plt.plot(range(1, num_epochs + 1), training_loss, 'r', label='Training loss')
plt.plot(range(1, num_epochs + 1), validation_loss, 'b', label='Validation loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

In [None]:
def plot_paths(x_sample, y_true, y_pred, idx):
    # Past  
    x_sample = x_sample.reshape(30, 2) # reshape back
    # True future
    y_true_sample = y_true
    # Predicted future
    y_pred_sample = y_pred

    plt.figure(figsize=(6,6))
    plt.plot(x_sample[:,0], x_sample[:,1], 'bo-', label='Past')         
    plt.plot(y_true_sample[:,0], y_true_sample[:,1], 'go-', label='True')   
    plt.plot(y_pred_sample[:,0], y_pred_sample[:,1], 'ro--', label='Predicted') 
    plt.xlabel('Latitude')
    plt.ylabel('Longitude')
    plt.title(f'Trajectory Sample {idx}')
    plt.legend()
    plt.show()

num_samples_to_plot = 50

# load model
net = model_lstm(hidden_size=hidden_size, num_layers=num_layers).to(device)
net.load_state_dict(torch.load("best_lstm_model.pt"))
net.eval()

# plot
for idx, (x, y) in enumerate(val_loader):
    if idx >= num_samples_to_plot:
        break
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_pred = net(x)
    
    # pick the first element in the batch
    x_np = x[0].cpu().numpy()             # (30,5)
    y_np = y[0].cpu().numpy()             # (10,5)
    y_pred_np = y_pred[0].cpu().numpy().reshape(10, 2)  # reshape flat 50 -> (10,5)

    plot_paths(x_np, y_np, y_pred_np, idx)
