In [19]:
from game import SnakeGame
from utils.rewards import manhattan_reward, naive_reward, advanced_naive_reward, euclidean_reward
from agents.QLearning import QLearning
from agents.DeepQLearning import DeepQLearning
from utils.utils import benchmark, plot, play_snake
from tqdm import tqdm
import pygame
import seaborn as sns
import matplotlib.pyplot as plt

In [20]:
W = 200
H = 200

In [21]:
game = SnakeGame(200, 200)
Qmodel = QLearning(game)
DQmodel = DeepQLearning(game)


In [22]:
reward = advanced_naive_reward

In [6]:
from collections import defaultdict

steps = [10, 100, 1000, 5000, 10000, 50000]
models = dict()

benchmarks = defaultdict(dict)
for step in steps:
    models[step] = {i: 0 for i in range(10)}
    for i in range(10):
        Qmodel = QLearning(game)
        scores = Qmodel.train(step, float('inf'), reward)
        scores, average = benchmark(Qmodel, W, H)
        benchmarks[step][i] = average
        models[step][i] = Qmodel


100%|██████████| 10/10 [00:00<00:00, 3163.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1241.52it/s]
100%|██████████| 10/10 [00:00<00:00, 8533.68it/s]
100%|██████████| 1000/1000 [00:00<00:00, 9835.81it/s]
100%|██████████| 10/10 [00:00<00:00, 8667.71it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1567.22it/s]
100%|██████████| 10/10 [00:00<00:00, 9345.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 11564.18it/s]
100%|██████████| 10/10 [00:00<00:00, 8495.65it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1283.20it/s]
100%|██████████| 10/10 [00:00<00:00, 8533.68it/s]
100%|██████████| 1000/1000 [00:00<00:00, 8300.13it/s]
100%|██████████| 10/10 [00:00<00:00, 11161.00it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1267.49it/s]
100%|██████████| 10/10 [00:00<00:00, 9198.04it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2968.49it/s]
100%|██████████| 10/10 [00:00<00:00, 13336.42it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1301.44it/s]
100%|██████████| 10/10 [00:00<00:00, 11808.29it/s]
100%|█████

In [8]:
avg = list()
for step in steps:
    print(step, list(benchmarks[step].values()))
    avg.append(sum(benchmarks[step].values())/10)

plot(avg)

10 [1.0, 1.515, 1.1368421052631579, 1.579, 1.0, 1.545, 1.125, 0.8880706921944035, 1.0909090909090908, 1.926]
100 [3.609896432681243, 6.986, 2.536213468869123, 1.4425228891149542, 2.4750254841997963, 3.9507211538461537, 1.7233009708737863, 1.3064971751412429, 2.840686274509804, 7.394]
1000 [17.433, 21.063, 15.712778429073857, 25.36, 19.755, 17.373, 16.007, 19.789, 15.631, 20.408]
5000 [19.443, 23.602, 19.293, 19.117, 20.659, 19.351, 20.343, 23.225, 24.403, 19.647]
10000 [19.446, 25.417, 19.644, 16.682, 20.56, 19.451, 19.759, 23.709, 25.336, 23.335]
50000 [23.618, 21.374, 23.181, 19.675, 20.364948453608246, 16.503, 17.333, 21.484, 24.985, 17.258]


In [24]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Data for each epoch
data = {
    "10": [1.0, 1.515, 1.1368421052631579, 1.579, 1.0, 1.545, 1.125, 0.8880706921944035, 1.0909090909090908, 1.926],
    "100": [3.609896432681243, 6.986, 2.536213468869123, 1.4425228891149542, 2.4750254841997963, 3.9507211538461537, 1.7233009708737863, 1.3064971751412429, 2.840686274509804, 7.394],
    "1000": [17.433, 21.063, 15.712778429073857, 25.36, 19.755, 17.373, 16.007, 19.789, 15.631, 20.408],
    "5000": [19.443, 23.602, 19.293, 19.117, 20.659, 19.351, 20.343, 23.225, 24.403, 19.647],
    "10000": [19.446, 25.417, 19.644, 16.682, 20.56, 19.451, 19.759, 23.709, 25.336, 23.335],
    "50000": [23.618, 21.374, 23.181, 19.675, 20.364948453608246, 16.503, 17.333, 21.484, 24.985, 17.258]
}

