In [1]:
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from tqdm import tqdm
import pandas as pd
import ast

In [2]:
class LSTMModel(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, output_size=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, lengths):
        # Pack the padded sequences before feeding to LSTM
        packed_input = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        
        # Pass through the LSTM
        packed_output, (hn, cn) = self.lstm(packed_input)
        
        # Optionally unpack the sequence (if needed)
        unpacked_output, _ = pad_packed_sequence(packed_output, batch_first=True)
        
        # We use only the last hidden state for the output (for simplicity)
        out = self.fc(hn[-1])
        return out

In [3]:
cityA = pd.read_csv("../Datasets/Task 2/frequent_sequences_A.csv")
cityA = list(cityA['Pattern'])
cityA_sequence = []

for str_seq in cityA:
    tuple_seq = ast.literal_eval(str_seq)
    cityA_sequence.append(tuple_seq)

cityA_sequence
train_sequences, test_sequences = train_test_split(cityA_sequence, test_size=0.2, random_state=42)




In [4]:
def get_inputs_labels(dataset):
    inputs = []
    labels = []
    for seq in dataset:
        inputs.append(torch.tensor(seq[:-1], dtype=torch.float32))
        labels.append(torch.tensor(seq[-1], dtype=torch.float32))

    return inputs, labels

In [5]:
X_train, y_train = get_inputs_labels(train_sequences)

model = LSTMModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for seq, target in tqdm(zip(X_train, y_train), total=len(X_train), desc=f"Epoch {epoch + 1}/{num_epochs}"):       
        seq_length = torch.tensor([len(seq)])  # Length of the sequence
        padded_seq = pad_sequence([seq], batch_first=True, padding_value=0).to(torch.float32)

        # Convert target to tensor
        target = target.unsqueeze(0)  # Add batch dimension

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(padded_seq, seq_length)

        # Calculate loss
        loss = criterion(output, target)
        epoch_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    # Print loss for every epoch
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(X_train):.4f}")
        
print("Training complete.")



Epoch 1/100: 100%|██████████| 17815/17815 [00:30<00:00, 587.77it/s]
Epoch 2/100: 100%|██████████| 17815/17815 [00:30<00:00, 588.30it/s]
Epoch 3/100: 100%|██████████| 17815/17815 [00:29<00:00, 596.64it/s]
Epoch 4/100: 100%|██████████| 17815/17815 [00:29<00:00, 608.67it/s]
Epoch 5/100: 100%|██████████| 17815/17815 [00:29<00:00, 603.32it/s]
Epoch 6/100: 100%|██████████| 17815/17815 [00:29<00:00, 608.81it/s]
Epoch 7/100: 100%|██████████| 17815/17815 [00:30<00:00, 588.90it/s]
Epoch 8/100: 100%|██████████| 17815/17815 [00:27<00:00, 640.67it/s]
Epoch 9/100: 100%|██████████| 17815/17815 [00:27<00:00, 649.54it/s]
Epoch 10/100: 100%|██████████| 17815/17815 [00:28<00:00, 622.74it/s]


Epoch [10/100], Loss: 27.0909


Epoch 11/100: 100%|██████████| 17815/17815 [00:30<00:00, 591.27it/s]
Epoch 12/100: 100%|██████████| 17815/17815 [00:29<00:00, 605.20it/s]
Epoch 13/100: 100%|██████████| 17815/17815 [00:28<00:00, 627.28it/s]
Epoch 14/100: 100%|██████████| 17815/17815 [00:27<00:00, 652.21it/s]
Epoch 15/100: 100%|██████████| 17815/17815 [00:27<00:00, 646.57it/s]
Epoch 16/100: 100%|██████████| 17815/17815 [00:29<00:00, 610.08it/s]
Epoch 17/100: 100%|██████████| 17815/17815 [00:28<00:00, 618.59it/s]
Epoch 18/100: 100%|██████████| 17815/17815 [00:29<00:00, 603.43it/s]
Epoch 19/100: 100%|██████████| 17815/17815 [00:29<00:00, 600.18it/s]
Epoch 20/100: 100%|██████████| 17815/17815 [00:28<00:00, 629.84it/s]


Epoch [20/100], Loss: 24.1404


Epoch 21/100: 100%|██████████| 17815/17815 [00:30<00:00, 586.78it/s]
Epoch 22/100: 100%|██████████| 17815/17815 [00:28<00:00, 631.51it/s]
Epoch 23/100: 100%|██████████| 17815/17815 [00:27<00:00, 638.86it/s]
Epoch 24/100: 100%|██████████| 17815/17815 [00:27<00:00, 644.58it/s]
Epoch 25/100: 100%|██████████| 17815/17815 [00:30<00:00, 583.02it/s]
Epoch 26/100: 100%|██████████| 17815/17815 [00:30<00:00, 587.94it/s]
Epoch 27/100: 100%|██████████| 17815/17815 [00:30<00:00, 589.87it/s]
Epoch 28/100: 100%|██████████| 17815/17815 [00:30<00:00, 575.78it/s]
Epoch 29/100: 100%|██████████| 17815/17815 [00:29<00:00, 612.47it/s]
Epoch 30/100: 100%|██████████| 17815/17815 [00:30<00:00, 581.69it/s]


