In [2]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from minimum_snap import UAVTrajectoryPlanner  # Ensure you have this module

In [21]:
class WaypointPredictorAgent(nn.Module):
    def __init__(self, input_dim, output_dim, dtype=torch.float64):  # Add dtype parameter
        super(WaypointPredictorAgent, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 128).to(dtype),  # Convert to specified dtype
            nn.ReLU(),
            nn.Linear(128, 256).to(dtype),  # Convert to specified dtype
            nn.ReLU(),
            nn.Linear(256, output_dim).to(dtype)  # Convert to specified dtype
        )
    
    def forward(self, trajectories):
        return self.network(trajectories)


In [20]:
def train(epochs, planner, agent, optimizer, criterion, waypoints, total_time):
    predicted_waypoints_all = []
    for epoch in range(epochs):
        optimizer.zero_grad()

        # Expert generates trajectories using UAVTrajectoryPlanner
        polys_x, polys_y, _ = planner(waypoints, total_time)

        # Correctly flatten the polynomial coefficients for input to the agent
        # Ensure both polys_x and polys_y are properly flattened
        agent_input = torch.cat((polys_x.flatten(), polys_y.flatten())).unsqueeze(0)

        # Ensure the input has the correct shape and dtype before passing to the agent
        assert agent_input.shape == (1, 36), "Input tensor has incorrect shape"
        agent_input = agent_input.to(dtype=torch.float64)

        # Agent predicts waypoints from the given trajectory information
        if epoch == 0:
            predicted_waypoints = initial_waypoints.flatten().clone().detach().requires_grad_(True)
        else:
            # Subsequent epochs use the agent's network to predict waypoints
            predicted_waypoints = agent(agent_input)

        # Compute the loss and update
        loss = criterion(predicted_waypoints, waypoints.flatten())
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f'Epoch {epoch}: Loss {loss.item()}')
        
        predicted_waypoints = predicted_waypoints.detach()  # Detach from the computation graph
        predicted_waypoints = predicted_waypoints.reshape(2, -1)  # Reshape to [2, N] format assuming 2D waypoints
        predicted_waypoints_all.append(predicted_waypoints.numpy())
    
    return predicted_waypoints, predicted_waypoints_all


In [22]:
# Define the UAVTrajectoryPlanner with its parameters
initial_waypoints = torch.tensor([[15, 30], [20, 60], [0, 20], [30, 40]], dtype=torch.float64).t()  # Custom initial waypoints for the agent
total_time = 24.0
poly_order = 5
start_vel = [0, 0]
start_acc = [0, 0]
end_vel = [0, 0]
end_acc = [0, 0]
dtype = torch.float64
device = 'cuda'  # Ensure this matches your available device
waypoints = torch.tensor([[0, 0], [1, 0], [1, 2], [0, 2]], dtype=torch.float64).t()

# Instantiate the UAVTrajectoryPlanner with the correct arguments
planner = UAVTrajectoryPlanner(waypoints, total_time, poly_order, start_vel, start_acc, end_vel, end_acc, dtype=dtype, device=device)


TypeError: __init__() got multiple values for argument 'dtype'

In [6]:
# Initialize the Agent model with appropriate input and output dimensions
num_waypoints = 4  # Number of waypoints including the starting point
input_dim = 6 * 3 * 2  # 6 coefficients, 3 segments, 2 trajectories (x and y)
output_dim = 4 * 2  # 4 waypoints, each with x and y coordinates

agent = WaypointPredictorAgent(input_dim, output_dim, dtype=torch.float64)


In [7]:
# Setup the optimizer and loss function for training
optimizer = optim.Adam(agent.parameters(), lr=0.001)
criterion = nn.MSELoss()


In [8]:
# Start the training process
predicted_waypoints, predicted_waypoints_all = train(
    epochs=100,
    planner=planner,
    agent=agent,
    optimizer=optimizer,
    criterion=criterion,
    waypoints=waypoints,
    total_time=total_time
)


NameError: name 'planner' is not defined

In [9]:
# Plotting
plt.figure(figsize=(10, 5))
for i, preds in enumerate(predicted_waypoints_all):
    if i % 5 == 0:  # Optionally reduce the number of plotted epochs for clarity
        plt.plot(preds[0], preds[1], 'o--', label=f'Epoch {i}', alpha=0.6)
        
# Plot original waypoints
original_waypoints = torch.tensor([[0, 0], [1, 0], [1, 2], [0, 2]], dtype=torch.float64).t()
plt.plot(original_waypoints[0].numpy(), original_waypoints[1].numpy(), 'ro-', label='Original Waypoints')

# Plot predicted waypoints
plt.plot(predicted_waypoints[0].numpy(), predicted_waypoints[1].numpy(), 'bx--', label='Predicted Waypoints')

plt.title("Comparison of Original and Predicted Waypoints")
plt.xlabel("X Coordinate")
plt.ylabel("Y Coordinate")
plt.legend()
plt.grid(True)
plt.axis('equal')  # Ensure equal aspect ratio for x and y axes to avoid distortion
plt.show()


NameError: name 'predicted_waypoints_all' is not defined

<Figure size 1000x500 with 0 Axes>