# Create subplots
fig = make_subplots(rows=2, cols=3, subplot_titles=("Games played 10", "Games played 100", "Games played 1000", "Games played 5000", "Games played 10000", "Games played 50000"))

# Add traces for each subplot
epochs = ["10", "100", "1000", "5000", "10000", "50000"]
for i, epoch in enumerate(epochs):
    row = i // 3 + 1
    col = i % 3 + 1
    fig.add_trace(go.Scatter(x=list(range(1, 11)), y=data[epoch], mode='lines+markers', name=f'Epochs {epoch}'), row=row, col=col)

# Update layout to set the same range for all y-axes and add labels
fig.update_yaxes(range=[0, 30], title_text="Average Score across 1000 games")
fig.update_xaxes(title_text="Agent")

# Update layout
fig.update_layout(height=800, width=1200, title_text="", showlegend=False)

# Show plot
fig.show()

# Calculate average scores for each epoch
avg = [sum(data[epoch]) / 10 for epoch in data]

# Create a new plot for the average scores
fig = go.Figure()

# Add trace for the average scores
fig.add_trace(go.Scatter(x=list(data.keys()), y=avg, mode='lines+markers', name='Average Score'))

# Update layout to add labels
fig.update_layout(
    title="Average Scores across all agents",
    title_x=0.5,
    xaxis_title="Games played",
    yaxis_title="Average Score",
    height=600,
    width=800
)

# Show plot
fig.show()

In [16]:
step = 5000
reward = manhattan_reward

models = list()
avg_list = list()
for i in range(100):
    print(i)
    Qmodel = QLearning(game)
    scores = Qmodel.train(step, float('inf'), reward)
    scores, average = benchmark(Qmodel, W, H)
    avg_list.append(average)
    models.append(Qmodel)





0


100%|██████████| 5000/5000 [00:00<00:00, 25561.06it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42484.72it/s]


1


100%|██████████| 5000/5000 [00:00<00:00, 28288.86it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41211.54it/s]


2


100%|██████████| 5000/5000 [00:00<00:00, 27908.48it/s]
100%|██████████| 1000/1000 [00:00<00:00, 36607.18it/s]


3


100%|██████████| 5000/5000 [00:00<00:00, 28501.19it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41690.81it/s]


4


100%|██████████| 5000/5000 [00:00<00:00, 27599.44it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41722.33it/s]


5


100%|██████████| 5000/5000 [00:00<00:00, 28024.69it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40261.32it/s]


6


100%|██████████| 5000/5000 [00:00<00:00, 28053.48it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41187.66it/s]


7


100%|██████████| 5000/5000 [00:00<00:00, 27810.11it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40346.92it/s]


8


100%|██████████| 5000/5000 [00:00<00:00, 28156.00it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42185.18it/s]


9


100%|██████████| 5000/5000 [00:00<00:00, 28254.33it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41650.65it/s]


10


100%|██████████| 5000/5000 [00:00<00:00, 27077.18it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41339.48it/s]


11


100%|██████████| 5000/5000 [00:00<00:00, 28055.21it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41230.17it/s]


12


100%|██████████| 5000/5000 [00:00<00:00, 27922.23it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42138.56it/s]


13


100%|██████████| 5000/5000 [00:00<00:00, 28024.20it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40564.65it/s]


14


100%|██████████| 5000/5000 [00:00<00:00, 27697.13it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41237.47it/s]


15


100%|██████████| 5000/5000 [00:00<00:00, 28142.25it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41542.16it/s]


16


100%|██████████| 5000/5000 [00:00<00:00, 27771.29it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40281.82it/s]


17


100%|██████████| 5000/5000 [00:00<00:00, 27531.83it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41325.23it/s]


18


100%|██████████| 5000/5000 [00:00<00:00, 26683.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40389.26it/s]


19


100%|██████████| 5000/5000 [00:00<00:00, 28259.58it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42506.68it/s]


20


100%|██████████| 5000/5000 [00:00<00:00, 28814.00it/s]
100%|██████████| 1000/1000 [00:00<00:00, 43057.36it/s]


21


100%|██████████| 5000/5000 [00:00<00:00, 19786.54it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42286.83it/s]


22


100%|██████████| 5000/5000 [00:00<00:00, 28902.71it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42306.45it/s]


23