Epoch [30/100], Loss: 24.4113


Epoch 31/100: 100%|██████████| 17815/17815 [00:31<00:00, 559.03it/s]
Epoch 32/100: 100%|██████████| 17815/17815 [00:36<00:00, 485.36it/s]
Epoch 33/100: 100%|██████████| 17815/17815 [00:35<00:00, 500.43it/s]
Epoch 34/100: 100%|██████████| 17815/17815 [00:30<00:00, 585.89it/s]
Epoch 35/100: 100%|██████████| 17815/17815 [00:30<00:00, 585.53it/s]
Epoch 36/100: 100%|██████████| 17815/17815 [00:29<00:00, 600.68it/s]
Epoch 37/100: 100%|██████████| 17815/17815 [00:29<00:00, 613.67it/s]
Epoch 38/100: 100%|██████████| 17815/17815 [00:31<00:00, 565.62it/s]
Epoch 39/100: 100%|██████████| 17815/17815 [00:28<00:00, 617.27it/s]
Epoch 40/100: 100%|██████████| 17815/17815 [00:28<00:00, 625.09it/s]


Epoch [40/100], Loss: 23.1239


Epoch 41/100: 100%|██████████| 17815/17815 [00:27<00:00, 637.11it/s]
Epoch 42/100: 100%|██████████| 17815/17815 [00:27<00:00, 645.24it/s]
Epoch 43/100: 100%|██████████| 17815/17815 [00:27<00:00, 642.24it/s]
Epoch 44/100: 100%|██████████| 17815/17815 [00:28<00:00, 626.97it/s]
Epoch 45/100: 100%|██████████| 17815/17815 [00:30<00:00, 579.50it/s]
Epoch 46/100: 100%|██████████| 17815/17815 [00:30<00:00, 586.47it/s]
Epoch 47/100: 100%|██████████| 17815/17815 [00:29<00:00, 594.03it/s]
Epoch 48/100: 100%|██████████| 17815/17815 [00:30<00:00, 591.55it/s]
Epoch 49/100: 100%|██████████| 17815/17815 [00:29<00:00, 603.95it/s]
Epoch 50/100: 100%|██████████| 17815/17815 [00:27<00:00, 649.39it/s]


Epoch [50/100], Loss: 22.8884


Epoch 51/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.63it/s]
Epoch 52/100: 100%|██████████| 17815/17815 [00:27<00:00, 653.70it/s]
Epoch 53/100: 100%|██████████| 17815/17815 [00:28<00:00, 628.54it/s]
Epoch 54/100: 100%|██████████| 17815/17815 [00:27<00:00, 645.49it/s]
Epoch 55/100: 100%|██████████| 17815/17815 [00:27<00:00, 646.56it/s]
Epoch 56/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.72it/s]
Epoch 57/100: 100%|██████████| 17815/17815 [00:27<00:00, 650.83it/s]
Epoch 58/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.69it/s]
Epoch 59/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.91it/s]
Epoch 60/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.74it/s]


Epoch [60/100], Loss: 26.0472


Epoch 61/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.14it/s]
Epoch 62/100: 100%|██████████| 17815/17815 [00:30<00:00, 588.03it/s]
Epoch 63/100: 100%|██████████| 17815/17815 [00:33<00:00, 527.26it/s]
Epoch 64/100: 100%|██████████| 17815/17815 [00:29<00:00, 601.02it/s]
Epoch 65/100: 100%|██████████| 17815/17815 [00:29<00:00, 609.44it/s]
Epoch 66/100: 100%|██████████| 17815/17815 [00:28<00:00, 626.52it/s]
Epoch 67/100: 100%|██████████| 17815/17815 [00:28<00:00, 634.72it/s]
Epoch 68/100: 100%|██████████| 17815/17815 [00:27<00:00, 644.00it/s]
Epoch 69/100: 100%|██████████| 17815/17815 [00:27<00:00, 649.02it/s]
Epoch 70/100: 100%|██████████| 17815/17815 [00:27<00:00, 636.69it/s]


Epoch [70/100], Loss: 23.2294


Epoch 71/100: 100%|██████████| 17815/17815 [00:31<00:00, 571.37it/s]
Epoch 72/100: 100%|██████████| 17815/17815 [00:27<00:00, 636.28it/s]
Epoch 73/100: 100%|██████████| 17815/17815 [00:27<00:00, 645.74it/s]
Epoch 74/100: 100%|██████████| 17815/17815 [00:27<00:00, 644.25it/s]
Epoch 75/100: 100%|██████████| 17815/17815 [00:27<00:00, 644.50it/s]
Epoch 76/100: 100%|██████████| 17815/17815 [00:27<00:00, 636.60it/s]
Epoch 77/100: 100%|██████████| 17815/17815 [00:28<00:00, 622.73it/s]
Epoch 78/100: 100%|██████████| 17815/17815 [00:28<00:00, 626.37it/s]
Epoch 79/100: 100%|██████████| 17815/17815 [00:27<00:00, 646.58it/s]
Epoch 80/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.75it/s]


