# Predict new flight positions

## 1. Load the saved model

In [1]:
# Import necessary libraries
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import math

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Context size of our custom model
context_size = 2048

class PositionEmbedding(torch.nn.Module):
    """Token and positioning embedding layer for a sequence."""
    def __init__(self):
        """Init variables and layers."""
        super().__init__()
        
        self.position_emb = torch.nn.Embedding(num_embeddings=context_size, embedding_dim=6)
    
    def forward(self, x):
        """Forward Pass."""
        len_input = x.size()[1]
        positions = torch.arange(start=0, end=len_input, step=1).to(device)
        position_embedding = self.position_emb(positions)
        return x + position_embedding

def create_attention_mask(key_length, query_length, dtype):
    """
    Create a Casual Mask for
    the multi head attention layer.
    """
    i = torch.arange(query_length)[:, None]
    j = torch.arange(key_length)
    mask = i >= j - key_length + query_length
    mask = torch.logical_not(mask)
    mask = mask.to(dtype)
    return mask

class TransformerBlock(torch.nn.Module):
    """Transformer Block Layer."""
    def __init__(self, num_heads, embed_dim, ff_dim, mask_function, dropout_rate=0.1):
        """Init variables and layers."""
        super().__init__()
        self.attn = torch.nn.MultiheadAttention(
          embed_dim=embed_dim,
          num_heads=num_heads,
          batch_first=True,
        )
        self.dropout_1 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm_1 = torch.nn.LayerNorm(
          normalized_shape=embed_dim, eps=1e-6
        )
        self.ffn_1 = torch.nn.Linear(
          in_features=embed_dim, out_features=ff_dim
        )
        self.ffn_2 = torch.nn.Linear(
          in_features=ff_dim, out_features=embed_dim
        )
        self.dropout_2 = torch.nn.Dropout(p=dropout_rate)
        self.layer_norm_2 = torch.nn.LayerNorm(
          normalized_shape=embed_dim, eps=1e-6
        )
        self.mask_function = mask_function
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """Forward Pass."""
        seq_len = inputs.size()[1]
        mask = self.mask_function(seq_len, seq_len, torch.bool).to(device)
        attention_output, _ = self.attn(
        query=inputs, key=inputs, value=inputs, attn_mask=mask
        )
        attention_output = self.dropout_1(attention_output)
        out1 = self.layer_norm_1(inputs + attention_output)
        ffn_1 = self.ffn_1(out1)
        ffn_2 = self.ffn_2(ffn_1)
        ffn_output = self.dropout_2(ffn_2)
        output = self.layer_norm_2(out1 + ffn_output)
        return output

class FlightModel(torch.nn.Module):
  def __init__(self, feed_forward_dim, num_heads):
    """Init Function."""
    super().__init__()
    self.embedding_layer = PositionEmbedding()
    self.transformer_layers = []
    for i in range(6):
        transformer = TransformerBlock(
          num_heads=num_heads,
          embed_dim=6,
          ff_dim=feed_forward_dim,
          mask_function=create_attention_mask,
        )
        self.transformer_layers.append(transformer)
    self.transformer_layers = nn.ModuleList(self.transformer_layers)
    self.output_layer = torch.nn.Linear(6, 6)

  def forward(self, input_tensor):
    """Forward Pass."""
    # Position embedding
    embedding = self.embedding_layer(input_tensor)
    # Transformer layers
    transformer_output = self.transformer_layers[0](embedding)
    for i in range(1, len(self.transformer_layers)):
        transformer_output = self.transformer_layers[i](transformer_output)
    # FC network
    output = self.output_layer(transformer_output)
    return output

model = FlightModel(24, 1).to(device)
model.load_state_dict(torch.load("./flight_prediction_model_78.pt", weights_only=True))
model.eval()

cuda


FlightModel(
  (embedding_layer): PositionEmbedding(
    (position_emb): Embedding(2048, 6)
  )
  (transformer_layers): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=6, out_features=6, bias=True)
      )
      (dropout_1): Dropout(p=0.1, inplace=False)
      (layer_norm_1): LayerNorm((6,), eps=1e-06, elementwise_affine=True)
      (ffn_1): Linear(in_features=6, out_features=24, bias=True)
      (ffn_2): Linear(in_features=24, out_features=6, bias=True)
      (dropout_2): Dropout(p=0.1, inplace=False)
      (layer_norm_2): LayerNorm((6,), eps=1e-06, elementwise_affine=True)
    )
  )
  (output_layer): Linear(in_features=6, out_features=6, bias=True)
)

## 2. Generate new plane positions

In [28]:
torch.set_printoptions(sci_mode=False)