100%|██████████| 5000/5000 [00:00<00:00, 28524.18it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41738.11it/s]


24


100%|██████████| 5000/5000 [00:00<00:00, 27914.87it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41366.39it/s]


25


100%|██████████| 5000/5000 [00:00<00:00, 26669.92it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41150.89it/s]


26


100%|██████████| 5000/5000 [00:00<00:00, 28397.65it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41548.74it/s]


27


100%|██████████| 5000/5000 [00:00<00:00, 28353.99it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41412.55it/s]


28


100%|██████████| 5000/5000 [00:00<00:00, 27931.86it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41319.12it/s]


29


100%|██████████| 5000/5000 [00:00<00:00, 27767.29it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41121.43it/s]


30


100%|██████████| 5000/5000 [00:00<00:00, 27467.68it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41339.48it/s]


31


100%|██████████| 5000/5000 [00:00<00:00, 27656.55it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41120.22it/s]


32


100%|██████████| 5000/5000 [00:00<00:00, 28069.85it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40829.62it/s]


33


100%|██████████| 5000/5000 [00:00<00:00, 26728.80it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41098.06it/s]


34


100%|██████████| 5000/5000 [00:00<00:00, 27997.42it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41788.84it/s]


35


100%|██████████| 5000/5000 [00:00<00:00, 27352.90it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41388.84it/s]


36


100%|██████████| 5000/5000 [00:00<00:00, 27471.78it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41164.21it/s]


37


100%|██████████| 5000/5000 [00:00<00:00, 27389.91it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41711.96it/s]


38


100%|██████████| 5000/5000 [00:00<00:00, 27935.32it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40513.72it/s]


39


100%|██████████| 5000/5000 [00:00<00:00, 28158.80it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40872.59it/s]


40


100%|██████████| 5000/5000 [00:00<00:00, 27292.10it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41495.31it/s]


41


100%|██████████| 5000/5000 [00:00<00:00, 26904.88it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42009.84it/s]


42


100%|██████████| 5000/5000 [00:00<00:00, 27854.06it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42406.54it/s]


43


100%|██████████| 5000/5000 [00:00<00:00, 27754.64it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41069.49it/s]


44


100%|██████████| 5000/5000 [00:00<00:00, 28075.71it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40848.70it/s]


45


100%|██████████| 5000/5000 [00:00<00:00, 27982.96it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42590.85it/s]


46


100%|██████████| 5000/5000 [00:00<00:00, 27950.14it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41165.83it/s]


47


100%|██████████| 5000/5000 [00:00<00:00, 26122.94it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41424.82it/s]


48


100%|██████████| 5000/5000 [00:00<00:00, 28113.58it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41792.17it/s]


49


100%|██████████| 5000/5000 [00:00<00:00, 27959.68it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41962.34it/s]


50


100%|██████████| 5000/5000 [00:00<00:00, 27920.18it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41224.09it/s]


51


100%|██████████| 5000/5000 [00:00<00:00, 27366.96it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40698.48it/s]


52


100%|██████████| 5000/5000 [00:00<00:00, 28267.58it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41472.74it/s]


53


100%|██████████| 5000/5000 [00:00<00:00, 28319.11it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42312.43it/s]


54


100%|██████████| 5000/5000 [00:00<00:00, 27859.02it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42112.75it/s]


55


100%|██████████| 5000/5000 [00:00<00:00, 26651.31it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42013.20it/s]


56


100%|██████████| 5000/5000 [00:00<00:00, 28247.17it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42224.25it/s]


57


100%|██████████| 5000/5000 [00:00<00:00, 27948.43it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41237.06it/s]


58


100%|██████████| 5000/5000 [00:00<00:00, 27789.99it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42006.05it/s]


59


100%|██████████| 5000/5000 [00:00<00:00, 28339.39it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42656.69it/s]


60


100%|██████████| 5000/5000 [00:00<00:00, 27993.16it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40392.38it/s]


61


100%|██████████| 5000/5000 [00:00<00:00, 27190.47it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40818.09it/s]


62


100%|██████████| 5000/5000 [00:00<00:00, 26367.86it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41142.01it/s]


63


100%|██████████| 5000/5000 [00:00<00:00, 27391.73it/s]
100%|██████████| 1000/1000 [00:00<00:00, 39394.97it/s]


64


