In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

In [None]:
class SnakeModel(nn.Module):  # 11 states in total
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(11, 512)
        # self.fc2 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512,3)

                
    def forward(self, x):
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc2(x)
        
        return x
    
    def save(self, file_name='model.pth'):
        model_path = './model'
        if not os.path.exists(model_path):
            os.makedirs(model_path)

        file_name = os.path.join(model_path, file_name)
        torch.save(self.state_dict(), file_name)

In [None]:
class SnakeModel2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(11, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 3)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.15)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        # x = self.fc2(x)
        # x = self.relu(x)
        # x = self.dropout(x)
        # x = self.fc3(x)
        # x = self.relu(x)
        # x = self.dropout(x)
        x = self.fc4(x)
        # x = self.softmax(x)
        return x
    
    def save(self, file_name2='model2.pth'):
        model_path2 = './model2'
        if not os.path.exists(model_path2):
            os.makedirs(model_path2)

        file_name2 = os.path.join(model_path2, file_name2)
        torch.save(self.state_dict(), file_name2)


In [None]:
# https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html
# Unable to implement due to structure
class SnakeNet(nn.Module):
    def __init__(self, h, w, outputs):
        super(SnakeNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)


        def conv2d_size_out(size, kernel_size = 5, stride = 2):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)

        return self.head(x)
    
    def save(self, file_name3='Net.pth'):
        model_path3 = './Net'
        if not os.path.exists(model_path3):
            os.makedirs(model_path3)

        file_name3 = os.path.join(model_path3, file_name3)
        torch.save(self.state_dict(), file_name3)

In [None]:
class Trainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)

        if len(state.shape) == 1:
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, )

        # predicted Q values
        pred = self.model(state)

        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

            target[idx][torch.argmax(action[idx]).item()] = Q_new
    
        # 2: New Q value = r + y * max(next_predicted Q value)
        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()

        self.optimizer.step()