# Test plane positions
plane_pos = [
                [39.335709,33.919495,30000,350.0,283.0,0],
                [39.336407,33.914917,30000,350.0,281.0,0],
                [39.337944,33.902649,30000,350.0,279.0,0],
                [39.339108,33.890503,30000,360.0,277.0,0],
                [39.342041,33.855385,29975,360.0,276.0,0],
                [39.342865,33.845295,29975,370.0,276.0,64],
                [39.345886,33.808754,30000,370.0,276.0,64],
                [39.34877,33.774422,30000,371.0,276.0,64],
                [39.352188,33.732788,29975,371.0,276.0,64],
                [39.355133,33.6968,29975,371.0,276.0,64],
                [39.358055,33.661621,30000,371.0,276.0,-64],
                [39.361038,33.625389,30000,371.0,276.0,64],
                [39.379944,33.390854,29975,373.0,276.0,64],
                [39.383286,33.347744,29975,373.0,275.0,-64],
                [39.385612,33.312073,29975,373.0,275.0,120],
                [39.39267,33.205341,29975,374.0,275.0,120],
                [39.394878,33.171997,29975,374.0,275.0,64],
                [39.39798,33.124794,29975,374.0,275.0,64],
                [39.400864,33.080669,30000,372.0,275.0,0],
                [39.403152,33.045799,30000,372.0,275.0,0],
                [39.405945,33.002213,29975,372.0,274.0,-64],
                [39.408379,32.964355,30000,373.0,275.0,0],
                [39.411163,32.921009,29975,373.0,275.0,64],
                [39.413361,32.886169,29975,373.0,275.0,0],
                [39.415455,32.852722,29975,373.0,274.0,64],
                [39.418442,32.805355,30000,372.0,275.0,0],
                [39.420639,32.770424,29975,372.0,274.0,0],
                [39.422791,32.735912,29975,373.0,274.0,0],
                [39.425537,32.691311,30000,372.0,274.0,0],
                [39.42778,32.655247,29975,372.0,275.0,0],
                [39.430489,32.611023,29975,371.0,275.0,64],
                [39.431374,32.593201,29975,372.0,273.0,0],
                [39.431717,32.576313,29975,373.0,271.0,0],
                [39.431561,32.559631,29975,374.0,269.0,64],
                [39.430771,32.540649,29975,375.0,266.0,0],
                [39.429428,32.522755,29975,376.0,264.0,-64],
                [39.42849,32.513733,29975,377.0,262.0,-64],
                [39.426113,32.494934,29975,377.0,260.0,0],
                [39.424255,32.483227,29975,380.0,258.0,0],
                [39.420807,32.464783,29975,381.0,256.0,0],
                [39.418808,32.455463,29975,381.0,254.0,0],
                [39.415688,32.442383,29975,383.0,253.0,0],
                [39.412197,32.429199,29975,383.0,251.0,0],
                [39.407776,32.414265,29975,383.0,249.0,64],
                [39.394638,32.370975,29975,388.0,249.0,0],
                [39.382263,32.330872,29975,390.0,248.0,0],
                [39.370285,32.29282,29975,390.0,248.0,64],
                [39.360057,32.260315,29975,390.0,248.0,64],
                [39.353817,32.240479,29975,390.0,248.0,0],
                [39.341436,32.201233,29975,392.0,248.0,64],
                [39.331194,32.16864,29975,390.7,247.9,32],
                [39.323327,32.143738,29975,389.8,247.8,32],
                [39.310758,32.104004,29975,388.5,247.9,448],
                [39.300568,32.071659,29975,387.6,247.9,64],
                [39.292419,32.046044,29975,385.3,247.9,64],
                [39.284637,32.021503,29975,385.3,247.9,0],
                [39.269348,31.9732,29975,384.8,247.7,192],
                [39.259225,31.941406,29975,384.4,247.8,448],
                [39.251724,31.917731,29975,383.5,247.8,-96],
                [39.238602,31.876465,29975,383.5,247.8,96],
                [39.226685,31.839275,29975,383.9,247.6,-96],
                [39.21514,31.802979,29975,384.8,247.7,-96],
                [39.212814,31.795747,29975,384.8,247.7,160],
                [39.209015,31.783806,29975,386.1,247.6,704],
                [39.200821,31.758191,29975,386.1,247.6,320],
                [39.195076,31.740356,29975,386.1,247.6,96],
                [39.185764,31.711365,30000,385.2,247.6,64],
                [39.169647,31.661165,29975,384.0,247.5,64],
                [39.161453,31.635729,29975,384.2,247.5,64],
                [39.14843,31.595276,29975,385.2,247.6,224],
                [39.145032,31.584717,29975,385.5,247.4,224],
                [39.142611,31.577271,29975,385.5,247.4,-448],
                [39.135721,31.555908,29975,386.5,247.5,-96],
                [39.134033,31.550645,29975,386.5,247.5,448],
                [39.133394,31.548584,29975,386.5,247.5,448],
                [39.13298,31.547361,30000,386.5,247.5,-64],
                [39.130228,31.538818,29975,386.5,247.5,-96],
                [39.129547,31.536733,29975,386.5,247.5,448],
                [39.120871,31.509949,29975,387.4,247.5,352],
                [39.118103,31.501385,30000,387.4,247.5,-544],
                [39.117426,31.499268,29975,387.4,247.5,672],
                [39.114354,31.489685,29975,386.5,247.5,672],
                [39.113004,31.485474,29975,386.5,247.5,96],
                [39.099564,31.444065,29975,385.5,247.4,512],
                [39.09613,31.433497,29975,385.5,247.4,512],
                [39.093292,31.42472,30000,385.5,247.4,-128],
                [39.0874,31.406616,29975,386.8,247.4,-544],
                [39.086334,31.403344,29975,386.8,247.4,-224],
                [39.084732,31.398329,29975,387.8,247.4,512],
                [39.082901,31.392716,29975,387.8,247.4,-288],
                [39.079468,31.382088,30000,387.8,247.4,-288],
                [39.076263,31.372176,29975,387.8,247.4,768],
                [39.074273,31.36615,29975,387.8,247.4,-1120],
                [39.071386,31.3573,29975,387.8,247.4,-1120],
                [39.070633,31.35504,29975,387.8,247.4,512],
                [39.06864,31.348755,29975,387.8,247.4,-480],
                [39.067848,31.346436,29975,387.8,247.4,928],
                [39.066824,31.343262,29975,387.8,247.4,-608],
                [39.053696,31.302979,29975,385.9,247.3,-256],
                [39.043688,31.272156,29975,384.6,247.4,-512],
                [39.041382,31.265119,29975,383.7,247.3,288],
                [39.038893,31.257507,29975,383.7,247.3,-448]
            ]