100%|██████████| 5000/5000 [00:00<00:00, 27280.17it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41910.77it/s]


65


100%|██████████| 5000/5000 [00:00<00:00, 27832.33it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42121.64it/s]


66


100%|██████████| 5000/5000 [00:00<00:00, 27997.08it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41781.35it/s]


67


100%|██████████| 5000/5000 [00:00<00:00, 28063.73it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42141.10it/s]


68


100%|██████████| 5000/5000 [00:00<00:00, 27571.94it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41086.39it/s]


69


100%|██████████| 5000/5000 [00:00<00:00, 26566.91it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41135.55it/s]


70


100%|██████████| 5000/5000 [00:00<00:00, 27851.88it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40494.16it/s]


71


100%|██████████| 5000/5000 [00:00<00:00, 28029.90it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41731.46it/s]


72


100%|██████████| 5000/5000 [00:00<00:00, 27614.89it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41082.77it/s]


73


100%|██████████| 5000/5000 [00:00<00:00, 28008.71it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40877.77it/s]


74


100%|██████████| 5000/5000 [00:00<00:00, 27199.22it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41250.85it/s]


75


100%|██████████| 5000/5000 [00:00<00:00, 27442.81it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40078.58it/s]


76


100%|██████████| 5000/5000 [00:00<00:00, 27516.95it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41874.37it/s]


77


100%|██████████| 5000/5000 [00:00<00:00, 26718.99it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42151.69it/s]


78


100%|██████████| 5000/5000 [00:00<00:00, 27631.19it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40522.33it/s]


79


100%|██████████| 5000/5000 [00:00<00:00, 26822.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40566.61it/s]


80


100%|██████████| 5000/5000 [00:00<00:00, 27297.03it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40581.14it/s]


81


100%|██████████| 5000/5000 [00:00<00:00, 27459.55it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40883.34it/s]


82


100%|██████████| 5000/5000 [00:00<00:00, 28192.19it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41944.72it/s]


83


100%|██████████| 5000/5000 [00:00<00:00, 27767.91it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40938.41it/s]


84


100%|██████████| 5000/5000 [00:00<00:00, 27583.29it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41727.31it/s]


85


100%|██████████| 5000/5000 [00:00<00:00, 25434.92it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41552.45it/s]


86


100%|██████████| 5000/5000 [00:00<00:00, 27180.11it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40864.62it/s]


87


100%|██████████| 5000/5000 [00:00<00:00, 27101.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41651.07it/s]


88


100%|██████████| 5000/5000 [00:00<00:00, 27729.73it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41990.49it/s]


89


100%|██████████| 5000/5000 [00:00<00:00, 23094.71it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41159.37it/s]


90


100%|██████████| 5000/5000 [00:00<00:00, 28132.25it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41563.15it/s]


91


100%|██████████| 5000/5000 [00:00<00:00, 27468.76it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41303.66it/s]


92


100%|██████████| 5000/5000 [00:00<00:00, 22447.70it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40006.33it/s]


93


100%|██████████| 5000/5000 [00:00<00:00, 27590.15it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41124.66it/s]


94


100%|██████████| 5000/5000 [00:00<00:00, 27901.80it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40981.61it/s]


95


100%|██████████| 5000/5000 [00:00<00:00, 27967.66it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40963.20it/s]


96


100%|██████████| 5000/5000 [00:00<00:00, 28001.30it/s]
100%|██████████| 1000/1000 [00:00<00:00, 42213.63it/s]


97


100%|██████████| 5000/5000 [00:00<00:00, 28132.55it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41204.65it/s]


98


100%|██████████| 5000/5000 [00:00<00:00, 27815.61it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41552.45it/s]


99


100%|██████████| 5000/5000 [00:00<00:00, 26213.12it/s]
100%|██████████| 1000/1000 [00:00<00:00, 41196.56it/s]


