In [None]:
import torch
import torch.nn as nn

# --- CNN Encoder ---
class VisionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 5, stride=2), nn.ReLU(),
            nn.Conv2d(16, 32, 5, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 5, stride=2), nn.ReLU(),
            nn.Flatten()
        )
        self.fc = nn.Linear(64 * 10 * 10, 128)  # Adjust depending on input size

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

# --- RL Policy ---
class PolicyNet(nn.Module):
    def __init__(self, state_dim=128, action_dim=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Tanh()  # Actions in [-1, 1]
        )
    def forward(self, state):
        return self.net(state)

# --- Main Drone Controller Skeleton ---
class DroneAgent:
    def __init__(self):
        self.encoder = VisionEncoder()
        self.policy = PolicyNet()
        # self.logic_plugin = ... (future)

    def act(self, image_tensor):
        state = self.encoder(image_tensor)
        raw_action = self.policy(state)
        # Optionally: raw_action = self.logic_plugin(raw_action)
        return raw_action

# --- Example usage in sim loop ---
agent = DroneAgent()
# image = get_image_from_airsim()  # (3, H, W) tensor
# action = agent.act(image)
# send_action_to_drone(action)


In [None]:
def logic_plugin(action, safety_data):
    # E.g., if collision imminent, force throttle up
    if safety_data["imminent_collision"]:
        action[3] = 1.0  # Max throttle
    return action