In [1]:
import os
import sys

sys.path.append(os.path.join("..")) # append path to get all model features

In [2]:
import numpy as np

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary

In [4]:
class LinearQNet(nn.Module):
    """
    LinearQNet - A Linear Q-Learning Neural Network
    
    A simplified linear neural network is sufficient to train most
    types of agents, infact more the simple the model is the better!
    This neural network will serve as the backbone of the `agent`
    that will learn to play the snake game.
    """
    
    def __init__(self, input_size : int, hidden_size : int, output_size : int) -> None:
        super().__init__()
        
        # layer definations
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        
    
    def forward(self, x) -> torch.Tensor:
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        
        return x
    
    
    def save(self, directory : str, filename : str = "model.pth"):
        # output path is `join(directory, filename)`
        if not os.path.exists(directory):
            os.makedirs(directory)
            
        fullpath = os.path.join(directory, filename)
        torch.save(self.state_dict(), fullpath)

In [5]:
model = LinearQNet(11, 256, 3)
summary(model, (11, )) # print model summary

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]           3,072
            Linear-2                    [-1, 3]             771
Total params: 3,843
Trainable params: 3,843
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.01
Estimated Total Size (MB): 0.02
----------------------------------------------------------------


In [6]:
prediction = model(torch.tensor(np.random.randint(low = 0, high = 2, size = 11), dtype = torch.float))

In [7]:
prediction

tensor([ 0.0040, -0.2901, -0.1886], grad_fn=<AddBackward0>)

In [17]:
class QTrainer(object):
    """Defination of a QTrainer, a Trainer for Q-Learning"""
    
    def __init__(self, model, lr, gamma) -> None:
        """
        Create an instance of QTrainer with `trainer = QTrainer(model, lr, gamma)`
        that understands the `environment` and sets attributes to the `agent`s' which
        performs certain tasks.
        
        :param model: A neural network model (using `pytorch`) that will be used for
                      training and validation.
                      
        :param lr: Learning rate of the model.
        
        :param gamma: Discount parameter of the QNet.
        """
        
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(model.parameters(), lr = self.lr)
        
        
    def train_step(self, states, action, reward, next_state, _is_game_over_) -> float:
        states = torch.tensor(states, dtype = torch.float)
        action = torch.tensor(action, dtype = torch.float)
        reward = torch.tensor(reward, dtype = torch.float)
        next_state = torch.tensor(next_state, dtype = torch.float)
        
        if len(states.shape) == 1:
            # (1, x); else (n, x)
            # learn with a particular random information
            states = torch.unsqueeze(states, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            next_state = torch.unsqueeze(next_state, 0)
            
            gameOver = (_is_game_over_, )
            
        ### predict Q-Value with current state ###
        prediction = model(states) # perform forward operation
        
        target = prediction.clone().detach()
        for idx in range(len(gameOver)):
            QNew = reward[idx]
            
            print(target)
            if not gameOver[idx]:
                QNew = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
            
            print(target)
            target[idx][torch.argmax(action[idx]).item()] = QNew
            print(target)
            
        ### update nn based on Q-Value ###
        self.optimizer.zero_grad()
        losses = self.criterion(target, prediction)
        print(losses)
        losses.backward() # gradient descent
        self.optimizer.step()
        
        return round(float(losses), 3)

In [18]:
trainer = QTrainer(model, 1e-3, 0.9)

In [19]:
trainer.train_step(
    torch.tensor(np.random.randint(low = 0, high = 2, size = 11), dtype = torch.float),
    [1, 0, 0], -2, torch.tensor(np.random.randint(low = 0, high = 2, size = 11), dtype = torch.float), 1
)

tensor([[-0.1991, -0.3462, -0.4581]])
tensor([[-0.1991, -0.3462, -0.4581]])
tensor([[-2.0000, -0.3462, -0.4581]])
tensor(1.0811, grad_fn=<MeanBackward0>)


  states = torch.tensor(states, dtype = torch.float)
  next_state = torch.tensor(next_state, dtype = torch.float)


1.081