In [18]:
avg_list_adv_naive = [23.199, 21.174, 24.369, 19.39632277834525, 20.436, 18.573, 19.332988624612202, 23.135, 17.857917570498916, 19.215, 20.772, 17.325, 16.199, 23.523, 20.252, 23.558, 19.273, 17.295362903225808, 20.816, 24.892892892892892, 23.123, 19.48, 16.701, 24.469, 23.286, 22.576, 25.894, 23.23, 16.92581300813008, 19.468, 19.193, 12.735106382978723, 15.922, 17.335, 17.134, 25.106, 20.327, 23.383, 25.72, 20.819, 21.119, 25.373, 19.53, 20.781, 22.974, 25.264, 19.005, 23.246, 19.976, 23.6, 17.135, 18.013, 20.697, 20.33811475409836, 20.056, 15.919, 20.2672147995889, 23.308, 19.668, 16.863, 23.175, 22.823, 16.479, 19.133, 23.75, 20.236, 17.386, 20.407826982492278, 20.234, 20.713, 30.853, 20.267, 19.556, 20.740277777777777, 23.402, 23.123, 23.01226158038147, 23.246, 19.828, 16.285, 17.520120724346075, 22.668, 20.444, 23.354, 31.438, 23.733, 18.834029227557412, 19.555, 19.295, 20.897, 19.444, 22.91, 23.236, 25.192, 19.525, 21.723, 21.265, 19.418, 19.271, 21.077]
plot(avg_list_adv_naive, title="Average score for 100 models trained for 5000 steps (advanced naive)", xlabel="Agent", ylabel="Average Score")

avg_list_naive = [23.762, 19.692, 25.653, 19.106, 20.519, 19.014, 25.438, 20.23, 20.249, 27.540860215053762, 19.195, 19.368, 30.934, 17.234, 20.266, 21.293, 20.825, 19.426, 20.236, 20.384, 26.135, 25.441, 16.138, 20.452, 23.402, 20.82059123343527, 19.454, 20.395, 20.437, 20.096, 25.985, 19.283, 21.179, 25.64, 20.3480947476828, 22.094, 28.727, 19.821, 23.82, 30.597, 21.647, 22.573, 19.053, 25.604, 25.487, 16.894, 23.397, 23.174, 21.309, 23.385, 19.761, 21.825, 28.293, 20.928, 25.367, 20.502, 25.326, 20.487, 20.405, 23.664, 21.57, 24.922, 26.199, 18.749, 20.326, 24.837, 29.068, 22.18, 24.448, 26.324, 21.413, 19.283, 20.24, 21.438, 25.268, 23.285, 20.079, 20.131, 23.755, 23.118, 23.306, 19.444, 25.328, 17.423, 19.622, 23.196, 22.571, 18.25, 17.148, 25.46946946946947, 17.473, 22.005, 25.324, 22.25, 14.33838383838384, 24.46, 23.395, 13.669, 21.472, 17.0668016194332]
plot(avg_list_naive, title="Average score for 100 models trained for 5000 steps (naive)", xlabel="Agent", ylabel="Average Score")

print(avg_list)


[0.0, 0.042, 0.05, 0.051, 0.0, 0.0, 0.0, 0.046, 0.0, 0.0, 0.051, 0.046, 0.0, 0.052, 0.042, 0.0, 0.046, 0.042, 0.0, 0.042, 0.0, 0.0, 0.0, 0.043, 0.042, 0.035, 0.0, 0.042, 0.036, 0.038, 0.046, 0.0, 0.046, 0.0, 0.049, 0.0, 0.05, 0.039, 0.043, 0.0, 0.0, 0.052, 0.0, 0.037, 0.036, 0.049, 0.044, 0.0, 0.056, 0.044, 0.042, 0.048, 0.038, 0.041, 0.0, 0.041, 0.0, 0.035, 0.0, 0.044, 0.037, 0.0, 0.0, 0.047, 0.033, 0.0, 0.0, 0.042, 0.04, 0.037, 0.0, 0.0, 0.041, 0.045, 0.04, 0.0, 0.053, 0.044, 0.0, 0.05, 0.0, 0.0, 0.0, 0.0, 0.041, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.048, 0.043, 0.046, 0.048, 0.0, 0.037, 0.045, 0.031, 0.039]


In [14]:
Qmodel = models[avg_list.index(max(avg_list))]
scores, avg_score = benchmark(Qmodel, 640, 400)

print(avg_score)

100%|██████████| 1000/1000 [00:27<00:00, 36.95it/s]

82.6189024390244





In [13]:
filename = 'policies/QL/best_naive_policy.txt'
Qmodel.save_model(filename)

In [7]:
filename = 'policies/QL/best_advanced_naive_policy.txt'
Qmodel.load_model(filename)

