In [None]:
# train.py
from env import DroneEnv
from models import Actor, Critic
from replay_buffer import ReplayBuffer
import torch
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = DroneEnv()

img_shape = (3, 84, 84)
action_dim = 4
actor = Actor(img_shape, action_dim).to(device)
critic1 = Critic(img_shape, action_dim).to(device)
critic2 = Critic(img_shape, action_dim).to(device)
actor_opt = optim.Adam(actor.parameters(), lr=3e-4)
critic1_opt = optim.Adam(critic1.parameters(), lr=3e-4)
critic2_opt = optim.Adam(critic2.parameters(), lr=3e-4)

replay = ReplayBuffer(100000, img_shape, action_dim)

gamma = 0.99
alpha = 0.2  # entropy regularization
batch_size = 64

for episode in range(1000):
    state = env.reset()
    done = False
    while not done:
        img = torch.tensor(state.transpose(2, 0, 1)).unsqueeze(0).to(device, dtype=torch.float32)
        action = actor(img).squeeze().detach().cpu().numpy()
        next_state, reward, done = env.step(action)

        replay.add(state, action, reward, next_state, done)
        state = next_state

        if replay.size > batch_size:
            s, a, r, s2, d = replay.sample(batch_size)
            s = torch.tensor(s.transpose(0, 3, 1, 2)).to(device)
            a = torch.tensor(a).to(device)
            r = torch.tensor(r).to(device)
            s2 = torch.tensor(s2.transpose(0, 3, 1, 2)).to(device)
            d = torch.tensor(d).to(device)

            with torch.no_grad():
                next_action = actor(s2)
                target_q1 = critic1(s2, next_action)
                target_q2 = critic2(s2, next_action)
                target_q = r + gamma * (1 - d) * torch.min(target_q1, target_q2)

            # Critic update
            q1 = critic1(s, a)
            q2 = critic2(s, a)
            critic1_loss = F.mse_loss(q1, target_q)
            critic2_loss = F.mse_loss(q2, target_q)

            critic1_opt.zero_grad()
            critic1_loss.backward()
            critic1_opt.step()

            critic2_opt.zero_grad()
            critic2_loss.backward()
            critic2_opt.step()

            # Actor update
            actor_loss = -critic1(s, actor(s)).mean()
            actor_opt.zero_grad()
            actor_loss.backward()
            actor_opt.step()