# Turn it into a tensor and set it as the correct shape
pos_input = torch.tensor(plane_pos)
pos_input = pos_input.view(1, -1, 6)

# Normalise using div_tensor
div_tensor = torch.tensor([90,  180, 10000, 1000, 360, 10000])
pos_input = torch.div(pos_input, div_tensor)

# Send both tensors to the correct device
pos_input = pos_input.to(device)
div_tensor = div_tensor.to(device)

# Evaluate with model
with torch.no_grad():
    for i in range(100):  
        output = model(pos_input)
        output = torch.squeeze(output)
        output = output[-1]
        output = output.view(1, 1, 6)
        pos_input = torch.cat((pos_input, output), 1)

pos_input = torch.mul(pos_input, div_tensor)
print(pos_input)


tensor([[[   39.3357,    33.9195, 30000.0000,   350.0000,   283.0000,
              0.0000],
         [   39.3364,    33.9149, 30000.0000,   350.0000,   281.0000,
              0.0000],
         [   39.3379,    33.9026, 30000.0000,   350.0000,   279.0000,
              0.0000],
         ...,
         [   40.1068,    13.2412, 37567.3828,   485.7389,   133.2328,
              0.4075],
         [   39.8500,    10.7933, 37541.4961,   486.8513,   128.8679,
              0.3975],
         [   39.8631,    11.2531, 37453.8672,   487.0148,   126.6510,
              0.4028]]], device='cuda:0')


### Print only Latitude and Longitude, then copy paste to https://tbensky.github.io/Maps/points.html

In [29]:
pos_input = torch.squeeze(pos_input)
for i in range(pos_input.size()[0]):
    next_pos = pos_input[i]
    print(f"{float(next_pos[0])},{float(next_pos[1])}")

39.33570861816406,33.91949462890625
39.33640670776367,33.9149169921875
39.33794403076172,33.90264892578125
39.339107513427734,33.8905029296875
39.342041015625,33.855384826660156
39.342864990234375,33.84529495239258
39.34588623046875,33.808753967285156
39.34877014160156,33.77442169189453
39.35218811035156,33.7327880859375
39.355133056640625,33.696800231933594
39.358055114746094,33.66162109375
39.36103820800781,33.625389099121094
39.37994384765625,33.39085388183594
39.38328552246094,33.34774398803711
39.38561248779297,33.31207275390625
39.392669677734375,33.20534133911133
39.39487838745117,33.1719970703125
39.397979736328125,33.124794006347656
39.40086364746094,33.08066940307617
39.40315246582031,33.045799255371094
39.40594482421875,33.00221252441406
39.40837860107422,32.96435546875
39.411163330078125,32.9210090637207
39.413360595703125,32.88616943359375
39.41545486450195,32.85272216796875
39.41844177246094,32.805355072021484
39.42063903808594,32.770423889160156
39.42279052734375,32.7359