In [51]:
import traceback
import random
from logging import getLogger

from fastapi import (APIRouter, Depends, FastAPI, HTTPException, WebSocket,
                     WebSocketDisconnect, status)
from fastapi.websockets import WebSocketState
import uvicorn
import nest_asyncio

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from collections import deque

In [52]:
class Policy(nn.Module):
    def __init__(self, N_FEATURES, N_ACTIONS, H_SIZE):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(N_FEATURES, H_SIZE)
        self.fc2 = nn.Linear(H_SIZE, H_SIZE)
        self.fc3 = nn.Linear(H_SIZE, H_SIZE)
        self.fc4 = nn.Linear(H_SIZE, N_ACTIONS)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.softmax(x, dim=1)

    def select_action(self, state, device):
        if np.random.uniform() < 0.5:
            # Explore: select a random action
            # action = np.random.randint(0, n_actions)
            if state[0] > state[2]:
                action =np.random.randint(0, 3)
            elif state[0] < state[2]:
                action = np.random.randint(6, 9)
            else:
                action = np.random.randint(0, 9)
            return action 
        else:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
                probs = self.forward(state)
                action = probs.argmax(dim=1).item()
                print(probs)
                return action

class DQN():
    torch.manual_seed(50)
    BINS = 10
    ANGLE_BINS = 2
    GAMMA = 0.99
    LR = 5e-4
    N_FEATURES = 8
    H_SIZE = 256

    ACTIONS = {0: (-1, -1), 1: (-1, 0), 2: (-1, 1), 3: (0, -1), 4: (0, 0), 5: (0, 1), 6: (1, -1), 7: (1, 0), 8: (1, 1)}
    N_ACTIONS = len(ACTIONS)

    def __init__(self, game_step):
        self.game_step = game_step
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.steps = 0
        self.last_player1_score = 0
        self.last_player2_score = 0
        self.last_player1_hit = 0
        self.last_player2_hit = 0
        self.score_diff = 0
        self.total_rewards = 0

    def process_data(self, data):
        player1 = data['player1']
        player2 = data['player2']
        ball = data['ball']

        player1_X = (player1['x'] // self.BINS)
        player1_angle = int(player1['angle'] // self.ANGLE_BINS)
        player2_X = (player2['x'] // self.BINS)
        player2_angle = int(player2['angle'] // self.ANGLE_BINS)
        ball_X = (ball['x'] // self.BINS)
        ball_Y = (ball['y'] // self.BINS)
        ball_velocity_X = int((ball['velocity']['x'] * 100) // 2)
        ball_velocity_Y = int((ball['velocity']['y'] * 100) // 2)

        state = np.array([player2_X, player2_angle, ball_X, ball_Y, ball_velocity_X, ball_velocity_Y, player1_X, player1_angle])
        reward = self.get_reward(player1, player2, ball['y'])
        done = self.check_done(player1, player2)
        self.update_stats(player1, player2)
        return state, reward, done

    async def start(self, policy: Policy, n_episodes):
        scores = []
        for i in range(1, n_episodes+1):
            rewards = []
            self.reset_stat()
            data = await self.game_step({"game": "reset"})
            state = self.process_data(data)[0]
            while True:
                action = policy.select_action(state, self.device)
                data = await self.game_step({"position": self.ACTIONS[action][0], "angle": self.ACTIONS[action][1]})
                state, reward, done  = self.process_data(data)
                rewards.append(reward)
                if done:
                    break

            score = sum(rewards)
            print("Score", score)

            scores.append(score)

        return scores

    async def get_result(self, model_path):
        policy = Policy(self.N_FEATURES, self.N_ACTIONS, self.H_SIZE).to(self.device)
        if os.path.exists(model_path):
            print("loaded")
            policy.load_state_dict(torch.load(model_path))
            scores = await self.start(policy,1000)
        else:
            print("model not exists")

    def check_done(self, player1, player2):
        if player1["score"] - player2["score"] != self.score_diff:
            self.score_diff = player1["score"] - player2["score"]
            return True
        return False

    def get_reward(self, player1, player2, ball_y):
        if player1["score"] - self.last_player1_score >= 1:
            return -8
        if player1["hit"] - self.last_player1_hit >= 1:
            return 8
        if player2["score"] - self.last_player2_score >= 1:
            return 18
        if player2["hit"] - self.last_player2_hit >= 1:
            return 0
        if ball_y > 225:
            return -0.01
        return 0

    def update_stats(self, player1, player2):
        self.last_player1_score = player1["score"]
        self.last_player1_hit = player1["hit"]
        self.last_player2_score = player2["score"]
        self.last_player2_hit = player2["hit"]

    def reset_stat(self):
        self.last_player1_score = 0
        self.last_player1_hit = 0
        self.last_player2_score = 0
        self.last_player2_hit = 0
        self.total_rewards = 0


In [53]:
result = []

log = getLogger(__name__)

game = APIRouter()

@game.websocket("/") # type: ignore
async def socket(websocket: WebSocket):

    await websocket.accept()
    
    while websocket.client_state == WebSocketState.CONNECTED:
        try:
            async def game_step(response):
                await websocket.send_json(response)
                data = await websocket.receive_json()
                return data
            dqn = DQN(game_step)
            result = await dqn.get_result("eval.pth")
            break
        except WebSocketDisconnect as e:
            log.info(f"Disconnected")
            return
        except Exception as e:
            log.error(f"error: {traceback.format_exc()}")
            return



In [54]:
app = FastAPI()
app.include_router(game)
nest_asyncio.apply()
uvicorn.run(
    app,
    host="0.0.0.0",
    port=8082,
    log_level="info",
    access_log=True,
    use_colors=True,
    proxy_headers=True,
)

[32mINFO[0m:     Started server process [[36m34932[0m]
[32mINFO[0m:     Waiting for application startup.
[32mINFO[0m:     Application startup complete.
[32mINFO[0m:     Uvicorn running on [1mhttp://0.0.0.0:8082[0m (Press CTRL+C to quit)
[32mINFO[0m:     ('127.0.0.1', 60682) - "WebSocket /" [accepted]
[32mINFO[0m:     connection open


loaded
tensor([[2.8626e-04, 1.1666e-04, 2.0151e-04, 3.2476e-04, 9.9827e-01, 1.6453e-04,
         2.1134e-04, 9.8711e-05, 3.2611e-04]], device='cuda:0')
tensor([[2.8626e-04, 1.1666e-04, 2.0151e-04, 3.2476e-04, 9.9827e-01, 1.6453e-04,
         2.1134e-04, 9.8711e-05, 3.2611e-04]], device='cuda:0')
tensor([[2.8567e-04, 1.1581e-04, 1.9793e-04, 3.2147e-04, 9.9829e-01, 1.5845e-04,
         2.1023e-04, 9.8474e-05, 3.2400e-04]], device='cuda:0')
tensor([[2.8567e-04, 1.1581e-04, 1.9793e-04, 3.2147e-04, 9.9829e-01, 1.5845e-04,
         2.1023e-04, 9.8474e-05, 3.2400e-04]], device='cuda:0')
tensor([[2.8567e-04, 1.1581e-04, 1.9793e-04, 3.2147e-04, 9.9829e-01, 1.5845e-04,
         2.1023e-04, 9.8474e-05, 3.2400e-04]], device='cuda:0')
tensor([[2.9427e-04, 1.1986e-04, 1.9990e-04, 3.3384e-04, 9.9825e-01, 1.5924e-04,
         2.1588e-04, 1.0032e-04, 3.2825e-04]], device='cuda:0')
tensor([[2.9427e-04, 1.1986e-04, 1.9990e-04, 3.3384e-04, 9.9825e-01, 1.5924e-04,
         2.1588e-04, 1.0032e-04, 3.2825e-0

In [None]:
print(result)