In [8]:
scores, avg_score = benchmark(Qmodel, 640, 400)

print(avg_score)

100%|██████████| 1000/1000 [00:28<00:00, 34.87it/s]

85.13374485596708





In [19]:
DQmodel.train(120, float('inf'), reward)

100%|██████████| 120/120 [00:09<00:00, 13.06it/s]

20





DQN(
  (fc1): Linear(in_features=6, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=3, bias=True)
)

In [45]:
from collections import defaultdict


reward = advanced_naive_reward

steps = [5, 10, 25, 40, 75, 150]
models = dict()

benchmarks = defaultdict(dict)
for step in steps:
    models[step] = {i: 0 for i in range(10)}
    for i in range(10):
        DQmodel = DeepQLearning(game)
        scores = DQmodel.train(step, float('inf'), reward)
        scores, average = benchmark(DQmodel, W, H)
        benchmarks[step][i] = average
        models[step][i] = Qmodel


100%|██████████| 5/5 [00:00<00:00, 24.77it/s]


1


100%|██████████| 1000/1000 [00:09<00:00, 100.75it/s]
100%|██████████| 5/5 [00:00<00:00, 56.33it/s]


2


100%|██████████| 1000/1000 [00:02<00:00, 398.74it/s]
100%|██████████| 5/5 [00:00<00:00, 45.35it/s]


1


100%|██████████| 1000/1000 [00:00<00:00, 1825.98it/s]
100%|██████████| 5/5 [00:00<00:00, 232.98it/s]


1


100%|██████████| 1000/1000 [00:00<00:00, 1269.38it/s]
100%|██████████| 5/5 [00:00<00:00, 168.55it/s]


1


100%|██████████| 1000/1000 [00:02<00:00, 365.60it/s]
100%|██████████| 5/5 [00:00<00:00, 71.22it/s]


1


100%|██████████| 1000/1000 [00:00<00:00, 1334.41it/s]
100%|██████████| 5/5 [00:00<00:00, 583.09it/s]


1


100%|██████████| 1000/1000 [00:08<00:00, 112.09it/s]
100%|██████████| 5/5 [00:00<00:00, 137.02it/s]


1


100%|██████████| 1000/1000 [00:01<00:00, 888.81it/s]
100%|██████████| 5/5 [00:00<00:00, 156.89it/s]


1


100%|██████████| 1000/1000 [00:09<00:00, 110.00it/s]
100%|██████████| 5/5 [00:00<00:00, 15098.29it/s]


1


100%|██████████| 1000/1000 [00:09<00:00, 106.90it/s]
100%|██████████| 10/10 [00:00<00:00, 52.09it/s]


1


100%|██████████| 1000/1000 [00:02<00:00, 388.37it/s]
100%|██████████| 10/10 [00:00<00:00, 83.75it/s]


2


100%|██████████| 1000/1000 [00:06<00:00, 155.50it/s]
100%|██████████| 10/10 [00:00<00:00, 50.83it/s]


2


100%|██████████| 1000/1000 [00:09<00:00, 110.03it/s]
100%|██████████| 10/10 [00:00<00:00, 101.41it/s]


1


100%|██████████| 1000/1000 [00:06<00:00, 163.11it/s]
100%|██████████| 10/10 [00:00<00:00, 60.13it/s]


2


100%|██████████| 1000/1000 [00:04<00:00, 229.58it/s]
100%|██████████| 10/10 [00:00<00:00, 26.42it/s]


2


100%|██████████| 1000/1000 [00:10<00:00, 97.37it/s]
100%|██████████| 10/10 [00:00<00:00, 60.70it/s]


2


100%|██████████| 1000/1000 [00:07<00:00, 134.30it/s]
100%|██████████| 10/10 [00:00<00:00, 41.27it/s]


1


100%|██████████| 1000/1000 [00:09<00:00, 106.14it/s]
100%|██████████| 10/10 [00:00<00:00, 51.68it/s]


2


100%|██████████| 1000/1000 [00:10<00:00, 97.63it/s]
100%|██████████| 10/10 [00:00<00:00, 56.97it/s]


2


100%|██████████| 1000/1000 [00:05<00:00, 187.25it/s]
100%|██████████| 25/25 [00:00<00:00, 31.93it/s]


5


