In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataloader import load_train, load_val
import matplotlib.pyplot as plt
from tqdm import tqdm


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)

In [None]:
LOOKBACK = 30
N_PREDICT = 10
N_FEATURES = 4    # LAT, LON, SOG, COG

class model_lstm(nn.Module):
    def __init__(self,hidden_size=64,num_layers=1):
        super(model_lstm, self).__init__()

        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(
            input_size=N_FEATURES, # number of features per timestep (4)
            hidden_size=hidden_size, # how many hidden units the LSTM should use
            num_layers=num_layers, # number of stacked LSTM layers
            batch_first=True     # input shape: (B, 30, 4) - batch is the first dimension
        )

        self.l_out = nn.Linear(
            in_features=hidden_size,
            out_features=N_PREDICT * N_FEATURES,
            bias=False
        )
    
    def forward(self, x):
        x, (h, c) = self.lstm(x) # x : (batch, 30, 4)
        x = x.reshape(-1, self.hidden_size) # flatten before Linear layer. x : (batch*30, hidden_size)
        x = self.l_out(x) # predict (10 Ã— 4) outputs
        x = x.view(-1, LOOKBACK, N_PREDICT, N_FEATURES) # reshape to (batch, 30, 10, 4), then take only the last timestep
        return x[:, -1]   # shape: (batch, 10, 4)
    
    
net = model_lstm()
print(net)

In [None]:
# Hyperparameters
num_epochs = 50
batch_size = 2048
lr = 1e-3
hidden_size = 64
num_layers = 1

In [None]:
#load data
train_ds = load_train()
val_ds = load_val()
train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=4, shuffle=False)


In [None]:
net = model_lstm(hidden_size=hidden_size, num_layers=num_layers)
net.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

# track loss
training_loss, validation_loss = [], []
best_val_loss = float("inf")

for epoch in range(1, num_epochs + 1):
    # validation
    net.eval()
    epoch_val_loss = 0.0
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False):
            x, y = x.to(device), y.to(device)
            y_pred = net(x)
            loss = criterion(y_pred, y)
            epoch_val_loss += loss.item() * x.size(0)
    epoch_val_loss /= len(val_loader.dataset)
    
    # save best model
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(net.state_dict(), "best_lstm_model.pt")
    
    # training
    net.train()
    epoch_train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = net(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * x.size(0)
    epoch_train_loss /= len(train_loader.dataset)
    
    # track
    training_loss.append(epoch_train_loss)
    validation_loss.append(epoch_val_loss)

In [None]:
# plot training and validation loss
plt.figure()
plt.plot(range(1, num_epochs + 1), training_loss, 'r', label='Training loss')
plt.plot(range(1, num_epochs + 1), validation_loss, 'b', label='Validation loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

#print the training and validation losses
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}: Training Loss = {training_loss[epoch]:.6f}, Validation Loss = {validation_loss[epoch]:.6f}") 

In [None]:
from plot_trajectory import plot_paths

batch_size = 1  # plot one trajectory at a time
num_samples_to_plot = 5

# load model
net = model_lstm(hidden_size=hidden_size, num_layers=num_layers).to(device)
net.load_state_dict(torch.load("best_lstm_model.pt"))
net.eval()

#plot-
for idx, (x, y) in enumerate(val_loader):
    if idx >= num_samples_to_plot:
        break
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_pred = net(x)
    
    # pick the first element in the batch
    x_np = x[0].cpu().numpy()             # (30,4)
    y_np = y[0].cpu().numpy()             # (10,4)
    y_pred_np = y_pred[0].cpu().numpy().reshape(10, 4)  # reshape flat 40 -> (10,4)

    plot_paths(x_np, y_np, y_pred_np, idx)
