In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import random
from einops.layers.torch import Rearrange
from einops import rearrange

from typing import Any, Dict, Tuple, Optional
from game_mechanics import (
    ChooseMoveCheckpoint,
    ShooterEnv,
    checkpoint_model,
    choose_move_randomly,
    human_player,
    load_network,
    play_shooter,
    save_network,
)
from tqdm.notebook import tqdm

from functools import partial
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
from copy import deepcopy
from functools import partial

from utils import *
%load_ext autoreload
%autoreload 2

pygame 2.1.2 (SDL 2.0.16, Python 3.10.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def choose_move(state, neural_network: nn.Module) -> int:
    probs = neural_network(state.to(device))
    probs = probs.cpu().detach().numpy()
    move = np.random.choice(range(6), p=probs)
    return int(move)

In [4]:
env = ShooterEnv(opponent_choose_move=choose_move_randomly, 
                 game_speed_multiplier=100_000,
                 include_barriers=False,
                 half_sized_game=True)

In [7]:
state, reward, done, info = env.reset()

In [38]:
def angle(point1, point2):
    distance = np.linalg.norm(point1 - point2)
    sin = (point1[0] - point2[0]) / distance
    cos = (point1[1] - point2[1]) / distance
    return [sin.item(), cos.item()]
    
def distance(point1, point2):
    return np.linalg.norm(point1 - point2)

def add_features(state):
    distance_ships = distance(state[:2], state[4:6])
    distance_ships
    angles_ships = angle(state[:2], state[4:6])
    angles_ships

    distance_bullets1_ship2 = [distance(state[4:6], state[8:10]), distance(state[4:6], state[12:14])]
    if (state[8:10] == torch.as_tensor([-1, -1])).all().item():
        distance_bullets1_ship2[0] = 10
    if (state[12:14] == torch.as_tensor([-1, -1])).all().item():
        distance_bullets1_ship2[1] = 10
    distance_bullets1_ship2

    distance_bullets2_ship1 = [distance(state[0:2], state[16:18]), distance(state[0:2], state[20:22])]
    if (state[16:18] == torch.as_tensor([-1, -1])).all().item():
        distance_bullets1_ship2[0] = 10
    if (state[20:22] == torch.as_tensor([-1, -1])).all().item():
        distance_bullets1_ship2[1] = 10

    bullets_fired = [1,1,1,1]
    if (state[8:10] == torch.as_tensor([-1, -1])).all().item():
        bullets_fired[0] = 0
    if (state[12:14] == torch.as_tensor([-1, -1])).all().item():
        bullets_fired[1] = 0
    if (state[16:18] == torch.as_tensor([-1, -1])).all().item():
        bullets_fired[2] = 0
    if (state[20:22] == torch.as_tensor([-1, -1])).all().item():
        bullets_fired[3] = 0

    features = [distance_ships] + angles_ships + distance_bullets1_ship2 + distance_bullets2_ship1 + bullets_fired
    features = torch.as_tensor(features, dtype=torch.float32)
    return torch.cat([state, features])

In [None]:
add

In [34]:

# distance(state[:2], state[4:6])

In [35]:
features

[1.5999999, 0.0, -1.0, 10, 10, 1.0189416, 1.0189416, 0, 0, 0, 0]

In [39]:
add_features(state).shape

torch.Size([35])

In [31]:
angles_ships

(tensor(0.), tensor(-1.))

In [27]:
(state[8:10] == torch.as_tensor([-1, -1])).all().item()

True

In [21]:
state[8:10]

tensor([-1., -1.])

In [8]:
state

tensor([ 0.0000e+00, -8.0444e-01,  1.0000e+00,  6.1232e-17,  0.0000e+00,
         7.9556e-01,  0.0000e+00,  1.0000e+00, -1.0000e+00, -1.0000e+00,
         0.0000e+00,  1.0000e+00, -1.0000e+00, -1.0000e+00,  0.0000e+00,
         1.0000e+00, -1.0000e+00, -1.0000e+00,  0.0000e+00,  1.0000e+00,
        -1.0000e+00, -1.0000e+00,  0.0000e+00,  1.0000e+00])