In [11]:
import numpy as np
from numpy import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 16)
        self.out = nn.Linear(16, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.out(x)

model = NeuralNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

epsilon = 0.1
gamma = 0.9
episodes = 50000

WIDTH, HEIGHT = 5, 5
states = np.arange(0, WIDTH * HEIGHT)
states.resize(WIDTH, HEIGHT)

r = np.zeros(WIDTH * HEIGHT)
r.resize(WIDTH, HEIGHT)
r[WIDTH - 1][HEIGHT - 1] = 10

start_state = 0
end_state = WIDTH * HEIGHT - 1

actions = ['U', 'D', 'L', 'R']

def in_bounds(current_state):
    return 0 <= current_state <= end_state

def get_q(state):
    state_tensor = torch.tensor([[state // WIDTH, state % WIDTH]], dtype=torch.float32)
    return model(state_tensor).squeeze()

def select_action(state):
    if np.random.random() < epsilon:
        return np.random.randint(4)
    else:
        return get_q(state).argmax().item()

def agent_step(state, action):

    if actions[action] == 'U':
        next_state = state - WIDTH
    elif actions[action] == 'D':
        next_state = state + WIDTH
    elif actions[action] == 'L':
        next_state = state - 1
    elif actions[action] == 'R':
        next_state = state + 1
    
    if in_bounds(next_state):
        td_target = torch.tensor(r[next_state // WIDTH][next_state % WIDTH] + gamma * get_q(state).max().item(), dtype=torch.float32)
    elif next_state == end_state:
        td_target = torch.tensor(r[next_state // WIDTH][next_state % WIDTH], dtype=torch.float32)
    else:
        td_target =  torch.tensor(-10, dtype=torch.float32)
    
    q_values = get_q(state)
    current_q = q_values[action]

    loss = loss_fn(current_q, td_target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    next_action = select_action(next_state)
    return next_state, next_action

for episode in tqdm(range(episodes)):

    current_state = start_state
    current_action = select_action(current_state)

    while in_bounds(current_state):
    
        current_state, current_action = agent_step(current_state, current_action)

for row in range(WIDTH):
    for col in range(HEIGHT):
        with torch.no_grad():
            print(f'[{row}, {col}] : {model(torch.tensor([row, col], dtype=torch.float32)).numpy()}')
            

100%|████████████████████████████████████| 50000/50000 [08:16<00:00, 100.62it/s]

[0, 0] : [-25.21883    -6.7714643  -1.5462     -1.5759706]
[0, 1] : [-23.631535    -5.2926893    2.9914434   -0.34147358]
[0, 2] : [-20.760471   -2.2748976   8.815998    2.2772093]
[0, 3] : [-17.889347    0.7428942  14.640705    4.89577  ]
[0, 4] : [-15.018223    3.7607164  20.465351    7.514453 ]
[1, 0] : [-8.867632   1.451786   3.8933213  2.5733118]
[1, 1] : [-5.961247  -1.9237776  2.1070516  2.2133741]
[1, 2] : [-2.1608167 -1.0956373  6.9642215  3.4717436]
[1, 3] : [ 4.240108   2.8499498 14.428272   7.5338926]
[1, 4] : [10.371487   6.8539476 21.626942  11.274646 ]
[2, 0] : [-9.817783  -6.549306   2.292652   1.9081602]
[2, 1] : [-0.6704143 -0.1258831  7.36745    6.072482 ]
[2, 2] : [-1.1276134   0.36771584  3.392986   -0.6708574 ]
[2, 3] : [4.3242984 4.8848467 5.965465  2.7529707]
[2, 4] : [12.599651  3.555089 16.271397 12.228259]
[3, 0] : [-10.547428  -13.462546    0.9356267   1.325366 ]
[3, 1] : [-1.4000896 -7.0391245  6.0104403  5.489665 ]
[3, 2] : [ 7.7472453  -0.61570644 11.0852




In [13]:
current_state = 0

while current_state != end_state:

    print(current_state)

    temp_state = current_state
    
    current_state, current_action = agent_step(current_state, current_action)

    if actions[current_action] == 'U':
        current_state = temp_state - WIDTH
    elif actions[current_action] == 'D':
        current_state = temp_state + WIDTH
    elif actions[current_action] == 'L':
        current_state = temp_state - 1
    elif actions[current_action] == 'R':
        current_state = temp_state + 1

    if not in_bounds(current_state):
        current_state = 0

print(current_state)
    

0
5
6
11
12
13
18
19
18
19
20
21
22
23
18
23
24