Epoch [80/100], Loss: 24.4928


Epoch 81/100: 100%|██████████| 17815/17815 [00:27<00:00, 642.04it/s]
Epoch 82/100: 100%|██████████| 17815/17815 [00:27<00:00, 642.49it/s]
Epoch 83/100: 100%|██████████| 17815/17815 [00:27<00:00, 645.70it/s]
Epoch 84/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.52it/s]
Epoch 85/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.65it/s]
Epoch 86/100: 100%|██████████| 17815/17815 [00:27<00:00, 644.66it/s]
Epoch 87/100: 100%|██████████| 17815/17815 [00:27<00:00, 649.29it/s]
Epoch 88/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.92it/s]
Epoch 89/100: 100%|██████████| 17815/17815 [00:27<00:00, 652.69it/s]
Epoch 90/100: 100%|██████████| 17815/17815 [00:27<00:00, 655.13it/s]


Epoch [90/100], Loss: 24.6480


Epoch 91/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.73it/s]
Epoch 92/100: 100%|██████████| 17815/17815 [00:27<00:00, 648.51it/s]
Epoch 93/100: 100%|██████████| 17815/17815 [00:27<00:00, 651.10it/s]
Epoch 94/100: 100%|██████████| 17815/17815 [00:27<00:00, 654.56it/s]
Epoch 95/100: 100%|██████████| 17815/17815 [00:27<00:00, 649.72it/s]
Epoch 96/100: 100%|██████████| 17815/17815 [00:27<00:00, 653.15it/s]
Epoch 97/100: 100%|██████████| 17815/17815 [00:28<00:00, 633.47it/s]
Epoch 98/100: 100%|██████████| 17815/17815 [00:27<00:00, 647.63it/s]
Epoch 99/100: 100%|██████████| 17815/17815 [00:28<00:00, 627.44it/s]
Epoch 100/100: 100%|██████████| 17815/17815 [00:27<00:00, 637.75it/s]

Epoch [100/100], Loss: 22.9910
Training complete.





In [6]:
torch.save(model.state_dict(), "lstm_model.pth")
print('saved to lstm_model.pth')

saved to lstm_model.pth


In [7]:
X_test, y_test = get_inputs_labels(test_sequences)

model.eval()
predictions = []

with torch.no_grad():
    for seq in X_test:
        # Prepare sequence
        seq_length = torch.tensor([len(seq)])  # Sequence length
        padded_seq = pad_sequence([seq], batch_first=True).float()  # Convert to float if needed

        # Forward pass to get the prediction
        pred = model(padded_seq, seq_length)
        predictions.append(pred.squeeze().cpu())  # Append the prediction, removing extra dimensions

# Display predictions and compare to true targets
for i, (pred, true_target) in enumerate(zip(predictions, y_test)):
    print(f"Test Sequence {i+1}:")
    print(f"  Predicted next coordinate: {pred.numpy()}")
    print(f"  True next coordinate: {true_target.numpy()}")

Test Sequence 1:
  Predicted next coordinate: [176.1075  45.214 ]
  True next coordinate: [176.  48.]
Test Sequence 2:
  Predicted next coordinate: [ 54.77323 168.20311]
  True next coordinate: [ 56. 168.]
Test Sequence 3:
  Predicted next coordinate: [ 18.27745 109.11273]
  True next coordinate: [ 24. 111.]
Test Sequence 4:
  Predicted next coordinate: [ 96.6445 110.6371]
  True next coordinate: [ 97. 111.]
Test Sequence 5:
  Predicted next coordinate: [ 1.5959778 34.46978  ]
  True next coordinate: [ 1. 33.]
Test Sequence 6:
  Predicted next coordinate: [117.64739  96.80162]
  True next coordinate: [117. 101.]
Test Sequence 7:
  Predicted next coordinate: [182.91556 103.50127]
  True next coordinate: [183. 104.]
Test Sequence 8:
  Predicted next coordinate: [143.89758   88.020996]
  True next coordinate: [142.  90.]
Test Sequence 9:
  Predicted next coordinate: [157.70715   37.794777]
  True next coordinate: [156.  40.]
Test Sequence 10:
  Predicted next coordinate: [ 76.88026 100.50

In [9]:
def pred_next_coord(sequence):
    seq_tensor = torch.tensor(sequence, dtype=torch.float32).unsqueeze(0)
    seq_length = torch.tensor([seq_tensor.size(1)])
    model.eval()

    with torch.no_grad():
        predicted_next_coord = model(seq_tensor, seq_length)

    print("Next coordinate predicted: ", predicted_next_coord.squeeze().numpy())

In [14]:
sequence = ((88, 95), (77, 99), (77, 100))


pred_next_coord(sequence)

Next coordinate predicted:  [73.32224  97.934845]