100%|██████████| 1000/1000 [00:10<00:00, 93.50it/s]
100%|██████████| 25/25 [00:01<00:00, 18.78it/s]


22


100%|██████████| 1000/1000 [00:07<00:00, 141.47it/s]
100%|██████████| 25/25 [00:01<00:00, 17.11it/s]


8


100%|██████████| 1000/1000 [00:10<00:00, 97.50it/s]
100%|██████████| 25/25 [00:00<00:00, 28.36it/s]


4


100%|██████████| 1000/1000 [00:09<00:00, 109.32it/s]
100%|██████████| 25/25 [00:01<00:00, 18.38it/s]


13


100%|██████████| 1000/1000 [00:12<00:00, 82.16it/s]
100%|██████████| 25/25 [00:01<00:00, 24.33it/s]


4


100%|██████████| 1000/1000 [00:10<00:00, 95.07it/s]
100%|██████████| 25/25 [00:01<00:00, 16.81it/s]


12


100%|██████████| 1000/1000 [00:06<00:00, 161.10it/s]
100%|██████████| 25/25 [00:00<00:00, 27.96it/s]


8


100%|██████████| 1000/1000 [00:09<00:00, 103.87it/s]
100%|██████████| 25/25 [00:00<00:00, 33.38it/s]


16


100%|██████████| 1000/1000 [00:07<00:00, 134.27it/s]
100%|██████████| 25/25 [00:01<00:00, 22.66it/s]


7


100%|██████████| 1000/1000 [00:11<00:00, 89.78it/s]
100%|██████████| 40/40 [00:02<00:00, 16.36it/s]


20


100%|██████████| 1000/1000 [00:11<00:00, 90.73it/s]
100%|██████████| 40/40 [00:02<00:00, 18.73it/s]


18


100%|██████████| 1000/1000 [00:09<00:00, 107.84it/s]
100%|██████████| 40/40 [00:01<00:00, 23.04it/s]


14


100%|██████████| 1000/1000 [00:10<00:00, 93.44it/s]
100%|██████████| 40/40 [00:03<00:00, 10.13it/s]


20


100%|██████████| 1000/1000 [00:09<00:00, 100.56it/s]
100%|██████████| 40/40 [00:02<00:00, 14.38it/s]


14


100%|██████████| 1000/1000 [00:10<00:00, 98.15it/s]
100%|██████████| 40/40 [00:01<00:00, 21.90it/s]


15


100%|██████████| 1000/1000 [00:06<00:00, 145.39it/s]
100%|██████████| 40/40 [00:01<00:00, 23.82it/s]


1


100%|██████████| 1000/1000 [00:13<00:00, 71.55it/s]
100%|██████████| 40/40 [00:02<00:00, 14.27it/s]


21


100%|██████████| 1000/1000 [00:06<00:00, 143.03it/s]
100%|██████████| 40/40 [00:02<00:00, 17.12it/s]


10


100%|██████████| 1000/1000 [00:06<00:00, 161.10it/s]
100%|██████████| 40/40 [00:01<00:00, 24.85it/s]


10


100%|██████████| 1000/1000 [00:09<00:00, 110.87it/s]
100%|██████████| 75/75 [00:05<00:00, 13.66it/s]


23


100%|██████████| 1000/1000 [00:11<00:00, 84.92it/s]
100%|██████████| 75/75 [00:03<00:00, 20.72it/s]


10


100%|██████████| 1000/1000 [00:10<00:00, 93.64it/s]
100%|██████████| 75/75 [00:05<00:00, 12.65it/s]


21


100%|██████████| 1000/1000 [00:10<00:00, 98.93it/s] 
100%|██████████| 75/75 [00:04<00:00, 15.50it/s]


21


100%|██████████| 1000/1000 [00:06<00:00, 155.79it/s]
100%|██████████| 75/75 [00:04<00:00, 18.28it/s]


11


100%|██████████| 1000/1000 [00:10<00:00, 95.24it/s]
100%|██████████| 75/75 [00:04<00:00, 15.82it/s]


20


100%|██████████| 1000/1000 [00:08<00:00, 120.26it/s]
100%|██████████| 75/75 [00:03<00:00, 24.23it/s]


8


100%|██████████| 1000/1000 [00:09<00:00, 107.72it/s]
100%|██████████| 75/75 [00:04<00:00, 17.09it/s]


