### Create a Neural Network with pytorch for RL

In [4]:
def choose_action(observation, action_space, epsilon, explotation=False):
    global steps_done
    sample = random.random()
    
    steps_done += 1

    if explotation:
        with torch.no_grad():
            return policy_net(observation).max(1)[1].view(1, 1)
    elif sample > epsilon:
        with torch.no_grad():
            return policy_net(observation).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[action_space.sample()]], device=device, dtype=torch.long)

def learn():
    """
    Function that performs a learning step using DQN
    """
    if len(memory) < batch_size:
        return

    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)    

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    expected_state_action_values = (next_state_values * gamma) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 200)
    optimizer.step()


In [None]:
env_human = gym.make("LunarLander-v3", render_mode="human")

for episode in range(10):
    state, info = env_human.reset()
    frame = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    episode_over = False
    while not episode_over:
        action = choose_action(frame, env.action_space, epsilon, explotation=True)
        observation, reward, terminated, truncated, info = env_human.step(action.item())

        frame = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        if terminated or truncated:
            episode_over = True

env_human.close()