19


100%|██████████| 1000/1000 [00:05<00:00, 167.62it/s]
100%|██████████| 75/75 [00:03<00:00, 24.37it/s]


2


100%|██████████| 1000/1000 [00:08<00:00, 111.90it/s]
100%|██████████| 75/75 [00:04<00:00, 17.58it/s]


20


100%|██████████| 1000/1000 [00:10<00:00, 99.32it/s]
100%|██████████| 150/150 [00:10<00:00, 14.02it/s]


20


100%|██████████| 1000/1000 [00:06<00:00, 149.39it/s]
100%|██████████| 150/150 [00:06<00:00, 22.91it/s]


5


100%|██████████| 1000/1000 [00:09<00:00, 106.75it/s]
100%|██████████| 150/150 [00:09<00:00, 15.63it/s]


15


100%|██████████| 1000/1000 [00:09<00:00, 102.19it/s]
100%|██████████| 150/150 [00:08<00:00, 18.66it/s]


13


100%|██████████| 1000/1000 [00:09<00:00, 106.87it/s]
100%|██████████| 150/150 [00:12<00:00, 12.24it/s]


27


100%|██████████| 1000/1000 [00:10<00:00, 98.90it/s]
100%|██████████| 150/150 [00:11<00:00, 12.83it/s]


21


100%|██████████| 1000/1000 [00:09<00:00, 106.03it/s]
100%|██████████| 150/150 [00:10<00:00, 13.68it/s]


23


100%|██████████| 1000/1000 [00:06<00:00, 164.93it/s]
100%|██████████| 150/150 [00:11<00:00, 13.32it/s]


19


100%|██████████| 1000/1000 [00:06<00:00, 161.89it/s]
100%|██████████| 150/150 [00:10<00:00, 14.15it/s]


18


100%|██████████| 1000/1000 [00:09<00:00, 107.81it/s]
100%|██████████| 150/150 [00:10<00:00, 14.45it/s]


17


100%|██████████| 1000/1000 [00:05<00:00, 171.93it/s]


In [46]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Data for each epoch
data = {str(i) : list(benchmarks[i].values()) for i in benchmarks.keys()}

# Create subplots
fig = make_subplots(rows=2, cols=3, subplot_titles=("Games played 10", "Games played 100", "Games played 1000", "Games played 5000", "Games played 10000", "Games played 50000"))

# Add traces for each subplot
epochs = list(data.keys())
for i, epoch in enumerate(epochs):
    row = i // 3 + 1
    col = i % 3 + 1
    fig.add_trace(go.Scatter(x=list(range(1, 11)), y=data[epoch], mode='lines+markers', name=f'Epochs {epoch}'), row=row, col=col)

# Update layout to set the same range for all y-axes and add labels
fig.update_yaxes(range=[0, 30], title_text="Average Score across 1000 games")
fig.update_xaxes(title_text="Agent")

# Update layout
fig.update_layout(height=800, width=1200, title_text="", showlegend=False)

# Show plot
fig.show()

# Calculate average scores for each epoch
avg = [sum(data[epoch]) / 10 for epoch in data]

# Create a new plot for the average scores
fig = go.Figure()

# Add trace for the average scores
fig.add_trace(go.Scatter(x=list(data.keys()), y=avg, mode='lines+markers', name='Average Score'))

# Update layout to add labels
fig.update_layout(
    title="Average Scores across all agents",
    title_x=0.5,
    xaxis_title="Games played",
    yaxis_title="Average Score",
    height=600,
    width=800
)

# Show plot
fig.show()

In [4]:
filename = 'QL/QL b_manhattan_policy.txt'
Qmodel.load_model(filename)

In [13]:
filename = 'DQL/manhattan_policy.txt'
DQmodel.save_model(filename)

In [5]:
filename = 'DQL/DQL b_manhattan_policy.txt'
DQmodel.load_model(filename)

In [38]:
DQmodel.train(20, float('inf'), reward)

#benchmark(Qmodel, W, H)
s, a =benchmark(DQmodel, W, H)
print(a)

100%|██████████| 20/20 [00:01<00:00, 11.81it/s]


19


100%|██████████| 1000/1000 [00:08<00:00, 112.12it/s]

20.305





In [5]:
play_snake(Qmodel)

Game Over => Score: 57
